mirror of
https://github.com/lancedb/lancedb.git
synced 2026-03-31 04:50:40 +00:00
Compare commits
12 Commits
python-v0.
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7f189f27b | ||
|
|
a0a2942ad5 | ||
|
|
e3d53dd185 | ||
|
|
66804e99fc | ||
|
|
9f85d4c639 | ||
|
|
1ba19d728e | ||
|
|
4c44587af0 | ||
|
|
1d1cafb59c | ||
|
|
4714598155 | ||
|
|
74f457a0f2 | ||
|
|
cca6a7c989 | ||
|
|
ad96489114 |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.27.2-beta.0"
|
||||
current_version = "0.27.2-beta.1"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -23,8 +23,10 @@ runs:
|
||||
steps:
|
||||
- name: CONFIRM ARM BUILD
|
||||
shell: bash
|
||||
env:
|
||||
ARM_BUILD: ${{ inputs.arm-build }}
|
||||
run: |
|
||||
echo "ARM BUILD: ${{ inputs.arm-build }}"
|
||||
echo "ARM BUILD: $ARM_BUILD"
|
||||
- name: Build x86_64 Manylinux wheel
|
||||
if: ${{ inputs.arm-build == 'false' }}
|
||||
uses: PyO3/maturin-action@v1
|
||||
|
||||
125
Cargo.lock
generated
125
Cargo.lock
generated
@@ -108,7 +108,7 @@ version = "1.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc"
|
||||
dependencies = [
|
||||
"windows-sys 0.60.2",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -119,7 +119,7 @@ checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d"
|
||||
dependencies = [
|
||||
"anstyle",
|
||||
"once_cell_polyfill",
|
||||
"windows-sys 0.60.2",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2682,7 +2682,7 @@ dependencies = [
|
||||
"libc",
|
||||
"option-ext",
|
||||
"redox_users",
|
||||
"windows-sys 0.60.2",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2876,7 +2876,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3072,8 +3072,9 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
|
||||
|
||||
[[package]]
|
||||
name = "fsst"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2195cc7f87e84bd695586137de99605e7e9579b26ec5e01b82960ddb4d0922f2"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"rand 0.9.2",
|
||||
@@ -3736,7 +3737,7 @@ dependencies = [
|
||||
"libc",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"socket2 0.5.10",
|
||||
"socket2 0.6.3",
|
||||
"system-configuration",
|
||||
"tokio",
|
||||
"tower-service",
|
||||
@@ -4037,7 +4038,7 @@ dependencies = [
|
||||
"portable-atomic",
|
||||
"portable-atomic-util",
|
||||
"serde_core",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4123,8 +4124,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "efe6c3ddd79cdfd2b7e1c23cafae52806906bc40fbd97de9e8cf2f8c7a75fc04"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -4190,8 +4192,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-arrow"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d9f5d95bdda2a2b790f1fb8028b5b6dcf661abeb3133a8bca0f3d24b054af87"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4211,8 +4214,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-bitpacking"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f827d6ab9f8f337a9509d5ad66a12f3314db8713868260521c344ef6135eb4e4"
|
||||
dependencies = [
|
||||
"arrayref",
|
||||
"paste",
|
||||
@@ -4221,8 +4225,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-core"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0f1e25df6a79bf72ee6bcde0851f19b1cd36c5848c1b7db83340882d3c9fdecb"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4259,8 +4264,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-datafusion"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "93146de8ae720cb90edef81c2f2d0a1b065fc2f23ecff2419546f389b0fa70a4"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4290,8 +4296,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-datagen"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ccec8ce4d8e0a87a99c431dab2364398029f2ffb649c1a693c60c79e05ed30dd"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4309,8 +4316,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-encoding"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c1aec0bbbac6bce829bc10f1ba066258126100596c375fb71908ecf11c2c2a5"
|
||||
dependencies = [
|
||||
"arrow-arith",
|
||||
"arrow-array",
|
||||
@@ -4347,8 +4355,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-file"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "14a8c548804f5b17486dc2d3282356ed1957095a852780283bc401fdd69e9075"
|
||||
dependencies = [
|
||||
"arrow-arith",
|
||||
"arrow-array",
|
||||
@@ -4380,8 +4389,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-index"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2da212f0090ea59f79ac3686660f596520c167fe1cb5f408900cf71d215f0e03"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -4445,8 +4455,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-io"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "41d958eb4b56f03bbe0f5f85eb2b4e9657882812297b6f711f201ffc995f259f"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -4487,8 +4498,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-linalg"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0285b70da35def7ed95e150fae1d5308089554e1290470403ed3c50cb235bc5e"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4504,8 +4516,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-namespace"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5f78e2a828b654e062a495462c6e3eb4fcf0e7e907d761b8f217fc09ccd3ceac"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"async-trait",
|
||||
@@ -4518,8 +4531,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-namespace-impls"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2392314f3da38f00d166295e44244208a65ccfc256e274fa8631849fc3f4d94"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-ipc",
|
||||
@@ -4563,8 +4577,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-table"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3df9c4adca3eb2074b3850432a9fb34248a3d90c3d6427d158b13ff9355664ee"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4603,8 +4618,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-testing"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7ed7119bdd6983718387b4ac44af873a165262ca94f181b104cd6f97912eb3bf"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-schema",
|
||||
@@ -4615,7 +4631,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb"
|
||||
version = "0.27.2-beta.0"
|
||||
version = "0.27.2-beta.1"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"anyhow",
|
||||
@@ -4697,9 +4713,10 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-nodejs"
|
||||
version = "0.27.2-beta.0"
|
||||
version = "0.27.2-beta.1"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
"arrow-ipc",
|
||||
"arrow-schema",
|
||||
"async-trait",
|
||||
@@ -4707,6 +4724,7 @@ dependencies = [
|
||||
"aws-lc-sys",
|
||||
"env_logger",
|
||||
"futures",
|
||||
"half",
|
||||
"lancedb",
|
||||
"log",
|
||||
"lzma-sys",
|
||||
@@ -4717,7 +4735,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-python"
|
||||
version = "0.30.2-beta.0"
|
||||
version = "0.30.2-beta.1"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"async-trait",
|
||||
@@ -5305,7 +5323,7 @@ version = "0.50.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
|
||||
dependencies = [
|
||||
"windows-sys 0.60.2",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6284,7 +6302,7 @@ version = "0.14.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7"
|
||||
dependencies = [
|
||||
"heck 0.4.1",
|
||||
"heck 0.5.0",
|
||||
"itertools 0.14.0",
|
||||
"log",
|
||||
"multimap",
|
||||
@@ -6471,7 +6489,7 @@ dependencies = [
|
||||
"quinn-udp",
|
||||
"rustc-hash",
|
||||
"rustls 0.23.37",
|
||||
"socket2 0.5.10",
|
||||
"socket2 0.6.3",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tracing",
|
||||
@@ -6508,9 +6526,9 @@ dependencies = [
|
||||
"cfg_aliases",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"socket2 0.5.10",
|
||||
"socket2 0.6.3",
|
||||
"tracing",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.60.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7039,7 +7057,7 @@ dependencies = [
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys 0.4.15",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7052,7 +7070,7 @@ dependencies = [
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys 0.12.1",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7572,7 +7590,7 @@ version = "0.8.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c1c97747dbf44bb1ca44a561ece23508e99cb592e862f22222dcf42f51d1e451"
|
||||
dependencies = [
|
||||
"heck 0.4.1",
|
||||
"heck 0.5.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
@@ -7584,7 +7602,7 @@ version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "54254b8531cafa275c5e096f62d48c81435d1015405a91198ddb11e967301d40"
|
||||
dependencies = [
|
||||
"heck 0.4.1",
|
||||
"heck 0.5.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
@@ -7607,7 +7625,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.60.2",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7711,7 +7729,6 @@ dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"psm",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
@@ -8072,7 +8089,7 @@ dependencies = [
|
||||
"getrandom 0.4.2",
|
||||
"once_cell",
|
||||
"rustix 1.1.4",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -8877,7 +8894,7 @@ version = "0.1.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
|
||||
dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
28
Cargo.toml
28
Cargo.toml
@@ -15,20 +15,20 @@ categories = ["database-implementations"]
|
||||
rust-version = "1.91.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=4.0.0-rc.3", default-features = false, "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-core = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datagen = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-file = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-io = { "version" = "=4.0.0-rc.3", default-features = false, "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-index = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-linalg = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace-impls = { "version" = "=4.0.0-rc.3", default-features = false, "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-table = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-testing = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datafusion = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-encoding = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-arrow = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance = { version = "=4.0.0", default-features = false }
|
||||
lance-core = { version = "=4.0.0" }
|
||||
lance-datagen = { version = "=4.0.0" }
|
||||
lance-file = { version = "=4.0.0" }
|
||||
lance-io = { version = "=4.0.0", default-features = false }
|
||||
lance-index = { version = "=4.0.0" }
|
||||
lance-linalg = { version = "=4.0.0" }
|
||||
lance-namespace = { version = "=4.0.0" }
|
||||
lance-namespace-impls = { version = "=4.0.0", default-features = false }
|
||||
lance-table = { version = "=4.0.0" }
|
||||
lance-testing = { version = "=4.0.0" }
|
||||
lance-datafusion = { version = "=4.0.0" }
|
||||
lance-encoding = { version = "=4.0.0" }
|
||||
lance-arrow = { version = "=4.0.0" }
|
||||
ahash = "0.8"
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "57.2", optional = false }
|
||||
|
||||
@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
|
||||
<dependency>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-core</artifactId>
|
||||
<version>0.27.2-beta.0</version>
|
||||
<version>0.27.2-beta.1</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ new EmbeddingFunction<T, M>(): EmbeddingFunction<T, M>
|
||||
### computeQueryEmbeddings()
|
||||
|
||||
```ts
|
||||
computeQueryEmbeddings(data): Promise<number[] | Float32Array | Float64Array>
|
||||
computeQueryEmbeddings(data): Promise<number[] | Uint8Array | Float32Array | Float64Array>
|
||||
```
|
||||
|
||||
Compute the embeddings for a single query
|
||||
@@ -63,7 +63,7 @@ Compute the embeddings for a single query
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`number`[] \| `Float32Array` \| `Float64Array`>
|
||||
`Promise`<`number`[] \| `Uint8Array` \| `Float32Array` \| `Float64Array`>
|
||||
|
||||
***
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ new TextEmbeddingFunction<M>(): TextEmbeddingFunction<M>
|
||||
### computeQueryEmbeddings()
|
||||
|
||||
```ts
|
||||
computeQueryEmbeddings(data): Promise<number[] | Float32Array | Float64Array>
|
||||
computeQueryEmbeddings(data): Promise<number[] | Uint8Array | Float32Array | Float64Array>
|
||||
```
|
||||
|
||||
Compute the embeddings for a single query
|
||||
@@ -48,7 +48,7 @@ Compute the embeddings for a single query
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`number`[] \| `Float32Array` \| `Float64Array`>
|
||||
`Promise`<`number`[] \| `Uint8Array` \| `Float32Array` \| `Float64Array`>
|
||||
|
||||
#### Overrides
|
||||
|
||||
|
||||
@@ -7,5 +7,10 @@
|
||||
# Type Alias: IntoVector
|
||||
|
||||
```ts
|
||||
type IntoVector: Float32Array | Float64Array | number[] | Promise<Float32Array | Float64Array | number[]>;
|
||||
type IntoVector:
|
||||
| Float32Array
|
||||
| Float64Array
|
||||
| Uint8Array
|
||||
| number[]
|
||||
| Promise<Float32Array | Float64Array | Uint8Array | number[]>;
|
||||
```
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.27.2-beta.0</version>
|
||||
<version>0.27.2-beta.1</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.27.2-beta.0</version>
|
||||
<version>0.27.2-beta.1</version>
|
||||
<packaging>pom</packaging>
|
||||
<name>${project.artifactId}</name>
|
||||
<description>LanceDB Java SDK Parent POM</description>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.27.2-beta.0"
|
||||
version = "0.27.2-beta.1"
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
@@ -15,6 +15,8 @@ crate-type = ["cdylib"]
|
||||
async-trait.workspace = true
|
||||
arrow-ipc.workspace = true
|
||||
arrow-array.workspace = true
|
||||
arrow-buffer = "57.2"
|
||||
half.workspace = true
|
||||
arrow-schema.workspace = true
|
||||
env_logger.workspace = true
|
||||
futures.workspace = true
|
||||
|
||||
110
nodejs/__test__/vector_types.test.ts
Normal file
110
nodejs/__test__/vector_types.test.ts
Normal file
@@ -0,0 +1,110 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import * as tmp from "tmp";
|
||||
|
||||
import { type Table, connect } from "../lancedb";
|
||||
import {
|
||||
Field,
|
||||
FixedSizeList,
|
||||
Float32,
|
||||
Int64,
|
||||
Schema,
|
||||
makeArrowTable,
|
||||
} from "../lancedb/arrow";
|
||||
|
||||
describe("Vector query with different typed arrays", () => {
|
||||
let tmpDir: tmp.DirResult;
|
||||
|
||||
afterEach(() => {
|
||||
tmpDir?.removeCallback();
|
||||
});
|
||||
|
||||
async function createFloat32Table(): Promise<Table> {
|
||||
tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
||||
const db = await connect(tmpDir.name);
|
||||
const schema = new Schema([
|
||||
new Field("id", new Int64(), true),
|
||||
new Field(
|
||||
"vec",
|
||||
new FixedSizeList(2, new Field("item", new Float32())),
|
||||
true,
|
||||
),
|
||||
]);
|
||||
const data = makeArrowTable(
|
||||
[
|
||||
{ id: 1n, vec: [1.0, 0.0] },
|
||||
{ id: 2n, vec: [0.0, 1.0] },
|
||||
{ id: 3n, vec: [1.0, 1.0] },
|
||||
],
|
||||
{ schema },
|
||||
);
|
||||
return db.createTable("test_f32", data);
|
||||
}
|
||||
|
||||
it("should search with Float32Array (baseline)", async () => {
|
||||
const table = await createFloat32Table();
|
||||
const results = await table
|
||||
.query()
|
||||
.nearestTo(new Float32Array([1.0, 0.0]))
|
||||
.limit(1)
|
||||
.toArray();
|
||||
|
||||
expect(results.length).toBe(1);
|
||||
expect(Number(results[0].id)).toBe(1);
|
||||
});
|
||||
|
||||
it("should search with number[] (backward compat)", async () => {
|
||||
const table = await createFloat32Table();
|
||||
const results = await table
|
||||
.query()
|
||||
.nearestTo([1.0, 0.0])
|
||||
.limit(1)
|
||||
.toArray();
|
||||
|
||||
expect(results.length).toBe(1);
|
||||
expect(Number(results[0].id)).toBe(1);
|
||||
});
|
||||
|
||||
it("should search with Float64Array via raw path", async () => {
|
||||
const table = await createFloat32Table();
|
||||
const results = await table
|
||||
.query()
|
||||
.nearestTo(new Float64Array([1.0, 0.0]))
|
||||
.limit(1)
|
||||
.toArray();
|
||||
|
||||
expect(results.length).toBe(1);
|
||||
expect(Number(results[0].id)).toBe(1);
|
||||
});
|
||||
|
||||
it("should add multiple query vectors with Float64Array", async () => {
|
||||
const table = await createFloat32Table();
|
||||
const results = await table
|
||||
.query()
|
||||
.nearestTo(new Float64Array([1.0, 0.0]))
|
||||
.addQueryVector(new Float64Array([0.0, 1.0]))
|
||||
.limit(2)
|
||||
.toArray();
|
||||
|
||||
expect(results.length).toBeGreaterThanOrEqual(2);
|
||||
});
|
||||
|
||||
// Float16Array is only available in Node 22+; not in TypeScript's standard lib yet
|
||||
const float16ArrayCtor = (globalThis as unknown as Record<string, unknown>)
|
||||
.Float16Array as (new (values: number[]) => unknown) | undefined;
|
||||
const hasFloat16 = float16ArrayCtor !== undefined;
|
||||
const f16it = hasFloat16 ? it : it.skip;
|
||||
|
||||
f16it("should search with Float16Array via raw path", async () => {
|
||||
const table = await createFloat32Table();
|
||||
const results = await table
|
||||
.query()
|
||||
.nearestTo(new float16ArrayCtor!([1.0, 0.0]) as Float32Array)
|
||||
.limit(1)
|
||||
.toArray();
|
||||
|
||||
expect(results.length).toBe(1);
|
||||
expect(Number(results[0].id)).toBe(1);
|
||||
});
|
||||
});
|
||||
@@ -117,8 +117,9 @@ export type TableLike =
|
||||
export type IntoVector =
|
||||
| Float32Array
|
||||
| Float64Array
|
||||
| Uint8Array
|
||||
| number[]
|
||||
| Promise<Float32Array | Float64Array | number[]>;
|
||||
| Promise<Float32Array | Float64Array | Uint8Array | number[]>;
|
||||
|
||||
export type MultiVector = IntoVector[];
|
||||
|
||||
@@ -126,14 +127,48 @@ export function isMultiVector(value: unknown): value is MultiVector {
|
||||
return Array.isArray(value) && isIntoVector(value[0]);
|
||||
}
|
||||
|
||||
// Float16Array is not in TypeScript's standard lib yet; access dynamically
|
||||
type Float16ArrayCtor = new (
|
||||
...args: unknown[]
|
||||
) => { buffer: ArrayBuffer; byteOffset: number; byteLength: number };
|
||||
const float16ArrayCtor = (globalThis as unknown as Record<string, unknown>)
|
||||
.Float16Array as Float16ArrayCtor | undefined;
|
||||
|
||||
export function isIntoVector(value: unknown): value is IntoVector {
|
||||
return (
|
||||
value instanceof Float32Array ||
|
||||
value instanceof Float64Array ||
|
||||
value instanceof Uint8Array ||
|
||||
(float16ArrayCtor !== undefined && value instanceof float16ArrayCtor) ||
|
||||
(Array.isArray(value) && !Array.isArray(value[0]))
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract the underlying byte buffer and data type from a typed array
|
||||
* for passing to the Rust NAPI layer without precision loss.
|
||||
*/
|
||||
export function extractVectorBuffer(
|
||||
vector: Float32Array | Float64Array | Uint8Array,
|
||||
): { data: Uint8Array; dtype: string } | null {
|
||||
if (float16ArrayCtor !== undefined && vector instanceof float16ArrayCtor) {
|
||||
return {
|
||||
data: new Uint8Array(vector.buffer, vector.byteOffset, vector.byteLength),
|
||||
dtype: "float16",
|
||||
};
|
||||
}
|
||||
if (vector instanceof Float64Array) {
|
||||
return {
|
||||
data: new Uint8Array(vector.buffer, vector.byteOffset, vector.byteLength),
|
||||
dtype: "float64",
|
||||
};
|
||||
}
|
||||
if (vector instanceof Uint8Array && !(vector instanceof Float32Array)) {
|
||||
return { data: vector, dtype: "uint8" };
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export function isArrowTable(value: object): value is TableLike {
|
||||
if (value instanceof ArrowTable) return true;
|
||||
return "schema" in value && "batches" in value;
|
||||
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
Table as ArrowTable,
|
||||
type IntoVector,
|
||||
RecordBatch,
|
||||
extractVectorBuffer,
|
||||
fromBufferToRecordBatch,
|
||||
fromRecordBatchToBuffer,
|
||||
tableFromIPC,
|
||||
@@ -661,10 +662,8 @@ export class VectorQuery extends StandardQueryBase<NativeVectorQuery> {
|
||||
const res = (async () => {
|
||||
try {
|
||||
const v = await vector;
|
||||
const arr = Float32Array.from(v);
|
||||
//
|
||||
// biome-ignore lint/suspicious/noExplicitAny: we need to get the `inner`, but js has no package scoping
|
||||
const value: any = this.addQueryVector(arr);
|
||||
const value: any = this.addQueryVector(v);
|
||||
const inner = value.inner as
|
||||
| NativeVectorQuery
|
||||
| Promise<NativeVectorQuery>;
|
||||
@@ -676,7 +675,12 @@ export class VectorQuery extends StandardQueryBase<NativeVectorQuery> {
|
||||
return new VectorQuery(res);
|
||||
} else {
|
||||
super.doCall((inner) => {
|
||||
inner.addQueryVector(Float32Array.from(vector));
|
||||
const raw = Array.isArray(vector) ? null : extractVectorBuffer(vector);
|
||||
if (raw) {
|
||||
inner.addQueryVectorRaw(raw.data, raw.dtype);
|
||||
} else {
|
||||
inner.addQueryVector(Float32Array.from(vector as number[]));
|
||||
}
|
||||
});
|
||||
return this;
|
||||
}
|
||||
@@ -765,14 +769,23 @@ export class Query extends StandardQueryBase<NativeQuery> {
|
||||
* a default `limit` of 10 will be used. @see {@link Query#limit}
|
||||
*/
|
||||
nearestTo(vector: IntoVector): VectorQuery {
|
||||
const callNearestTo = (
|
||||
inner: NativeQuery,
|
||||
resolved: Float32Array | Float64Array | Uint8Array | number[],
|
||||
): NativeVectorQuery => {
|
||||
const raw = Array.isArray(resolved)
|
||||
? null
|
||||
: extractVectorBuffer(resolved);
|
||||
if (raw) {
|
||||
return inner.nearestToRaw(raw.data, raw.dtype);
|
||||
}
|
||||
return inner.nearestTo(Float32Array.from(resolved as number[]));
|
||||
};
|
||||
|
||||
if (this.inner instanceof Promise) {
|
||||
const nativeQuery = this.inner.then(async (inner) => {
|
||||
if (vector instanceof Promise) {
|
||||
const arr = await vector.then((v) => Float32Array.from(v));
|
||||
return inner.nearestTo(arr);
|
||||
} else {
|
||||
return inner.nearestTo(Float32Array.from(vector));
|
||||
}
|
||||
const resolved = vector instanceof Promise ? await vector : vector;
|
||||
return callNearestTo(inner, resolved);
|
||||
});
|
||||
return new VectorQuery(nativeQuery);
|
||||
}
|
||||
@@ -780,10 +793,8 @@ export class Query extends StandardQueryBase<NativeQuery> {
|
||||
const res = (async () => {
|
||||
try {
|
||||
const v = await vector;
|
||||
const arr = Float32Array.from(v);
|
||||
//
|
||||
// biome-ignore lint/suspicious/noExplicitAny: we need to get the `inner`, but js has no package scoping
|
||||
const value: any = this.nearestTo(arr);
|
||||
const value: any = this.nearestTo(v);
|
||||
const inner = value.inner as
|
||||
| NativeVectorQuery
|
||||
| Promise<NativeVectorQuery>;
|
||||
@@ -794,7 +805,7 @@ export class Query extends StandardQueryBase<NativeQuery> {
|
||||
})();
|
||||
return new VectorQuery(res);
|
||||
} else {
|
||||
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector));
|
||||
const vectorQuery = callNearestTo(this.inner, vector);
|
||||
return new VectorQuery(vectorQuery);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.27.2-beta.0",
|
||||
"version": "0.27.2-beta.1",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.27.2-beta.0",
|
||||
"version": "0.27.2-beta.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||
"version": "0.27.2-beta.0",
|
||||
"version": "0.27.2-beta.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.27.2-beta.0",
|
||||
"version": "0.27.2-beta.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||
"version": "0.27.2-beta.0",
|
||||
"version": "0.27.2-beta.1",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||
"version": "0.27.2-beta.0",
|
||||
"version": "0.27.2-beta.1",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.27.2-beta.0",
|
||||
"version": "0.27.2-beta.1",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.27.2-beta.0",
|
||||
"version": "0.27.2-beta.1",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.27.2-beta.0",
|
||||
"version": "0.27.2-beta.1",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
"ann"
|
||||
],
|
||||
"private": false,
|
||||
"version": "0.27.2-beta.0",
|
||||
"version": "0.27.2-beta.1",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
|
||||
@@ -3,6 +3,12 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::{
|
||||
Array, Float16Array as ArrowFloat16Array, Float32Array as ArrowFloat32Array,
|
||||
Float64Array as ArrowFloat64Array, UInt8Array as ArrowUInt8Array,
|
||||
};
|
||||
use arrow_buffer::ScalarBuffer;
|
||||
use half::f16;
|
||||
use lancedb::index::scalar::{
|
||||
BooleanQuery, BoostQuery, FtsQuery, FullTextSearchQuery, MatchQuery, MultiMatchQuery, Occur,
|
||||
Operator, PhraseQuery,
|
||||
@@ -24,6 +30,33 @@ use crate::rerankers::RerankHybridCallbackArgs;
|
||||
use crate::rerankers::Reranker;
|
||||
use crate::util::{parse_distance_type, schema_to_buffer};
|
||||
|
||||
fn bytes_to_arrow_array(data: Uint8Array, dtype: String) -> napi::Result<Arc<dyn Array>> {
|
||||
let buf = arrow_buffer::Buffer::from(data.to_vec());
|
||||
let num_bytes = buf.len();
|
||||
match dtype.as_str() {
|
||||
"float16" => {
|
||||
let scalar_buf = ScalarBuffer::<f16>::new(buf, 0, num_bytes / 2);
|
||||
Ok(Arc::new(ArrowFloat16Array::new(scalar_buf, None)))
|
||||
}
|
||||
"float32" => {
|
||||
let scalar_buf = ScalarBuffer::<f32>::new(buf, 0, num_bytes / 4);
|
||||
Ok(Arc::new(ArrowFloat32Array::new(scalar_buf, None)))
|
||||
}
|
||||
"float64" => {
|
||||
let scalar_buf = ScalarBuffer::<f64>::new(buf, 0, num_bytes / 8);
|
||||
Ok(Arc::new(ArrowFloat64Array::new(scalar_buf, None)))
|
||||
}
|
||||
"uint8" => {
|
||||
let scalar_buf = ScalarBuffer::<u8>::new(buf, 0, num_bytes);
|
||||
Ok(Arc::new(ArrowUInt8Array::new(scalar_buf, None)))
|
||||
}
|
||||
_ => Err(napi::Error::from_reason(format!(
|
||||
"Unsupported vector dtype: {}. Expected one of: float16, float32, float64, uint8",
|
||||
dtype
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub struct Query {
|
||||
inner: LanceDbQuery,
|
||||
@@ -78,6 +111,13 @@ impl Query {
|
||||
Ok(VectorQuery { inner })
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn nearest_to_raw(&mut self, data: Uint8Array, dtype: String) -> Result<VectorQuery> {
|
||||
let array = bytes_to_arrow_array(data, dtype)?;
|
||||
let inner = self.inner.clone().nearest_to(array).default_error()?;
|
||||
Ok(VectorQuery { inner })
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn fast_search(&mut self) {
|
||||
self.inner = self.inner.clone().fast_search();
|
||||
@@ -163,6 +203,13 @@ impl VectorQuery {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn add_query_vector_raw(&mut self, data: Uint8Array, dtype: String) -> Result<()> {
|
||||
let array = bytes_to_arrow_array(data, dtype)?;
|
||||
self.inner = self.inner.clone().add_query_vector(array).default_error()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn distance_type(&mut self, distance_type: String) -> napi::Result<()> {
|
||||
let distance_type = parse_distance_type(distance_type)?;
|
||||
|
||||
@@ -10,6 +10,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
import weakref
|
||||
import logging
|
||||
from functools import wraps
|
||||
|
||||
@@ -70,7 +70,7 @@ def ensure_vector_query(
|
||||
) -> Union[List[float], List[List[float]], pa.Array, List[pa.Array]]:
|
||||
if isinstance(val, list):
|
||||
if len(val) == 0:
|
||||
return ValueError("Vector query must be a non-empty list")
|
||||
raise ValueError("Vector query must be a non-empty list")
|
||||
sample = val[0]
|
||||
else:
|
||||
if isinstance(val, float):
|
||||
@@ -83,7 +83,7 @@ def ensure_vector_query(
|
||||
return val
|
||||
if isinstance(sample, list):
|
||||
if len(sample) == 0:
|
||||
return ValueError("Vector query must be a non-empty list")
|
||||
raise ValueError("Vector query must be a non-empty list")
|
||||
if isinstance(sample[0], float):
|
||||
# val is list of list of floats
|
||||
return val
|
||||
|
||||
@@ -278,7 +278,7 @@ def _sanitize_data(
|
||||
|
||||
if metadata:
|
||||
new_metadata = target_schema.metadata or {}
|
||||
new_metadata = new_metadata.update(metadata)
|
||||
new_metadata.update(metadata)
|
||||
target_schema = target_schema.with_metadata(new_metadata)
|
||||
|
||||
_validate_schema(target_schema)
|
||||
@@ -3857,7 +3857,13 @@ class AsyncTable:
|
||||
|
||||
# _santitize_data is an old code path, but we will use it until the
|
||||
# new code path is ready.
|
||||
if on_bad_vectors != "error" or (
|
||||
if mode == "overwrite":
|
||||
# For overwrite, apply the same preprocessing as create_table
|
||||
# so vector columns are inferred as FixedSizeList.
|
||||
data, _ = sanitize_create_table(
|
||||
data, None, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
)
|
||||
elif on_bad_vectors != "error" or (
|
||||
schema.metadata is not None and b"embedding_functions" in schema.metadata
|
||||
):
|
||||
data = _sanitize_data(
|
||||
|
||||
@@ -546,3 +546,24 @@ def test_openai_no_retry_on_401(mock_sleep):
|
||||
assert mock_func.call_count == 1
|
||||
# Verify that sleep was never called (no retries)
|
||||
assert mock_sleep.call_count == 0
|
||||
|
||||
|
||||
def test_url_retrieve_downloads_image():
|
||||
"""
|
||||
Embedding functions like open-clip, siglip, and jinaai use url_retrieve()
|
||||
to download images from HTTP URLs. For example, open_clip._to_pil() calls:
|
||||
|
||||
PIL_Image.open(io.BytesIO(url_retrieve(image)))
|
||||
|
||||
Verify that url_retrieve() can download an image and open it as PIL Image,
|
||||
matching the real usage pattern in embedding functions.
|
||||
"""
|
||||
import io
|
||||
|
||||
Image = pytest.importorskip("PIL.Image")
|
||||
from lancedb.embeddings.utils import url_retrieve
|
||||
|
||||
image_url = "http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg"
|
||||
image_bytes = url_retrieve(image_url)
|
||||
img = Image.open(io.BytesIO(image_bytes))
|
||||
assert img.size[0] > 0 and img.size[1] > 0
|
||||
|
||||
@@ -8,6 +8,7 @@ import shutil
|
||||
import pytest
|
||||
import pyarrow as pa
|
||||
import lancedb
|
||||
from lance_namespace.errors import NamespaceNotEmptyError, TableNotFoundError
|
||||
|
||||
|
||||
class TestNamespaceConnection:
|
||||
@@ -130,7 +131,7 @@ class TestNamespaceConnection:
|
||||
assert len(list(db.table_names(namespace=["test_ns"]))) == 0
|
||||
|
||||
# Should not be able to open dropped table
|
||||
with pytest.raises(RuntimeError):
|
||||
with pytest.raises(TableNotFoundError):
|
||||
db.open_table("table1", namespace=["test_ns"])
|
||||
|
||||
def test_create_table_with_schema(self):
|
||||
@@ -340,7 +341,7 @@ class TestNamespaceConnection:
|
||||
db.create_table("test_table", schema=schema, namespace=["test_namespace"])
|
||||
|
||||
# Try to drop namespace with tables - should fail
|
||||
with pytest.raises(RuntimeError, match="is not empty"):
|
||||
with pytest.raises(NamespaceNotEmptyError):
|
||||
db.drop_namespace(["test_namespace"])
|
||||
|
||||
# Drop table first
|
||||
|
||||
@@ -30,6 +30,7 @@ from lancedb.query import (
|
||||
PhraseQuery,
|
||||
Query,
|
||||
FullTextSearchQuery,
|
||||
ensure_vector_query,
|
||||
)
|
||||
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
|
||||
from lancedb.table import AsyncTable, LanceTable
|
||||
@@ -1501,6 +1502,18 @@ def test_search_empty_table(mem_db):
|
||||
assert results == []
|
||||
|
||||
|
||||
def test_ensure_vector_query_empty_list():
|
||||
"""Regression: ensure_vector_query used to return instead of raise ValueError."""
|
||||
with pytest.raises(ValueError, match="non-empty"):
|
||||
ensure_vector_query([])
|
||||
|
||||
|
||||
def test_ensure_vector_query_nested_empty_list():
|
||||
"""Regression: ensure_vector_query used to return instead of raise ValueError."""
|
||||
with pytest.raises(ValueError, match="non-empty"):
|
||||
ensure_vector_query([[]])
|
||||
|
||||
|
||||
def test_fast_search(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
|
||||
|
||||
@@ -527,6 +527,36 @@ async def test_add_async(mem_db_async: AsyncConnection):
|
||||
assert await table.count_rows() == 3
|
||||
|
||||
|
||||
def test_add_overwrite_infers_vector_schema(mem_db: DBConnection):
|
||||
"""Overwrite should infer vector columns the same way create_table does.
|
||||
|
||||
Regression test for https://github.com/lancedb/lancedb/issues/3183
|
||||
"""
|
||||
table = mem_db.create_table(
|
||||
"test_overwrite_vec",
|
||||
data=[
|
||||
{"vector": [1.0, 2.0, 3.0, 4.0], "item": "foo"},
|
||||
{"vector": [5.0, 6.0, 7.0, 8.0], "item": "bar"},
|
||||
],
|
||||
)
|
||||
# create_table infers vector as fixed_size_list<float32, 4>
|
||||
original_type = table.schema.field("vector").type
|
||||
assert pa.types.is_fixed_size_list(original_type)
|
||||
|
||||
# overwrite with plain Python lists (PyArrow infers list<double>)
|
||||
table.add(
|
||||
[
|
||||
{"vector": [10.0, 20.0, 30.0, 40.0], "item": "baz"},
|
||||
],
|
||||
mode="overwrite",
|
||||
)
|
||||
# overwrite should infer vector column the same way as create_table
|
||||
new_type = table.schema.field("vector").type
|
||||
assert pa.types.is_fixed_size_list(new_type), (
|
||||
f"Expected fixed_size_list after overwrite, got {new_type}"
|
||||
)
|
||||
|
||||
|
||||
def test_add_progress_callback(mem_db: DBConnection):
|
||||
table = mem_db.create_table(
|
||||
"test",
|
||||
@@ -2143,3 +2173,33 @@ def test_table_uri(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
table = db.create_table("my_table", data=[{"x": 0}])
|
||||
assert table.uri == str(tmp_path / "my_table.lance")
|
||||
|
||||
|
||||
def test_sanitize_data_metadata_not_stripped():
|
||||
"""Regression test: dict.update() returns None, so assigning its result
|
||||
would silently replace metadata with None, causing with_metadata(None)
|
||||
to strip all schema metadata from the target schema."""
|
||||
from lancedb.table import _sanitize_data
|
||||
|
||||
schema = pa.schema(
|
||||
[pa.field("x", pa.int64())],
|
||||
metadata={b"existing_key": b"existing_value"},
|
||||
)
|
||||
batch = pa.record_batch([pa.array([1, 2, 3])], schema=schema)
|
||||
|
||||
# Use a different field type so the reader and target schemas differ,
|
||||
# forcing _cast_to_target_schema to rebuild the schema with the
|
||||
# target's metadata (instead of taking the fast-path).
|
||||
target_schema = pa.schema(
|
||||
[pa.field("x", pa.int32())],
|
||||
metadata={b"existing_key": b"existing_value"},
|
||||
)
|
||||
|
||||
reader = pa.RecordBatchReader.from_batches(schema, [batch])
|
||||
metadata = {b"new_key": b"new_value"}
|
||||
result = _sanitize_data(reader, target_schema=target_schema, metadata=metadata)
|
||||
|
||||
result_schema = result.schema
|
||||
assert result_schema.metadata is not None
|
||||
assert result_schema.metadata[b"existing_key"] == b"existing_value"
|
||||
assert result_schema.metadata[b"new_key"] == b"new_value"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.27.2-beta.0"
|
||||
version = "0.27.2-beta.1"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
|
||||
@@ -240,7 +240,7 @@ impl Shuffler {
|
||||
.await?;
|
||||
// Need to read the entire file in a single batch for in-memory shuffling
|
||||
let batch = reader.read_record_batch(0, reader.num_rows()).await?;
|
||||
let mut rng = rng.lock().unwrap();
|
||||
let mut rng = rng.lock().unwrap_or_else(|e| e.into_inner());
|
||||
Self::shuffle_batch(&batch, &mut rng, clump_size)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -66,13 +66,13 @@ impl IoTrackingStore {
|
||||
}
|
||||
|
||||
fn record_read(&self, num_bytes: u64) {
|
||||
let mut stats = self.stats.lock().unwrap();
|
||||
let mut stats = self.stats.lock().unwrap_or_else(|e| e.into_inner());
|
||||
stats.read_iops += 1;
|
||||
stats.read_bytes += num_bytes;
|
||||
}
|
||||
|
||||
fn record_write(&self, num_bytes: u64) {
|
||||
let mut stats = self.stats.lock().unwrap();
|
||||
let mut stats = self.stats.lock().unwrap_or_else(|e| e.into_inner());
|
||||
stats.write_iops += 1;
|
||||
stats.write_bytes += num_bytes;
|
||||
}
|
||||
@@ -229,10 +229,63 @@ impl MultipartUpload for IoTrackingMultipartUpload {
|
||||
|
||||
fn put_part(&mut self, payload: PutPayload) -> UploadPart {
|
||||
{
|
||||
let mut stats = self.stats.lock().unwrap();
|
||||
let mut stats = self.stats.lock().unwrap_or_else(|e| e.into_inner());
|
||||
stats.write_iops += 1;
|
||||
stats.write_bytes += payload.content_length() as u64;
|
||||
}
|
||||
self.target.put_part(payload)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// Helper: poison a Mutex<IoStats> by panicking while holding the lock.
|
||||
fn poison_stats(stats: &Arc<Mutex<IoStats>>) {
|
||||
let stats_clone = stats.clone();
|
||||
let handle = std::thread::spawn(move || {
|
||||
let _guard = stats_clone.lock().unwrap();
|
||||
panic!("intentional panic to poison stats mutex");
|
||||
});
|
||||
let _ = handle.join();
|
||||
assert!(stats.lock().is_err(), "mutex should be poisoned");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_record_read_recovers_from_poisoned_lock() {
|
||||
let stats = Arc::new(Mutex::new(IoStats::default()));
|
||||
let store = IoTrackingStore {
|
||||
target: Arc::new(object_store::memory::InMemory::new()),
|
||||
stats: stats.clone(),
|
||||
};
|
||||
|
||||
poison_stats(&stats);
|
||||
|
||||
// record_read should not panic
|
||||
store.record_read(1024);
|
||||
|
||||
// Verify the stats were updated despite poisoning
|
||||
let s = stats.lock().unwrap_or_else(|e| e.into_inner());
|
||||
assert_eq!(s.read_iops, 1);
|
||||
assert_eq!(s.read_bytes, 1024);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_record_write_recovers_from_poisoned_lock() {
|
||||
let stats = Arc::new(Mutex::new(IoStats::default()));
|
||||
let store = IoTrackingStore {
|
||||
target: Arc::new(object_store::memory::InMemory::new()),
|
||||
stats: stats.clone(),
|
||||
};
|
||||
|
||||
poison_stats(&stats);
|
||||
|
||||
// record_write should not panic
|
||||
store.record_write(2048);
|
||||
|
||||
let s = stats.lock().unwrap_or_else(|e| e.into_inner());
|
||||
assert_eq!(s.write_iops, 1);
|
||||
assert_eq!(s.write_bytes, 2048);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ use std::sync::Arc;
|
||||
use std::{future::Future, time::Duration};
|
||||
|
||||
use arrow::compute::concat_batches;
|
||||
use arrow_array::{Array, Float16Array, Float32Array, Float64Array, make_array};
|
||||
use arrow_array::{Array, Float16Array, Float32Array, Float64Array, RecordBatch, make_array};
|
||||
use arrow_schema::{DataType, SchemaRef};
|
||||
use datafusion_expr::Expr;
|
||||
use datafusion_physical_plan::ExecutionPlan;
|
||||
@@ -17,15 +17,17 @@ use lance_datafusion::exec::execute_plan;
|
||||
use lance_index::scalar::FullTextSearchQuery;
|
||||
use lance_index::scalar::inverted::SCORE_COL;
|
||||
use lance_index::vector::DIST_COL;
|
||||
use lance_io::stream::RecordBatchStreamAdapter;
|
||||
|
||||
use crate::DistanceType;
|
||||
use crate::error::{Error, Result};
|
||||
use crate::rerankers::rrf::RRFReranker;
|
||||
use crate::rerankers::{NormalizeMethod, Reranker, check_reranker_result};
|
||||
use crate::table::BaseTable;
|
||||
use crate::utils::TimeoutStream;
|
||||
use crate::{arrow::SendableRecordBatchStream, table::AnyQuery};
|
||||
use crate::utils::{MaxBatchLengthStream, TimeoutStream};
|
||||
use crate::{
|
||||
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
|
||||
table::AnyQuery,
|
||||
};
|
||||
|
||||
mod hybrid;
|
||||
|
||||
@@ -604,6 +606,14 @@ impl Default for QueryExecutionOptions {
|
||||
}
|
||||
}
|
||||
|
||||
impl QueryExecutionOptions {
|
||||
fn without_output_batch_length_limit(&self) -> Self {
|
||||
let mut options = self.clone();
|
||||
options.max_batch_length = 0;
|
||||
options
|
||||
}
|
||||
}
|
||||
|
||||
/// A trait for a query object that can be executed to get results
|
||||
///
|
||||
/// There are various kinds of queries but they all return results
|
||||
@@ -1180,6 +1190,8 @@ impl VectorQuery {
|
||||
&self,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let max_batch_length = options.max_batch_length as usize;
|
||||
let internal_options = options.without_output_batch_length_limit();
|
||||
// clone query and specify we want to include row IDs, which can be needed for reranking
|
||||
let mut fts_query = Query::new(self.parent.clone());
|
||||
fts_query.request = self.request.base.clone();
|
||||
@@ -1189,8 +1201,8 @@ impl VectorQuery {
|
||||
|
||||
vector_query.request.base.full_text_search = None;
|
||||
let (fts_results, vec_results) = try_join!(
|
||||
fts_query.execute_with_options(options.clone()),
|
||||
vector_query.inner_execute_with_options(options)
|
||||
fts_query.execute_with_options(internal_options.clone()),
|
||||
vector_query.inner_execute_with_options(internal_options)
|
||||
)?;
|
||||
|
||||
let (fts_results, vec_results) = try_join!(
|
||||
@@ -1245,9 +1257,7 @@ impl VectorQuery {
|
||||
results = results.drop_column(ROW_ID)?;
|
||||
}
|
||||
|
||||
Ok(SendableRecordBatchStream::from(
|
||||
RecordBatchStreamAdapter::new(results.schema(), stream::iter([Ok(results)])),
|
||||
))
|
||||
Ok(single_batch_stream(results, max_batch_length))
|
||||
}
|
||||
|
||||
async fn inner_execute_with_options(
|
||||
@@ -1256,6 +1266,7 @@ impl VectorQuery {
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let plan = self.create_plan(options.clone()).await?;
|
||||
let inner = execute_plan(plan, Default::default())?;
|
||||
let inner = MaxBatchLengthStream::new_boxed(inner, options.max_batch_length as usize);
|
||||
let inner = if let Some(timeout) = options.timeout {
|
||||
TimeoutStream::new_boxed(inner, timeout)
|
||||
} else {
|
||||
@@ -1265,6 +1276,25 @@ impl VectorQuery {
|
||||
}
|
||||
}
|
||||
|
||||
fn single_batch_stream(batch: RecordBatch, max_batch_length: usize) -> SendableRecordBatchStream {
|
||||
let schema = batch.schema();
|
||||
if max_batch_length == 0 || batch.num_rows() <= max_batch_length {
|
||||
return Box::pin(SimpleRecordBatchStream::new(
|
||||
stream::iter([Ok(batch)]),
|
||||
schema,
|
||||
));
|
||||
}
|
||||
|
||||
let mut batches = Vec::with_capacity(batch.num_rows().div_ceil(max_batch_length));
|
||||
let mut offset = 0;
|
||||
while offset < batch.num_rows() {
|
||||
let length = (batch.num_rows() - offset).min(max_batch_length);
|
||||
batches.push(Ok(batch.slice(offset, length)));
|
||||
offset += length;
|
||||
}
|
||||
Box::pin(SimpleRecordBatchStream::new(stream::iter(batches), schema))
|
||||
}
|
||||
|
||||
impl ExecutableQuery for VectorQuery {
|
||||
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
let query = AnyQuery::VectorQuery(self.request.clone());
|
||||
@@ -1753,6 +1783,50 @@ mod tests {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
async fn make_large_vector_table(tmp_dir: &tempfile::TempDir, rows: usize) -> Table {
|
||||
let dataset_path = tmp_dir.path().join("large_test.lance");
|
||||
let uri = dataset_path.to_str().unwrap();
|
||||
|
||||
let schema = Arc::new(ArrowSchema::new(vec![
|
||||
ArrowField::new("id", DataType::Utf8, false),
|
||||
ArrowField::new(
|
||||
"vector",
|
||||
DataType::FixedSizeList(
|
||||
Arc::new(ArrowField::new("item", DataType::Float32, true)),
|
||||
4,
|
||||
),
|
||||
false,
|
||||
),
|
||||
]));
|
||||
|
||||
let ids = StringArray::from_iter_values((0..rows).map(|i| format!("row-{i}")));
|
||||
let vectors = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
(0..rows).map(|i| Some(vec![Some(i as f32), Some(1.0), Some(2.0), Some(3.0)])),
|
||||
4,
|
||||
);
|
||||
let batch =
|
||||
RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(vectors)]).unwrap();
|
||||
|
||||
let conn = connect(uri).execute().await.unwrap();
|
||||
conn.create_table("my_table", vec![batch])
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
async fn assert_stream_batches_at_most(
|
||||
mut results: SendableRecordBatchStream,
|
||||
max_batch_length: usize,
|
||||
) {
|
||||
let mut saw_batch = false;
|
||||
while let Some(batch) = results.next().await {
|
||||
let batch = batch.unwrap();
|
||||
saw_batch = true;
|
||||
assert!(batch.num_rows() <= max_batch_length);
|
||||
}
|
||||
assert!(saw_batch);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_with_options() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
@@ -1772,6 +1846,83 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_query_execute_with_options_respects_max_batch_length() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let table = make_large_vector_table(&tmp_dir, 10_000).await;
|
||||
|
||||
let results = table
|
||||
.query()
|
||||
.nearest_to(vec![0.0, 1.0, 2.0, 3.0])
|
||||
.unwrap()
|
||||
.limit(10_000)
|
||||
.execute_with_options(QueryExecutionOptions {
|
||||
max_batch_length: 100,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_stream_batches_at_most(results, 100).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hybrid_query_execute_with_options_respects_max_batch_length() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let dataset_path = tmp_dir.path();
|
||||
let conn = connect(dataset_path.to_str().unwrap())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let dims = 2;
|
||||
let rows = 512;
|
||||
let schema = Arc::new(ArrowSchema::new(vec![
|
||||
ArrowField::new("text", DataType::Utf8, false),
|
||||
ArrowField::new(
|
||||
"vector",
|
||||
DataType::FixedSizeList(
|
||||
Arc::new(ArrowField::new("item", DataType::Float32, true)),
|
||||
dims,
|
||||
),
|
||||
false,
|
||||
),
|
||||
]));
|
||||
|
||||
let text = StringArray::from_iter_values((0..rows).map(|_| "match"));
|
||||
let vectors = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
(0..rows).map(|i| Some(vec![Some(i as f32), Some(0.0)])),
|
||||
dims,
|
||||
);
|
||||
let record_batch =
|
||||
RecordBatch::try_new(schema.clone(), vec![Arc::new(text), Arc::new(vectors)]).unwrap();
|
||||
let table = conn
|
||||
.create_table("my_table", record_batch)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
table
|
||||
.create_index(&["text"], crate::index::Index::FTS(Default::default()))
|
||||
.replace(true)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = table
|
||||
.query()
|
||||
.full_text_search(FullTextSearchQuery::new("match".to_string()))
|
||||
.limit(rows)
|
||||
.nearest_to(&[0.0, 0.0])
|
||||
.unwrap()
|
||||
.execute_with_options(QueryExecutionOptions {
|
||||
max_batch_length: 100,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_stream_batches_at_most(results, 100).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analyze_plan() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
|
||||
@@ -130,7 +130,10 @@ impl<S: HttpSend + 'static> RemoteInsertExec<S> {
|
||||
// TODO: this will be used when we wire this up to Table::add().
|
||||
#[allow(dead_code)]
|
||||
pub fn add_result(&self) -> Option<AddResult> {
|
||||
self.add_result.lock().unwrap().clone()
|
||||
self.add_result
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.clone()
|
||||
}
|
||||
|
||||
/// Stream the input into an HTTP body as an Arrow IPC stream, capturing any
|
||||
|
||||
@@ -204,7 +204,9 @@ impl ExecutionPlan for InsertExec {
|
||||
|
||||
let to_commit = {
|
||||
// Don't hold the lock over an await point.
|
||||
let mut txns = partial_transactions.lock().unwrap();
|
||||
let mut txns = partial_transactions
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
txns.push(transaction);
|
||||
if txns.len() == total_partitions {
|
||||
Some(std::mem::take(&mut *txns))
|
||||
|
||||
@@ -82,7 +82,7 @@ impl DatasetConsistencyWrapper {
|
||||
/// pinned dataset regardless of consistency mode.
|
||||
pub async fn get(&self) -> Result<Arc<Dataset>> {
|
||||
{
|
||||
let state = self.state.lock().unwrap();
|
||||
let state = self.state.lock()?;
|
||||
if state.pinned_version.is_some() {
|
||||
return Ok(state.dataset.clone());
|
||||
}
|
||||
@@ -101,7 +101,7 @@ impl DatasetConsistencyWrapper {
|
||||
}
|
||||
ConsistencyMode::Strong => refresh_latest(self.state.clone()).await,
|
||||
ConsistencyMode::Lazy => {
|
||||
let state = self.state.lock().unwrap();
|
||||
let state = self.state.lock()?;
|
||||
Ok(state.dataset.clone())
|
||||
}
|
||||
}
|
||||
@@ -116,7 +116,7 @@ impl DatasetConsistencyWrapper {
|
||||
/// concurrent [`as_time_travel`](Self::as_time_travel) call), the update
|
||||
/// is silently ignored — the write already committed to storage.
|
||||
pub fn update(&self, dataset: Dataset) {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner());
|
||||
if state.pinned_version.is_some() {
|
||||
// A concurrent as_time_travel() beat us here. The write succeeded
|
||||
// in storage, but since we're now pinned we don't advance the
|
||||
@@ -139,7 +139,7 @@ impl DatasetConsistencyWrapper {
|
||||
|
||||
/// Check that the dataset is in a mutable mode (Latest).
|
||||
pub fn ensure_mutable(&self) -> Result<()> {
|
||||
let state = self.state.lock().unwrap();
|
||||
let state = self.state.lock()?;
|
||||
if state.pinned_version.is_some() {
|
||||
Err(crate::Error::InvalidInput {
|
||||
message: "table cannot be modified when a specific version is checked out"
|
||||
@@ -152,13 +152,16 @@ impl DatasetConsistencyWrapper {
|
||||
|
||||
/// Returns the version, if in time travel mode, or None otherwise.
|
||||
pub fn time_travel_version(&self) -> Option<u64> {
|
||||
self.state.lock().unwrap().pinned_version
|
||||
self.state
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.pinned_version
|
||||
}
|
||||
|
||||
/// Convert into a wrapper in latest version mode.
|
||||
pub async fn as_latest(&self) -> Result<()> {
|
||||
let dataset = {
|
||||
let state = self.state.lock().unwrap();
|
||||
let state = self.state.lock()?;
|
||||
if state.pinned_version.is_none() {
|
||||
return Ok(());
|
||||
}
|
||||
@@ -168,7 +171,7 @@ impl DatasetConsistencyWrapper {
|
||||
let latest_version = dataset.latest_version_id().await?;
|
||||
let new_dataset = dataset.checkout_version(latest_version).await?;
|
||||
|
||||
let mut state = self.state.lock().unwrap();
|
||||
let mut state = self.state.lock()?;
|
||||
if state.pinned_version.is_some() {
|
||||
state.dataset = Arc::new(new_dataset);
|
||||
state.pinned_version = None;
|
||||
@@ -184,7 +187,7 @@ impl DatasetConsistencyWrapper {
|
||||
let target_ref = target_version.into();
|
||||
|
||||
let (should_checkout, dataset) = {
|
||||
let state = self.state.lock().unwrap();
|
||||
let state = self.state.lock()?;
|
||||
let should = match state.pinned_version {
|
||||
None => true,
|
||||
Some(version) => match &target_ref {
|
||||
@@ -204,7 +207,7 @@ impl DatasetConsistencyWrapper {
|
||||
let new_dataset = dataset.checkout_version(target_ref).await?;
|
||||
let version_value = new_dataset.version().version;
|
||||
|
||||
let mut state = self.state.lock().unwrap();
|
||||
let mut state = self.state.lock()?;
|
||||
state.dataset = Arc::new(new_dataset);
|
||||
state.pinned_version = Some(version_value);
|
||||
Ok(())
|
||||
@@ -212,7 +215,7 @@ impl DatasetConsistencyWrapper {
|
||||
|
||||
pub async fn reload(&self) -> Result<()> {
|
||||
let (dataset, pinned_version) = {
|
||||
let state = self.state.lock().unwrap();
|
||||
let state = self.state.lock()?;
|
||||
(state.dataset.clone(), state.pinned_version)
|
||||
};
|
||||
|
||||
@@ -230,7 +233,7 @@ impl DatasetConsistencyWrapper {
|
||||
|
||||
let new_dataset = dataset.checkout_version(version).await?;
|
||||
|
||||
let mut state = self.state.lock().unwrap();
|
||||
let mut state = self.state.lock()?;
|
||||
if state.pinned_version == Some(version) {
|
||||
state.dataset = Arc::new(new_dataset);
|
||||
}
|
||||
@@ -242,14 +245,14 @@ impl DatasetConsistencyWrapper {
|
||||
}
|
||||
|
||||
async fn refresh_latest(state: Arc<Mutex<DatasetState>>) -> Result<Arc<Dataset>> {
|
||||
let dataset = { state.lock().unwrap().dataset.clone() };
|
||||
let dataset = { state.lock()?.dataset.clone() };
|
||||
|
||||
let mut ds = (*dataset).clone();
|
||||
ds.checkout_latest().await?;
|
||||
let new_arc = Arc::new(ds);
|
||||
|
||||
{
|
||||
let mut state = state.lock().unwrap();
|
||||
let mut state = state.lock()?;
|
||||
if state.pinned_version.is_none()
|
||||
&& new_arc.manifest().version >= state.dataset.manifest().version
|
||||
{
|
||||
@@ -612,4 +615,108 @@ mod tests {
|
||||
let s = io_stats.incremental_stats();
|
||||
assert_eq!(s.read_iops, 0, "step 5, elapsed={:?}", start.elapsed());
|
||||
}
|
||||
|
||||
/// Helper: poison the mutex inside a DatasetConsistencyWrapper.
|
||||
fn poison_state(wrapper: &DatasetConsistencyWrapper) {
|
||||
let state = wrapper.state.clone();
|
||||
let handle = std::thread::spawn(move || {
|
||||
let _guard = state.lock().unwrap();
|
||||
panic!("intentional panic to poison mutex");
|
||||
});
|
||||
let _ = handle.join(); // join collects the panic
|
||||
assert!(wrapper.state.lock().is_err(), "mutex should be poisoned");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_returns_error_on_poisoned_lock() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let ds = create_test_dataset(uri).await;
|
||||
|
||||
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
|
||||
poison_state(&wrapper);
|
||||
|
||||
// get() should return Err, not panic
|
||||
let result = wrapper.get().await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ensure_mutable_returns_error_on_poisoned_lock() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let ds = create_test_dataset(uri).await;
|
||||
|
||||
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
|
||||
poison_state(&wrapper);
|
||||
|
||||
let result = wrapper.ensure_mutable();
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_recovers_from_poisoned_lock() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let ds = create_test_dataset(uri).await;
|
||||
let ds_v2 = append_to_dataset(uri).await;
|
||||
|
||||
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
|
||||
poison_state(&wrapper);
|
||||
|
||||
// update() returns (), should not panic
|
||||
wrapper.update(ds_v2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_time_travel_version_recovers_from_poisoned_lock() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let ds = create_test_dataset(uri).await;
|
||||
|
||||
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
|
||||
poison_state(&wrapper);
|
||||
|
||||
// Should not panic, returns whatever was in the mutex
|
||||
let _version = wrapper.time_travel_version();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_as_latest_returns_error_on_poisoned_lock() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let ds = create_test_dataset(uri).await;
|
||||
|
||||
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
|
||||
poison_state(&wrapper);
|
||||
|
||||
let result = wrapper.as_latest().await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_as_time_travel_returns_error_on_poisoned_lock() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let ds = create_test_dataset(uri).await;
|
||||
|
||||
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
|
||||
poison_state(&wrapper);
|
||||
|
||||
let result = wrapper.as_time_travel(1u64).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reload_returns_error_on_poisoned_lock() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let ds = create_test_dataset(uri).await;
|
||||
|
||||
let wrapper = DatasetConsistencyWrapper::new_latest(ds, None);
|
||||
poison_state(&wrapper);
|
||||
|
||||
let result = wrapper.reload().await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ use crate::expr::expr_to_sql_string;
|
||||
use crate::query::{
|
||||
DEFAULT_TOP_K, QueryExecutionOptions, QueryFilter, QueryRequest, Select, VectorQueryRequest,
|
||||
};
|
||||
use crate::utils::{TimeoutStream, default_vector_column};
|
||||
use crate::utils::{MaxBatchLengthStream, TimeoutStream, default_vector_column};
|
||||
use arrow::array::{AsArray, FixedSizeListBuilder, Float32Builder};
|
||||
use arrow::datatypes::{Float32Type, UInt8Type};
|
||||
use arrow_array::Array;
|
||||
@@ -66,6 +66,7 @@ async fn execute_generic_query(
|
||||
) -> Result<DatasetRecordBatchStream> {
|
||||
let plan = create_plan(table, query, options.clone()).await?;
|
||||
let inner = execute_plan(plan, Default::default())?;
|
||||
let inner = MaxBatchLengthStream::new_boxed(inner, options.max_batch_length as usize);
|
||||
let inner = if let Some(timeout) = options.timeout {
|
||||
TimeoutStream::new_boxed(inner, timeout)
|
||||
} else {
|
||||
@@ -200,7 +201,9 @@ pub async fn create_plan(
|
||||
scanner.with_row_id();
|
||||
}
|
||||
|
||||
scanner.batch_size(options.max_batch_length as usize);
|
||||
if options.max_batch_length > 0 {
|
||||
scanner.batch_size(options.max_batch_length as usize);
|
||||
}
|
||||
|
||||
if query.base.fast_search {
|
||||
scanner.fast_search();
|
||||
|
||||
@@ -130,8 +130,11 @@ impl WriteProgressTracker {
|
||||
pub fn record_batch(&self, rows: usize, bytes: usize) {
|
||||
// Lock order: callback first, then rows_and_bytes. This is the only
|
||||
// order used anywhere, so deadlocks cannot occur.
|
||||
let mut cb = self.callback.lock().unwrap();
|
||||
let mut guard = self.rows_and_bytes.lock().unwrap();
|
||||
let mut cb = self.callback.lock().unwrap_or_else(|e| e.into_inner());
|
||||
let mut guard = self
|
||||
.rows_and_bytes
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
guard.0 += rows;
|
||||
guard.1 += bytes;
|
||||
let progress = self.snapshot(guard.0, guard.1, false);
|
||||
@@ -151,8 +154,11 @@ impl WriteProgressTracker {
|
||||
/// `total_rows` is always `Some` on the final callback: it uses the known
|
||||
/// total if available, or falls back to the number of rows actually written.
|
||||
pub fn finish(&self) {
|
||||
let mut cb = self.callback.lock().unwrap();
|
||||
let guard = self.rows_and_bytes.lock().unwrap();
|
||||
let mut cb = self.callback.lock().unwrap_or_else(|e| e.into_inner());
|
||||
let guard = self
|
||||
.rows_and_bytes
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
let mut snap = self.snapshot(guard.0, guard.1, true);
|
||||
snap.total_rows = Some(self.total_rows.unwrap_or(guard.0));
|
||||
drop(guard);
|
||||
@@ -376,4 +382,50 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_record_batch_recovers_from_poisoned_callback_lock() {
|
||||
use super::{ProgressCallback, WriteProgressTracker};
|
||||
use std::sync::Mutex;
|
||||
|
||||
let callback: ProgressCallback = Arc::new(Mutex::new(|_: &super::WriteProgress| {}));
|
||||
|
||||
// Poison the callback mutex
|
||||
let cb_clone = callback.clone();
|
||||
let handle = std::thread::spawn(move || {
|
||||
let _guard = cb_clone.lock().unwrap();
|
||||
panic!("intentional panic to poison callback mutex");
|
||||
});
|
||||
let _ = handle.join();
|
||||
assert!(
|
||||
callback.lock().is_err(),
|
||||
"callback mutex should be poisoned"
|
||||
);
|
||||
|
||||
let tracker = WriteProgressTracker::new(callback, Some(100));
|
||||
|
||||
// record_batch should not panic
|
||||
tracker.record_batch(10, 1024);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_finish_recovers_from_poisoned_callback_lock() {
|
||||
use super::{ProgressCallback, WriteProgressTracker};
|
||||
use std::sync::Mutex;
|
||||
|
||||
let callback: ProgressCallback = Arc::new(Mutex::new(|_: &super::WriteProgress| {}));
|
||||
|
||||
// Poison the callback mutex
|
||||
let cb_clone = callback.clone();
|
||||
let handle = std::thread::spawn(move || {
|
||||
let _guard = cb_clone.lock().unwrap();
|
||||
panic!("intentional panic to poison callback mutex");
|
||||
});
|
||||
let _ = handle.join();
|
||||
|
||||
let tracker = WriteProgressTracker::new(callback, Some(100));
|
||||
|
||||
// finish should not panic
|
||||
tracker.finish();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,7 +122,7 @@ where
|
||||
/// This is a cheap synchronous check useful as a fast path before
|
||||
/// constructing a fetch closure for [`get()`](Self::get).
|
||||
pub fn try_get(&self) -> Option<V> {
|
||||
let cache = self.inner.lock().unwrap();
|
||||
let cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
|
||||
cache.state.fresh_value(self.ttl, self.refresh_window)
|
||||
}
|
||||
|
||||
@@ -138,7 +138,7 @@ where
|
||||
{
|
||||
// Fast path: check if cache is fresh
|
||||
{
|
||||
let cache = self.inner.lock().unwrap();
|
||||
let cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
|
||||
if let Some(value) = cache.state.fresh_value(self.ttl, self.refresh_window) {
|
||||
return Ok(value);
|
||||
}
|
||||
@@ -147,7 +147,7 @@ where
|
||||
// Slow path
|
||||
let mut fetch = Some(fetch);
|
||||
let action = {
|
||||
let mut cache = self.inner.lock().unwrap();
|
||||
let mut cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
|
||||
self.determine_action(&mut cache, &mut fetch)
|
||||
};
|
||||
|
||||
@@ -161,7 +161,7 @@ where
|
||||
///
|
||||
/// This avoids a blocking fetch on the first [`get()`](Self::get) call.
|
||||
pub fn seed(&self, value: V) {
|
||||
let mut cache = self.inner.lock().unwrap();
|
||||
let mut cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
|
||||
cache.state = State::Current(value, clock::now());
|
||||
}
|
||||
|
||||
@@ -170,7 +170,7 @@ where
|
||||
/// Any in-flight background fetch from before this call will not update the
|
||||
/// cache (the generation counter prevents stale writes).
|
||||
pub fn invalidate(&self) {
|
||||
let mut cache = self.inner.lock().unwrap();
|
||||
let mut cache = self.inner.lock().unwrap_or_else(|e| e.into_inner());
|
||||
cache.state = State::Empty;
|
||||
cache.generation += 1;
|
||||
}
|
||||
@@ -267,7 +267,7 @@ where
|
||||
let fut_for_spawn = shared.clone();
|
||||
tokio::spawn(async move {
|
||||
let result = fut_for_spawn.await;
|
||||
let mut cache = inner.lock().unwrap();
|
||||
let mut cache = inner.lock().unwrap_or_else(|e| e.into_inner());
|
||||
// Only update if no invalidation has happened since we started
|
||||
if cache.generation != generation {
|
||||
return;
|
||||
@@ -590,4 +590,67 @@ mod tests {
|
||||
let v = cache.get(ok_fetcher(count.clone(), "fresh")).await.unwrap();
|
||||
assert_eq!(v, "fresh");
|
||||
}
|
||||
|
||||
/// Helper: poison the inner mutex of a BackgroundCache.
|
||||
fn poison_cache(cache: &BackgroundCache<String, TestError>) {
|
||||
let inner = cache.inner.clone();
|
||||
let handle = std::thread::spawn(move || {
|
||||
let _guard = inner.lock().unwrap();
|
||||
panic!("intentional panic to poison mutex");
|
||||
});
|
||||
let _ = handle.join();
|
||||
assert!(cache.inner.lock().is_err(), "mutex should be poisoned");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_try_get_recovers_from_poisoned_lock() {
|
||||
let cache = new_cache();
|
||||
let count = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
// Seed a value first
|
||||
cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap();
|
||||
cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap(); // peek
|
||||
|
||||
poison_cache(&cache);
|
||||
|
||||
// try_get() should not panic — it recovers via unwrap_or_else
|
||||
let result = cache.try_get();
|
||||
// The value may or may not be fresh depending on timing, but it must not panic
|
||||
let _ = result;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_recovers_from_poisoned_lock() {
|
||||
let cache = new_cache();
|
||||
let count = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
poison_cache(&cache);
|
||||
|
||||
// get() should not panic — it recovers and can still fetch
|
||||
let result = cache.get(ok_fetcher(count.clone(), "recovered")).await;
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), "recovered");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_seed_recovers_from_poisoned_lock() {
|
||||
let cache = new_cache();
|
||||
poison_cache(&cache);
|
||||
|
||||
// seed() should not panic
|
||||
cache.seed("seeded".to_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invalidate_recovers_from_poisoned_lock() {
|
||||
let cache = new_cache();
|
||||
let count = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
cache.get(ok_fetcher(count.clone(), "hello")).await.unwrap();
|
||||
|
||||
poison_cache(&cache);
|
||||
|
||||
// invalidate() should not panic
|
||||
cache.invalidate();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -335,6 +335,85 @@ impl Stream for TimeoutStream {
|
||||
}
|
||||
}
|
||||
|
||||
/// A `Stream` wrapper that slices oversized batches to enforce a maximum batch length.
|
||||
pub struct MaxBatchLengthStream {
|
||||
inner: SendableRecordBatchStream,
|
||||
max_batch_length: Option<usize>,
|
||||
buffered_batch: Option<RecordBatch>,
|
||||
buffered_offset: usize,
|
||||
}
|
||||
|
||||
impl MaxBatchLengthStream {
|
||||
pub fn new(inner: SendableRecordBatchStream, max_batch_length: usize) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
max_batch_length: (max_batch_length > 0).then_some(max_batch_length),
|
||||
buffered_batch: None,
|
||||
buffered_offset: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_boxed(
|
||||
inner: SendableRecordBatchStream,
|
||||
max_batch_length: usize,
|
||||
) -> SendableRecordBatchStream {
|
||||
if max_batch_length == 0 {
|
||||
inner
|
||||
} else {
|
||||
Box::pin(Self::new(inner, max_batch_length))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RecordBatchStream for MaxBatchLengthStream {
|
||||
fn schema(&self) -> SchemaRef {
|
||||
self.inner.schema()
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for MaxBatchLengthStream {
|
||||
type Item = DataFusionResult<RecordBatch>;
|
||||
|
||||
fn poll_next(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Option<Self::Item>> {
|
||||
loop {
|
||||
let Some(max_batch_length) = self.max_batch_length else {
|
||||
return Pin::new(&mut self.inner).poll_next(cx);
|
||||
};
|
||||
|
||||
if let Some(batch) = self.buffered_batch.clone() {
|
||||
if self.buffered_offset < batch.num_rows() {
|
||||
let remaining = batch.num_rows() - self.buffered_offset;
|
||||
let length = remaining.min(max_batch_length);
|
||||
let sliced = batch.slice(self.buffered_offset, length);
|
||||
self.buffered_offset += length;
|
||||
if self.buffered_offset >= batch.num_rows() {
|
||||
self.buffered_batch = None;
|
||||
self.buffered_offset = 0;
|
||||
}
|
||||
return std::task::Poll::Ready(Some(Ok(sliced)));
|
||||
}
|
||||
|
||||
self.buffered_batch = None;
|
||||
self.buffered_offset = 0;
|
||||
}
|
||||
|
||||
match Pin::new(&mut self.inner).poll_next(cx) {
|
||||
std::task::Poll::Ready(Some(Ok(batch))) => {
|
||||
if batch.num_rows() <= max_batch_length {
|
||||
return std::task::Poll::Ready(Some(Ok(batch)));
|
||||
}
|
||||
self.buffered_batch = Some(batch);
|
||||
self.buffered_offset = 0;
|
||||
}
|
||||
other => return other,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use arrow_array::Int32Array;
|
||||
@@ -470,7 +549,7 @@ mod tests {
|
||||
assert_eq!(string_to_datatype(string), Some(expected));
|
||||
}
|
||||
|
||||
fn sample_batch() -> RecordBatch {
|
||||
fn sample_batch(num_rows: i32) -> RecordBatch {
|
||||
let schema = Arc::new(Schema::new(vec![Field::new(
|
||||
"col1",
|
||||
DataType::Int32,
|
||||
@@ -478,14 +557,14 @@ mod tests {
|
||||
)]));
|
||||
RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
vec![Arc::new(Int32Array::from_iter_values(0..num_rows))],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout_stream() {
|
||||
let batch = sample_batch();
|
||||
let batch = sample_batch(3);
|
||||
let schema = batch.schema();
|
||||
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
|
||||
|
||||
@@ -515,7 +594,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout_stream_zero_duration() {
|
||||
let batch = sample_batch();
|
||||
let batch = sample_batch(3);
|
||||
let schema = batch.schema();
|
||||
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
|
||||
|
||||
@@ -534,7 +613,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout_stream_completes_normally() {
|
||||
let batch = sample_batch();
|
||||
let batch = sample_batch(3);
|
||||
let schema = batch.schema();
|
||||
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
|
||||
|
||||
@@ -552,4 +631,35 @@ mod tests {
|
||||
// Stream should be empty now
|
||||
assert!(timeout_stream.next().await.is_none());
|
||||
}
|
||||
|
||||
async fn collect_batch_sizes(
|
||||
stream: SendableRecordBatchStream,
|
||||
max_batch_length: usize,
|
||||
) -> Vec<usize> {
|
||||
let mut sliced_stream = MaxBatchLengthStream::new(stream, max_batch_length);
|
||||
sliced_stream
|
||||
.by_ref()
|
||||
.map(|batch| batch.unwrap().num_rows())
|
||||
.collect::<Vec<_>>()
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_max_batch_length_stream_behaviors() {
|
||||
let schema = sample_batch(7).schema();
|
||||
let mock_stream = stream::iter(vec![Ok(sample_batch(2)), Ok(sample_batch(7))]);
|
||||
|
||||
let sendable_stream: SendableRecordBatchStream =
|
||||
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream));
|
||||
assert_eq!(
|
||||
collect_batch_sizes(sendable_stream, 3).await,
|
||||
vec![2, 3, 3, 1]
|
||||
);
|
||||
|
||||
let sendable_stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new(
|
||||
schema,
|
||||
stream::iter(vec![Ok(sample_batch(2)), Ok(sample_batch(7))]),
|
||||
));
|
||||
assert_eq!(collect_batch_sizes(sendable_stream, 0).await, vec![2, 7]);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user