Compare commits

...

8 Commits

Author SHA1 Message Date
Lance Release
143184c0ae Bump version: 0.25.2 → 0.25.3-beta.0 2025-10-14 02:25:16 +00:00
Jack Ye
dadb042978 feat: bump lance to 0.38.3-beta.2 and rust to 1.90.0 (#2714) 2025-10-10 14:02:41 -07:00
Weston Pace
5a19cf15a6 feat: a utility for creating "permutation views" (#2552)
I'm working on a lancedb version of pytorch data loading (and hopefully
addressing https://github.com/lancedb/lance/issues/3727).

However, rather than rely on pytorch for everything I'm moving some of
the things that pytorch does into rust. This gives us more control over
data loading (e.g. using shards or a hash-based split) and it allows
permutations to be persistent. In particular I hope to be able to:

* Create a persistent permutation
* This permutation can handle splits, filtering, shuffling, and sharding
* Create a rust data loader that can read a permutation (one or more
splits), or a subset of a permutation (for DDP)
* Create a python data loader that delegates to the rust data loader

Eventually create integrations for other data loading libraries,
including rust & node
2025-10-09 18:07:31 -07:00
Will Jones
3dcec724b7 chore: loosen pin on chrono (#2710)
Fixes #2709
2025-10-09 14:23:56 -07:00
LuQQiu
86a6bb9fcb chore: supports limit push down through MetadataEraserExec (#2679)
For limit to sucessfully push down to FilteredReadExec
https://github.com/lancedb/lance/pull/4795/
2025-10-09 09:33:38 -07:00
BubbleCal
b59d1007d3 feat(index): add IVF_RQ index type (#2687)
this expose IVF_RQ (RabitQ quantization) index type to lancedb

---------

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
2025-10-09 15:46:18 +08:00
Lance Release
56a16b1728 Bump version: 0.22.2-beta.3 → 0.22.2 2025-10-08 18:13:08 +00:00
Lance Release
b7afed9beb Bump version: 0.22.2-beta.2 → 0.22.2-beta.3 2025-10-08 18:12:23 +00:00
67 changed files with 4233 additions and 200 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.22.2-beta.2"
current_version = "0.22.2"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.

145
Cargo.lock generated
View File

@@ -72,12 +72,6 @@ version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
[[package]]
name = "android-tzdata"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
[[package]]
name = "android_system_properties"
version = "0.1.5"
@@ -1474,17 +1468,16 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "chrono"
version = "0.4.41"
version = "0.4.42"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d"
checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2"
dependencies = [
"android-tzdata",
"iana-time-zone",
"js-sys",
"num-traits",
"serde",
"wasm-bindgen",
"windows-link 0.1.3",
"windows-link 0.2.1",
]
[[package]]
@@ -1573,7 +1566,7 @@ checksum = "e0d05af1e006a2407bedef5af410552494ce5be9090444dbbcb57258c1af3d56"
dependencies = [
"crossterm 0.27.0",
"crossterm 0.28.1",
"strum 0.26.3",
"strum",
"strum_macros 0.26.4",
"unicode-width",
]
@@ -3052,8 +3045,7 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "fsst"
version = "0.38.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "480fc4f47567da549ab44bb2f37f6db1570c9eff7200e50334b69fa1daa74339"
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
dependencies = [
"arrow-array",
"rand 0.9.2",
@@ -3850,7 +3842,7 @@ dependencies = [
"js-sys",
"log",
"wasm-bindgen",
"windows-core 0.62.2",
"windows-core 0.61.2",
]
[[package]]
@@ -4238,8 +4230,7 @@ dependencies = [
[[package]]
name = "lance"
version = "0.38.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e2d2472f58d01894bc5f0a9f9d28dfca4649c9e28faf467c47e87f788ef322b"
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
dependencies = [
"arrow",
"arrow-arith",
@@ -4302,8 +4293,7 @@ dependencies = [
[[package]]
name = "lance-arrow"
version = "0.38.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2abba8770c4217fbdc8b517cdfb7183639b02dc5c2bcad1e7c69ffdcf4fbe1a"
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4322,8 +4312,7 @@ dependencies = [
[[package]]
name = "lance-bitpacking"
version = "0.38.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "efb7af69bff8d8499999684f961b0a4dc6e159065c773041545d19bc158f0814"
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
dependencies = [
"arrayref",
"paste",
@@ -4333,8 +4322,7 @@ dependencies = [
[[package]]
name = "lance-core"
version = "0.38.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "356a5df5f9cd7cb4aedaf78a4e346190ae50ba574b828316caed7d1df3b6dcd8"
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4371,8 +4359,7 @@ dependencies = [
[[package]]
name = "lance-datafusion"
version = "0.38.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8e8ec07021bdaba6a441563d8fbcb0431350aae6842910ae3622557765f218f"
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
dependencies = [
"arrow",
"arrow-array",
@@ -4401,8 +4388,7 @@ dependencies = [
[[package]]
name = "lance-datagen"
version = "0.38.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d4fe98730cd5297dc68b22f6ad7e1e27cf34e2db05586b64d3540ca74a519a61"
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
dependencies = [
"arrow",
"arrow-array",
@@ -4420,8 +4406,7 @@ dependencies = [
[[package]]
name = "lance-encoding"
version = "0.38.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef073d419cc00ef41dd95cb25203b333118b224151ae397145530b1d559769c9"
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
dependencies = [
"arrow-arith",
"arrow-array",
@@ -4449,7 +4434,7 @@ dependencies = [
"prost-types",
"rand 0.9.2",
"snafu",
"strum 0.25.0",
"strum",
"tokio",
"tracing",
"xxhash-rust",
@@ -4459,8 +4444,7 @@ dependencies = [
[[package]]
name = "lance-file"
version = "0.38.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e34aba3a41f119188da997730560e4a6915ee5a38b672bbf721fdc99121aa1e"
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
dependencies = [
"arrow-arith",
"arrow-array",
@@ -4494,8 +4478,7 @@ dependencies = [
[[package]]
name = "lance-index"
version = "0.38.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c5f480f801c8efb41a6dedc48a5cacff6044a10f82c6f9764b8dac7194a7754e"
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
dependencies = [
"arrow",
"arrow-arith",
@@ -4558,8 +4541,7 @@ dependencies = [
[[package]]
name = "lance-io"
version = "0.38.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0708125c74965b2b7e5e0c4fe2d8e6bd8346a7031484f8844cf06c08bfa29a72"
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
dependencies = [
"arrow",
"arrow-arith",
@@ -4599,8 +4581,7 @@ dependencies = [
[[package]]
name = "lance-linalg"
version = "0.38.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da9d1c22deed92420a1869e4b89188ccecc7e1aee2ea4e5bca92eae861511d60"
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4655,8 +4636,7 @@ dependencies = [
[[package]]
name = "lance-table"
version = "0.38.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "805e6c64efbb3295f74714668c9033121ffdfa6c868f067024e65ade700b8b8b"
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
dependencies = [
"arrow",
"arrow-array",
@@ -4695,8 +4675,7 @@ dependencies = [
[[package]]
name = "lance-testing"
version = "0.38.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ac735b5eb153a6ac841ce0206e4c30df941610c812cc89c8ae20006f8d0b018"
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.2#73a2c7e1f52932f589ad0ac63eb41194b9f9421a"
dependencies = [
"arrow-array",
"arrow-schema",
@@ -4707,8 +4686,9 @@ dependencies = [
[[package]]
name = "lancedb"
version = "0.22.2-beta.2"
version = "0.22.2"
dependencies = [
"ahash",
"anyhow",
"arrow",
"arrow-array",
@@ -4744,8 +4724,11 @@ dependencies = [
"http 1.3.1",
"http-body 1.0.1",
"lance",
"lance-core",
"lance-datafusion",
"lance-datagen",
"lance-encoding",
"lance-file",
"lance-index",
"lance-io",
"lance-linalg",
@@ -4771,6 +4754,7 @@ dependencies = [
"serde_with",
"snafu",
"tempfile",
"test-log",
"tokenizers",
"tokio",
"url",
@@ -4796,7 +4780,7 @@ dependencies = [
[[package]]
name = "lancedb-nodejs"
version = "0.22.2-beta.2"
version = "0.22.2"
dependencies = [
"arrow-array",
"arrow-ipc",
@@ -4816,7 +4800,7 @@ dependencies = [
[[package]]
name = "lancedb-python"
version = "0.25.2-beta.2"
version = "0.25.2"
dependencies = [
"arrow",
"async-trait",
@@ -7909,20 +7893,14 @@ version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "strum"
version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125"
dependencies = [
"strum_macros 0.25.3",
]
[[package]]
name = "strum"
version = "0.26.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06"
dependencies = [
"strum_macros 0.26.4",
]
[[package]]
name = "strum_macros"
@@ -8244,6 +8222,28 @@ dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "test-log"
version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e33b98a582ea0be1168eba097538ee8dd4bbe0f2b01b22ac92ea30054e5be7b"
dependencies = [
"env_logger",
"test-log-macros",
"tracing-subscriber",
]
[[package]]
name = "test-log-macros"
version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "451b374529930d7601b1eef8d32bc79ae870b6079b069401709c2a8bf9e75f36"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.106",
]
[[package]]
name = "thiserror"
version = "1.0.69"
@@ -9061,21 +9061,8 @@ dependencies = [
"windows-implement",
"windows-interface",
"windows-link 0.1.3",
"windows-result 0.3.4",
"windows-strings 0.4.2",
]
[[package]]
name = "windows-core"
version = "0.62.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb"
dependencies = [
"windows-implement",
"windows-interface",
"windows-link 0.2.1",
"windows-result 0.4.1",
"windows-strings 0.5.1",
"windows-result",
"windows-strings",
]
[[package]]
@@ -9140,8 +9127,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b8a9ed28765efc97bbc954883f4e6796c33a06546ebafacbabee9696967499e"
dependencies = [
"windows-link 0.1.3",
"windows-result 0.3.4",
"windows-strings 0.4.2",
"windows-result",
"windows-strings",
]
[[package]]
@@ -9153,15 +9140,6 @@ dependencies = [
"windows-link 0.1.3",
]
[[package]]
name = "windows-result"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5"
dependencies = [
"windows-link 0.2.1",
]
[[package]]
name = "windows-strings"
version = "0.4.2"
@@ -9171,15 +9149,6 @@ dependencies = [
"windows-link 0.1.3",
]
[[package]]
name = "windows-strings"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091"
dependencies = [
"windows-link 0.2.1",
]
[[package]]
name = "windows-sys"
version = "0.45.0"

View File

@@ -15,15 +15,19 @@ categories = ["database-implementations"]
rust-version = "1.78.0"
[workspace.dependencies]
lance = { "version" = "=0.38.2", default-features = false, "features" = ["dynamodb"] }
lance-io = { "version" = "=0.38.2", default-features = false }
lance-index = "=0.38.2"
lance-linalg = "=0.38.2"
lance-table = "=0.38.2"
lance-testing = "=0.38.2"
lance-datafusion = "=0.38.2"
lance-encoding = "=0.38.2"
lance = { "version" = "=0.38.2", default-features = false, "features" = ["dynamodb"], "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-core = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-datagen = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-file = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-io = { "version" = "=0.38.2", default-features = false, "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-index = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-linalg = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-table = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-testing = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-datafusion = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-encoding = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-namespace = "0.0.18"
ahash = "0.8"
# Note that this one does not include pyarrow
arrow = { version = "56.2", optional = false }
arrow-array = "56.2"
@@ -48,6 +52,7 @@ log = "0.4"
moka = { version = "0.12", features = ["future"] }
object_store = "0.12.0"
pin-project = "1.0.7"
rand = "0.9"
snafu = "0.8"
url = "2"
num-traits = "0.2"
@@ -55,20 +60,18 @@ regex = "1.10"
lazy_static = "1"
semver = "1.0.25"
crunchy = "0.2.4"
# Temporary pins to work around downstream issues
# https://github.com/apache/arrow-rs/commit/2fddf85afcd20110ce783ed5b4cdeb82293da30b
chrono = "=0.4.41"
chrono = "0.4"
# Workaround for: https://github.com/Lokathor/bytemuck/issues/306
bytemuck_derive = ">=1.8.1, <1.9.0"
# This is only needed when we reference preview releases of lance
# [patch.crates-io]
# # Force to use the same lance version as the rest of the project to avoid duplicate dependencies
# lance = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
# lance-io = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
# lance-index = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
# lance-linalg = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
# lance-table = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
# lance-testing = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
# lance-datafusion = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
# lance-encoding = { "version" = "=0.38.0", "tag" = "v0.38.0", "git" = "https://github.com/lancedb/lance.git" }
# Force to use the same lance version as the rest of the project to avoid duplicate dependencies
[patch.crates-io]
lance = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-io = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-index = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-linalg = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-table = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-testing = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-datafusion = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }
lance-encoding = { "version" = "=0.38.2", "tag" = "v0.38.3-beta.2", "git" = "https://github.com/lancedb/lance.git" }

View File

@@ -194,6 +194,37 @@ currently is also a memory intensive operation.
***
### ivfRq()
```ts
static ivfRq(options?): Index
```
Create an IvfRq index
IVF-RQ (RabitQ Quantization) compresses vectors using RabitQ quantization
and organizes them into IVF partitions.
The compression scheme is called RabitQ quantization. Each dimension is quantized into a small number of bits.
The parameters `num_bits` and `num_partitions` control this process, providing a tradeoff
between index size (and thus search speed) and index accuracy.
The partitioning process is called IVF and the `num_partitions` parameter controls how
many groups to create.
Note that training an IVF RQ index on a large dataset is a slow operation and
currently is also a memory intensive operation.
#### Parameters
* **options?**: `Partial`&lt;[`IvfRqOptions`](../interfaces/IvfRqOptions.md)&gt;
#### Returns
[`Index`](Index.md)
***
### labelList()
```ts

View File

@@ -0,0 +1,220 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / PermutationBuilder
# Class: PermutationBuilder
A PermutationBuilder for creating data permutations with splits, shuffling, and filtering.
This class provides a TypeScript wrapper around the native Rust PermutationBuilder,
offering methods to configure data splits, shuffling, and filtering before executing
the permutation to create a new table.
## Methods
### execute()
```ts
execute(): Promise<Table>
```
Execute the permutation and create the destination table.
#### Returns
`Promise`&lt;[`Table`](Table.md)&gt;
A Promise that resolves to the new Table instance
#### Example
```ts
const permutationTable = await builder.execute();
console.log(`Created table: ${permutationTable.name}`);
```
***
### filter()
```ts
filter(filter): PermutationBuilder
```
Configure filtering for the permutation.
#### Parameters
* **filter**: `string`
SQL filter expression
#### Returns
[`PermutationBuilder`](PermutationBuilder.md)
A new PermutationBuilder instance
#### Example
```ts
builder.filter("age > 18 AND status = 'active'");
```
***
### shuffle()
```ts
shuffle(options): PermutationBuilder
```
Configure shuffling for the permutation.
#### Parameters
* **options**: [`ShuffleOptions`](../interfaces/ShuffleOptions.md)
Configuration for shuffling
#### Returns
[`PermutationBuilder`](PermutationBuilder.md)
A new PermutationBuilder instance
#### Example
```ts
// Basic shuffle
builder.shuffle({ seed: 42 });
// Shuffle with clump size
builder.shuffle({ seed: 42, clumpSize: 10 });
```
***
### splitCalculated()
```ts
splitCalculated(calculation): PermutationBuilder
```
Configure calculated splits for the permutation.
#### Parameters
* **calculation**: `string`
SQL expression for calculating splits
#### Returns
[`PermutationBuilder`](PermutationBuilder.md)
A new PermutationBuilder instance
#### Example
```ts
builder.splitCalculated("user_id % 3");
```
***
### splitHash()
```ts
splitHash(options): PermutationBuilder
```
Configure hash-based splits for the permutation.
#### Parameters
* **options**: [`SplitHashOptions`](../interfaces/SplitHashOptions.md)
Configuration for hash-based splitting
#### Returns
[`PermutationBuilder`](PermutationBuilder.md)
A new PermutationBuilder instance
#### Example
```ts
builder.splitHash({
columns: ["user_id"],
splitWeights: [70, 30],
discardWeight: 0
});
```
***
### splitRandom()
```ts
splitRandom(options): PermutationBuilder
```
Configure random splits for the permutation.
#### Parameters
* **options**: [`SplitRandomOptions`](../interfaces/SplitRandomOptions.md)
Configuration for random splitting
#### Returns
[`PermutationBuilder`](PermutationBuilder.md)
A new PermutationBuilder instance
#### Example
```ts
// Split by ratios
builder.splitRandom({ ratios: [0.7, 0.3], seed: 42 });
// Split by counts
builder.splitRandom({ counts: [1000, 500], seed: 42 });
// Split with fixed size
builder.splitRandom({ fixed: 100, seed: 42 });
```
***
### splitSequential()
```ts
splitSequential(options): PermutationBuilder
```
Configure sequential splits for the permutation.
#### Parameters
* **options**: [`SplitSequentialOptions`](../interfaces/SplitSequentialOptions.md)
Configuration for sequential splitting
#### Returns
[`PermutationBuilder`](PermutationBuilder.md)
A new PermutationBuilder instance
#### Example
```ts
// Split by ratios
builder.splitSequential({ ratios: [0.8, 0.2] });
// Split by counts
builder.splitSequential({ counts: [800, 200] });
// Split with fixed size
builder.splitSequential({ fixed: 1000 });
```

View File

@@ -0,0 +1,37 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / permutationBuilder
# Function: permutationBuilder()
```ts
function permutationBuilder(table, destTableName): PermutationBuilder
```
Create a permutation builder for the given table.
## Parameters
* **table**: [`Table`](../classes/Table.md)
The source table to create a permutation from
* **destTableName**: `string`
The name for the destination permutation table
## Returns
[`PermutationBuilder`](../classes/PermutationBuilder.md)
A PermutationBuilder instance
## Example
```ts
const builder = permutationBuilder(sourceTable, "training_data")
.splitRandom({ ratios: [0.8, 0.2], seed: 42 })
.shuffle({ seed: 123 });
const trainingTable = await builder.execute();
```

View File

@@ -28,6 +28,7 @@
- [MultiMatchQuery](classes/MultiMatchQuery.md)
- [NativeJsHeaderProvider](classes/NativeJsHeaderProvider.md)
- [OAuthHeaderProvider](classes/OAuthHeaderProvider.md)
- [PermutationBuilder](classes/PermutationBuilder.md)
- [PhraseQuery](classes/PhraseQuery.md)
- [Query](classes/Query.md)
- [QueryBase](classes/QueryBase.md)
@@ -68,6 +69,7 @@
- [IndexStatistics](interfaces/IndexStatistics.md)
- [IvfFlatOptions](interfaces/IvfFlatOptions.md)
- [IvfPqOptions](interfaces/IvfPqOptions.md)
- [IvfRqOptions](interfaces/IvfRqOptions.md)
- [MergeResult](interfaces/MergeResult.md)
- [OpenTableOptions](interfaces/OpenTableOptions.md)
- [OptimizeOptions](interfaces/OptimizeOptions.md)
@@ -75,6 +77,10 @@
- [QueryExecutionOptions](interfaces/QueryExecutionOptions.md)
- [RemovalStats](interfaces/RemovalStats.md)
- [RetryConfig](interfaces/RetryConfig.md)
- [ShuffleOptions](interfaces/ShuffleOptions.md)
- [SplitHashOptions](interfaces/SplitHashOptions.md)
- [SplitRandomOptions](interfaces/SplitRandomOptions.md)
- [SplitSequentialOptions](interfaces/SplitSequentialOptions.md)
- [TableNamesOptions](interfaces/TableNamesOptions.md)
- [TableStatistics](interfaces/TableStatistics.md)
- [TimeoutConfig](interfaces/TimeoutConfig.md)
@@ -102,3 +108,4 @@
- [connect](functions/connect.md)
- [makeArrowTable](functions/makeArrowTable.md)
- [packBits](functions/packBits.md)
- [permutationBuilder](functions/permutationBuilder.md)

View File

@@ -0,0 +1,23 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / ShuffleOptions
# Interface: ShuffleOptions
## Properties
### clumpSize?
```ts
optional clumpSize: number;
```
***
### seed?
```ts
optional seed: number;
```

View File

@@ -0,0 +1,31 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / SplitHashOptions
# Interface: SplitHashOptions
## Properties
### columns
```ts
columns: string[];
```
***
### discardWeight?
```ts
optional discardWeight: number;
```
***
### splitWeights
```ts
splitWeights: number[];
```

View File

@@ -0,0 +1,39 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / SplitRandomOptions
# Interface: SplitRandomOptions
## Properties
### counts?
```ts
optional counts: number[];
```
***
### fixed?
```ts
optional fixed: number;
```
***
### ratios?
```ts
optional ratios: number[];
```
***
### seed?
```ts
optional seed: number;
```

View File

@@ -0,0 +1,31 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / SplitSequentialOptions
# Interface: SplitSequentialOptions
## Properties
### counts?
```ts
optional counts: number[];
```
***
### fixed?
```ts
optional fixed: number;
```
***
### ratios?
```ts
optional ratios: number[];
```

View File

@@ -8,7 +8,7 @@
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.22.2-beta.2</version>
<version>0.22.2-final.0</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -8,7 +8,7 @@
<parent>
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.22.2-beta.2</version>
<version>0.22.2-final.0</version>
<relativePath>../pom.xml</relativePath>
</parent>

View File

@@ -6,7 +6,7 @@
<groupId>com.lancedb</groupId>
<artifactId>lancedb-parent</artifactId>
<version>0.22.2-beta.2</version>
<version>0.22.2-final.0</version>
<packaging>pom</packaging>
<name>${project.artifactId}</name>
<description>LanceDB Java SDK Parent POM</description>

View File

@@ -1,7 +1,7 @@
[package]
name = "lancedb-nodejs"
edition.workspace = true
version = "0.22.2-beta.2"
version = "0.22.2"
license.workspace = true
description.workspace = true
repository.workspace = true

View File

@@ -0,0 +1,234 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
import * as tmp from "tmp";
import { Table, connect, permutationBuilder } from "../lancedb";
import { makeArrowTable } from "../lancedb/arrow";
describe("PermutationBuilder", () => {
let tmpDir: tmp.DirResult;
let table: Table;
beforeEach(async () => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
const db = await connect(tmpDir.name);
// Create test data
const data = makeArrowTable(
[
{ id: 1, value: 10 },
{ id: 2, value: 20 },
{ id: 3, value: 30 },
{ id: 4, value: 40 },
{ id: 5, value: 50 },
{ id: 6, value: 60 },
{ id: 7, value: 70 },
{ id: 8, value: 80 },
{ id: 9, value: 90 },
{ id: 10, value: 100 },
],
{ vectorColumns: {} },
);
table = await db.createTable("test_table", data);
});
afterEach(() => {
tmpDir.removeCallback();
});
test("should create permutation builder", () => {
const builder = permutationBuilder(table, "permutation_table");
expect(builder).toBeDefined();
});
test("should execute basic permutation", async () => {
const builder = permutationBuilder(table, "permutation_table");
const permutationTable = await builder.execute();
expect(permutationTable).toBeDefined();
expect(permutationTable.name).toBe("permutation_table");
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
});
test("should create permutation with random splits", async () => {
const builder = permutationBuilder(table, "permutation_table").splitRandom({
ratios: [1.0],
seed: 42,
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
});
test("should create permutation with percentage splits", async () => {
const builder = permutationBuilder(table, "permutation_table").splitRandom({
ratios: [0.3, 0.7],
seed: 42,
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Check split distribution
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(10);
});
test("should create permutation with count splits", async () => {
const builder = permutationBuilder(table, "permutation_table").splitRandom({
counts: [3, 7],
seed: 42,
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Check split distribution
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBe(3);
expect(split1Count).toBe(7);
});
test("should create permutation with hash splits", async () => {
const builder = permutationBuilder(table, "permutation_table").splitHash({
columns: ["id"],
splitWeights: [50, 50],
discardWeight: 0,
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Check that splits exist
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(10);
});
test("should create permutation with sequential splits", async () => {
const builder = permutationBuilder(
table,
"permutation_table",
).splitSequential({ ratios: [0.5, 0.5] });
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Check split distribution - sequential should give exactly 5 and 5
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBe(5);
expect(split1Count).toBe(5);
});
test("should create permutation with calculated splits", async () => {
const builder = permutationBuilder(
table,
"permutation_table",
).splitCalculated("id % 2");
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Check split distribution
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(10);
});
test("should create permutation with shuffle", async () => {
const builder = permutationBuilder(table, "permutation_table").shuffle({
seed: 42,
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
});
test("should create permutation with shuffle and clump size", async () => {
const builder = permutationBuilder(table, "permutation_table").shuffle({
seed: 42,
clumpSize: 2,
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
});
test("should create permutation with filter", async () => {
const builder = permutationBuilder(table, "permutation_table").filter(
"value > 50",
);
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(5); // Values 60, 70, 80, 90, 100
});
test("should chain multiple operations", async () => {
const builder = permutationBuilder(table, "permutation_table")
.filter("value <= 80")
.splitRandom({ ratios: [0.5, 0.5], seed: 42 })
.shuffle({ seed: 123 });
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(8); // Values 10, 20, 30, 40, 50, 60, 70, 80
// Check split distribution
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(8);
});
test("should throw error for invalid split arguments", () => {
const builder = permutationBuilder(table, "permutation_table");
// Test no arguments provided
expect(() => builder.splitRandom({})).toThrow(
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
);
// Test multiple arguments provided
expect(() =>
builder.splitRandom({ ratios: [0.5, 0.5], counts: [3, 7], seed: 42 }),
).toThrow("Exactly one of 'ratios', 'counts', or 'fixed' must be provided");
});
test("should throw error when builder is consumed", async () => {
const builder = permutationBuilder(table, "permutation_table");
// Execute once
await builder.execute();
// Should throw error on second execution
await expect(builder.execute()).rejects.toThrow("Builder already consumed");
});
});

View File

@@ -861,6 +861,15 @@ describe("When creating an index", () => {
});
});
it("should be able to create IVF_RQ", async () => {
await tbl.createIndex("vec", {
config: Index.ivfRq({
numPartitions: 10,
numBits: 1,
}),
});
});
it("should allow me to replace (or not) an existing index", async () => {
await tbl.createIndex("id");
// Default is replace=true

View File

@@ -43,6 +43,10 @@ export {
DeleteResult,
DropColumnsResult,
UpdateResult,
SplitRandomOptions,
SplitHashOptions,
SplitSequentialOptions,
ShuffleOptions,
} from "./native.js";
export {
@@ -85,6 +89,7 @@ export {
Index,
IndexOptions,
IvfPqOptions,
IvfRqOptions,
IvfFlatOptions,
HnswPqOptions,
HnswSqOptions,
@@ -110,6 +115,7 @@ export {
export { MergeInsertBuilder, WriteExecutionOptions } from "./merge";
export * as embedding from "./embedding";
export { permutationBuilder, PermutationBuilder } from "./permutation";
export * as rerankers from "./rerankers";
export {
SchemaLike,

View File

@@ -112,6 +112,77 @@ export interface IvfPqOptions {
sampleRate?: number;
}
export interface IvfRqOptions {
/**
* The number of IVF partitions to create.
*
* This value should generally scale with the number of rows in the dataset.
* By default the number of partitions is the square root of the number of
* rows.
*
* If this value is too large then the first part of the search (picking the
* right partition) will be slow. If this value is too small then the second
* part of the search (searching within a partition) will be slow.
*/
numPartitions?: number;
/**
* Number of bits per dimension for residual quantization.
*
* This value controls how much each residual component is compressed. The more
* bits, the more accurate the index will be but the slower search. Typical values
* are small integers; the default is 1 bit per dimension.
*/
numBits?: number;
/**
* Distance type to use to build the index.
*
* Default value is "l2".
*
* This is used when training the index to calculate the IVF partitions
* (vectors are grouped in partitions with similar vectors according to this
* distance type) and during quantization.
*
* The distance type used to train an index MUST match the distance type used
* to search the index. Failure to do so will yield inaccurate results.
*
* The following distance types are available:
*
* "l2" - Euclidean distance.
* "cosine" - Cosine distance.
* "dot" - Dot product.
*/
distanceType?: "l2" | "cosine" | "dot";
/**
* Max iterations to train IVF kmeans.
*
* When training an IVF index we use kmeans to calculate the partitions. This parameter
* controls how many iterations of kmeans to run.
*
* The default value is 50.
*/
maxIterations?: number;
/**
* The number of vectors, per partition, to sample when training IVF kmeans.
*
* When an IVF index is trained, we need to calculate partitions. These are groups
* of vectors that are similar to each other. To do this we use an algorithm called kmeans.
*
* Running kmeans on a large dataset can be slow. To speed this up we run kmeans on a
* random sample of the data. This parameter controls the size of the sample. The total
* number of vectors used to train the index is `sample_rate * num_partitions`.
*
* Increasing this value might improve the quality of the index but in most cases the
* default should be sufficient.
*
* The default value is 256.
*/
sampleRate?: number;
}
/**
* Options to create an `HNSW_PQ` index
*/
@@ -523,6 +594,35 @@ export class Index {
options?.distanceType,
options?.numPartitions,
options?.numSubVectors,
options?.numBits,
options?.maxIterations,
options?.sampleRate,
),
);
}
/**
* Create an IvfRq index
*
* IVF-RQ (RabitQ Quantization) compresses vectors using RabitQ quantization
* and organizes them into IVF partitions.
*
* The compression scheme is called RabitQ quantization. Each dimension is quantized into a small number of bits.
* The parameters `num_bits` and `num_partitions` control this process, providing a tradeoff
* between index size (and thus search speed) and index accuracy.
*
* The partitioning process is called IVF and the `num_partitions` parameter controls how
* many groups to create.
*
* Note that training an IVF RQ index on a large dataset is a slow operation and
* currently is also a memory intensive operation.
*/
static ivfRq(options?: Partial<IvfRqOptions>) {
return new Index(
LanceDbIndex.ivfRq(
options?.distanceType,
options?.numPartitions,
options?.numBits,
options?.maxIterations,
options?.sampleRate,
),

View File

@@ -0,0 +1,188 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
import {
PermutationBuilder as NativePermutationBuilder,
Table as NativeTable,
ShuffleOptions,
SplitHashOptions,
SplitRandomOptions,
SplitSequentialOptions,
permutationBuilder as nativePermutationBuilder,
} from "./native.js";
import { LocalTable, Table } from "./table";
/**
* A PermutationBuilder for creating data permutations with splits, shuffling, and filtering.
*
* This class provides a TypeScript wrapper around the native Rust PermutationBuilder,
* offering methods to configure data splits, shuffling, and filtering before executing
* the permutation to create a new table.
*/
export class PermutationBuilder {
private inner: NativePermutationBuilder;
/**
* @hidden
*/
constructor(inner: NativePermutationBuilder) {
this.inner = inner;
}
/**
* Configure random splits for the permutation.
*
* @param options - Configuration for random splitting
* @returns A new PermutationBuilder instance
* @example
* ```ts
* // Split by ratios
* builder.splitRandom({ ratios: [0.7, 0.3], seed: 42 });
*
* // Split by counts
* builder.splitRandom({ counts: [1000, 500], seed: 42 });
*
* // Split with fixed size
* builder.splitRandom({ fixed: 100, seed: 42 });
* ```
*/
splitRandom(options: SplitRandomOptions): PermutationBuilder {
const newInner = this.inner.splitRandom(options);
return new PermutationBuilder(newInner);
}
/**
* Configure hash-based splits for the permutation.
*
* @param options - Configuration for hash-based splitting
* @returns A new PermutationBuilder instance
* @example
* ```ts
* builder.splitHash({
* columns: ["user_id"],
* splitWeights: [70, 30],
* discardWeight: 0
* });
* ```
*/
splitHash(options: SplitHashOptions): PermutationBuilder {
const newInner = this.inner.splitHash(options);
return new PermutationBuilder(newInner);
}
/**
* Configure sequential splits for the permutation.
*
* @param options - Configuration for sequential splitting
* @returns A new PermutationBuilder instance
* @example
* ```ts
* // Split by ratios
* builder.splitSequential({ ratios: [0.8, 0.2] });
*
* // Split by counts
* builder.splitSequential({ counts: [800, 200] });
*
* // Split with fixed size
* builder.splitSequential({ fixed: 1000 });
* ```
*/
splitSequential(options: SplitSequentialOptions): PermutationBuilder {
const newInner = this.inner.splitSequential(options);
return new PermutationBuilder(newInner);
}
/**
* Configure calculated splits for the permutation.
*
* @param calculation - SQL expression for calculating splits
* @returns A new PermutationBuilder instance
* @example
* ```ts
* builder.splitCalculated("user_id % 3");
* ```
*/
splitCalculated(calculation: string): PermutationBuilder {
const newInner = this.inner.splitCalculated(calculation);
return new PermutationBuilder(newInner);
}
/**
* Configure shuffling for the permutation.
*
* @param options - Configuration for shuffling
* @returns A new PermutationBuilder instance
* @example
* ```ts
* // Basic shuffle
* builder.shuffle({ seed: 42 });
*
* // Shuffle with clump size
* builder.shuffle({ seed: 42, clumpSize: 10 });
* ```
*/
shuffle(options: ShuffleOptions): PermutationBuilder {
const newInner = this.inner.shuffle(options);
return new PermutationBuilder(newInner);
}
/**
* Configure filtering for the permutation.
*
* @param filter - SQL filter expression
* @returns A new PermutationBuilder instance
* @example
* ```ts
* builder.filter("age > 18 AND status = 'active'");
* ```
*/
filter(filter: string): PermutationBuilder {
const newInner = this.inner.filter(filter);
return new PermutationBuilder(newInner);
}
/**
* Execute the permutation and create the destination table.
*
* @returns A Promise that resolves to the new Table instance
* @example
* ```ts
* const permutationTable = await builder.execute();
* console.log(`Created table: ${permutationTable.name}`);
* ```
*/
async execute(): Promise<Table> {
const nativeTable: NativeTable = await this.inner.execute();
return new LocalTable(nativeTable);
}
}
/**
* Create a permutation builder for the given table.
*
* @param table - The source table to create a permutation from
* @param destTableName - The name for the destination permutation table
* @returns A PermutationBuilder instance
* @example
* ```ts
* const builder = permutationBuilder(sourceTable, "training_data")
* .splitRandom({ ratios: [0.8, 0.2], seed: 42 })
* .shuffle({ seed: 123 });
*
* const trainingTable = await builder.execute();
* ```
*/
export function permutationBuilder(
table: Table,
destTableName: string,
): PermutationBuilder {
// Extract the inner native table from the TypeScript wrapper
const localTable = table as LocalTable;
// Access inner through type assertion since it's private
const nativeBuilder = nativePermutationBuilder(
// biome-ignore lint/suspicious/noExplicitAny: need access to private variable
(localTable as any).inner,
destTableName,
);
return new PermutationBuilder(nativeBuilder);
}

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-arm64",
"version": "0.22.2-beta.2",
"version": "0.22.2",
"os": ["darwin"],
"cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-darwin-x64",
"version": "0.22.2-beta.2",
"version": "0.22.2",
"os": ["darwin"],
"cpu": ["x64"],
"main": "lancedb.darwin-x64.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.22.2-beta.2",
"version": "0.22.2",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-arm64-musl",
"version": "0.22.2-beta.2",
"version": "0.22.2",
"os": ["linux"],
"cpu": ["arm64"],
"main": "lancedb.linux-arm64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.22.2-beta.2",
"version": "0.22.2",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-linux-x64-musl",
"version": "0.22.2-beta.2",
"version": "0.22.2",
"os": ["linux"],
"cpu": ["x64"],
"main": "lancedb.linux-x64-musl.node",

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-arm64-msvc",
"version": "0.22.2-beta.2",
"version": "0.22.2",
"os": [
"win32"
],

View File

@@ -1,6 +1,6 @@
{
"name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.22.2-beta.2",
"version": "0.22.2",
"os": ["win32"],
"cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node",

View File

@@ -1,12 +1,12 @@
{
"name": "@lancedb/lancedb",
"version": "0.22.2-beta.2",
"version": "0.22.2",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "@lancedb/lancedb",
"version": "0.22.2-beta.2",
"version": "0.22.2",
"cpu": [
"x64",
"arm64"

View File

@@ -11,7 +11,7 @@
"ann"
],
"private": false,
"version": "0.22.2-beta.2",
"version": "0.22.2",
"main": "dist/index.js",
"exports": {
".": "./dist/index.js",

View File

@@ -6,6 +6,7 @@ use std::sync::Mutex;
use lancedb::index::scalar::{BTreeIndexBuilder, FtsIndexBuilder};
use lancedb::index::vector::{
IvfFlatIndexBuilder, IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder,
IvfRqIndexBuilder,
};
use lancedb::index::Index as LanceDbIndex;
use napi_derive::napi;
@@ -65,6 +66,36 @@ impl Index {
})
}
#[napi(factory)]
pub fn ivf_rq(
distance_type: Option<String>,
num_partitions: Option<u32>,
num_bits: Option<u32>,
max_iterations: Option<u32>,
sample_rate: Option<u32>,
) -> napi::Result<Self> {
let mut ivf_rq_builder = IvfRqIndexBuilder::default();
if let Some(distance_type) = distance_type {
let distance_type = parse_distance_type(distance_type)?;
ivf_rq_builder = ivf_rq_builder.distance_type(distance_type);
}
if let Some(num_partitions) = num_partitions {
ivf_rq_builder = ivf_rq_builder.num_partitions(num_partitions);
}
if let Some(num_bits) = num_bits {
ivf_rq_builder = ivf_rq_builder.num_bits(num_bits);
}
if let Some(max_iterations) = max_iterations {
ivf_rq_builder = ivf_rq_builder.max_iterations(max_iterations);
}
if let Some(sample_rate) = sample_rate {
ivf_rq_builder = ivf_rq_builder.sample_rate(sample_rate);
}
Ok(Self {
inner: Mutex::new(Some(LanceDbIndex::IvfRq(ivf_rq_builder))),
})
}
#[napi(factory)]
pub fn ivf_flat(
distance_type: Option<String>,

View File

@@ -12,6 +12,7 @@ mod header;
mod index;
mod iterator;
pub mod merge;
pub mod permutation;
mod query;
pub mod remote;
mod rerankers;

222
nodejs/src/permutation.rs Normal file
View File

@@ -0,0 +1,222 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::{Arc, Mutex};
use crate::{error::NapiErrorExt, table::Table};
use lancedb::dataloader::{
permutation::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
split::{SplitSizes, SplitStrategy},
};
use napi_derive::napi;
#[napi(object)]
pub struct SplitRandomOptions {
pub ratios: Option<Vec<f64>>,
pub counts: Option<Vec<i64>>,
pub fixed: Option<i64>,
pub seed: Option<i64>,
}
#[napi(object)]
pub struct SplitHashOptions {
pub columns: Vec<String>,
pub split_weights: Vec<i64>,
pub discard_weight: Option<i64>,
}
#[napi(object)]
pub struct SplitSequentialOptions {
pub ratios: Option<Vec<f64>>,
pub counts: Option<Vec<i64>>,
pub fixed: Option<i64>,
}
#[napi(object)]
pub struct ShuffleOptions {
pub seed: Option<i64>,
pub clump_size: Option<i64>,
}
pub struct PermutationBuilderState {
pub builder: Option<LancePermutationBuilder>,
pub dest_table_name: String,
}
#[napi]
pub struct PermutationBuilder {
state: Arc<Mutex<PermutationBuilderState>>,
}
impl PermutationBuilder {
pub fn new(builder: LancePermutationBuilder, dest_table_name: String) -> Self {
Self {
state: Arc::new(Mutex::new(PermutationBuilderState {
builder: Some(builder),
dest_table_name,
})),
}
}
}
impl PermutationBuilder {
fn modify(
&self,
func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder,
) -> napi::Result<Self> {
let mut state = self.state.lock().unwrap();
let builder = state
.builder
.take()
.ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?;
state.builder = Some(func(builder));
Ok(Self {
state: self.state.clone(),
})
}
}
#[napi]
impl PermutationBuilder {
/// Configure random splits
#[napi]
pub fn split_random(&self, options: SplitRandomOptions) -> napi::Result<Self> {
// Check that exactly one split type is provided
let split_args_count = [
options.ratios.is_some(),
options.counts.is_some(),
options.fixed.is_some(),
]
.iter()
.filter(|&&x| x)
.count();
if split_args_count != 1 {
return Err(napi::Error::from_reason(
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
));
}
let sizes = if let Some(ratios) = options.ratios {
SplitSizes::Percentages(ratios)
} else if let Some(counts) = options.counts {
SplitSizes::Counts(counts.into_iter().map(|c| c as u64).collect())
} else if let Some(fixed) = options.fixed {
SplitSizes::Fixed(fixed as u64)
} else {
unreachable!("One of the split arguments must be provided");
};
let seed = options.seed.map(|s| s as u64);
self.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
}
/// Configure hash-based splits
#[napi]
pub fn split_hash(&self, options: SplitHashOptions) -> napi::Result<Self> {
let split_weights = options
.split_weights
.into_iter()
.map(|w| w as u64)
.collect();
let discard_weight = options.discard_weight.unwrap_or(0) as u64;
self.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Hash {
columns: options.columns,
split_weights,
discard_weight,
})
})
}
/// Configure sequential splits
#[napi]
pub fn split_sequential(&self, options: SplitSequentialOptions) -> napi::Result<Self> {
// Check that exactly one split type is provided
let split_args_count = [
options.ratios.is_some(),
options.counts.is_some(),
options.fixed.is_some(),
]
.iter()
.filter(|&&x| x)
.count();
if split_args_count != 1 {
return Err(napi::Error::from_reason(
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
));
}
let sizes = if let Some(ratios) = options.ratios {
SplitSizes::Percentages(ratios)
} else if let Some(counts) = options.counts {
SplitSizes::Counts(counts.into_iter().map(|c| c as u64).collect())
} else if let Some(fixed) = options.fixed {
SplitSizes::Fixed(fixed as u64)
} else {
unreachable!("One of the split arguments must be provided");
};
self.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
}
/// Configure calculated splits
#[napi]
pub fn split_calculated(&self, calculation: String) -> napi::Result<Self> {
self.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Calculated { calculation })
})
}
/// Configure shuffling
#[napi]
pub fn shuffle(&self, options: ShuffleOptions) -> napi::Result<Self> {
let seed = options.seed.map(|s| s as u64);
let clump_size = options.clump_size.map(|c| c as u64);
self.modify(|builder| {
builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size })
})
}
/// Configure filtering
#[napi]
pub fn filter(&self, filter: String) -> napi::Result<Self> {
self.modify(|builder| builder.with_filter(filter))
}
/// Execute the permutation builder and create the table
#[napi]
pub async fn execute(&self) -> napi::Result<Table> {
let (builder, dest_table_name) = {
let mut state = self.state.lock().unwrap();
let builder = state
.builder
.take()
.ok_or_else(|| napi::Error::from_reason("Builder already consumed"))?;
let dest_table_name = std::mem::take(&mut state.dest_table_name);
(builder, dest_table_name)
};
let table = builder.build(&dest_table_name).await.default_error()?;
Ok(Table::new(table))
}
}
/// Create a permutation builder for the given table
#[napi]
pub fn permutation_builder(
table: &crate::table::Table,
dest_table_name: String,
) -> napi::Result<PermutationBuilder> {
use lancedb::dataloader::permutation::PermutationBuilder as LancePermutationBuilder;
let inner_table = table.inner_ref()?.clone();
let inner_builder = LancePermutationBuilder::new(inner_table);
Ok(PermutationBuilder::new(inner_builder, dest_table_name))
}

View File

@@ -26,7 +26,7 @@ pub struct Table {
}
impl Table {
fn inner_ref(&self) -> napi::Result<&LanceDbTable> {
pub(crate) fn inner_ref(&self) -> napi::Result<&LanceDbTable> {
self.inner
.as_ref()
.ok_or_else(|| napi::Error::from_reason(format!("Table {} is closed", self.name)))

View File

@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "0.25.2"
current_version = "0.25.3-beta.0"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.
@@ -24,6 +24,19 @@ commit = true
message = "Bump version: {current_version} → {new_version}"
commit_args = ""
# Update Cargo.lock after version bump
pre_commit_hooks = [
"""
cd python && cargo update -p lancedb-python
if git diff --quiet ../Cargo.lock; then
echo "Cargo.lock unchanged"
else
git add ../Cargo.lock
echo "Updated and staged Cargo.lock"
fi
""",
]
[tool.bumpversion.parts.pre_l]
values = ["beta", "final"]
optional_value = "final"

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb-python"
version = "0.25.2"
version = "0.25.3-beta.0"
edition.workspace = true
description = "Python bindings for LanceDB"
license.workspace = true

View File

@@ -296,3 +296,34 @@ class AlterColumnsResult:
class DropColumnsResult:
version: int
class AsyncPermutationBuilder:
def select(self, projections: Dict[str, str]) -> "AsyncPermutationBuilder": ...
def split_random(
self,
*,
ratios: Optional[List[float]] = None,
counts: Optional[List[int]] = None,
fixed: Optional[int] = None,
seed: Optional[int] = None,
) -> "AsyncPermutationBuilder": ...
def split_hash(
self, columns: List[str], split_weights: List[int], *, discard_weight: int = 0
) -> "AsyncPermutationBuilder": ...
def split_sequential(
self,
*,
ratios: Optional[List[float]] = None,
counts: Optional[List[int]] = None,
fixed: Optional[int] = None,
) -> "AsyncPermutationBuilder": ...
def split_calculated(self, calculation: str) -> "AsyncPermutationBuilder": ...
def shuffle(
self, seed: Optional[int], clump_size: Optional[int]
) -> "AsyncPermutationBuilder": ...
def filter(self, filter: str) -> "AsyncPermutationBuilder": ...
async def execute(self) -> Table: ...
def async_permutation_builder(
table: Table, dest_table_name: str
) -> AsyncPermutationBuilder: ...

View File

@@ -5,6 +5,7 @@
from __future__ import annotations
from abc import abstractmethod
from datetime import timedelta
from pathlib import Path
import sys
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Union
@@ -40,7 +41,6 @@ import deprecation
if TYPE_CHECKING:
import pyarrow as pa
from .pydantic import LanceModel
from datetime import timedelta
from ._lancedb import Connection as LanceDbConnection
from .common import DATA, URI
@@ -452,7 +452,12 @@ class LanceDBConnection(DBConnection):
read_consistency_interval: Optional[timedelta] = None,
storage_options: Optional[Dict[str, str]] = None,
session: Optional[Session] = None,
_inner: Optional[LanceDbConnection] = None,
):
if _inner is not None:
self._conn = _inner
return
if not isinstance(uri, Path):
scheme = get_uri_scheme(uri)
is_local = isinstance(uri, Path) or scheme == "file"
@@ -461,11 +466,6 @@ class LanceDBConnection(DBConnection):
uri = Path(uri)
uri = uri.expanduser().absolute()
Path(uri).mkdir(parents=True, exist_ok=True)
self._uri = str(uri)
self._entered = False
self.read_consistency_interval = read_consistency_interval
self.storage_options = storage_options
self.session = session
if read_consistency_interval is not None:
read_consistency_interval_secs = read_consistency_interval.total_seconds()
@@ -484,10 +484,32 @@ class LanceDBConnection(DBConnection):
session,
)
# TODO: It would be nice if we didn't store self.storage_options but it is
# currently used by the LanceTable.to_lance method. This doesn't _really_
# work because some paths like LanceDBConnection.from_inner will lose the
# storage_options. Also, this class really shouldn't be holding any state
# beyond _conn.
self.storage_options = storage_options
self._conn = AsyncConnection(LOOP.run(do_connect()))
@property
def read_consistency_interval(self) -> Optional[timedelta]:
return LOOP.run(self._conn.get_read_consistency_interval())
@property
def session(self) -> Optional[Session]:
return self._conn.session
@property
def uri(self) -> str:
return self._conn.uri
@classmethod
def from_inner(cls, inner: LanceDbConnection):
return cls(None, _inner=inner)
def __repr__(self) -> str:
val = f"{self.__class__.__name__}(uri={self._uri!r}"
val = f"{self.__class__.__name__}(uri={self._conn.uri!r}"
if self.read_consistency_interval is not None:
val += f", read_consistency_interval={repr(self.read_consistency_interval)}"
val += ")"
@@ -497,6 +519,10 @@ class LanceDBConnection(DBConnection):
conn = AsyncConnection(await lancedb_connect(self.uri))
return await conn.table_names(start_after=start_after, limit=limit)
@property
def _inner(self) -> LanceDbConnection:
return self._conn._inner
@override
def list_namespaces(
self,
@@ -856,6 +882,13 @@ class AsyncConnection(object):
def uri(self) -> str:
return self._inner.uri
async def get_read_consistency_interval(self) -> Optional[timedelta]:
interval_secs = await self._inner.get_read_consistency_interval()
if interval_secs is not None:
return timedelta(seconds=interval_secs)
else:
return None
async def list_namespaces(
self,
namespace: List[str] = [],

View File

@@ -605,9 +605,53 @@ class IvfPq:
target_partition_size: Optional[int] = None
@dataclass
class IvfRq:
"""Describes an IVF RQ Index
IVF-RQ (Residual Quantization) stores a compressed copy of each vector using
residual quantization and organizes them into IVF partitions. Parameters
largely mirror IVF-PQ for consistency.
Attributes
----------
distance_type: str, default "l2"
Distance metric used to train the index and for quantization.
The following distance types are available:
"l2" - Euclidean distance.
"cosine" - Cosine distance.
"dot" - Dot product.
num_partitions: int, default sqrt(num_rows)
Number of IVF partitions to create.
num_bits: int, default 1
Number of bits to encode each dimension.
max_iterations: int, default 50
Max iterations to train kmeans when computing IVF partitions.
sample_rate: int, default 256
Controls the number of training vectors: sample_rate * num_partitions.
target_partition_size, default is 8192
Target size of each partition.
"""
distance_type: Literal["l2", "cosine", "dot"] = "l2"
num_partitions: Optional[int] = None
num_bits: int = 1
max_iterations: int = 50
sample_rate: int = 256
target_partition_size: Optional[int] = None
__all__ = [
"BTree",
"IvfPq",
"IvfRq",
"IvfFlat",
"HnswPq",
"HnswSq",

View File

@@ -0,0 +1,72 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from ._lancedb import async_permutation_builder
from .table import LanceTable
from .background_loop import LOOP
from typing import Optional
class PermutationBuilder:
def __init__(self, table: LanceTable, dest_table_name: str):
self._async = async_permutation_builder(table, dest_table_name)
def select(self, projections: dict[str, str]) -> "PermutationBuilder":
self._async.select(projections)
return self
def split_random(
self,
*,
ratios: Optional[list[float]] = None,
counts: Optional[list[int]] = None,
fixed: Optional[int] = None,
seed: Optional[int] = None,
) -> "PermutationBuilder":
self._async.split_random(ratios=ratios, counts=counts, fixed=fixed, seed=seed)
return self
def split_hash(
self,
columns: list[str],
split_weights: list[int],
*,
discard_weight: Optional[int] = None,
) -> "PermutationBuilder":
self._async.split_hash(columns, split_weights, discard_weight=discard_weight)
return self
def split_sequential(
self,
*,
ratios: Optional[list[float]] = None,
counts: Optional[list[int]] = None,
fixed: Optional[int] = None,
) -> "PermutationBuilder":
self._async.split_sequential(ratios=ratios, counts=counts, fixed=fixed)
return self
def split_calculated(self, calculation: str) -> "PermutationBuilder":
self._async.split_calculated(calculation)
return self
def shuffle(
self, *, seed: Optional[int] = None, clump_size: Optional[int] = None
) -> "PermutationBuilder":
self._async.shuffle(seed=seed, clump_size=clump_size)
return self
def filter(self, filter: str) -> "PermutationBuilder":
self._async.filter(filter)
return self
def execute(self) -> LanceTable:
async def do_execute():
inner_tbl = await self._async.execute()
return LanceTable.from_inner(inner_tbl)
return LOOP.run(do_execute())
def permutation_builder(table: LanceTable, dest_table_name: str) -> PermutationBuilder:
return PermutationBuilder(table, dest_table_name)

View File

@@ -44,7 +44,7 @@ import numpy as np
from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
from .index import BTree, IvfFlat, IvfPq, Bitmap, IvfRq, LabelList, HnswPq, HnswSq, FTS
from .merge import LanceMergeInsertBuilder
from .pydantic import LanceModel, model_to_dict
from .query import (
@@ -74,6 +74,7 @@ from .index import lang_mapping
if TYPE_CHECKING:
from .db import LanceDBConnection
from ._lancedb import (
Table as LanceDBTable,
OptimizeStats,
@@ -88,7 +89,6 @@ if TYPE_CHECKING:
MergeResult,
UpdateResult,
)
from .db import LanceDBConnection
from .index import IndexConfig
import pandas
import PIL
@@ -1707,22 +1707,38 @@ class LanceTable(Table):
namespace: List[str] = [],
storage_options: Optional[Dict[str, str]] = None,
index_cache_size: Optional[int] = None,
_async: AsyncTable = None,
):
self._conn = connection
self._namespace = namespace
self._table = LOOP.run(
connection._conn.open_table(
name,
namespace=namespace,
storage_options=storage_options,
index_cache_size=index_cache_size,
if _async is not None:
self._table = _async
else:
self._table = LOOP.run(
connection._conn.open_table(
name,
namespace=namespace,
storage_options=storage_options,
index_cache_size=index_cache_size,
)
)
)
@property
def name(self) -> str:
return self._table.name
@classmethod
def from_inner(cls, tbl: LanceDBTable):
from .db import LanceDBConnection
async_tbl = AsyncTable(tbl)
conn = LanceDBConnection.from_inner(tbl.database())
return cls(
conn,
async_tbl.name,
_async=async_tbl,
)
@classmethod
def open(cls, db, name, *, namespace: List[str] = [], **kwargs):
tbl = cls(db, name, namespace=namespace, **kwargs)
@@ -1991,7 +2007,7 @@ class LanceTable(Table):
index_cache_size: Optional[int] = None,
num_bits: int = 8,
index_type: Literal[
"IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
"IVF_FLAT", "IVF_PQ", "IVF_RQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
] = "IVF_PQ",
max_iterations: int = 50,
sample_rate: int = 256,
@@ -2039,6 +2055,15 @@ class LanceTable(Table):
sample_rate=sample_rate,
target_partition_size=target_partition_size,
)
elif index_type == "IVF_RQ":
config = IvfRq(
distance_type=metric,
num_partitions=num_partitions,
num_bits=num_bits,
max_iterations=max_iterations,
sample_rate=sample_rate,
target_partition_size=target_partition_size,
)
elif index_type == "IVF_HNSW_PQ":
config = HnswPq(
distance_type=metric,
@@ -2747,6 +2772,10 @@ class LanceTable(Table):
self._table._do_merge(merge, new_data, on_bad_vectors, fill_value)
)
@property
def _inner(self) -> LanceDBTable:
return self._table._inner
@deprecation.deprecated(
deprecated_in="0.21.0",
current_version=__version__,
@@ -3330,7 +3359,7 @@ class AsyncTable:
*,
replace: Optional[bool] = None,
config: Optional[
Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS]
Union[IvfFlat, IvfPq, IvfRq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS]
] = None,
wait_timeout: Optional[timedelta] = None,
name: Optional[str] = None,
@@ -3369,11 +3398,12 @@ class AsyncTable:
"""
if config is not None:
if not isinstance(
config, (IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS)
config,
(IvfFlat, IvfPq, IvfRq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS),
):
raise TypeError(
"config must be an instance of IvfPq, HnswPq, HnswSq, BTree,"
" Bitmap, LabelList, or FTS"
"config must be an instance of IvfPq, IvfRq, HnswPq, HnswSq, BTree,"
" Bitmap, LabelList, or FTS, but got " + str(type(config))
)
try:
await self._inner.create_index(

View File

@@ -18,10 +18,17 @@ AddMode = Literal["append", "overwrite"]
CreateMode = Literal["create", "overwrite"]
# Index type literals
VectorIndexType = Literal["IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"]
VectorIndexType = Literal["IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ", "IVF_RQ"]
ScalarIndexType = Literal["BTREE", "BITMAP", "LABEL_LIST"]
IndexType = Literal[
"IVF_PQ", "IVF_HNSW_PQ", "IVF_HNSW_SQ", "FTS", "BTREE", "BITMAP", "LABEL_LIST"
"IVF_PQ",
"IVF_HNSW_PQ",
"IVF_HNSW_SQ",
"FTS",
"BTREE",
"BITMAP",
"LABEL_LIST",
"IVF_RQ",
]
# Tokenizer literals

View File

@@ -8,7 +8,17 @@ import pyarrow as pa
import pytest
import pytest_asyncio
from lancedb import AsyncConnection, AsyncTable, connect_async
from lancedb.index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
from lancedb.index import (
BTree,
IvfFlat,
IvfPq,
IvfRq,
Bitmap,
LabelList,
HnswPq,
HnswSq,
FTS,
)
@pytest_asyncio.fixture
@@ -195,6 +205,16 @@ async def test_create_4bit_ivfpq_index(some_table: AsyncTable):
assert stats.loss >= 0.0
@pytest.mark.asyncio
async def test_create_ivfrq_index(some_table: AsyncTable):
await some_table.create_index("vector", config=IvfRq(num_bits=1))
indices = await some_table.list_indices()
assert len(indices) == 1
assert indices[0].index_type == "IvfRq"
assert indices[0].columns == ["vector"]
assert indices[0].name == "vector_idx"
@pytest.mark.asyncio
async def test_create_hnswpq_index(some_table: AsyncTable):
await some_table.create_index("vector", config=HnswPq(num_partitions=10))

View File

@@ -0,0 +1,496 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import pyarrow as pa
import pytest
from lancedb.permutation import permutation_builder
def test_split_random_ratios(mem_db):
"""Test random splitting with ratios."""
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_random(ratios=[0.3, 0.7])
.execute()
)
# Check that the table was created and has data
assert permutation_tbl.count_rows() == 100
# Check that split_id column exists and has correct values
data = permutation_tbl.search(None).to_arrow().to_pydict()
split_ids = data["split_id"]
assert set(split_ids) == {0, 1}
# Check approximate split sizes (allowing for rounding)
split_0_count = split_ids.count(0)
split_1_count = split_ids.count(1)
assert 25 <= split_0_count <= 35 # ~30% ± tolerance
assert 65 <= split_1_count <= 75 # ~70% ± tolerance
def test_split_random_counts(mem_db):
"""Test random splitting with absolute counts."""
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_random(counts=[20, 30])
.execute()
)
# Check that we have exactly the requested counts
assert permutation_tbl.count_rows() == 50
data = permutation_tbl.search(None).to_arrow().to_pydict()
split_ids = data["split_id"]
assert split_ids.count(0) == 20
assert split_ids.count(1) == 30
def test_split_random_fixed(mem_db):
"""Test random splitting with fixed number of splits."""
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation").split_random(fixed=4).execute()
)
# Check that we have 4 splits with 25 rows each
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
split_ids = data["split_id"]
assert set(split_ids) == {0, 1, 2, 3}
for split_id in range(4):
assert split_ids.count(split_id) == 25
def test_split_random_with_seed(mem_db):
"""Test that seeded random splits are reproducible."""
tbl = mem_db.create_table("test_table", pa.table({"x": range(50), "y": range(50)}))
# Create two identical permutations with same seed
perm1 = (
permutation_builder(tbl, "perm1")
.split_random(ratios=[0.6, 0.4], seed=42)
.execute()
)
perm2 = (
permutation_builder(tbl, "perm2")
.split_random(ratios=[0.6, 0.4], seed=42)
.execute()
)
# Results should be identical
data1 = perm1.search(None).to_arrow().to_pydict()
data2 = perm2.search(None).to_arrow().to_pydict()
assert data1["row_id"] == data2["row_id"]
assert data1["split_id"] == data2["split_id"]
def test_split_hash(mem_db):
"""Test hash-based splitting."""
tbl = mem_db.create_table(
"test_table",
pa.table(
{
"id": range(100),
"category": (["A", "B", "C"] * 34)[:100], # Repeating pattern
"value": range(100),
}
),
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_hash(["category"], [1, 1], discard_weight=0)
.execute()
)
# Should have all 100 rows (no discard)
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
split_ids = data["split_id"]
assert set(split_ids) == {0, 1}
# Verify that each split has roughly 50 rows (allowing for hash variance)
split_0_count = split_ids.count(0)
split_1_count = split_ids.count(1)
assert 30 <= split_0_count <= 70 # ~50 ± 20 tolerance for hash distribution
assert 30 <= split_1_count <= 70 # ~50 ± 20 tolerance for hash distribution
# Hash splits should be deterministic - same category should go to same split
# Let's verify by creating another permutation and checking consistency
perm2 = (
permutation_builder(tbl, "test_permutation2")
.split_hash(["category"], [1, 1], discard_weight=0)
.execute()
)
data2 = perm2.search(None).to_arrow().to_pydict()
assert data["split_id"] == data2["split_id"] # Should be identical
def test_split_hash_with_discard(mem_db):
"""Test hash-based splitting with discard weight."""
tbl = mem_db.create_table(
"test_table",
pa.table({"id": range(100), "category": ["A", "B"] * 50, "value": range(100)}),
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_hash(["category"], [1, 1], discard_weight=2) # Should discard ~50%
.execute()
)
# Should have fewer than 100 rows due to discard
row_count = permutation_tbl.count_rows()
assert row_count < 100
assert row_count > 0 # But not empty
def test_split_sequential(mem_db):
"""Test sequential splitting."""
tbl = mem_db.create_table(
"test_table", pa.table({"x": range(100), "y": range(100)})
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_sequential(counts=[30, 40])
.execute()
)
assert permutation_tbl.count_rows() == 70
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
split_ids = data["split_id"]
# Sequential should maintain order
assert row_ids == sorted(row_ids)
# First 30 should be split 0, next 40 should be split 1
assert split_ids[:30] == [0] * 30
assert split_ids[30:] == [1] * 40
def test_split_calculated(mem_db):
"""Test calculated splitting."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(100), "value": range(100)})
)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_calculated("id % 3") # Split based on id modulo 3
.execute()
)
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
split_ids = data["split_id"]
# Verify the calculation: each row's split_id should equal row_id % 3
for i, (row_id, split_id) in enumerate(zip(row_ids, split_ids)):
assert split_id == row_id % 3
def test_split_error_cases(mem_db):
"""Test error handling for invalid split parameters."""
tbl = mem_db.create_table("test_table", pa.table({"x": range(10), "y": range(10)}))
# Test split_random with no parameters
with pytest.raises(Exception):
permutation_builder(tbl, "error1").split_random().execute()
# Test split_random with multiple parameters
with pytest.raises(Exception):
permutation_builder(tbl, "error2").split_random(
ratios=[0.5, 0.5], counts=[5, 5]
).execute()
# Test split_sequential with no parameters
with pytest.raises(Exception):
permutation_builder(tbl, "error3").split_sequential().execute()
# Test split_sequential with multiple parameters
with pytest.raises(Exception):
permutation_builder(tbl, "error4").split_sequential(
ratios=[0.5, 0.5], fixed=2
).execute()
def test_shuffle_no_seed(mem_db):
"""Test shuffling without a seed."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(100), "value": range(100)})
)
# Create a permutation with shuffling (no seed)
permutation_tbl = permutation_builder(tbl, "test_permutation").shuffle().execute()
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
# Row IDs should not be in sequential order due to shuffling
# This is probabilistic but with 100 rows, it's extremely unlikely they'd stay
# in order
assert row_ids != list(range(100))
def test_shuffle_with_seed(mem_db):
"""Test that shuffling with a seed is reproducible."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(50), "value": range(50)})
)
# Create two identical permutations with same shuffle seed
perm1 = permutation_builder(tbl, "perm1").shuffle(seed=42).execute()
perm2 = permutation_builder(tbl, "perm2").shuffle(seed=42).execute()
# Results should be identical due to same seed
data1 = perm1.search(None).to_arrow().to_pydict()
data2 = perm2.search(None).to_arrow().to_pydict()
assert data1["row_id"] == data2["row_id"]
assert data1["split_id"] == data2["split_id"]
def test_shuffle_with_clump_size(mem_db):
"""Test shuffling with clump size."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(100), "value": range(100)})
)
# Create a permutation with shuffling using clumps
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.shuffle(clump_size=10) # 10-row clumps
.execute()
)
assert permutation_tbl.count_rows() == 100
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
for i in range(10):
start = row_ids[i * 10]
assert row_ids[i * 10 : (i + 1) * 10] == list(range(start, start + 10))
def test_shuffle_different_seeds(mem_db):
"""Test that different seeds produce different shuffle orders."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(50), "value": range(50)})
)
# Create two permutations with different shuffle seeds
perm1 = (
permutation_builder(tbl, "perm1")
.split_random(fixed=2)
.shuffle(seed=42)
.execute()
)
perm2 = (
permutation_builder(tbl, "perm2")
.split_random(fixed=2)
.shuffle(seed=123)
.execute()
)
# Results should be different due to different seeds
data1 = perm1.search(None).to_arrow().to_pydict()
data2 = perm2.search(None).to_arrow().to_pydict()
# Row order should be different
assert data1["row_id"] != data2["row_id"]
def test_shuffle_combined_with_splits(mem_db):
"""Test shuffling combined with different split strategies."""
tbl = mem_db.create_table(
"test_table",
pa.table(
{
"id": range(100),
"category": (["A", "B", "C"] * 34)[:100],
"value": range(100),
}
),
)
# Test shuffle with random splits
perm_random = (
permutation_builder(tbl, "perm_random")
.split_random(ratios=[0.6, 0.4], seed=42)
.shuffle(seed=123, clump_size=None)
.execute()
)
# Test shuffle with hash splits
perm_hash = (
permutation_builder(tbl, "perm_hash")
.split_hash(["category"], [1, 1], discard_weight=0)
.shuffle(seed=456, clump_size=5)
.execute()
)
# Test shuffle with sequential splits
perm_sequential = (
permutation_builder(tbl, "perm_sequential")
.split_sequential(counts=[40, 35])
.shuffle(seed=789, clump_size=None)
.execute()
)
# Verify all permutations work and have expected properties
assert perm_random.count_rows() == 100
assert perm_hash.count_rows() == 100
assert perm_sequential.count_rows() == 75
# Verify shuffle affected the order
data_random = perm_random.search(None).to_arrow().to_pydict()
data_sequential = perm_sequential.search(None).to_arrow().to_pydict()
assert data_random["row_id"] != list(range(100))
assert data_sequential["row_id"] != list(range(75))
def test_no_shuffle_maintains_order(mem_db):
"""Test that not calling shuffle maintains the original order."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(50), "value": range(50)})
)
# Create permutation without shuffle (should maintain some order)
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.split_sequential(counts=[25, 25]) # Sequential maintains order
.execute()
)
assert permutation_tbl.count_rows() == 50
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
# With sequential splits and no shuffle, should maintain order
assert row_ids == list(range(50))
def test_filter_basic(mem_db):
"""Test basic filtering functionality."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(100), "value": range(100, 200)})
)
# Filter to only include rows where id < 50
permutation_tbl = (
permutation_builder(tbl, "test_permutation").filter("id < 50").execute()
)
assert permutation_tbl.count_rows() == 50
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
# All row_ids should be less than 50
assert all(row_id < 50 for row_id in row_ids)
def test_filter_with_splits(mem_db):
"""Test filtering combined with split strategies."""
tbl = mem_db.create_table(
"test_table",
pa.table(
{
"id": range(100),
"category": (["A", "B", "C"] * 34)[:100],
"value": range(100),
}
),
)
# Filter to only category A and B, then split
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.filter("category IN ('A', 'B')")
.split_random(ratios=[0.5, 0.5])
.execute()
)
# Should have fewer than 100 rows due to filtering
row_count = permutation_tbl.count_rows()
assert row_count == 67
data = permutation_tbl.search(None).to_arrow().to_pydict()
categories = data["category"]
# All categories should be A or B
assert all(cat in ["A", "B"] for cat in categories)
def test_filter_with_shuffle(mem_db):
"""Test filtering combined with shuffling."""
tbl = mem_db.create_table(
"test_table",
pa.table(
{
"id": range(100),
"category": (["A", "B", "C", "D"] * 25)[:100],
"value": range(100),
}
),
)
# Filter and shuffle
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.filter("category IN ('A', 'C')")
.shuffle(seed=42)
.execute()
)
row_count = permutation_tbl.count_rows()
assert row_count == 50 # Should have 50 rows (A and C categories)
data = permutation_tbl.search(None).to_arrow().to_pydict()
row_ids = data["row_id"]
assert row_ids != sorted(row_ids)
def test_filter_empty_result(mem_db):
"""Test filtering that results in empty set."""
tbl = mem_db.create_table(
"test_table", pa.table({"id": range(10), "value": range(10)})
)
# Filter that matches nothing
permutation_tbl = (
permutation_builder(tbl, "test_permutation")
.filter("value > 100") # No values > 100 in our data
.execute()
)
assert permutation_tbl.count_rows() == 0

View File

@@ -4,7 +4,10 @@
use std::{collections::HashMap, sync::Arc, time::Duration};
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
use lancedb::{connection::Connection as LanceConnection, database::CreateTableMode};
use lancedb::{
connection::Connection as LanceConnection,
database::{CreateTableMode, ReadConsistency},
};
use pyo3::{
exceptions::{PyRuntimeError, PyValueError},
pyclass, pyfunction, pymethods, Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
@@ -23,7 +26,7 @@ impl Connection {
Self { inner: Some(inner) }
}
fn get_inner(&self) -> PyResult<&LanceConnection> {
pub(crate) fn get_inner(&self) -> PyResult<&LanceConnection> {
self.inner
.as_ref()
.ok_or_else(|| PyRuntimeError::new_err("Connection is closed"))
@@ -63,6 +66,18 @@ impl Connection {
self.get_inner().map(|inner| inner.uri().to_string())
}
#[pyo3(signature = ())]
pub fn get_read_consistency_interval(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.get_inner()?.clone();
future_into_py(self_.py(), async move {
Ok(match inner.read_consistency().await.infer_error()? {
ReadConsistency::Manual => None,
ReadConsistency::Eventual(duration) => Some(duration.as_secs_f64()),
ReadConsistency::Strong => Some(0.0_f64),
})
})
}
#[pyo3(signature = (namespace=vec![], start_after=None, limit=None))]
pub fn table_names(
self_: PyRef<'_, Self>,

View File

@@ -1,7 +1,7 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use lancedb::index::vector::IvfFlatIndexBuilder;
use lancedb::index::vector::{IvfFlatIndexBuilder, IvfRqIndexBuilder};
use lancedb::index::{
scalar::{BTreeIndexBuilder, FtsIndexBuilder},
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
@@ -87,6 +87,22 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
}
Ok(LanceDbIndex::IvfPq(ivf_pq_builder))
},
"IvfRq" => {
let params = source.extract::<IvfRqParams>()?;
let distance_type = parse_distance_type(params.distance_type)?;
let mut ivf_rq_builder = IvfRqIndexBuilder::default()
.distance_type(distance_type)
.max_iterations(params.max_iterations)
.sample_rate(params.sample_rate)
.num_bits(params.num_bits);
if let Some(num_partitions) = params.num_partitions {
ivf_rq_builder = ivf_rq_builder.num_partitions(num_partitions);
}
if let Some(target_partition_size) = params.target_partition_size {
ivf_rq_builder = ivf_rq_builder.target_partition_size(target_partition_size);
}
Ok(LanceDbIndex::IvfRq(ivf_rq_builder))
},
"HnswPq" => {
let params = source.extract::<IvfHnswPqParams>()?;
let distance_type = parse_distance_type(params.distance_type)?;
@@ -170,6 +186,16 @@ struct IvfPqParams {
target_partition_size: Option<u32>,
}
#[derive(FromPyObject)]
struct IvfRqParams {
distance_type: String,
num_partitions: Option<u32>,
num_bits: u32,
max_iterations: u32,
sample_rate: u32,
target_partition_size: Option<u32>,
}
#[derive(FromPyObject)]
struct IvfHnswPqParams {
distance_type: String,

View File

@@ -5,6 +5,7 @@ use arrow::RecordBatchStream;
use connection::{connect, Connection};
use env_logger::Env;
use index::IndexConfig;
use permutation::PyAsyncPermutationBuilder;
use pyo3::{
pymodule,
types::{PyModule, PyModuleMethods},
@@ -22,6 +23,7 @@ pub mod connection;
pub mod error;
pub mod header;
pub mod index;
pub mod permutation;
pub mod query;
pub mod session;
pub mod table;
@@ -49,7 +51,9 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<DeleteResult>()?;
m.add_class::<DropColumnsResult>()?;
m.add_class::<UpdateResult>()?;
m.add_class::<PyAsyncPermutationBuilder>()?;
m.add_function(wrap_pyfunction!(connect, m)?)?;
m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?;
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
Ok(())

177
python/src/permutation.rs Normal file
View File

@@ -0,0 +1,177 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::{Arc, Mutex};
use crate::{error::PythonErrorExt, table::Table};
use lancedb::dataloader::{
permutation::{PermutationBuilder as LancePermutationBuilder, ShuffleStrategy},
split::{SplitSizes, SplitStrategy},
};
use pyo3::{
exceptions::PyRuntimeError, pyclass, pymethods, types::PyAnyMethods, Bound, PyAny, PyRefMut,
PyResult,
};
use pyo3_async_runtimes::tokio::future_into_py;
/// Create a permutation builder for the given table
#[pyo3::pyfunction]
pub fn async_permutation_builder(
table: Bound<'_, PyAny>,
dest_table_name: String,
) -> PyResult<PyAsyncPermutationBuilder> {
let table = table.getattr("_inner")?.downcast_into::<Table>()?;
let inner_table = table.borrow().inner_ref()?.clone();
let inner_builder = LancePermutationBuilder::new(inner_table);
Ok(PyAsyncPermutationBuilder {
state: Arc::new(Mutex::new(PyAsyncPermutationBuilderState {
builder: Some(inner_builder),
dest_table_name,
})),
})
}
struct PyAsyncPermutationBuilderState {
builder: Option<LancePermutationBuilder>,
dest_table_name: String,
}
#[pyclass(name = "AsyncPermutationBuilder")]
pub struct PyAsyncPermutationBuilder {
state: Arc<Mutex<PyAsyncPermutationBuilderState>>,
}
impl PyAsyncPermutationBuilder {
fn modify(
&self,
func: impl FnOnce(LancePermutationBuilder) -> LancePermutationBuilder,
) -> PyResult<Self> {
let mut state = self.state.lock().unwrap();
let builder = state
.builder
.take()
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
state.builder = Some(func(builder));
Ok(Self {
state: self.state.clone(),
})
}
}
#[pymethods]
impl PyAsyncPermutationBuilder {
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None, seed=None))]
pub fn split_random(
slf: PyRefMut<'_, Self>,
ratios: Option<Vec<f64>>,
counts: Option<Vec<u64>>,
fixed: Option<u64>,
seed: Option<u64>,
) -> PyResult<Self> {
// Check that exactly one split type is provided
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
.iter()
.filter(|&&x| x)
.count();
if split_args_count != 1 {
return Err(pyo3::exceptions::PyValueError::new_err(
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
));
}
let sizes = if let Some(ratios) = ratios {
SplitSizes::Percentages(ratios)
} else if let Some(counts) = counts {
SplitSizes::Counts(counts)
} else if let Some(fixed) = fixed {
SplitSizes::Fixed(fixed)
} else {
unreachable!("One of the split arguments must be provided");
};
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Random { seed, sizes }))
}
#[pyo3(signature = (columns, split_weights, *, discard_weight=0))]
pub fn split_hash(
slf: PyRefMut<'_, Self>,
columns: Vec<String>,
split_weights: Vec<u64>,
discard_weight: u64,
) -> PyResult<Self> {
slf.modify(|builder| {
builder.with_split_strategy(SplitStrategy::Hash {
columns,
split_weights,
discard_weight,
})
})
}
#[pyo3(signature = (*, ratios=None, counts=None, fixed=None))]
pub fn split_sequential(
slf: PyRefMut<'_, Self>,
ratios: Option<Vec<f64>>,
counts: Option<Vec<u64>>,
fixed: Option<u64>,
) -> PyResult<Self> {
// Check that exactly one split type is provided
let split_args_count = [ratios.is_some(), counts.is_some(), fixed.is_some()]
.iter()
.filter(|&&x| x)
.count();
if split_args_count != 1 {
return Err(pyo3::exceptions::PyValueError::new_err(
"Exactly one of 'ratios', 'counts', or 'fixed' must be provided",
));
}
let sizes = if let Some(ratios) = ratios {
SplitSizes::Percentages(ratios)
} else if let Some(counts) = counts {
SplitSizes::Counts(counts)
} else if let Some(fixed) = fixed {
SplitSizes::Fixed(fixed)
} else {
unreachable!("One of the split arguments must be provided");
};
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Sequential { sizes }))
}
pub fn split_calculated(slf: PyRefMut<'_, Self>, calculation: String) -> PyResult<Self> {
slf.modify(|builder| builder.with_split_strategy(SplitStrategy::Calculated { calculation }))
}
pub fn shuffle(
slf: PyRefMut<'_, Self>,
seed: Option<u64>,
clump_size: Option<u64>,
) -> PyResult<Self> {
slf.modify(|builder| {
builder.with_shuffle_strategy(ShuffleStrategy::Random { seed, clump_size })
})
}
pub fn filter(slf: PyRefMut<'_, Self>, filter: String) -> PyResult<Self> {
slf.modify(|builder| builder.with_filter(filter))
}
pub fn execute(slf: PyRefMut<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let mut state = slf.state.lock().unwrap();
let builder = state
.builder
.take()
.ok_or_else(|| PyRuntimeError::new_err("Builder already consumed"))?;
let dest_table_name = std::mem::take(&mut state.dest_table_name);
future_into_py(slf.py(), async move {
let table = builder.build(&dest_table_name).await.infer_error()?;
Ok(Table::new(table))
})
}
}

View File

@@ -3,6 +3,7 @@
use std::{collections::HashMap, sync::Arc};
use crate::{
connection::Connection,
error::PythonErrorExt,
index::{extract_index_params, IndexConfig},
query::{Query, TakeQuery},
@@ -249,7 +250,7 @@ impl Table {
}
impl Table {
fn inner_ref(&self) -> PyResult<&LanceDbTable> {
pub(crate) fn inner_ref(&self) -> PyResult<&LanceDbTable> {
self.inner
.as_ref()
.ok_or_else(|| PyRuntimeError::new_err(format!("Table {} is closed", self.name)))
@@ -272,6 +273,13 @@ impl Table {
self.inner.take();
}
pub fn database(&self) -> PyResult<Connection> {
let inner = self.inner_ref()?.clone();
let inner_connection =
lancedb::Connection::new(inner.database().clone(), inner.embedding_registry().clone());
Ok(Connection::new(inner_connection))
}
pub fn schema(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {

View File

@@ -1,2 +1,2 @@
[toolchain]
channel = "1.86.0"
channel = "1.90.0"

View File

@@ -1,6 +1,6 @@
[package]
name = "lancedb"
version = "0.22.2-beta.2"
version = "0.22.2"
edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true
@@ -11,6 +11,7 @@ rust-version.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
ahash = { workspace = true }
arrow = { workspace = true }
arrow-array = { workspace = true }
arrow-data = { workspace = true }
@@ -24,12 +25,16 @@ datafusion-common.workspace = true
datafusion-execution.workspace = true
datafusion-expr.workspace = true
datafusion-physical-plan.workspace = true
datafusion.workspace = true
object_store = { workspace = true }
snafu = { workspace = true }
half = { workspace = true }
lazy_static.workspace = true
lance = { workspace = true }
lance-core = { workspace = true }
lance-datafusion.workspace = true
lance-datagen = { workspace = true }
lance-file = { workspace = true }
lance-io = { workspace = true }
lance-index = { workspace = true }
lance-table = { workspace = true }
@@ -46,11 +51,13 @@ bytes = "1"
futures.workspace = true
num-traits.workspace = true
url.workspace = true
rand.workspace = true
regex.workspace = true
serde = { version = "^1" }
serde_json = { version = "1" }
async-openai = { version = "0.20.0", optional = true }
serde_with = { version = "3.8.1" }
tempfile = "3.5.0"
aws-sdk-bedrockruntime = { version = "1.27.0", optional = true }
# For remote feature
reqwest = { version = "0.12.0", default-features = false, features = [
@@ -61,9 +68,8 @@ reqwest = { version = "0.12.0", default-features = false, features = [
"macos-system-configuration",
"stream",
], optional = true }
rand = { version = "0.9", features = ["small_rng"], optional = true }
http = { version = "1", optional = true } # Matching what is in reqwest
uuid = { version = "1.7.0", features = ["v4"], optional = true }
uuid = { version = "1.7.0", features = ["v4"] }
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
polars = { version = ">=0.37,<0.40.0", optional = true }
hf-hub = { version = "0.4.1", optional = true, default-features = false, features = [
@@ -84,7 +90,6 @@ bytemuck_derive.workspace = true
[dev-dependencies]
anyhow = "1"
tempfile = "3.5.0"
rand = { version = "0.9", features = ["small_rng"] }
random_word = { version = "0.4.3", features = ["en"] }
uuid = { version = "1.7.0", features = ["v4"] }
walkdir = "2"
@@ -96,6 +101,7 @@ aws-smithy-runtime = { version = "1.9.1" }
datafusion.workspace = true
http-body = "1" # Matching reqwest
rstest = "0.23.0"
test-log = "0.2"
[features]
@@ -105,7 +111,7 @@ oss = ["lance/oss", "lance-io/oss"]
gcs = ["lance/gcp", "lance-io/gcp"]
azure = ["lance/azure", "lance-io/azure"]
dynamodb = ["lance/dynamodb", "aws"]
remote = ["dep:reqwest", "dep:http", "dep:rand", "dep:uuid"]
remote = ["dep:reqwest", "dep:http"]
fp16kernels = ["lance-linalg/fp16kernels"]
s3-test = []
bedrock = ["dep:aws-sdk-bedrockruntime"]

View File

@@ -7,6 +7,7 @@ pub use arrow_schema;
use datafusion_common::DataFusionError;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use futures::{Stream, StreamExt, TryStreamExt};
use lance_datagen::{BatchCount, BatchGeneratorBuilder, RowCount};
#[cfg(feature = "polars")]
use {crate::polars_arrow_convertors, polars::frame::ArrowChunk, polars::prelude::DataFrame};
@@ -161,6 +162,26 @@ impl IntoArrowStream for datafusion_physical_plan::SendableRecordBatchStream {
}
}
pub trait LanceDbDatagenExt {
fn into_ldb_stream(
self,
batch_size: RowCount,
num_batches: BatchCount,
) -> SendableRecordBatchStream;
}
impl LanceDbDatagenExt for BatchGeneratorBuilder {
fn into_ldb_stream(
self,
batch_size: RowCount,
num_batches: BatchCount,
) -> SendableRecordBatchStream {
let (stream, schema) = self.into_reader_stream(batch_size, num_batches);
let stream = stream.map_err(|err| Error::Arrow { source: err });
Box::pin(SimpleRecordBatchStream::new(stream, schema))
}
}
#[cfg(feature = "polars")]
/// An iterator of record batches formed from a Polars DataFrame.
pub struct PolarsDataFrameRecordBatchReader {

View File

@@ -19,7 +19,7 @@ use crate::database::listing::{
use crate::database::{
CloneTableRequest, CreateNamespaceRequest, CreateTableData, CreateTableMode,
CreateTableRequest, Database, DatabaseOptions, DropNamespaceRequest, ListNamespacesRequest,
OpenTableRequest, TableNamesRequest,
OpenTableRequest, ReadConsistency, TableNamesRequest,
};
use crate::embeddings::{
EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry, MemoryRegistry, WithEmbeddings,
@@ -152,6 +152,7 @@ impl CreateTableBuilder<true> {
let request = self.into_request()?;
Ok(Table::new_with_embedding_registry(
parent.create_table(request).await?,
parent,
embedding_registry,
))
}
@@ -211,9 +212,9 @@ impl CreateTableBuilder<false> {
/// Execute the create table operation
pub async fn execute(self) -> Result<Table> {
Ok(Table::new(
self.parent.clone().create_table(self.request).await?,
))
let parent = self.parent.clone();
let table = parent.create_table(self.request).await?;
Ok(Table::new(table, parent))
}
}
@@ -462,8 +463,10 @@ impl OpenTableBuilder {
/// Open the table
pub async fn execute(self) -> Result<Table> {
let table = self.parent.open_table(self.request).await?;
Ok(Table::new_with_embedding_registry(
self.parent.clone().open_table(self.request).await?,
table,
self.parent,
self.embedding_registry,
))
}
@@ -519,16 +522,15 @@ impl CloneTableBuilder {
/// Execute the clone operation
pub async fn execute(self) -> Result<Table> {
Ok(Table::new(
self.parent.clone().clone_table(self.request).await?,
))
let parent = self.parent.clone();
let table = parent.clone_table(self.request).await?;
Ok(Table::new(table, parent))
}
}
/// A connection to LanceDB
#[derive(Clone)]
pub struct Connection {
uri: String,
internal: Arc<dyn Database>,
embedding_registry: Arc<dyn EmbeddingRegistry>,
}
@@ -540,9 +542,19 @@ impl std::fmt::Display for Connection {
}
impl Connection {
pub fn new(
internal: Arc<dyn Database>,
embedding_registry: Arc<dyn EmbeddingRegistry>,
) -> Self {
Self {
internal,
embedding_registry,
}
}
/// Get the URI of the connection
pub fn uri(&self) -> &str {
self.uri.as_str()
self.internal.uri()
}
/// Get access to the underlying database
@@ -675,6 +687,11 @@ impl Connection {
.await
}
/// Get the read consistency of the connection
pub async fn read_consistency(&self) -> Result<ReadConsistency> {
self.internal.read_consistency().await
}
/// Drop a table in the database.
///
/// # Arguments
@@ -973,7 +990,6 @@ impl ConnectBuilder {
)?);
Ok(Connection {
internal,
uri: self.request.uri,
embedding_registry: self
.embedding_registry
.unwrap_or_else(|| Arc::new(MemoryRegistry::new())),
@@ -996,7 +1012,6 @@ impl ConnectBuilder {
let internal = Arc::new(ListingDatabase::connect_with_options(&self.request).await?);
Ok(Connection {
internal,
uri: self.request.uri,
embedding_registry: self
.embedding_registry
.unwrap_or_else(|| Arc::new(MemoryRegistry::new())),
@@ -1104,7 +1119,6 @@ impl ConnectNamespaceBuilder {
Ok(Connection {
internal,
uri: format!("namespace://{}", self.ns_impl),
embedding_registry: self
.embedding_registry
.unwrap_or_else(|| Arc::new(MemoryRegistry::new())),
@@ -1139,7 +1153,6 @@ mod test_utils {
let internal = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
Self {
internal,
uri: "db://test".to_string(),
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
@@ -1156,7 +1169,6 @@ mod test_utils {
));
Self {
internal,
uri: "db://test".to_string(),
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
@@ -1187,7 +1199,7 @@ mod tests {
#[tokio::test]
async fn test_connect() {
let tc = new_test_connection().await.unwrap();
assert_eq!(tc.connection.uri, tc.uri);
assert_eq!(tc.connection.uri(), tc.uri);
}
#[cfg(not(windows))]
@@ -1208,7 +1220,7 @@ mod tests {
.await
.unwrap();
assert_eq!(db.uri, relative_uri.to_str().unwrap().to_string());
assert_eq!(db.uri(), relative_uri.to_str().unwrap().to_string());
}
#[tokio::test]

View File

@@ -16,6 +16,7 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use arrow_array::RecordBatchReader;
use async_trait::async_trait;
@@ -213,6 +214,20 @@ impl CloneTableRequest {
}
}
/// How long until a change is reflected from one Table instance to another
///
/// Tables are always internally consistent. If a write method is called on
/// a table instance it will be immediately visible in that same table instance.
pub enum ReadConsistency {
/// Changes will not be automatically propagated until the checkout_latest
/// method is called on the target table
Manual,
/// Changes will be propagated automatically within the given duration
Eventual(Duration),
/// Changes are immediately visible in target tables
Strong,
}
/// The `Database` trait defines the interface for database implementations.
///
/// A database is responsible for managing tables and their metadata.
@@ -220,6 +235,10 @@ impl CloneTableRequest {
pub trait Database:
Send + Sync + std::any::Any + std::fmt::Debug + std::fmt::Display + 'static
{
/// Get the uri of the database
fn uri(&self) -> &str;
/// Get the read consistency of the database
async fn read_consistency(&self) -> Result<ReadConsistency>;
/// List immediate child namespace names in the given namespace
async fn list_namespaces(&self, request: ListNamespacesRequest) -> Result<Vec<String>>;
/// Create a new namespace

View File

@@ -17,6 +17,7 @@ use object_store::local::LocalFileSystem;
use snafu::ResultExt;
use crate::connection::ConnectRequest;
use crate::database::ReadConsistency;
use crate::error::{CreateDirSnafu, Error, Result};
use crate::io::object_store::MirroringObjectStoreWrapper;
use crate::table::NativeTable;
@@ -598,6 +599,22 @@ impl Database for ListingDatabase {
Ok(Vec::new())
}
fn uri(&self) -> &str {
&self.uri
}
async fn read_consistency(&self) -> Result<ReadConsistency> {
if let Some(read_consistency_inverval) = self.read_consistency_interval {
if read_consistency_inverval.is_zero() {
Ok(ReadConsistency::Strong)
} else {
Ok(ReadConsistency::Eventual(read_consistency_inverval))
}
} else {
Ok(ReadConsistency::Manual)
}
}
async fn create_namespace(&self, _request: CreateNamespaceRequest) -> Result<()> {
Err(Error::NotSupported {
message: "Namespace operations are not supported for listing database".into(),
@@ -1249,7 +1266,8 @@ mod tests {
)
.unwrap();
let source_table_obj = Table::new(source_table.clone());
let db = Arc::new(db);
let source_table_obj = Table::new(source_table.clone(), db.clone());
source_table_obj
.add(Box::new(arrow_array::RecordBatchIterator::new(
vec![Ok(batch2)],
@@ -1320,7 +1338,8 @@ mod tests {
.unwrap();
// Create a tag for the current version
let source_table_obj = Table::new(source_table.clone());
let db = Arc::new(db);
let source_table_obj = Table::new(source_table.clone(), db.clone());
let mut tags = source_table_obj.tags().await.unwrap();
tags.create("v1.0", source_table.version().await.unwrap())
.await
@@ -1336,7 +1355,7 @@ mod tests {
)
.unwrap();
let source_table_obj = Table::new(source_table.clone());
let source_table_obj = Table::new(source_table.clone(), db.clone());
source_table_obj
.add(Box::new(arrow_array::RecordBatchIterator::new(
vec![Ok(batch2)],
@@ -1432,7 +1451,8 @@ mod tests {
)
.unwrap();
let cloned_table_obj = Table::new(cloned_table.clone());
let db = Arc::new(db);
let cloned_table_obj = Table::new(cloned_table.clone(), db.clone());
cloned_table_obj
.add(Box::new(arrow_array::RecordBatchIterator::new(
vec![Ok(batch_clone)],
@@ -1452,7 +1472,7 @@ mod tests {
)
.unwrap();
let source_table_obj = Table::new(source_table.clone());
let source_table_obj = Table::new(source_table.clone(), db);
source_table_obj
.add(Box::new(arrow_array::RecordBatchIterator::new(
vec![Ok(batch_source)],
@@ -1495,6 +1515,7 @@ mod tests {
.unwrap();
// Add more data to create new versions
let db = Arc::new(db);
for i in 0..3 {
let batch = RecordBatch::try_new(
schema.clone(),
@@ -1502,7 +1523,7 @@ mod tests {
)
.unwrap();
let source_table_obj = Table::new(source_table.clone());
let source_table_obj = Table::new(source_table.clone(), db.clone());
source_table_obj
.add(Box::new(arrow_array::RecordBatchIterator::new(
vec![Ok(batch)],

View File

@@ -16,9 +16,9 @@ use lance_namespace::{
LanceNamespace,
};
use crate::connection::ConnectRequest;
use crate::database::listing::ListingDatabase;
use crate::error::{Error, Result};
use crate::{connection::ConnectRequest, database::ReadConsistency};
use super::{
BaseTable, CloneTableRequest, CreateNamespaceRequest as DbCreateNamespaceRequest,
@@ -36,6 +36,8 @@ pub struct LanceNamespaceDatabase {
read_consistency_interval: Option<std::time::Duration>,
// Optional session for object stores and caching
session: Option<Arc<lance::session::Session>>,
// database URI
uri: String,
}
impl LanceNamespaceDatabase {
@@ -57,6 +59,7 @@ impl LanceNamespaceDatabase {
storage_options,
read_consistency_interval,
session,
uri: format!("namespace://{}", ns_impl),
})
}
@@ -130,6 +133,22 @@ impl std::fmt::Display for LanceNamespaceDatabase {
#[async_trait]
impl Database for LanceNamespaceDatabase {
fn uri(&self) -> &str {
&self.uri
}
async fn read_consistency(&self) -> Result<ReadConsistency> {
if let Some(read_consistency_inverval) = self.read_consistency_interval {
if read_consistency_inverval.is_zero() {
Ok(ReadConsistency::Strong)
} else {
Ok(ReadConsistency::Eventual(read_consistency_inverval))
}
} else {
Ok(ReadConsistency::Manual)
}
}
async fn list_namespaces(&self, request: DbListNamespacesRequest) -> Result<Vec<String>> {
let ns_request = ListNamespacesRequest {
id: if request.namespace.is_empty() {

View File

@@ -0,0 +1,7 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
pub mod permutation;
pub mod shuffle;
pub mod split;
pub mod util;

View File

@@ -0,0 +1,294 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! Contains the [PermutationBuilder] to create a permutation "view" of an existing table.
//!
//! A permutation view can apply a filter, divide the data into splits, and shuffle the data.
//! The permutation table only stores the split ids and row ids. It is not a materialized copy of
//! the underlying data and can be very lightweight.
//!
//! Building a permutation table should be fairly quick and memory efficient, even for billions or
//! trillions of rows.
use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_execution::{disk_manager::DiskManagerBuilder, runtime_env::RuntimeEnvBuilder};
use datafusion_expr::col;
use futures::TryStreamExt;
use lance_datafusion::exec::SessionContextExt;
use crate::{
arrow::{SendableRecordBatchStream, SendableRecordBatchStreamExt, SimpleRecordBatchStream},
dataloader::{
shuffle::{Shuffler, ShufflerConfig},
split::{SplitStrategy, Splitter, SPLIT_ID_COLUMN},
util::{rename_column, TemporaryDirectory},
},
query::{ExecutableQuery, QueryBase},
Connection, Error, Result, Table,
};
/// Configuration for creating a permutation table
#[derive(Debug, Default)]
pub struct PermutationConfig {
/// Splitting configuration
pub split_strategy: SplitStrategy,
/// Shuffle strategy
pub shuffle_strategy: ShuffleStrategy,
/// Optional filter to apply to the base table
pub filter: Option<String>,
/// Directory to use for temporary files
pub temp_dir: TemporaryDirectory,
}
/// Strategy for shuffling the data.
#[derive(Debug, Clone)]
pub enum ShuffleStrategy {
/// The data is randomly shuffled
///
/// A seed can be provided to make the shuffle deterministic.
///
/// If a clump size is provided, then data will be shuffled in small blocks of contiguous rows.
/// This decreases the overall randomization but can improve I/O performance when reading from
/// cloud storage.
///
/// For example, a clump size of 16 will means we will shuffle blocks of 16 contiguous rows. This
/// will mean 16x fewer IOPS but these 16 rows will always be close together and this can influence
/// the performance of the model. Note: shuffling within clumps can still be done at read time but
/// this will only provide a local shuffle and not a global shuffle.
Random {
seed: Option<u64>,
clump_size: Option<u64>,
},
/// The data is not shuffled
///
/// This is useful for debugging and testing.
None,
}
impl Default for ShuffleStrategy {
fn default() -> Self {
Self::None
}
}
/// Builder for creating a permutation table.
///
/// A permutation table is a table that stores split assignments and a shuffled order of rows. This
/// can be used to create a
pub struct PermutationBuilder {
config: PermutationConfig,
base_table: Table,
}
impl PermutationBuilder {
pub fn new(base_table: Table) -> Self {
Self {
config: PermutationConfig::default(),
base_table,
}
}
/// Configures the strategy for assigning rows to splits.
///
/// For example, it is common to create a test/train split of the data. Splits can also be used
/// to limit the number of rows. For example, to only use 10% of the data in a permutation you can
/// create a single split with 10% of the data.
///
/// Splits are _not_ required for parallel processing. A single split can be loaded in parallel across
/// multiple processes and multiple nodes.
///
/// The default is a single split that contains all rows.
pub fn with_split_strategy(mut self, split_strategy: SplitStrategy) -> Self {
self.config.split_strategy = split_strategy;
self
}
/// Configures the strategy for shuffling the data.
///
/// The default is to shuffle the data randomly at row-level granularity (no shard size) and
/// with a random seed.
pub fn with_shuffle_strategy(mut self, shuffle_strategy: ShuffleStrategy) -> Self {
self.config.shuffle_strategy = shuffle_strategy;
self
}
/// Configures a filter to apply to the base table.
///
/// Only rows matching the filter will be included in the permutation.
pub fn with_filter(mut self, filter: String) -> Self {
self.config.filter = Some(filter);
self
}
/// Configures the directory to use for temporary files.
///
/// The default is to use the operating system's default temporary directory.
pub fn with_temp_dir(mut self, temp_dir: TemporaryDirectory) -> Self {
self.config.temp_dir = temp_dir;
self
}
async fn sort_by_split_id(
&self,
data: SendableRecordBatchStream,
) -> Result<SendableRecordBatchStream> {
let ctx = SessionContext::new_with_config_rt(
SessionConfig::default(),
RuntimeEnvBuilder::new()
.with_memory_limit(100 * 1024 * 1024, 1.0)
.with_disk_manager_builder(
DiskManagerBuilder::default()
.with_mode(self.config.temp_dir.to_disk_manager_mode()),
)
.build_arc()
.unwrap(),
);
let df = ctx
.read_one_shot(data.into_df_stream())
.map_err(|e| Error::Other {
message: format!("Failed to setup sort by split id: {}", e),
source: Some(e.into()),
})?;
let df_stream = df
.sort_by(vec![col(SPLIT_ID_COLUMN)])
.map_err(|e| Error::Other {
message: format!("Failed to plan sort by split id: {}", e),
source: Some(e.into()),
})?
.execute_stream()
.await
.map_err(|e| Error::Other {
message: format!("Failed to sort by split id: {}", e),
source: Some(e.into()),
})?;
let schema = df_stream.schema();
let stream = df_stream.map_err(|e| Error::Other {
message: format!("Failed to execute sort by split id: {}", e),
source: Some(e.into()),
});
Ok(Box::pin(SimpleRecordBatchStream { schema, stream }))
}
/// Builds the permutation table and stores it in the given database.
pub async fn build(self, dest_table_name: &str) -> Result<Table> {
// First pass, apply filter and load row ids
let mut rows = self.base_table.query().with_row_id();
if let Some(filter) = &self.config.filter {
rows = rows.only_if(filter);
}
let splitter = Splitter::new(
self.config.temp_dir.clone(),
self.config.split_strategy.clone(),
);
let mut needs_sort = !splitter.orders_by_split_id();
// Might need to load additional columns to calculate splits (e.g. hash columns or calculated
// split id)
rows = splitter.project(rows);
let num_rows = self
.base_table
.count_rows(self.config.filter.clone())
.await? as u64;
// Apply splits
let rows = rows.execute().await?;
let split_data = splitter.apply(rows, num_rows).await?;
// Shuffle data if requested
let shuffled = match self.config.shuffle_strategy {
ShuffleStrategy::None => split_data,
ShuffleStrategy::Random { seed, clump_size } => {
let shuffler = Shuffler::new(ShufflerConfig {
seed,
clump_size,
temp_dir: self.config.temp_dir.clone(),
max_rows_per_file: 10 * 1024 * 1024,
});
shuffler.shuffle(split_data, num_rows).await?
}
};
// We want the final permutation to be sorted by the split id. If we shuffled or if
// the split was not assigned sequentially then we need to sort the data.
needs_sort |= !matches!(self.config.shuffle_strategy, ShuffleStrategy::None);
let sorted = if needs_sort {
self.sort_by_split_id(shuffled).await?
} else {
shuffled
};
// Rename _rowid to row_id
let renamed = rename_column(sorted, "_rowid", "row_id")?;
// Create permutation table
let conn = Connection::new(
self.base_table.database().clone(),
self.base_table.embedding_registry().clone(),
);
conn.create_table_streaming(dest_table_name, renamed)
.execute()
.await
}
}
#[cfg(test)]
mod tests {
use arrow::datatypes::Int32Type;
use lance_datagen::{BatchCount, RowCount};
use crate::{arrow::LanceDbDatagenExt, connect, dataloader::split::SplitSizes};
use super::*;
#[tokio::test]
async fn test_permutation_builder() {
let temp_dir = tempfile::tempdir().unwrap();
let db = connect(temp_dir.path().to_str().unwrap())
.execute()
.await
.unwrap();
let initial_data = lance_datagen::gen_batch()
.col("some_value", lance_datagen::array::step::<Int32Type>())
.into_ldb_stream(RowCount::from(100), BatchCount::from(10));
let data_table = db
.create_table_streaming("mytbl", initial_data)
.execute()
.await
.unwrap();
let permutation_table = PermutationBuilder::new(data_table)
.with_filter("some_value > 57".to_string())
.with_split_strategy(SplitStrategy::Random {
seed: Some(42),
sizes: SplitSizes::Percentages(vec![0.05, 0.30]),
})
.build("permutation")
.await
.unwrap();
// Potentially brittle seed-dependent values below
assert_eq!(permutation_table.count_rows(None).await.unwrap(), 330);
assert_eq!(
permutation_table
.count_rows(Some("split_id = 0".to_string()))
.await
.unwrap(),
47
);
assert_eq!(
permutation_table
.count_rows(Some("split_id = 1".to_string()))
.await
.unwrap(),
283
);
}
}

View File

@@ -0,0 +1,475 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::sync::{Arc, Mutex};
use arrow::compute::concat_batches;
use arrow_array::{RecordBatch, UInt64Array};
use futures::{StreamExt, TryStreamExt};
use lance::io::ObjectStore;
use lance_core::{cache::LanceCache, utils::futures::FinallyStreamExt};
use lance_encoding::decoder::DecoderPlugins;
use lance_file::v2::{
reader::{FileReader, FileReaderOptions},
writer::{FileWriter, FileWriterOptions},
};
use lance_index::scalar::IndexReader;
use lance_io::{
scheduler::{ScanScheduler, SchedulerConfig},
utils::CachedFileSize,
};
use rand::{seq::SliceRandom, Rng, RngCore};
use crate::{
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
dataloader::util::{non_crypto_rng, TemporaryDirectory},
Error, Result,
};
#[derive(Debug, Clone)]
pub struct ShufflerConfig {
/// An optional seed to make the shuffle deterministic
pub seed: Option<u64>,
/// The maximum number of rows to write to a single file
///
/// The shuffler will need to hold at least this many rows in memory. Setting this value
/// extremely large could cause the shuffler to use a lot of memory (depending on row size).
///
/// However, the shuffler will also need to hold total_num_rows / max_rows_per_file file
/// writers in memory. Each of these will consume some amount of data for column write buffers.
/// So setting this value too small could _also_ cause the shuffler to use a lot of memory and
/// open file handles.
pub max_rows_per_file: u64,
/// The temporary directory to use for writing files
pub temp_dir: TemporaryDirectory,
/// The size of the clumps to shuffle within
///
/// If a clump size is provided, then data will be shuffled in small blocks of contiguous rows.
/// This decreases the overall randomization but can improve I/O performance when reading from
/// cloud storage.
pub clump_size: Option<u64>,
}
impl Default for ShufflerConfig {
fn default() -> Self {
Self {
max_rows_per_file: 1024 * 1024,
seed: Option::default(),
temp_dir: TemporaryDirectory::default(),
clump_size: None,
}
}
}
/// A shuffler that can shuffle a stream of record batches
///
/// To do this the stream is consumed and written to temporary files. A new stream is returned
/// which returns the shuffled data from the temporary files.
///
/// If there are fewer than max_rows_per_file rows in the input stream, then the shuffler will not
/// write any files and will instead perform an in-memory shuffle.
///
/// The number of rows in the input stream must be known in advance.
pub struct Shuffler {
config: ShufflerConfig,
id: String,
}
impl Shuffler {
pub fn new(config: ShufflerConfig) -> Self {
let id = uuid::Uuid::new_v4().to_string();
Self { config, id }
}
/// Shuffles a single batch of data in memory
fn shuffle_batch(
batch: &RecordBatch,
rng: &mut dyn RngCore,
clump_size: u64,
) -> Result<RecordBatch> {
let num_clumps = (batch.num_rows() as u64).div_ceil(clump_size);
let mut indices = (0..num_clumps).collect::<Vec<_>>();
indices.shuffle(rng);
let indices = if clump_size == 1 {
UInt64Array::from(indices)
} else {
UInt64Array::from_iter_values(indices.iter().flat_map(|&clump_index| {
if clump_index == num_clumps - 1 {
clump_index * clump_size..batch.num_rows() as u64
} else {
clump_index * clump_size..(clump_index + 1) * clump_size
}
}))
};
Ok(arrow::compute::take_record_batch(batch, &indices)?)
}
async fn in_memory_shuffle(
&self,
data: SendableRecordBatchStream,
mut rng: Box<dyn RngCore + Send>,
) -> Result<SendableRecordBatchStream> {
let schema = data.schema();
let batches = data.try_collect::<Vec<_>>().await?;
let batch = concat_batches(&schema, &batches)?;
let shuffled = Self::shuffle_batch(&batch, &mut rng, self.config.clump_size.unwrap_or(1))?;
log::debug!("Shuffle job {}: in-memory shuffle complete", self.id);
Ok(Box::pin(SimpleRecordBatchStream::new(
futures::stream::once(async move { Ok(shuffled) }),
schema,
)))
}
async fn do_shuffle(
&self,
mut data: SendableRecordBatchStream,
num_rows: u64,
mut rng: Box<dyn RngCore + Send>,
) -> Result<SendableRecordBatchStream> {
let num_files = num_rows.div_ceil(self.config.max_rows_per_file);
let temp_dir = self.config.temp_dir.create_temp_dir()?;
let tmp_dir = temp_dir.path().to_path_buf();
let clump_size = self.config.clump_size.unwrap_or(1);
if clump_size == 0 {
return Err(Error::InvalidInput {
message: "clump size must be greater than 0".to_string(),
});
}
let object_store = ObjectStore::local();
let arrow_schema = data.schema();
let schema = lance::datatypes::Schema::try_from(arrow_schema.as_ref())?;
// Create file writers
let mut file_writers = Vec::with_capacity(num_files as usize);
for file_index in 0..num_files {
let path = tmp_dir.join(format!("shuffle_{}_{file_index}.lance", self.id));
let path =
object_store::path::Path::from_absolute_path(path).map_err(|err| Error::Other {
message: format!("Failed to create temporary file: {}", err),
source: None,
})?;
let object_writer = object_store.create(&path).await?;
let writer =
FileWriter::try_new(object_writer, schema.clone(), FileWriterOptions::default())?;
file_writers.push(writer);
}
let mut num_rows_seen = 0;
// Randomly distribute clumps to files
while let Some(batch) = data.try_next().await? {
num_rows_seen += batch.num_rows() as u64;
let is_last = num_rows_seen == num_rows;
if num_rows_seen > num_rows {
return Err(Error::Runtime {
message: format!("Expected {} rows but saw {} rows", num_rows, num_rows_seen),
});
}
// This is kind of an annoying limitation but if we allow runt clumps from batches then
// clumps will get unaligned and we will mess up the clumps when we do the in-memory
// shuffle step. If this is a problem we can probably figure out a better way to do this.
if !is_last && batch.num_rows() as u64 % clump_size != 0 {
return Err(Error::Runtime {
message: format!(
"Expected batch size ({}) to be divisible by clump size ({})",
batch.num_rows(),
clump_size
),
});
}
let num_clumps = (batch.num_rows() as u64).div_ceil(clump_size);
let mut batch_offsets_for_files =
vec![Vec::<u64>::with_capacity(batch.num_rows()); num_files as usize];
// Partition the batch randomly and write to the appropriate accumulator
for clump_offset in 0..num_clumps {
let clump_start = clump_offset * clump_size;
let num_rows_in_clump = clump_size.min(batch.num_rows() as u64 - clump_start);
let clump_end = clump_start + num_rows_in_clump;
let file_index = rng.random_range(0..num_files);
batch_offsets_for_files[file_index as usize].extend(clump_start..clump_end);
}
for (file_index, batch_offsets) in batch_offsets_for_files.into_iter().enumerate() {
if batch_offsets.is_empty() {
continue;
}
let indices = UInt64Array::from(batch_offsets);
let partition = arrow::compute::take_record_batch(&batch, &indices)?;
file_writers[file_index].write_batch(&partition).await?;
}
}
// Finish writing files
for (file_idx, mut writer) in file_writers.into_iter().enumerate() {
let num_written = writer.finish().await?;
log::debug!(
"Shuffle job {}: wrote {} rows to file {}",
self.id,
num_written,
file_idx
);
}
let scheduler_config = SchedulerConfig::max_bandwidth(&object_store);
let scan_scheduler = ScanScheduler::new(Arc::new(object_store), scheduler_config);
let job_id = self.id.clone();
let rng = Arc::new(Mutex::new(rng));
// Second pass, read each file as a single batch and shuffle
let stream = futures::stream::iter(0..num_files)
.then(move |file_index| {
let scan_scheduler = scan_scheduler.clone();
let rng = rng.clone();
let tmp_dir = tmp_dir.clone();
let job_id = job_id.clone();
async move {
let path = tmp_dir.join(format!("shuffle_{}_{file_index}.lance", job_id));
let path = object_store::path::Path::from_absolute_path(path).unwrap();
let file_scheduler = scan_scheduler
.open_file(&path, &CachedFileSize::unknown())
.await?;
let reader = FileReader::try_open(
file_scheduler,
None,
Arc::<DecoderPlugins>::default(),
&LanceCache::no_cache(),
FileReaderOptions::default(),
)
.await?;
// Need to read the entire file in a single batch for in-memory shuffling
let batch = reader.read_record_batch(0, reader.num_rows()).await?;
let mut rng = rng.lock().unwrap();
Self::shuffle_batch(&batch, &mut rng, clump_size)
}
})
.finally(move || drop(temp_dir))
.boxed();
Ok(Box::pin(SimpleRecordBatchStream::new(stream, arrow_schema)))
}
pub async fn shuffle(
self,
data: SendableRecordBatchStream,
num_rows: u64,
) -> Result<SendableRecordBatchStream> {
log::debug!(
"Shuffle job {}: shuffling {} rows and {} columns",
self.id,
num_rows,
data.schema().fields.len()
);
let rng = non_crypto_rng(&self.config.seed);
if num_rows < self.config.max_rows_per_file {
return self.in_memory_shuffle(data, rng).await;
}
self.do_shuffle(data, num_rows, rng).await
}
}
#[cfg(test)]
mod tests {
use crate::arrow::LanceDbDatagenExt;
use super::*;
use arrow::{array::AsArray, datatypes::Int32Type};
use datafusion::prelude::SessionContext;
use datafusion_expr::col;
use futures::TryStreamExt;
use lance_datagen::{BatchCount, BatchGeneratorBuilder, ByteCount, RowCount, Seed};
use rand::{rngs::SmallRng, SeedableRng};
fn test_gen() -> BatchGeneratorBuilder {
lance_datagen::gen_batch()
.with_seed(Seed::from(42))
.col("id", lance_datagen::array::step::<Int32Type>())
.col(
"name",
lance_datagen::array::rand_utf8(ByteCount::from(10), false),
)
}
fn create_test_batch(size: RowCount) -> RecordBatch {
test_gen().into_batch_rows(size).unwrap()
}
fn create_test_stream(
num_batches: BatchCount,
batch_size: RowCount,
) -> SendableRecordBatchStream {
test_gen().into_ldb_stream(batch_size, num_batches)
}
#[test]
fn test_shuffle_batch_deterministic() {
let batch = create_test_batch(RowCount::from(10));
let mut rng1 = SmallRng::seed_from_u64(42);
let mut rng2 = SmallRng::seed_from_u64(42);
let shuffled1 = Shuffler::shuffle_batch(&batch, &mut rng1, 1).unwrap();
let shuffled2 = Shuffler::shuffle_batch(&batch, &mut rng2, 1).unwrap();
// Same seed should produce same shuffle
assert_eq!(shuffled1, shuffled2);
}
#[test]
fn test_shuffle_with_clumps() {
let batch = create_test_batch(RowCount::from(10));
let mut rng = SmallRng::seed_from_u64(42);
let shuffled = Shuffler::shuffle_batch(&batch, &mut rng, 3).unwrap();
let values = shuffled.column(0).as_primitive::<Int32Type>();
let mut iter = values.into_iter().map(|o| o.unwrap());
let mut frag_seen = false;
let mut clumps_seen = 0;
while let Some(first) = iter.next() {
// 9 is the last value and not a full clump
if first != 9 {
// Otherwise we should have a full clump
let second = iter.next().unwrap();
let third = iter.next().unwrap();
assert_eq!(first + 1, second);
assert_eq!(first + 2, third);
clumps_seen += 1;
} else {
frag_seen = true;
}
}
assert_eq!(clumps_seen, 3);
assert!(frag_seen);
}
async fn sort_batch(batch: RecordBatch) -> RecordBatch {
let ctx = SessionContext::new();
let df = ctx.read_batch(batch).unwrap();
let sorted = df.sort_by(vec![col("id")]).unwrap();
let batches = sorted.collect().await.unwrap();
let schema = batches[0].schema();
concat_batches(&schema, &batches).unwrap()
}
#[tokio::test]
async fn test_shuffle_batch_preserves_data() {
let batch = create_test_batch(RowCount::from(100));
let mut rng = SmallRng::seed_from_u64(42);
let shuffled = Shuffler::shuffle_batch(&batch, &mut rng, 1).unwrap();
assert_ne!(shuffled, batch);
let sorted = sort_batch(shuffled).await;
assert_eq!(sorted, batch);
}
#[test]
fn test_shuffle_batch_empty() {
let batch = create_test_batch(RowCount::from(0));
let mut rng = SmallRng::seed_from_u64(42);
let shuffled = Shuffler::shuffle_batch(&batch, &mut rng, 1).unwrap();
assert_eq!(shuffled.num_rows(), 0);
}
#[tokio::test]
async fn test_in_memory_shuffle() {
let config = ShufflerConfig {
temp_dir: TemporaryDirectory::None,
..Default::default()
};
let shuffler = Shuffler::new(config);
let stream = create_test_stream(BatchCount::from(5), RowCount::from(20));
let result_stream = shuffler.shuffle(stream, 100).await.unwrap();
let result_batches: Vec<RecordBatch> = result_stream.try_collect().await.unwrap();
assert_eq!(result_batches.len(), 1);
let result_batch = result_batches.into_iter().next().unwrap();
let unshuffled_batches = create_test_stream(BatchCount::from(5), RowCount::from(20))
.try_collect::<Vec<_>>()
.await
.unwrap();
let schema = unshuffled_batches[0].schema();
let unshuffled_batch = concat_batches(&schema, &unshuffled_batches).unwrap();
let sorted = sort_batch(result_batch).await;
assert_eq!(unshuffled_batch, sorted);
}
#[tokio::test]
async fn test_external_shuffle() {
let config = ShufflerConfig {
max_rows_per_file: 100,
..Default::default()
};
let shuffler = Shuffler::new(config);
let stream = create_test_stream(BatchCount::from(5), RowCount::from(1000));
let result_stream = shuffler.shuffle(stream, 5000).await.unwrap();
let result_batches: Vec<RecordBatch> = result_stream.try_collect().await.unwrap();
let unshuffled_batches = create_test_stream(BatchCount::from(5), RowCount::from(1000))
.try_collect::<Vec<_>>()
.await
.unwrap();
let schema = unshuffled_batches[0].schema();
let unshuffled_batch = concat_batches(&schema, &unshuffled_batches).unwrap();
assert_eq!(result_batches.len(), 50);
let result_batch = concat_batches(&schema, &result_batches).unwrap();
let sorted = sort_batch(result_batch).await;
assert_eq!(unshuffled_batch, sorted);
}
#[test_log::test(tokio::test)]
async fn test_external_clump_shuffle() {
let config = ShufflerConfig {
max_rows_per_file: 100,
clump_size: Some(30),
..Default::default()
};
let shuffler = Shuffler::new(config);
// Batch size (900) must be multiple of clump size (30)
let stream = create_test_stream(BatchCount::from(5), RowCount::from(900));
let schema = stream.schema();
// Remove 10 rows from the last batch to simulate ending with partial clump
let mut batches = stream.try_collect::<Vec<_>>().await.unwrap();
let last_index = batches.len() - 1;
let sliced_last = batches[last_index].slice(0, 890);
batches[last_index] = sliced_last;
let stream = Box::pin(SimpleRecordBatchStream::new(
futures::stream::iter(batches).map(Ok).boxed(),
schema.clone(),
));
let result_stream = shuffler.shuffle(stream, 4490).await.unwrap();
let result_batches: Vec<RecordBatch> = result_stream.try_collect().await.unwrap();
let result_batch = concat_batches(&schema, &result_batches).unwrap();
let ids = result_batch.column(0).as_primitive::<Int32Type>();
let mut iter = ids.into_iter().map(|o| o.unwrap());
while let Some(first) = iter.next() {
let rows_left_in_clump = if first == 4470 { 19 } else { 29 };
let mut expected_next = first + 1;
for _ in 0..rows_left_in_clump {
let next = iter.next().unwrap();
assert_eq!(next, expected_next);
expected_next += 1;
}
}
}
}

View File

@@ -0,0 +1,804 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::{
iter,
sync::{
atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
Arc,
},
};
use arrow_array::{Array, BooleanArray, RecordBatch, UInt64Array};
use arrow_schema::{DataType, Field, Schema};
use datafusion_common::hash_utils::create_hashes;
use futures::{StreamExt, TryStreamExt};
use lance::arrow::SchemaExt;
use crate::{
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
dataloader::{
shuffle::{Shuffler, ShufflerConfig},
util::TemporaryDirectory,
},
query::{Query, QueryBase, Select},
Error, Result,
};
pub const SPLIT_ID_COLUMN: &str = "split_id";
/// Strategy for assigning rows to splits
#[derive(Debug, Clone)]
pub enum SplitStrategy {
/// All rows will have split id 0
NoSplit,
/// Rows will be randomly assigned to splits
///
/// A seed can be provided to make the assignment deterministic.
Random {
seed: Option<u64>,
sizes: SplitSizes,
},
/// Rows will be assigned to splits based on the values in the specified columns.
///
/// This will ensure rows are always assigned to the same split if the given columns do not change.
///
/// The `split_weights` are used to determine the approximate number of rows in each split. This
/// controls how we divide up the u64 hash space. However, it does not guarantee any particular division
/// of rows. For example, if all rows have identical hash values then all rows will be assigned to the same split
/// regardless of the weights.
///
/// The `discard_weight` controls what percentage of rows should be throw away. For example, if you want your
/// first split to have ~5% of your rows and the second split to have ~10% of your rows then you would set
/// split_weights to [1, 2] and discard weight to 17 (or you could set split_weights to [5, 10] and discard_weight
/// to 85). If you set discard_weight to 0 then all rows will be assigned to a split.
Hash {
columns: Vec<String>,
split_weights: Vec<u64>,
discard_weight: u64,
},
/// Rows will be assigned to splits sequentially.
///
/// The first N1 rows are assigned to split 1, the next N2 rows are assigned to split 2, etc.
///
/// This is mainly useful for debugging and testing.
Sequential { sizes: SplitSizes },
/// Rows will be assigned to splits based on a calculation of one or more columns.
///
/// This is useful when the splits already exist in the base table.
///
/// The provided `calculation` should be an SQL statement that returns an integer value between
/// 0 and the number of splits - 1 (the number of splits is defined by the `splits` configuration).
///
/// If this strategy is used then the counts/percentages in the SplitSizes are ignored.
Calculated { calculation: String },
}
// The default is not to split the data
//
// All data will be assigned to a single split.
impl Default for SplitStrategy {
fn default() -> Self {
Self::NoSplit
}
}
impl SplitStrategy {
pub fn validate(&self, num_rows: u64) -> Result<()> {
match self {
Self::NoSplit => Ok(()),
Self::Random { sizes, .. } => sizes.validate(num_rows),
Self::Hash {
split_weights,
columns,
..
} => {
if columns.is_empty() {
return Err(Error::InvalidInput {
message: "Hash strategy requires at least one column".to_string(),
});
}
if split_weights.is_empty() {
return Err(Error::InvalidInput {
message: "Hash strategy requires at least one split weight".to_string(),
});
}
if split_weights.contains(&0) {
return Err(Error::InvalidInput {
message: "Split weights must be greater than 0".to_string(),
});
}
Ok(())
}
Self::Sequential { sizes } => sizes.validate(num_rows),
Self::Calculated { .. } => Ok(()),
}
}
}
pub struct Splitter {
temp_dir: TemporaryDirectory,
strategy: SplitStrategy,
}
impl Splitter {
pub fn new(temp_dir: TemporaryDirectory, strategy: SplitStrategy) -> Self {
Self { temp_dir, strategy }
}
fn sequential_split_id(
num_rows: u64,
split_sizes: &[u64],
split_index: &AtomicUsize,
counter_in_split: &AtomicU64,
exhausted: &AtomicBool,
) -> UInt64Array {
let mut split_ids = Vec::<u64>::with_capacity(num_rows as usize);
while split_ids.len() < num_rows as usize {
let split_id = split_index.load(Ordering::Relaxed);
let counter = counter_in_split.load(Ordering::Relaxed);
let split_size = split_sizes[split_id];
let remaining_in_split = split_size - counter;
let remaining_in_batch = num_rows - split_ids.len() as u64;
let mut done = false;
let rows_to_add = if remaining_in_batch < remaining_in_split {
counter_in_split.fetch_add(remaining_in_batch, Ordering::Relaxed);
remaining_in_batch
} else {
split_index.fetch_add(1, Ordering::Relaxed);
counter_in_split.store(0, Ordering::Relaxed);
if split_id == split_sizes.len() - 1 {
exhausted.store(true, Ordering::Relaxed);
done = true;
}
remaining_in_split
};
split_ids.extend(iter::repeat(split_id as u64).take(rows_to_add as usize));
if done {
// Quit early if we've run out of splits
break;
}
}
UInt64Array::from(split_ids)
}
async fn apply_sequential(
&self,
source: SendableRecordBatchStream,
num_rows: u64,
sizes: &SplitSizes,
) -> Result<SendableRecordBatchStream> {
let split_sizes = sizes.to_counts(num_rows);
let split_index = AtomicUsize::new(0);
let counter_in_split = AtomicU64::new(0);
let exhausted = AtomicBool::new(false);
let schema = source.schema();
let new_schema = Arc::new(schema.try_with_column(Field::new(
SPLIT_ID_COLUMN,
DataType::UInt64,
false,
))?);
let new_schema_clone = new_schema.clone();
let stream = source.filter_map(move |batch| {
let batch = match batch {
Ok(batch) => batch,
Err(e) => {
return std::future::ready(Some(Err(e)));
}
};
if exhausted.load(Ordering::Relaxed) {
return std::future::ready(None);
}
let split_ids = Self::sequential_split_id(
batch.num_rows() as u64,
&split_sizes,
&split_index,
&counter_in_split,
&exhausted,
);
let mut arrays = batch.columns().to_vec();
// This can happen if we exhaust all splits in the middle of a batch
if split_ids.len() < batch.num_rows() {
arrays = arrays
.iter()
.map(|arr| arr.slice(0, split_ids.len()))
.collect();
}
arrays.push(Arc::new(split_ids));
std::future::ready(Some(Ok(
RecordBatch::try_new(new_schema.clone(), arrays).unwrap()
)))
});
Ok(Box::pin(SimpleRecordBatchStream::new(
stream,
new_schema_clone,
)))
}
fn hash_split_id(batch: &RecordBatch, thresholds: &[u64], total_weight: u64) -> UInt64Array {
let arrays = batch
.columns()
.iter()
// Don't hash the last column which should always be the row id
.take(batch.columns().len() - 1)
.cloned()
.collect::<Vec<_>>();
let mut hashes = vec![0; batch.num_rows()];
let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0);
create_hashes(&arrays, &random_state, &mut hashes).unwrap();
// As an example, let's assume the weights are 1, 2. Our total weight is 3.
//
// Our thresholds are [1, 3]
// Our modulo output will be 0, 1, or 2.
//
// thresholds.binary_search(0) => Err(0) => 0
// thresholds.binary_search(1) => Ok(0) => 1
// thresholds.binary_search(2) => Err(1) => 1
let split_ids = hashes
.iter()
.map(|h| {
let h = h % total_weight;
let split_id = match thresholds.binary_search(&h) {
Ok(i) => (i + 1) as u64,
Err(i) => i as u64,
};
if split_id == thresholds.len() as u64 {
// If we're at the last threshold then we discard the row (indicated by setting
// the split_id to null)
None
} else {
Some(split_id)
}
})
.collect::<Vec<_>>();
UInt64Array::from(split_ids)
}
async fn apply_hash(
&self,
source: SendableRecordBatchStream,
weights: &[u64],
discard_weight: u64,
) -> Result<SendableRecordBatchStream> {
let row_id_index = source.schema().fields.len() - 1;
let new_schema = Arc::new(Schema::new(vec![
source.schema().field(row_id_index).clone(),
Field::new(SPLIT_ID_COLUMN, DataType::UInt64, false),
]));
let total_weight = weights.iter().sum::<u64>() + discard_weight;
// Thresholds are the cumulative sum of the weights
let mut offset = 0;
let thresholds = weights
.iter()
.map(|w| {
let value = offset + w;
offset = value;
value
})
.collect::<Vec<_>>();
let new_schema_clone = new_schema.clone();
let stream = source.map_ok(move |batch| {
let split_ids = Self::hash_split_id(&batch, &thresholds, total_weight);
if split_ids.null_count() > 0 {
let is_valid = split_ids.nulls().unwrap().inner();
let is_valid_mask = BooleanArray::new(is_valid.clone(), None);
let split_ids = arrow::compute::filter(&split_ids, &is_valid_mask).unwrap();
let row_ids = batch.column(row_id_index);
let row_ids = arrow::compute::filter(row_ids.as_ref(), &is_valid_mask).unwrap();
RecordBatch::try_new(new_schema.clone(), vec![row_ids, split_ids]).unwrap()
} else {
RecordBatch::try_new(
new_schema.clone(),
vec![batch.column(row_id_index).clone(), Arc::new(split_ids)],
)
.unwrap()
}
});
Ok(Box::pin(SimpleRecordBatchStream::new(
stream,
new_schema_clone,
)))
}
pub async fn apply(
&self,
source: SendableRecordBatchStream,
num_rows: u64,
) -> Result<SendableRecordBatchStream> {
self.strategy.validate(num_rows)?;
match &self.strategy {
// For consistency, even if no-split, we still give a split id column of all 0s
SplitStrategy::NoSplit => {
self.apply_sequential(source, num_rows, &SplitSizes::Counts(vec![num_rows]))
.await
}
SplitStrategy::Random { seed, sizes } => {
let shuffler = Shuffler::new(ShufflerConfig {
seed: *seed,
// In this case we are only shuffling row ids so we can use a large max_rows_per_file
max_rows_per_file: 10 * 1024 * 1024,
temp_dir: self.temp_dir.clone(),
clump_size: None,
});
let shuffled = shuffler.shuffle(source, num_rows).await?;
self.apply_sequential(shuffled, num_rows, sizes).await
}
SplitStrategy::Sequential { sizes } => {
self.apply_sequential(source, num_rows, sizes).await
}
// Nothing to do, split is calculated in projection
SplitStrategy::Calculated { .. } => Ok(source),
SplitStrategy::Hash {
split_weights,
discard_weight,
..
} => {
self.apply_hash(source, split_weights, *discard_weight)
.await
}
}
}
pub fn project(&self, query: Query) -> Query {
match &self.strategy {
SplitStrategy::Calculated { calculation } => query.select(Select::Dynamic(vec![(
SPLIT_ID_COLUMN.to_string(),
calculation.clone(),
)])),
SplitStrategy::Hash { columns, .. } => query.select(Select::Columns(columns.clone())),
_ => query,
}
}
pub fn orders_by_split_id(&self) -> bool {
match &self.strategy {
SplitStrategy::Hash { .. } | SplitStrategy::Calculated { .. } => true,
SplitStrategy::NoSplit
| SplitStrategy::Sequential { .. }
// It may be strange but for random we shuffle and then assign splits so the result is
// sorted by split id
| SplitStrategy::Random { .. } => false,
}
}
}
/// Split configuration - either percentages or absolute counts
///
/// If the percentages do not sum to 1.0 (or the counts do not sum to the total number of rows)
/// the remaining rows will not be included in the permutation.
///
/// The default implementation assigns all rows to a single split.
#[derive(Debug, Clone)]
pub enum SplitSizes {
/// Percentage splits (must sum to <= 1.0)
///
/// The number of rows in each split is the nearest integer to the percentage multiplied by
/// the total number of rows.
Percentages(Vec<f64>),
/// Absolute row counts per split
///
/// If the dataset doesn't contain enough matching rows to fill all splits then an error
/// will be raised.
Counts(Vec<u64>),
/// Divides data into a fixed number of splits
///
/// Will divide the data evenly.
///
/// If the number of rows is not divisible by the number of splits then the rows per split
/// is rounded down.
Fixed(u64),
}
impl Default for SplitSizes {
fn default() -> Self {
Self::Percentages(vec![1.0])
}
}
impl SplitSizes {
pub fn validate(&self, num_rows: u64) -> Result<()> {
match self {
Self::Percentages(percentages) => {
for percentage in percentages {
if *percentage < 0.0 || *percentage > 1.0 {
return Err(Error::InvalidInput {
message: "Split percentages must be between 0.0 and 1.0".to_string(),
});
}
if percentage * (num_rows as f64) < 1.0 {
return Err(Error::InvalidInput {
message: format!(
"One of the splits has {}% of {} rows which rounds to 0 rows",
percentage * 100.0,
num_rows
),
});
}
}
if percentages.iter().sum::<f64>() > 1.0 {
return Err(Error::InvalidInput {
message: "Split percentages must sum to 1.0 or less".to_string(),
});
}
}
Self::Counts(counts) => {
if counts.iter().sum::<u64>() > num_rows {
return Err(Error::InvalidInput {
message: format!(
"Split counts specified {} rows but only {} are available",
counts.iter().sum::<u64>(),
num_rows
),
});
}
if counts.contains(&0) {
return Err(Error::InvalidInput {
message: "Split counts must be greater than 0".to_string(),
});
}
}
Self::Fixed(num_splits) => {
if *num_splits > num_rows {
return Err(Error::InvalidInput {
message: format!(
"Split fixed config specified {} splits but only {} rows are available. Must have at least 1 row per split.",
*num_splits, num_rows
),
});
}
if (num_rows / num_splits) == 0 {
return Err(Error::InvalidInput {
message: format!(
"Split fixed config specified {} splits but only {} rows are available. Must have at least 1 row per split.",
*num_splits, num_rows
),
});
}
}
}
Ok(())
}
pub fn to_counts(&self, num_rows: u64) -> Vec<u64> {
let sizes = match self {
Self::Percentages(percentages) => {
let mut percentage_sum = 0.0_f64;
let mut counts = percentages
.iter()
.map(|p| {
let count = (p * (num_rows as f64)).round() as u64;
percentage_sum += p;
count
})
.collect::<Vec<_>>();
let sum = counts.iter().sum::<u64>();
let is_basically_one =
(num_rows as f64 - percentage_sum * num_rows as f64).abs() < 0.5;
// If the sum of percentages is close to 1.0 then rounding errors can add up
// to more or less than num_rows
//
// Drop items from buckets until we have the correct number of rows
let mut excess = sum as i64 - num_rows as i64;
let mut drop_idx = 0;
while excess > 0 {
if counts[drop_idx] > 0 {
counts[drop_idx] -= 1;
excess -= 1;
}
drop_idx += 1;
if drop_idx == counts.len() {
drop_idx = 0;
}
}
// On the other hand, if the percentages sum to ~1.0 then the we also shouldn't _lose_
// rows due to rounding errors
let mut add_idx = 0;
while is_basically_one && excess < 0 {
counts[add_idx] += 1;
add_idx += 1;
excess += 1;
if add_idx == counts.len() {
add_idx = 0;
}
}
counts
}
Self::Counts(counts) => counts.clone(),
Self::Fixed(num_splits) => {
let rows_per_split = num_rows / *num_splits;
vec![rows_per_split; *num_splits as usize]
}
};
assert!(sizes.iter().sum::<u64>() <= num_rows);
sizes
}
}
#[cfg(test)]
mod tests {
use crate::arrow::LanceDbDatagenExt;
use super::*;
use arrow::{
array::AsArray,
compute::concat_batches,
datatypes::{Int32Type, UInt64Type},
};
use arrow_array::Int32Array;
use futures::TryStreamExt;
use lance_datagen::{BatchCount, ByteCount, RowCount, Seed};
use std::sync::Arc;
const ID_COLUMN: &str = "id";
#[test]
fn test_split_sizes_percentages_validation() {
// Valid percentages
let sizes = SplitSizes::Percentages(vec![0.7, 0.3]);
assert!(sizes.validate(100).is_ok());
// Sum > 1.0
let sizes = SplitSizes::Percentages(vec![0.7, 0.4]);
assert!(sizes.validate(100).is_err());
// Negative percentage
let sizes = SplitSizes::Percentages(vec![-0.1, 0.5]);
assert!(sizes.validate(100).is_err());
// Percentage > 1.0
let sizes = SplitSizes::Percentages(vec![1.5]);
assert!(sizes.validate(100).is_err());
// Percentage rounds to 0 rows
let sizes = SplitSizes::Percentages(vec![0.001]);
assert!(sizes.validate(100).is_err());
}
#[test]
fn test_split_sizes_counts_validation() {
// Valid counts
let sizes = SplitSizes::Counts(vec![30, 70]);
assert!(sizes.validate(100).is_ok());
// Sum > num_rows
let sizes = SplitSizes::Counts(vec![60, 50]);
assert!(sizes.validate(100).is_err());
// Counts are 0
let sizes = SplitSizes::Counts(vec![0, 100]);
assert!(sizes.validate(100).is_err());
}
#[test]
fn test_split_sizes_fixed_validation() {
// Valid fixed splits
let sizes = SplitSizes::Fixed(5);
assert!(sizes.validate(100).is_ok());
// More splits than rows
let sizes = SplitSizes::Fixed(150);
assert!(sizes.validate(100).is_err());
}
#[test]
fn test_split_sizes_to_sizes_percentages() {
let sizes = SplitSizes::Percentages(vec![0.3, 0.7]);
let result = sizes.to_counts(100);
assert_eq!(result, vec![30, 70]);
// Test rounding
let sizes = SplitSizes::Percentages(vec![0.3, 0.41]);
let result = sizes.to_counts(70);
assert_eq!(result, vec![21, 29]);
}
#[test]
fn test_split_sizes_to_sizes_fixed() {
let sizes = SplitSizes::Fixed(4);
let result = sizes.to_counts(100);
assert_eq!(result, vec![25, 25, 25, 25]);
// Test with remainder
let sizes = SplitSizes::Fixed(3);
let result = sizes.to_counts(10);
assert_eq!(result, vec![3, 3, 3]);
}
fn test_data() -> SendableRecordBatchStream {
lance_datagen::gen_batch()
.with_seed(Seed::from(42))
.col(ID_COLUMN, lance_datagen::array::step::<Int32Type>())
.into_ldb_stream(RowCount::from(10), BatchCount::from(5))
}
async fn verify_splitter(
splitter: Splitter,
data: SendableRecordBatchStream,
num_rows: u64,
expected_split_sizes: &[u64],
row_ids_in_order: bool,
) {
let split_batches = splitter
.apply(data, num_rows)
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let schema = split_batches[0].schema();
let split_batch = concat_batches(&schema, &split_batches).unwrap();
let total_split_sizes = expected_split_sizes.iter().sum::<u64>();
assert_eq!(split_batch.num_rows(), total_split_sizes as usize);
let mut expected = Vec::with_capacity(total_split_sizes as usize);
for (i, size) in expected_split_sizes.iter().enumerate() {
expected.extend(iter::repeat(i as u64).take(*size as usize));
}
let expected = Arc::new(UInt64Array::from(expected)) as Arc<dyn Array>;
assert_eq!(&expected, split_batch.column(1));
let expected_row_ids =
Arc::new(Int32Array::from_iter_values(0..total_split_sizes as i32)) as Arc<dyn Array>;
if row_ids_in_order {
assert_eq!(&expected_row_ids, split_batch.column(0));
} else {
assert_ne!(&expected_row_ids, split_batch.column(0));
}
}
#[tokio::test]
async fn test_fixed_sequential_split() {
let splitter = Splitter::new(
// Sequential splitting doesn't need a temp dir
TemporaryDirectory::None,
SplitStrategy::Sequential {
sizes: SplitSizes::Fixed(3),
},
);
verify_splitter(splitter, test_data(), 50, &[16, 16, 16], true).await;
}
#[tokio::test]
async fn test_fixed_random_split() {
let splitter = Splitter::new(
TemporaryDirectory::None,
SplitStrategy::Random {
seed: Some(42),
sizes: SplitSizes::Fixed(3),
},
);
verify_splitter(splitter, test_data(), 50, &[16, 16, 16], false).await;
}
#[tokio::test]
async fn test_counts_sequential_split() {
let splitter = Splitter::new(
// Sequential splitting doesn't need a temp dir
TemporaryDirectory::None,
SplitStrategy::Sequential {
sizes: SplitSizes::Counts(vec![5, 15, 10]),
},
);
verify_splitter(splitter, test_data(), 50, &[5, 15, 10], true).await;
}
#[tokio::test]
async fn test_counts_random_split() {
let splitter = Splitter::new(
TemporaryDirectory::None,
SplitStrategy::Random {
seed: Some(42),
sizes: SplitSizes::Counts(vec![5, 15, 10]),
},
);
verify_splitter(splitter, test_data(), 50, &[5, 15, 10], false).await;
}
#[tokio::test]
async fn test_percentages_sequential_split() {
let splitter = Splitter::new(
// Sequential splitting doesn't need a temp dir
TemporaryDirectory::None,
SplitStrategy::Sequential {
sizes: SplitSizes::Percentages(vec![0.217, 0.168, 0.17]),
},
);
verify_splitter(splitter, test_data(), 50, &[11, 8, 9], true).await;
}
#[tokio::test]
async fn test_percentages_random_split() {
let splitter = Splitter::new(
TemporaryDirectory::None,
SplitStrategy::Random {
seed: Some(42),
sizes: SplitSizes::Percentages(vec![0.217, 0.168, 0.17]),
},
);
verify_splitter(splitter, test_data(), 50, &[11, 8, 9], false).await;
}
#[tokio::test]
async fn test_hash_split() {
let data = lance_datagen::gen_batch()
.with_seed(Seed::from(42))
.col(
"hash1",
lance_datagen::array::rand_utf8(ByteCount::from(10), false),
)
.col("hash2", lance_datagen::array::step::<Int32Type>())
.col(ID_COLUMN, lance_datagen::array::step::<Int32Type>())
.into_ldb_stream(RowCount::from(10), BatchCount::from(5));
let splitter = Splitter::new(
TemporaryDirectory::None,
SplitStrategy::Hash {
columns: vec!["hash1".to_string(), "hash2".to_string()],
split_weights: vec![1, 2],
discard_weight: 1,
},
);
let split_batches = splitter
.apply(data, 10)
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let schema = split_batches[0].schema();
let split_batch = concat_batches(&schema, &split_batches).unwrap();
// These assertions are all based on fixed seed in data generation but they match
// up roughly to what we expect (25% discarded, 25% in split 0, 50% in split 1)
// 14 rows (28%) are discarded because discard_weight is 1
assert_eq!(split_batch.num_rows(), 36);
assert_eq!(split_batch.num_columns(), 2);
let split_ids = split_batch.column(1).as_primitive::<UInt64Type>().values();
let num_in_split_0 = split_ids.iter().filter(|v| **v == 0).count();
let num_in_split_1 = split_ids.iter().filter(|v| **v == 1).count();
assert_eq!(num_in_split_0, 11); // 22%
assert_eq!(num_in_split_1, 25); // 50%
}
}

View File

@@ -0,0 +1,98 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::{path::PathBuf, sync::Arc};
use arrow_array::RecordBatch;
use arrow_schema::{Fields, Schema};
use datafusion_execution::disk_manager::DiskManagerMode;
use futures::TryStreamExt;
use rand::{rngs::SmallRng, RngCore, SeedableRng};
use tempfile::TempDir;
use crate::{
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
Error, Result,
};
/// Directory to use for temporary files
#[derive(Debug, Clone, Default)]
pub enum TemporaryDirectory {
/// Use the operating system's default temporary directory (e.g. /tmp)
#[default]
OsDefault,
/// Use the specified directory (must be an absolute path)
Specific(PathBuf),
/// If spilling is required, then error out
None,
}
impl TemporaryDirectory {
pub fn create_temp_dir(&self) -> Result<TempDir> {
match self {
Self::OsDefault => tempfile::tempdir(),
Self::Specific(path) => tempfile::Builder::default().tempdir_in(path),
Self::None => {
return Err(Error::Runtime {
message: "No temporary directory was supplied and this operation requires spilling to disk".to_string(),
});
}
}
.map_err(|err| Error::Other {
message: "Failed to create temporary directory".to_string(),
source: Some(err.into()),
})
}
pub fn to_disk_manager_mode(&self) -> DiskManagerMode {
match self {
Self::OsDefault => DiskManagerMode::OsTmpDirectory,
Self::Specific(path) => DiskManagerMode::Directories(vec![path.clone()]),
Self::None => DiskManagerMode::Disabled,
}
}
}
pub fn non_crypto_rng(seed: &Option<u64>) -> Box<dyn RngCore + Send> {
Box::new(
seed.as_ref()
.map(|seed| SmallRng::seed_from_u64(*seed))
.unwrap_or_else(SmallRng::from_os_rng),
)
}
pub fn rename_column(
stream: SendableRecordBatchStream,
old_name: &str,
new_name: &str,
) -> Result<SendableRecordBatchStream> {
let schema = stream.schema();
let field_index = schema.index_of(old_name)?;
let new_fields = schema
.fields
.iter()
.cloned()
.enumerate()
.map(|(idx, f)| {
if idx == field_index {
Arc::new(f.as_ref().clone().with_name(new_name))
} else {
f
}
})
.collect::<Fields>();
let new_schema = Arc::new(Schema::new(new_fields).with_metadata(schema.metadata().clone()));
let new_schema_clone = new_schema.clone();
let renamed_stream = stream.and_then(move |batch| {
let renamed_batch =
RecordBatch::try_new(new_schema.clone(), batch.columns().to_vec()).map_err(Error::from);
std::future::ready(renamed_batch)
});
Ok(Box::pin(SimpleRecordBatchStream::new(
renamed_stream,
new_schema_clone,
)))
}

View File

@@ -8,6 +8,7 @@ use std::sync::Arc;
use std::time::Duration;
use vector::IvfFlatIndexBuilder;
use crate::index::vector::IvfRqIndexBuilder;
use crate::{table::BaseTable, DistanceType, Error, Result};
use self::{
@@ -53,6 +54,9 @@ pub enum Index {
/// IVF index with Product Quantization
IvfPq(IvfPqIndexBuilder),
/// IVF index with RabitQ Quantization
IvfRq(IvfRqIndexBuilder),
/// IVF-HNSW index with Product Quantization
/// It is a variant of the HNSW algorithm that uses product quantization to compress the vectors.
IvfHnswPq(IvfHnswPqIndexBuilder),
@@ -275,6 +279,8 @@ pub enum IndexType {
IvfFlat,
#[serde(alias = "IVF_PQ")]
IvfPq,
#[serde(alias = "IVF_RQ")]
IvfRq,
#[serde(alias = "IVF_HNSW_PQ")]
IvfHnswPq,
#[serde(alias = "IVF_HNSW_SQ")]
@@ -296,6 +302,7 @@ impl std::fmt::Display for IndexType {
match self {
Self::IvfFlat => write!(f, "IVF_FLAT"),
Self::IvfPq => write!(f, "IVF_PQ"),
Self::IvfRq => write!(f, "IVF_RQ"),
Self::IvfHnswPq => write!(f, "IVF_HNSW_PQ"),
Self::IvfHnswSq => write!(f, "IVF_HNSW_SQ"),
Self::BTree => write!(f, "BTREE"),
@@ -317,6 +324,7 @@ impl std::str::FromStr for IndexType {
"FTS" | "INVERTED" => Ok(Self::FTS),
"IVF_FLAT" => Ok(Self::IvfFlat),
"IVF_PQ" => Ok(Self::IvfPq),
"IVF_RQ" => Ok(Self::IvfRq),
"IVF_HNSW_PQ" => Ok(Self::IvfHnswPq),
"IVF_HNSW_SQ" => Ok(Self::IvfHnswSq),
_ => Err(Error::InvalidInput {

View File

@@ -291,6 +291,52 @@ pub(crate) fn suggested_num_sub_vectors(dim: u32) -> u32 {
}
}
/// Builder for an IVF RQ index.
///
/// This index stores a compressed (quantized) copy of every vector. Each dimension
/// is quantized into a small number of bits.
/// The parameters `num_bits` control this process, providing a tradeoff
/// between index size (and thus search speed) and index accuracy.
///
/// The partitioning process is called IVF and the `num_partitions` parameter controls how
/// many groups to create.
///
/// Note that training an IVF RQ index on a large dataset is a slow operation and
/// currently is also a memory intensive operation.
#[derive(Debug, Clone)]
pub struct IvfRqIndexBuilder {
// IVF
pub(crate) distance_type: DistanceType,
pub(crate) num_partitions: Option<u32>,
pub(crate) num_bits: Option<u32>,
pub(crate) sample_rate: u32,
pub(crate) max_iterations: u32,
pub(crate) target_partition_size: Option<u32>,
}
impl Default for IvfRqIndexBuilder {
fn default() -> Self {
Self {
distance_type: DistanceType::L2,
num_partitions: None,
num_bits: None,
sample_rate: 256,
max_iterations: 50,
target_partition_size: None,
}
}
}
impl IvfRqIndexBuilder {
impl_distance_type_setter!();
impl_ivf_params_setter!();
pub fn num_bits(mut self, num_bits: u32) -> Self {
self.num_bits = Some(num_bits);
self
}
}
/// Builder for an IVF HNSW PQ index.
///
/// This index is a combination of IVF and HNSW.

View File

@@ -194,6 +194,7 @@ pub mod arrow;
pub mod connection;
pub mod data;
pub mod database;
pub mod dataloader;
pub mod embeddings;
pub mod error;
pub mod index;

View File

@@ -16,7 +16,7 @@ use tokio::task::spawn_blocking;
use crate::database::{
CloneTableRequest, CreateNamespaceRequest, CreateTableData, CreateTableMode,
CreateTableRequest, Database, DatabaseOptions, DropNamespaceRequest, ListNamespacesRequest,
OpenTableRequest, TableNamesRequest,
OpenTableRequest, ReadConsistency, TableNamesRequest,
};
use crate::error::Result;
use crate::table::BaseTable;
@@ -189,6 +189,7 @@ struct ListTablesResponse {
pub struct RemoteDatabase<S: HttpSend = Sender> {
client: RestfulLanceDbClient<S>,
table_cache: Cache<String, Arc<RemoteTable<S>>>,
uri: String,
}
impl RemoteDatabase {
@@ -217,6 +218,7 @@ impl RemoteDatabase {
Ok(Self {
client,
table_cache,
uri: uri.to_owned(),
})
}
}
@@ -238,6 +240,7 @@ mod test_utils {
Self {
client,
table_cache: Cache::new(0),
uri: "http://localhost".to_string(),
}
}
@@ -250,6 +253,7 @@ mod test_utils {
Self {
client,
table_cache: Cache::new(0),
uri: "http://localhost".to_string(),
}
}
}
@@ -315,6 +319,17 @@ fn build_cache_key(name: &str, namespace: &[String]) -> String {
#[async_trait]
impl<S: HttpSend> Database for RemoteDatabase<S> {
fn uri(&self) -> &str {
&self.uri
}
async fn read_consistency(&self) -> Result<ReadConsistency> {
Err(Error::NotSupported {
message: "Getting the read consistency of a remote database is not yet supported"
.to_string(),
})
}
async fn table_names(&self, request: TableNamesRequest) -> Result<Vec<String>> {
let mut req = if !request.namespace.is_empty() {
let namespace_id =

View File

@@ -50,6 +50,7 @@ use std::sync::Arc;
use crate::arrow::IntoArrow;
use crate::connection::NoData;
use crate::database::Database;
use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry};
use crate::error::{Error, Result};
use crate::index::vector::{suggested_num_partitions_for_hnsw, VectorIndex};
@@ -611,9 +612,10 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
/// A Table is a collection of strong typed Rows.
///
/// The type of the each row is defined in Apache Arrow [Schema].
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct Table {
inner: Arc<dyn BaseTable>,
database: Arc<dyn Database>,
embedding_registry: Arc<dyn EmbeddingRegistry>,
}
@@ -631,11 +633,13 @@ mod test_utils {
{
let inner = Arc::new(crate::remote::table::RemoteTable::new_mock(
name.into(),
handler,
handler.clone(),
None,
));
let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
Self {
inner,
database,
// Registry is unused.
embedding_registry: Arc::new(MemoryRegistry::new()),
}
@@ -651,11 +655,13 @@ mod test_utils {
{
let inner = Arc::new(crate::remote::table::RemoteTable::new_mock(
name.into(),
handler,
handler.clone(),
Some(version),
));
let database = Arc::new(crate::remote::db::RemoteDatabase::new_mock(handler));
Self {
inner,
database,
// Registry is unused.
embedding_registry: Arc::new(MemoryRegistry::new()),
}
@@ -670,9 +676,10 @@ impl std::fmt::Display for Table {
}
impl Table {
pub fn new(inner: Arc<dyn BaseTable>) -> Self {
pub fn new(inner: Arc<dyn BaseTable>, database: Arc<dyn Database>) -> Self {
Self {
inner,
database,
embedding_registry: Arc::new(MemoryRegistry::new()),
}
}
@@ -681,12 +688,22 @@ impl Table {
&self.inner
}
pub fn database(&self) -> &Arc<dyn Database> {
&self.database
}
pub fn embedding_registry(&self) -> &Arc<dyn EmbeddingRegistry> {
&self.embedding_registry
}
pub(crate) fn new_with_embedding_registry(
inner: Arc<dyn BaseTable>,
database: Arc<dyn Database>,
embedding_registry: Arc<dyn EmbeddingRegistry>,
) -> Self {
Self {
inner,
database,
embedding_registry,
}
}
@@ -1416,12 +1433,6 @@ impl Tags for NativeTags {
}
}
impl From<NativeTable> for Table {
fn from(table: NativeTable) -> Self {
Self::new(Arc::new(table))
}
}
pub trait NativeTableExt {
/// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`].
fn as_native(&self) -> Option<&NativeTable>;
@@ -1843,6 +1854,18 @@ impl NativeTable {
);
Ok(Box::new(lance_idx_params))
}
Index::IvfRq(index) => {
Self::validate_index_type(field, "IVF RQ", supported_vector_data_type)?;
let num_partitions = self
.get_num_partitions(index.num_partitions, false, None)
.await?;
let lance_idx_params = VectorIndexParams::ivf_rq(
num_partitions as usize,
index.num_bits.unwrap_or(1) as u8,
index.distance_type.into(),
);
Ok(Box::new(lance_idx_params))
}
Index::IvfHnswPq(index) => {
Self::validate_index_type(field, "IVF HNSW PQ", supported_vector_data_type)?;
let dim = Self::get_vector_dimension(field)?;
@@ -1912,9 +1935,11 @@ impl NativeTable {
Index::Bitmap(_) => IndexType::Bitmap,
Index::LabelList(_) => IndexType::LabelList,
Index::FTS(_) => IndexType::Inverted,
Index::IvfFlat(_) | Index::IvfPq(_) | Index::IvfHnswPq(_) | Index::IvfHnswSq(_) => {
IndexType::Vector
}
Index::IvfFlat(_)
| Index::IvfPq(_)
| Index::IvfRq(_)
| Index::IvfHnswPq(_)
| Index::IvfHnswSq(_) => IndexType::Vector,
}
}

View File

@@ -125,6 +125,10 @@ impl ExecutionPlan for MetadataEraserExec {
fn partition_statistics(&self, partition: Option<usize>) -> DataFusionResult<Statistics> {
self.input.partition_statistics(partition)
}
fn supports_limit_pushdown(&self) -> bool {
true
}
}
#[derive(Debug)]