mirror of
https://github.com/lancedb/lancedb.git
synced 2026-04-01 05:20:40 +00:00
Compare commits
9 Commits
codex/upda
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3ba46135a5 | ||
|
|
f903d07887 | ||
|
|
5d550124bd | ||
|
|
c57cb310a2 | ||
|
|
97754f5123 | ||
|
|
7b1c063848 | ||
|
|
c7f189f27b | ||
|
|
a0a2942ad5 | ||
|
|
e3d53dd185 |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.27.2-beta.1"
|
||||
current_version = "0.27.2"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
123
Cargo.lock
generated
123
Cargo.lock
generated
@@ -108,7 +108,7 @@ version = "1.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc"
|
||||
dependencies = [
|
||||
"windows-sys 0.60.2",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -119,7 +119,7 @@ checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d"
|
||||
dependencies = [
|
||||
"anstyle",
|
||||
"once_cell_polyfill",
|
||||
"windows-sys 0.60.2",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2682,7 +2682,7 @@ dependencies = [
|
||||
"libc",
|
||||
"option-ext",
|
||||
"redox_users",
|
||||
"windows-sys 0.60.2",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2876,7 +2876,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3072,8 +3072,9 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
|
||||
|
||||
[[package]]
|
||||
name = "fsst"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2195cc7f87e84bd695586137de99605e7e9579b26ec5e01b82960ddb4d0922f2"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"rand 0.9.2",
|
||||
@@ -3736,7 +3737,7 @@ dependencies = [
|
||||
"libc",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"socket2 0.5.10",
|
||||
"socket2 0.6.3",
|
||||
"system-configuration",
|
||||
"tokio",
|
||||
"tower-service",
|
||||
@@ -4037,7 +4038,7 @@ dependencies = [
|
||||
"portable-atomic",
|
||||
"portable-atomic-util",
|
||||
"serde_core",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4123,8 +4124,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "efe6c3ddd79cdfd2b7e1c23cafae52806906bc40fbd97de9e8cf2f8c7a75fc04"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -4190,8 +4192,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-arrow"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d9f5d95bdda2a2b790f1fb8028b5b6dcf661abeb3133a8bca0f3d24b054af87"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4211,8 +4214,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-bitpacking"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f827d6ab9f8f337a9509d5ad66a12f3314db8713868260521c344ef6135eb4e4"
|
||||
dependencies = [
|
||||
"arrayref",
|
||||
"paste",
|
||||
@@ -4221,8 +4225,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-core"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0f1e25df6a79bf72ee6bcde0851f19b1cd36c5848c1b7db83340882d3c9fdecb"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4259,8 +4264,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-datafusion"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "93146de8ae720cb90edef81c2f2d0a1b065fc2f23ecff2419546f389b0fa70a4"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4290,8 +4296,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-datagen"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ccec8ce4d8e0a87a99c431dab2364398029f2ffb649c1a693c60c79e05ed30dd"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4309,8 +4316,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-encoding"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c1aec0bbbac6bce829bc10f1ba066258126100596c375fb71908ecf11c2c2a5"
|
||||
dependencies = [
|
||||
"arrow-arith",
|
||||
"arrow-array",
|
||||
@@ -4347,8 +4355,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-file"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "14a8c548804f5b17486dc2d3282356ed1957095a852780283bc401fdd69e9075"
|
||||
dependencies = [
|
||||
"arrow-arith",
|
||||
"arrow-array",
|
||||
@@ -4380,8 +4389,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-index"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2da212f0090ea59f79ac3686660f596520c167fe1cb5f408900cf71d215f0e03"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -4445,8 +4455,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-io"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "41d958eb4b56f03bbe0f5f85eb2b4e9657882812297b6f711f201ffc995f259f"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -4487,8 +4498,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-linalg"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0285b70da35def7ed95e150fae1d5308089554e1290470403ed3c50cb235bc5e"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4504,8 +4516,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-namespace"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5f78e2a828b654e062a495462c6e3eb4fcf0e7e907d761b8f217fc09ccd3ceac"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"async-trait",
|
||||
@@ -4518,8 +4531,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-namespace-impls"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2392314f3da38f00d166295e44244208a65ccfc256e274fa8631849fc3f4d94"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-ipc",
|
||||
@@ -4563,8 +4577,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-table"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3df9c4adca3eb2074b3850432a9fb34248a3d90c3d6427d158b13ff9355664ee"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4603,8 +4618,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-testing"
|
||||
version = "4.0.0-rc.3"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v4.0.0-rc.3#b27462427380a2e942019fb28776695d9c8a67be"
|
||||
version = "4.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7ed7119bdd6983718387b4ac44af873a165262ca94f181b104cd6f97912eb3bf"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-schema",
|
||||
@@ -4615,7 +4631,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb"
|
||||
version = "0.27.2-beta.1"
|
||||
version = "0.27.2"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"anyhow",
|
||||
@@ -4697,7 +4713,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-nodejs"
|
||||
version = "0.27.2-beta.1"
|
||||
version = "0.27.2"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4719,7 +4735,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-python"
|
||||
version = "0.30.2-beta.1"
|
||||
version = "0.30.2"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"async-trait",
|
||||
@@ -5307,7 +5323,7 @@ version = "0.50.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
|
||||
dependencies = [
|
||||
"windows-sys 0.60.2",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6286,7 +6302,7 @@ version = "0.14.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7"
|
||||
dependencies = [
|
||||
"heck 0.4.1",
|
||||
"heck 0.5.0",
|
||||
"itertools 0.14.0",
|
||||
"log",
|
||||
"multimap",
|
||||
@@ -6473,7 +6489,7 @@ dependencies = [
|
||||
"quinn-udp",
|
||||
"rustc-hash",
|
||||
"rustls 0.23.37",
|
||||
"socket2 0.5.10",
|
||||
"socket2 0.6.3",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tracing",
|
||||
@@ -6510,9 +6526,9 @@ dependencies = [
|
||||
"cfg_aliases",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"socket2 0.5.10",
|
||||
"socket2 0.6.3",
|
||||
"tracing",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.60.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7041,7 +7057,7 @@ dependencies = [
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys 0.4.15",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7054,7 +7070,7 @@ dependencies = [
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys 0.12.1",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7574,7 +7590,7 @@ version = "0.8.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c1c97747dbf44bb1ca44a561ece23508e99cb592e862f22222dcf42f51d1e451"
|
||||
dependencies = [
|
||||
"heck 0.4.1",
|
||||
"heck 0.5.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
@@ -7586,7 +7602,7 @@ version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "54254b8531cafa275c5e096f62d48c81435d1015405a91198ddb11e967301d40"
|
||||
dependencies = [
|
||||
"heck 0.4.1",
|
||||
"heck 0.5.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
@@ -7609,7 +7625,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.60.2",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7713,7 +7729,6 @@ dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"psm",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
@@ -8074,7 +8089,7 @@ dependencies = [
|
||||
"getrandom 0.4.2",
|
||||
"once_cell",
|
||||
"rustix 1.1.4",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -8879,7 +8894,7 @@ version = "0.1.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
|
||||
dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
28
Cargo.toml
28
Cargo.toml
@@ -15,20 +15,20 @@ categories = ["database-implementations"]
|
||||
rust-version = "1.91.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=4.0.0-rc.3", default-features = false, "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-core = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datagen = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-file = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-io = { "version" = "=4.0.0-rc.3", default-features = false, "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-index = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-linalg = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace-impls = { "version" = "=4.0.0-rc.3", default-features = false, "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-table = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-testing = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datafusion = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-encoding = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-arrow = { "version" = "=4.0.0-rc.3", "tag" = "v4.0.0-rc.3", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance = { version = "=4.0.0", default-features = false }
|
||||
lance-core = { version = "=4.0.0" }
|
||||
lance-datagen = { version = "=4.0.0" }
|
||||
lance-file = { version = "=4.0.0" }
|
||||
lance-io = { version = "=4.0.0", default-features = false }
|
||||
lance-index = { version = "=4.0.0" }
|
||||
lance-linalg = { version = "=4.0.0" }
|
||||
lance-namespace = { version = "=4.0.0" }
|
||||
lance-namespace-impls = { version = "=4.0.0", default-features = false }
|
||||
lance-table = { version = "=4.0.0" }
|
||||
lance-testing = { version = "=4.0.0" }
|
||||
lance-datafusion = { version = "=4.0.0" }
|
||||
lance-encoding = { version = "=4.0.0" }
|
||||
lance-arrow = { version = "=4.0.0" }
|
||||
ahash = "0.8"
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "57.2", optional = false }
|
||||
|
||||
@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
|
||||
<dependency>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-core</artifactId>
|
||||
<version>0.27.2-beta.1</version>
|
||||
<version>0.27.2</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
|
||||
@@ -36,6 +36,20 @@ is also an [asynchronous API client](#connections-asynchronous).
|
||||
|
||||
::: lancedb.table.Tags
|
||||
|
||||
## Expressions
|
||||
|
||||
Type-safe expression builder for filters and projections. Use these instead
|
||||
of raw SQL strings with [where][lancedb.query.LanceQueryBuilder.where] and
|
||||
[select][lancedb.query.LanceQueryBuilder.select].
|
||||
|
||||
::: lancedb.expr.Expr
|
||||
|
||||
::: lancedb.expr.col
|
||||
|
||||
::: lancedb.expr.lit
|
||||
|
||||
::: lancedb.expr.func
|
||||
|
||||
## Querying (Synchronous)
|
||||
|
||||
::: lancedb.query.Query
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.27.2-beta.1</version>
|
||||
<version>0.27.2-final.0</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.27.2-beta.1</version>
|
||||
<version>0.27.2-final.0</version>
|
||||
<packaging>pom</packaging>
|
||||
<name>${project.artifactId}</name>
|
||||
<description>LanceDB Java SDK Parent POM</description>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.27.2-beta.1"
|
||||
version = "0.27.2"
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.27.2-beta.1",
|
||||
"version": "0.27.2",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.27.2-beta.1",
|
||||
"version": "0.27.2",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||
"version": "0.27.2-beta.1",
|
||||
"version": "0.27.2",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.27.2-beta.1",
|
||||
"version": "0.27.2",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||
"version": "0.27.2-beta.1",
|
||||
"version": "0.27.2",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||
"version": "0.27.2-beta.1",
|
||||
"version": "0.27.2",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.27.2-beta.1",
|
||||
"version": "0.27.2",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.27.2-beta.1",
|
||||
"version": "0.27.2",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.27.2-beta.1",
|
||||
"version": "0.27.2",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
"ann"
|
||||
],
|
||||
"private": false,
|
||||
"version": "0.27.2-beta.1",
|
||||
"version": "0.27.2",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.30.2-beta.1"
|
||||
current_version = "0.30.2"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
2
python/.gitignore
vendored
2
python/.gitignore
vendored
@@ -1,3 +1,5 @@
|
||||
# Test data created by some example tests
|
||||
data/
|
||||
_lancedb.pyd
|
||||
# macOS debug symbols bundle generated during build
|
||||
*.dSYM/
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.30.2-beta.1"
|
||||
version = "0.30.2"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
|
||||
@@ -18,6 +18,7 @@ from .db import AsyncConnection, DBConnection, LanceDBConnection
|
||||
from .io import StorageOptionsProvider
|
||||
from .remote import ClientConfig
|
||||
from .remote.db import RemoteDBConnection
|
||||
from .expr import Expr, col, lit, func
|
||||
from .schema import vector
|
||||
from .table import AsyncTable, Table
|
||||
from ._lancedb import Session
|
||||
@@ -271,6 +272,10 @@ __all__ = [
|
||||
"AsyncConnection",
|
||||
"AsyncLanceNamespaceDBConnection",
|
||||
"AsyncTable",
|
||||
"col",
|
||||
"Expr",
|
||||
"func",
|
||||
"lit",
|
||||
"URI",
|
||||
"sanitize_uri",
|
||||
"vector",
|
||||
|
||||
@@ -27,6 +27,32 @@ from .remote import ClientConfig
|
||||
IvfHnswPq: type[HnswPq] = HnswPq
|
||||
IvfHnswSq: type[HnswSq] = HnswSq
|
||||
|
||||
class PyExpr:
|
||||
"""A type-safe DataFusion expression node (Rust-side handle)."""
|
||||
|
||||
def eq(self, other: "PyExpr") -> "PyExpr": ...
|
||||
def ne(self, other: "PyExpr") -> "PyExpr": ...
|
||||
def lt(self, other: "PyExpr") -> "PyExpr": ...
|
||||
def lte(self, other: "PyExpr") -> "PyExpr": ...
|
||||
def gt(self, other: "PyExpr") -> "PyExpr": ...
|
||||
def gte(self, other: "PyExpr") -> "PyExpr": ...
|
||||
def and_(self, other: "PyExpr") -> "PyExpr": ...
|
||||
def or_(self, other: "PyExpr") -> "PyExpr": ...
|
||||
def not_(self) -> "PyExpr": ...
|
||||
def add(self, other: "PyExpr") -> "PyExpr": ...
|
||||
def sub(self, other: "PyExpr") -> "PyExpr": ...
|
||||
def mul(self, other: "PyExpr") -> "PyExpr": ...
|
||||
def div(self, other: "PyExpr") -> "PyExpr": ...
|
||||
def lower(self) -> "PyExpr": ...
|
||||
def upper(self) -> "PyExpr": ...
|
||||
def contains(self, substr: "PyExpr") -> "PyExpr": ...
|
||||
def cast(self, data_type: pa.DataType) -> "PyExpr": ...
|
||||
def to_sql(self) -> str: ...
|
||||
|
||||
def expr_col(name: str) -> PyExpr: ...
|
||||
def expr_lit(value: Union[bool, int, float, str]) -> PyExpr: ...
|
||||
def expr_func(name: str, args: List[PyExpr]) -> PyExpr: ...
|
||||
|
||||
class Session:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -225,7 +251,9 @@ class RecordBatchStream:
|
||||
|
||||
class Query:
|
||||
def where(self, filter: str): ...
|
||||
def select(self, columns: Tuple[str, str]): ...
|
||||
def where_expr(self, expr: PyExpr): ...
|
||||
def select(self, columns: List[Tuple[str, str]]): ...
|
||||
def select_expr(self, columns: List[Tuple[str, PyExpr]]): ...
|
||||
def select_columns(self, columns: List[str]): ...
|
||||
def limit(self, limit: int): ...
|
||||
def offset(self, offset: int): ...
|
||||
@@ -251,7 +279,9 @@ class TakeQuery:
|
||||
|
||||
class FTSQuery:
|
||||
def where(self, filter: str): ...
|
||||
def select(self, columns: List[str]): ...
|
||||
def where_expr(self, expr: PyExpr): ...
|
||||
def select(self, columns: List[Tuple[str, str]]): ...
|
||||
def select_expr(self, columns: List[Tuple[str, PyExpr]]): ...
|
||||
def limit(self, limit: int): ...
|
||||
def offset(self, offset: int): ...
|
||||
def fast_search(self): ...
|
||||
@@ -270,7 +300,9 @@ class VectorQuery:
|
||||
async def output_schema(self) -> pa.Schema: ...
|
||||
async def execute(self) -> RecordBatchStream: ...
|
||||
def where(self, filter: str): ...
|
||||
def select(self, columns: List[str]): ...
|
||||
def where_expr(self, expr: PyExpr): ...
|
||||
def select(self, columns: List[Tuple[str, str]]): ...
|
||||
def select_expr(self, columns: List[Tuple[str, PyExpr]]): ...
|
||||
def select_with_projection(self, columns: Tuple[str, str]): ...
|
||||
def limit(self, limit: int): ...
|
||||
def offset(self, offset: int): ...
|
||||
@@ -287,7 +319,9 @@ class VectorQuery:
|
||||
|
||||
class HybridQuery:
|
||||
def where(self, filter: str): ...
|
||||
def select(self, columns: List[str]): ...
|
||||
def where_expr(self, expr: PyExpr): ...
|
||||
def select(self, columns: List[Tuple[str, str]]): ...
|
||||
def select_expr(self, columns: List[Tuple[str, PyExpr]]): ...
|
||||
def limit(self, limit: int): ...
|
||||
def offset(self, offset: int): ...
|
||||
def fast_search(self): ...
|
||||
|
||||
298
python/python/lancedb/expr.py
Normal file
298
python/python/lancedb/expr.py
Normal file
@@ -0,0 +1,298 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
"""Type-safe expression builder for filters and projections.
|
||||
|
||||
Instead of writing raw SQL strings you can build expressions with Python
|
||||
operators::
|
||||
|
||||
from lancedb.expr import col, lit
|
||||
|
||||
# filter: age > 18 AND status = 'active'
|
||||
filt = (col("age") > lit(18)) & (col("status") == lit("active"))
|
||||
|
||||
# projection: compute a derived column
|
||||
proj = {"score": col("raw_score") * lit(1.5)}
|
||||
|
||||
table.search().where(filt).select(proj).to_list()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from lancedb._lancedb import PyExpr, expr_col, expr_lit, expr_func
|
||||
|
||||
__all__ = ["Expr", "col", "lit", "func"]
|
||||
|
||||
_STR_TO_PA_TYPE: dict = {
|
||||
"bool": pa.bool_(),
|
||||
"boolean": pa.bool_(),
|
||||
"int8": pa.int8(),
|
||||
"int16": pa.int16(),
|
||||
"int32": pa.int32(),
|
||||
"int64": pa.int64(),
|
||||
"uint8": pa.uint8(),
|
||||
"uint16": pa.uint16(),
|
||||
"uint32": pa.uint32(),
|
||||
"uint64": pa.uint64(),
|
||||
"float16": pa.float16(),
|
||||
"float32": pa.float32(),
|
||||
"float": pa.float32(),
|
||||
"float64": pa.float64(),
|
||||
"double": pa.float64(),
|
||||
"string": pa.string(),
|
||||
"utf8": pa.string(),
|
||||
"str": pa.string(),
|
||||
"large_string": pa.large_utf8(),
|
||||
"large_utf8": pa.large_utf8(),
|
||||
"date32": pa.date32(),
|
||||
"date": pa.date32(),
|
||||
"date64": pa.date64(),
|
||||
}
|
||||
|
||||
|
||||
def _coerce(value: "ExprLike") -> "Expr":
|
||||
"""Return *value* as an :class:`Expr`, wrapping plain Python values via
|
||||
:func:`lit` if needed."""
|
||||
if isinstance(value, Expr):
|
||||
return value
|
||||
return lit(value)
|
||||
|
||||
|
||||
# Type alias used in annotations.
|
||||
ExprLike = Union["Expr", bool, int, float, str]
|
||||
|
||||
|
||||
class Expr:
|
||||
"""A type-safe expression node.
|
||||
|
||||
Construct instances with :func:`col` and :func:`lit`, then combine them
|
||||
using Python operators or the named methods below.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from lancedb.expr import col, lit
|
||||
>>> filt = (col("age") > lit(18)) & (col("name").lower() == lit("alice"))
|
||||
>>> proj = {"double": col("x") * lit(2)}
|
||||
"""
|
||||
|
||||
# Make Expr unhashable so that == returns an Expr rather than being used
|
||||
# for dict keys / set membership.
|
||||
__hash__ = None # type: ignore[assignment]
|
||||
|
||||
def __init__(self, inner: PyExpr) -> None:
|
||||
self._inner = inner
|
||||
|
||||
# ── comparisons ──────────────────────────────────────────────────────────
|
||||
|
||||
def __eq__(self, other: ExprLike) -> "Expr": # type: ignore[override]
|
||||
"""Equal to (``col("x") == 1``)."""
|
||||
return Expr(self._inner.eq(_coerce(other)._inner))
|
||||
|
||||
def __ne__(self, other: ExprLike) -> "Expr": # type: ignore[override]
|
||||
"""Not equal to (``col("x") != 1``)."""
|
||||
return Expr(self._inner.ne(_coerce(other)._inner))
|
||||
|
||||
def __lt__(self, other: ExprLike) -> "Expr":
|
||||
"""Less than (``col("x") < 1``)."""
|
||||
return Expr(self._inner.lt(_coerce(other)._inner))
|
||||
|
||||
def __le__(self, other: ExprLike) -> "Expr":
|
||||
"""Less than or equal to (``col("x") <= 1``)."""
|
||||
return Expr(self._inner.lte(_coerce(other)._inner))
|
||||
|
||||
def __gt__(self, other: ExprLike) -> "Expr":
|
||||
"""Greater than (``col("x") > 1``)."""
|
||||
return Expr(self._inner.gt(_coerce(other)._inner))
|
||||
|
||||
def __ge__(self, other: ExprLike) -> "Expr":
|
||||
"""Greater than or equal to (``col("x") >= 1``)."""
|
||||
return Expr(self._inner.gte(_coerce(other)._inner))
|
||||
|
||||
# ── logical ──────────────────────────────────────────────────────────────
|
||||
|
||||
def __and__(self, other: "Expr") -> "Expr":
|
||||
"""Logical AND (``expr_a & expr_b``)."""
|
||||
return Expr(self._inner.and_(_coerce(other)._inner))
|
||||
|
||||
def __or__(self, other: "Expr") -> "Expr":
|
||||
"""Logical OR (``expr_a | expr_b``)."""
|
||||
return Expr(self._inner.or_(_coerce(other)._inner))
|
||||
|
||||
def __invert__(self) -> "Expr":
|
||||
"""Logical NOT (``~expr``)."""
|
||||
return Expr(self._inner.not_())
|
||||
|
||||
# ── arithmetic ───────────────────────────────────────────────────────────
|
||||
|
||||
def __add__(self, other: ExprLike) -> "Expr":
|
||||
"""Add (``col("x") + 1``)."""
|
||||
return Expr(self._inner.add(_coerce(other)._inner))
|
||||
|
||||
def __radd__(self, other: ExprLike) -> "Expr":
|
||||
"""Right-hand add (``1 + col("x")``)."""
|
||||
return Expr(_coerce(other)._inner.add(self._inner))
|
||||
|
||||
def __sub__(self, other: ExprLike) -> "Expr":
|
||||
"""Subtract (``col("x") - 1``)."""
|
||||
return Expr(self._inner.sub(_coerce(other)._inner))
|
||||
|
||||
def __rsub__(self, other: ExprLike) -> "Expr":
|
||||
"""Right-hand subtract (``1 - col("x")``)."""
|
||||
return Expr(_coerce(other)._inner.sub(self._inner))
|
||||
|
||||
def __mul__(self, other: ExprLike) -> "Expr":
|
||||
"""Multiply (``col("x") * 2``)."""
|
||||
return Expr(self._inner.mul(_coerce(other)._inner))
|
||||
|
||||
def __rmul__(self, other: ExprLike) -> "Expr":
|
||||
"""Right-hand multiply (``2 * col("x")``)."""
|
||||
return Expr(_coerce(other)._inner.mul(self._inner))
|
||||
|
||||
def __truediv__(self, other: ExprLike) -> "Expr":
|
||||
"""Divide (``col("x") / 2``)."""
|
||||
return Expr(self._inner.div(_coerce(other)._inner))
|
||||
|
||||
def __rtruediv__(self, other: ExprLike) -> "Expr":
|
||||
"""Right-hand divide (``1 / col("x")``)."""
|
||||
return Expr(_coerce(other)._inner.div(self._inner))
|
||||
|
||||
# ── string methods ───────────────────────────────────────────────────────
|
||||
|
||||
def lower(self) -> "Expr":
|
||||
"""Convert string column values to lowercase."""
|
||||
return Expr(self._inner.lower())
|
||||
|
||||
def upper(self) -> "Expr":
|
||||
"""Convert string column values to uppercase."""
|
||||
return Expr(self._inner.upper())
|
||||
|
||||
def contains(self, substr: "ExprLike") -> "Expr":
|
||||
"""Return True where the string contains *substr*."""
|
||||
return Expr(self._inner.contains(_coerce(substr)._inner))
|
||||
|
||||
# ── type cast ────────────────────────────────────────────────────────────
|
||||
|
||||
def cast(self, data_type: Union[str, "pa.DataType"]) -> "Expr":
|
||||
"""Cast values to *data_type*.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data_type:
|
||||
A PyArrow ``DataType`` (e.g. ``pa.int32()``) or one of the type
|
||||
name strings: ``"bool"``, ``"int8"``, ``"int16"``, ``"int32"``,
|
||||
``"int64"``, ``"uint8"``–``"uint64"``, ``"float32"``,
|
||||
``"float64"``, ``"string"``, ``"date32"``, ``"date64"``.
|
||||
"""
|
||||
if isinstance(data_type, str):
|
||||
try:
|
||||
data_type = _STR_TO_PA_TYPE[data_type]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"unsupported data type: '{data_type}'. Supported: "
|
||||
f"{', '.join(_STR_TO_PA_TYPE)}"
|
||||
)
|
||||
return Expr(self._inner.cast(data_type))
|
||||
|
||||
# ── named comparison helpers (alternative to operators) ──────────────────
|
||||
|
||||
def eq(self, other: ExprLike) -> "Expr":
|
||||
"""Equal to."""
|
||||
return self.__eq__(other)
|
||||
|
||||
def ne(self, other: ExprLike) -> "Expr":
|
||||
"""Not equal to."""
|
||||
return self.__ne__(other)
|
||||
|
||||
def lt(self, other: ExprLike) -> "Expr":
|
||||
"""Less than."""
|
||||
return self.__lt__(other)
|
||||
|
||||
def lte(self, other: ExprLike) -> "Expr":
|
||||
"""Less than or equal to."""
|
||||
return self.__le__(other)
|
||||
|
||||
def gt(self, other: ExprLike) -> "Expr":
|
||||
"""Greater than."""
|
||||
return self.__gt__(other)
|
||||
|
||||
def gte(self, other: ExprLike) -> "Expr":
|
||||
"""Greater than or equal to."""
|
||||
return self.__ge__(other)
|
||||
|
||||
def and_(self, other: "Expr") -> "Expr":
|
||||
"""Logical AND."""
|
||||
return self.__and__(other)
|
||||
|
||||
def or_(self, other: "Expr") -> "Expr":
|
||||
"""Logical OR."""
|
||||
return self.__or__(other)
|
||||
|
||||
# ── utilities ────────────────────────────────────────────────────────────
|
||||
|
||||
def to_sql(self) -> str:
|
||||
"""Render the expression as a SQL string (useful for debugging)."""
|
||||
return self._inner.to_sql()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Expr({self._inner.to_sql()})"
|
||||
|
||||
|
||||
# ── free functions ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def col(name: str) -> Expr:
|
||||
"""Reference a table column by name.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name:
|
||||
The column name.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from lancedb.expr import col, lit
|
||||
>>> col("age") > lit(18)
|
||||
Expr((age > 18))
|
||||
"""
|
||||
return Expr(expr_col(name))
|
||||
|
||||
|
||||
def lit(value: Union[bool, int, float, str]) -> Expr:
|
||||
"""Create a literal (constant) value expression.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
value:
|
||||
A Python ``bool``, ``int``, ``float``, or ``str``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from lancedb.expr import col, lit
|
||||
>>> col("price") * lit(1.1)
|
||||
Expr((price * 1.1))
|
||||
"""
|
||||
return Expr(expr_lit(value))
|
||||
|
||||
|
||||
def func(name: str, *args: ExprLike) -> Expr:
|
||||
"""Call an arbitrary SQL function by name.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name:
|
||||
The SQL function name (e.g. ``"lower"``, ``"upper"``).
|
||||
*args:
|
||||
The function arguments as :class:`Expr` or plain Python literals.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from lancedb.expr import col, func
|
||||
>>> func("lower", col("name"))
|
||||
Expr(lower(name))
|
||||
"""
|
||||
inner_args = [_coerce(a)._inner for a in args]
|
||||
return Expr(expr_func(name, inner_args))
|
||||
@@ -38,6 +38,7 @@ from .rerankers.base import Reranker
|
||||
from .rerankers.rrf import RRFReranker
|
||||
from .rerankers.util import check_reranker_result
|
||||
from .util import flatten_columns
|
||||
from .expr import Expr
|
||||
from lancedb._lancedb import fts_query_to_json
|
||||
from typing_extensions import Annotated
|
||||
|
||||
@@ -449,8 +450,8 @@ class Query(pydantic.BaseModel):
|
||||
ensure_vector_query,
|
||||
] = None
|
||||
|
||||
# sql filter to refine the query with
|
||||
filter: Optional[str] = None
|
||||
# sql filter or type-safe Expr to refine the query with
|
||||
filter: Optional[Union[str, Expr]] = None
|
||||
|
||||
# if True then apply the filter after vector search
|
||||
postfilter: Optional[bool] = None
|
||||
@@ -464,8 +465,8 @@ class Query(pydantic.BaseModel):
|
||||
# distance type to use for vector search
|
||||
distance_type: Optional[str] = None
|
||||
|
||||
# which columns to return in the results
|
||||
columns: Optional[Union[List[str], Dict[str, str]]] = None
|
||||
# which columns to return in the results (dict values may be str or Expr)
|
||||
columns: Optional[Union[List[str], Dict[str, Union[str, Expr]]]] = None
|
||||
|
||||
# minimum number of IVF partitions to search
|
||||
#
|
||||
@@ -856,14 +857,15 @@ class LanceQueryBuilder(ABC):
|
||||
self._offset = offset
|
||||
return self
|
||||
|
||||
def select(self, columns: Union[list[str], dict[str, str]]) -> Self:
|
||||
def select(self, columns: Union[list[str], dict[str, Union[str, Expr]]]) -> Self:
|
||||
"""Set the columns to return.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
columns: list of str, or dict of str to str default None
|
||||
columns: list of str, or dict of str to str or Expr
|
||||
List of column names to be fetched.
|
||||
Or a dictionary of column names to SQL expressions.
|
||||
Or a dictionary of column names to SQL expressions or
|
||||
:class:`~lancedb.expr.Expr` objects.
|
||||
All columns are fetched if None or unspecified.
|
||||
|
||||
Returns
|
||||
@@ -877,15 +879,15 @@ class LanceQueryBuilder(ABC):
|
||||
raise ValueError("columns must be a list or a dictionary")
|
||||
return self
|
||||
|
||||
def where(self, where: str, prefilter: bool = True) -> Self:
|
||||
def where(self, where: Union[str, Expr], prefilter: bool = True) -> Self:
|
||||
"""Set the where clause.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
where: str
|
||||
The where clause which is a valid SQL where clause. See
|
||||
`Lance filter pushdown <https://lance.org/guide/read_and_write#filter-push-down>`_
|
||||
for valid SQL expressions.
|
||||
where: str or :class:`~lancedb.expr.Expr`
|
||||
The filter condition. Can be a SQL string or a type-safe
|
||||
:class:`~lancedb.expr.Expr` built with :func:`~lancedb.expr.col`
|
||||
and :func:`~lancedb.expr.lit`.
|
||||
prefilter: bool, default True
|
||||
If True, apply the filter before vector search, otherwise the
|
||||
filter is applied on the result of vector search.
|
||||
@@ -1355,15 +1357,17 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
return result_set
|
||||
|
||||
def where(self, where: str, prefilter: bool = None) -> LanceVectorQueryBuilder:
|
||||
def where(
|
||||
self, where: Union[str, Expr], prefilter: bool = None
|
||||
) -> LanceVectorQueryBuilder:
|
||||
"""Set the where clause.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
where: str
|
||||
The where clause which is a valid SQL where clause. See
|
||||
`Lance filter pushdown <https://lance.org/guide/read_and_write#filter-push-down>`_
|
||||
for valid SQL expressions.
|
||||
where: str or :class:`~lancedb.expr.Expr`
|
||||
The filter condition. Can be a SQL string or a type-safe
|
||||
:class:`~lancedb.expr.Expr` built with :func:`~lancedb.expr.col`
|
||||
and :func:`~lancedb.expr.lit`.
|
||||
prefilter: bool, default True
|
||||
If True, apply the filter before vector search, otherwise the
|
||||
filter is applied on the result of vector search.
|
||||
@@ -2286,10 +2290,20 @@ class AsyncQueryBase(object):
|
||||
"""
|
||||
if isinstance(columns, list) and all(isinstance(c, str) for c in columns):
|
||||
self._inner.select_columns(columns)
|
||||
elif isinstance(columns, dict) and all(
|
||||
isinstance(k, str) and isinstance(v, str) for k, v in columns.items()
|
||||
):
|
||||
self._inner.select(list(columns.items()))
|
||||
elif isinstance(columns, dict) and all(isinstance(k, str) for k in columns):
|
||||
if any(isinstance(v, Expr) for v in columns.values()):
|
||||
# At least one value is an Expr — use the type-safe path.
|
||||
from .expr import _coerce
|
||||
|
||||
pairs = [(k, _coerce(v)._inner) for k, v in columns.items()]
|
||||
self._inner.select_expr(pairs)
|
||||
elif all(isinstance(v, str) for v in columns.values()):
|
||||
self._inner.select(list(columns.items()))
|
||||
else:
|
||||
raise TypeError(
|
||||
"dict values must be str or Expr, got "
|
||||
+ str({k: type(v) for k, v in columns.items()})
|
||||
)
|
||||
else:
|
||||
raise TypeError("columns must be a list of column names or a dict")
|
||||
return self
|
||||
@@ -2529,11 +2543,13 @@ class AsyncStandardQuery(AsyncQueryBase):
|
||||
"""
|
||||
super().__init__(inner)
|
||||
|
||||
def where(self, predicate: str) -> Self:
|
||||
def where(self, predicate: Union[str, Expr]) -> Self:
|
||||
"""
|
||||
Only return rows matching the given predicate
|
||||
|
||||
The predicate should be supplied as an SQL query string.
|
||||
The predicate can be a SQL string or a type-safe
|
||||
:class:`~lancedb.expr.Expr` built with :func:`~lancedb.expr.col`
|
||||
and :func:`~lancedb.expr.lit`.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -2545,7 +2561,10 @@ class AsyncStandardQuery(AsyncQueryBase):
|
||||
Filtering performance can often be improved by creating a scalar index
|
||||
on the filter column(s).
|
||||
"""
|
||||
self._inner.where(predicate)
|
||||
if isinstance(predicate, Expr):
|
||||
self._inner.where_expr(predicate._inner)
|
||||
else:
|
||||
self._inner.where(predicate)
|
||||
return self
|
||||
|
||||
def limit(self, limit: int) -> Self:
|
||||
|
||||
@@ -568,4 +568,4 @@ class RemoteDBConnection(DBConnection):
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection to the database."""
|
||||
self._client.close()
|
||||
self._conn.close()
|
||||
|
||||
@@ -4211,7 +4211,7 @@ class AsyncTable:
|
||||
async_query = async_query.offset(query.offset)
|
||||
if query.columns:
|
||||
async_query = async_query.select(query.columns)
|
||||
if query.filter:
|
||||
if query.filter is not None:
|
||||
async_query = async_query.where(query.filter)
|
||||
if query.fast_search:
|
||||
async_query = async_query.fast_search()
|
||||
|
||||
@@ -559,7 +559,8 @@ def test_url_retrieve_downloads_image():
|
||||
matching the real usage pattern in embedding functions.
|
||||
"""
|
||||
import io
|
||||
from PIL import Image
|
||||
|
||||
Image = pytest.importorskip("PIL.Image")
|
||||
from lancedb.embeddings.utils import url_retrieve
|
||||
|
||||
image_url = "http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg"
|
||||
|
||||
@@ -1201,6 +1201,18 @@ async def test_header_provider_overrides_static_headers():
|
||||
await db.table_names()
|
||||
|
||||
|
||||
def test_close():
|
||||
"""Test that close() works without AttributeError."""
|
||||
import asyncio
|
||||
|
||||
def handler(req):
|
||||
req.send_response(200)
|
||||
req.end_headers()
|
||||
|
||||
with mock_lancedb_connection(handler) as db:
|
||||
asyncio.run(db.close())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("exception", [KeyboardInterrupt, SystemExit, GeneratorExit])
|
||||
def test_background_loop_cancellation(exception):
|
||||
"""Test that BackgroundEventLoop.run() cancels the future on interrupt."""
|
||||
|
||||
175
python/src/expr.rs
Normal file
175
python/src/expr.rs
Normal file
@@ -0,0 +1,175 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! PyO3 bindings for the LanceDB expression builder API.
|
||||
//!
|
||||
//! This module exposes [`PyExpr`] and helper free functions so Python can
|
||||
//! build type-safe filter / projection expressions that map directly to
|
||||
//! DataFusion [`Expr`] nodes, bypassing SQL string parsing.
|
||||
|
||||
use arrow::{datatypes::DataType, pyarrow::PyArrowType};
|
||||
use lancedb::expr::{DfExpr, col as ldb_col, contains, expr_cast, lit as df_lit, lower, upper};
|
||||
use pyo3::{Bound, PyAny, PyResult, exceptions::PyValueError, prelude::*, pyfunction};
|
||||
|
||||
/// A type-safe DataFusion expression.
|
||||
///
|
||||
/// Instances are constructed via the free functions [`expr_col`] and
|
||||
/// [`expr_lit`] and combined with the methods on this struct. On the Python
|
||||
/// side a thin wrapper class (`lancedb.expr.Expr`) delegates to these methods
|
||||
/// and adds Python operator overloads.
|
||||
#[pyclass(name = "PyExpr")]
|
||||
#[derive(Clone)]
|
||||
pub struct PyExpr(pub DfExpr);
|
||||
|
||||
#[pymethods]
|
||||
impl PyExpr {
|
||||
// ── comparisons ──────────────────────────────────────────────────────────
|
||||
|
||||
fn eq(&self, other: &Self) -> Self {
|
||||
Self(self.0.clone().eq(other.0.clone()))
|
||||
}
|
||||
|
||||
fn ne(&self, other: &Self) -> Self {
|
||||
Self(self.0.clone().not_eq(other.0.clone()))
|
||||
}
|
||||
|
||||
fn lt(&self, other: &Self) -> Self {
|
||||
Self(self.0.clone().lt(other.0.clone()))
|
||||
}
|
||||
|
||||
fn lte(&self, other: &Self) -> Self {
|
||||
Self(self.0.clone().lt_eq(other.0.clone()))
|
||||
}
|
||||
|
||||
fn gt(&self, other: &Self) -> Self {
|
||||
Self(self.0.clone().gt(other.0.clone()))
|
||||
}
|
||||
|
||||
fn gte(&self, other: &Self) -> Self {
|
||||
Self(self.0.clone().gt_eq(other.0.clone()))
|
||||
}
|
||||
|
||||
// ── logical ──────────────────────────────────────────────────────────────
|
||||
|
||||
fn and_(&self, other: &Self) -> Self {
|
||||
Self(self.0.clone().and(other.0.clone()))
|
||||
}
|
||||
|
||||
fn or_(&self, other: &Self) -> Self {
|
||||
Self(self.0.clone().or(other.0.clone()))
|
||||
}
|
||||
|
||||
fn not_(&self) -> Self {
|
||||
use std::ops::Not;
|
||||
Self(self.0.clone().not())
|
||||
}
|
||||
|
||||
// ── arithmetic ───────────────────────────────────────────────────────────
|
||||
|
||||
fn add(&self, other: &Self) -> Self {
|
||||
use std::ops::Add;
|
||||
Self(self.0.clone().add(other.0.clone()))
|
||||
}
|
||||
|
||||
fn sub(&self, other: &Self) -> Self {
|
||||
use std::ops::Sub;
|
||||
Self(self.0.clone().sub(other.0.clone()))
|
||||
}
|
||||
|
||||
fn mul(&self, other: &Self) -> Self {
|
||||
use std::ops::Mul;
|
||||
Self(self.0.clone().mul(other.0.clone()))
|
||||
}
|
||||
|
||||
fn div(&self, other: &Self) -> Self {
|
||||
use std::ops::Div;
|
||||
Self(self.0.clone().div(other.0.clone()))
|
||||
}
|
||||
|
||||
// ── string functions ─────────────────────────────────────────────────────
|
||||
|
||||
/// Convert string column to lowercase.
|
||||
fn lower(&self) -> Self {
|
||||
Self(lower(self.0.clone()))
|
||||
}
|
||||
|
||||
/// Convert string column to uppercase.
|
||||
fn upper(&self) -> Self {
|
||||
Self(upper(self.0.clone()))
|
||||
}
|
||||
|
||||
/// Test whether the string contains `substr`.
|
||||
fn contains(&self, substr: &Self) -> Self {
|
||||
Self(contains(self.0.clone(), substr.0.clone()))
|
||||
}
|
||||
|
||||
// ── type cast ────────────────────────────────────────────────────────────
|
||||
|
||||
/// Cast the expression to `data_type`.
|
||||
///
|
||||
/// `data_type` must be a PyArrow `DataType` (e.g. `pa.int32()`).
|
||||
/// On the Python side, `lancedb.expr.Expr.cast` also accepts type name
|
||||
/// strings via `pa.lib.ensure_type` before forwarding here.
|
||||
fn cast(&self, data_type: PyArrowType<DataType>) -> Self {
|
||||
Self(expr_cast(self.0.clone(), data_type.0))
|
||||
}
|
||||
|
||||
// ── utilities ────────────────────────────────────────────────────────────
|
||||
|
||||
/// Render the expression as a SQL string (useful for debugging).
|
||||
fn to_sql(&self) -> PyResult<String> {
|
||||
lancedb::expr::expr_to_sql_string(&self.0).map_err(|e| PyValueError::new_err(e.to_string()))
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> PyResult<String> {
|
||||
let sql =
|
||||
lancedb::expr::expr_to_sql_string(&self.0).unwrap_or_else(|_| "<expr>".to_string());
|
||||
Ok(format!("PyExpr({})", sql))
|
||||
}
|
||||
}
|
||||
|
||||
// ── free functions ────────────────────────────────────────────────────────────
|
||||
|
||||
/// Create a column reference expression.
|
||||
///
|
||||
/// The column name is preserved exactly as given (case-sensitive), so
|
||||
/// `col("firstName")` correctly references a field named `firstName`.
|
||||
#[pyfunction]
|
||||
pub fn expr_col(name: &str) -> PyExpr {
|
||||
PyExpr(ldb_col(name))
|
||||
}
|
||||
|
||||
/// Create a literal value expression.
|
||||
///
|
||||
/// Supported Python types: `bool`, `int`, `float`, `str`.
|
||||
#[pyfunction]
|
||||
pub fn expr_lit(value: Bound<'_, PyAny>) -> PyResult<PyExpr> {
|
||||
// bool must be checked before int because bool is a subclass of int in Python
|
||||
if let Ok(b) = value.extract::<bool>() {
|
||||
return Ok(PyExpr(df_lit(b)));
|
||||
}
|
||||
if let Ok(i) = value.extract::<i64>() {
|
||||
return Ok(PyExpr(df_lit(i)));
|
||||
}
|
||||
if let Ok(f) = value.extract::<f64>() {
|
||||
return Ok(PyExpr(df_lit(f)));
|
||||
}
|
||||
if let Ok(s) = value.extract::<String>() {
|
||||
return Ok(PyExpr(df_lit(s)));
|
||||
}
|
||||
Err(PyValueError::new_err(format!(
|
||||
"unsupported literal type: {}. Supported: bool, int, float, str",
|
||||
value.get_type().name()?
|
||||
)))
|
||||
}
|
||||
|
||||
/// Call an arbitrary registered SQL function by name.
|
||||
///
|
||||
/// See `lancedb::expr::func` for the list of supported function names.
|
||||
#[pyfunction]
|
||||
pub fn expr_func(name: &str, args: Vec<PyExpr>) -> PyResult<PyExpr> {
|
||||
let df_args: Vec<DfExpr> = args.into_iter().map(|e| e.0).collect();
|
||||
lancedb::expr::func(name, df_args)
|
||||
.map(PyExpr)
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))
|
||||
}
|
||||
@@ -4,6 +4,7 @@
|
||||
use arrow::RecordBatchStream;
|
||||
use connection::{Connection, connect};
|
||||
use env_logger::Env;
|
||||
use expr::{PyExpr, expr_col, expr_func, expr_lit};
|
||||
use index::IndexConfig;
|
||||
use permutation::{PyAsyncPermutationBuilder, PyPermutationReader};
|
||||
use pyo3::{
|
||||
@@ -21,6 +22,7 @@ use table::{
|
||||
pub mod arrow;
|
||||
pub mod connection;
|
||||
pub mod error;
|
||||
pub mod expr;
|
||||
pub mod header;
|
||||
pub mod index;
|
||||
pub mod namespace;
|
||||
@@ -55,10 +57,14 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<UpdateResult>()?;
|
||||
m.add_class::<PyAsyncPermutationBuilder>()?;
|
||||
m.add_class::<PyPermutationReader>()?;
|
||||
m.add_class::<PyExpr>()?;
|
||||
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(query::fts_query_to_json, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(expr_col, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(expr_lit, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(expr_func, m)?)?;
|
||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -35,12 +35,10 @@ use pyo3::types::PyList;
|
||||
use pyo3::types::{PyDict, PyString};
|
||||
use pyo3::{FromPyObject, exceptions::PyRuntimeError};
|
||||
use pyo3::{PyErr, pyclass};
|
||||
use pyo3::{
|
||||
exceptions::{PyNotImplementedError, PyValueError},
|
||||
intern,
|
||||
};
|
||||
use pyo3::{exceptions::PyValueError, intern};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
use crate::expr::PyExpr;
|
||||
use crate::util::parse_distance_type;
|
||||
use crate::{arrow::RecordBatchStream, util::PyLanceDB};
|
||||
use crate::{error::PythonErrorExt, index::class_name};
|
||||
@@ -344,9 +342,13 @@ impl<'py> IntoPyObject<'py> for PyQueryFilter {
|
||||
|
||||
fn into_pyobject(self, py: pyo3::Python<'py>) -> PyResult<Self::Output> {
|
||||
match self.0 {
|
||||
QueryFilter::Datafusion(_) => Err(PyNotImplementedError::new_err(
|
||||
"Datafusion filter has no conversion to Python",
|
||||
)),
|
||||
QueryFilter::Datafusion(expr) => {
|
||||
// Serialize the DataFusion expression to a SQL string so that
|
||||
// callers (e.g. remote tables) see the same format as Sql.
|
||||
let sql = lancedb::expr::expr_to_sql_string(&expr)
|
||||
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
|
||||
Ok(sql.into_pyobject(py)?.into_any())
|
||||
}
|
||||
QueryFilter::Sql(sql) => Ok(sql.into_pyobject(py)?.into_any()),
|
||||
QueryFilter::Substrait(substrait) => Ok(substrait.into_pyobject(py)?.into_any()),
|
||||
}
|
||||
@@ -370,10 +372,20 @@ impl Query {
|
||||
self.inner = self.inner.clone().only_if(predicate);
|
||||
}
|
||||
|
||||
pub fn where_expr(&mut self, expr: PyExpr) {
|
||||
self.inner = self.inner.clone().only_if_expr(expr.0);
|
||||
}
|
||||
|
||||
pub fn select(&mut self, columns: Vec<(String, String)>) {
|
||||
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
||||
}
|
||||
|
||||
pub fn select_expr(&mut self, columns: Vec<(String, PyExpr)>) {
|
||||
let pairs: Vec<(String, lancedb::expr::DfExpr)> =
|
||||
columns.into_iter().map(|(name, e)| (name, e.0)).collect();
|
||||
self.inner = self.inner.clone().select(Select::Expr(pairs));
|
||||
}
|
||||
|
||||
pub fn select_columns(&mut self, columns: Vec<String>) {
|
||||
self.inner = self.inner.clone().select(Select::columns(&columns));
|
||||
}
|
||||
@@ -607,10 +619,20 @@ impl FTSQuery {
|
||||
self.inner = self.inner.clone().only_if(predicate);
|
||||
}
|
||||
|
||||
pub fn where_expr(&mut self, expr: PyExpr) {
|
||||
self.inner = self.inner.clone().only_if_expr(expr.0);
|
||||
}
|
||||
|
||||
pub fn select(&mut self, columns: Vec<(String, String)>) {
|
||||
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
||||
}
|
||||
|
||||
pub fn select_expr(&mut self, columns: Vec<(String, PyExpr)>) {
|
||||
let pairs: Vec<(String, lancedb::expr::DfExpr)> =
|
||||
columns.into_iter().map(|(name, e)| (name, e.0)).collect();
|
||||
self.inner = self.inner.clone().select(Select::Expr(pairs));
|
||||
}
|
||||
|
||||
pub fn select_columns(&mut self, columns: Vec<String>) {
|
||||
self.inner = self.inner.clone().select(Select::columns(&columns));
|
||||
}
|
||||
@@ -725,6 +747,10 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().only_if(predicate);
|
||||
}
|
||||
|
||||
pub fn where_expr(&mut self, expr: PyExpr) {
|
||||
self.inner = self.inner.clone().only_if_expr(expr.0);
|
||||
}
|
||||
|
||||
pub fn add_query_vector(&mut self, vector: Bound<'_, PyAny>) -> PyResult<()> {
|
||||
let data: ArrayData = ArrayData::from_pyarrow_bound(&vector)?;
|
||||
let array = make_array(data);
|
||||
@@ -736,6 +762,12 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
||||
}
|
||||
|
||||
pub fn select_expr(&mut self, columns: Vec<(String, PyExpr)>) {
|
||||
let pairs: Vec<(String, lancedb::expr::DfExpr)> =
|
||||
columns.into_iter().map(|(name, e)| (name, e.0)).collect();
|
||||
self.inner = self.inner.clone().select(Select::Expr(pairs));
|
||||
}
|
||||
|
||||
pub fn select_columns(&mut self, columns: Vec<String>) {
|
||||
self.inner = self.inner.clone().select(Select::columns(&columns));
|
||||
}
|
||||
@@ -890,11 +922,21 @@ impl HybridQuery {
|
||||
self.inner_fts.r#where(predicate);
|
||||
}
|
||||
|
||||
pub fn where_expr(&mut self, expr: PyExpr) {
|
||||
self.inner_vec.where_expr(expr.clone());
|
||||
self.inner_fts.where_expr(expr);
|
||||
}
|
||||
|
||||
pub fn select(&mut self, columns: Vec<(String, String)>) {
|
||||
self.inner_vec.select(columns.clone());
|
||||
self.inner_fts.select(columns);
|
||||
}
|
||||
|
||||
pub fn select_expr(&mut self, columns: Vec<(String, PyExpr)>) {
|
||||
self.inner_vec.select_expr(columns.clone());
|
||||
self.inner_fts.select_expr(columns);
|
||||
}
|
||||
|
||||
pub fn select_columns(&mut self, columns: Vec<String>) {
|
||||
self.inner_vec.select_columns(columns.clone());
|
||||
self.inner_fts.select_columns(columns);
|
||||
|
||||
387
python/tests/test_expr.py
Normal file
387
python/tests/test_expr.py
Normal file
@@ -0,0 +1,387 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
"""Tests for the type-safe expression builder API."""
|
||||
|
||||
import pytest
|
||||
import pyarrow as pa
|
||||
import lancedb
|
||||
from lancedb.expr import Expr, col, lit, func
|
||||
|
||||
|
||||
# ── unit tests for Expr construction ─────────────────────────────────────────
|
||||
|
||||
|
||||
class TestExprConstruction:
|
||||
def test_col_returns_expr(self):
|
||||
e = col("age")
|
||||
assert isinstance(e, Expr)
|
||||
|
||||
def test_lit_int(self):
|
||||
e = lit(42)
|
||||
assert isinstance(e, Expr)
|
||||
|
||||
def test_lit_float(self):
|
||||
e = lit(3.14)
|
||||
assert isinstance(e, Expr)
|
||||
|
||||
def test_lit_str(self):
|
||||
e = lit("hello")
|
||||
assert isinstance(e, Expr)
|
||||
|
||||
def test_lit_bool(self):
|
||||
e = lit(True)
|
||||
assert isinstance(e, Expr)
|
||||
|
||||
def test_lit_unsupported_type_raises(self):
|
||||
with pytest.raises(Exception):
|
||||
lit([1, 2, 3])
|
||||
|
||||
def test_func(self):
|
||||
e = func("lower", col("name"))
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "lower(name)"
|
||||
|
||||
def test_func_unknown_raises(self):
|
||||
with pytest.raises(Exception):
|
||||
func("not_a_real_function", col("x"))
|
||||
|
||||
|
||||
class TestExprOperators:
|
||||
def test_eq_operator(self):
|
||||
e = col("x") == lit(1)
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(x = 1)"
|
||||
|
||||
def test_ne_operator(self):
|
||||
e = col("x") != lit(1)
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(x <> 1)"
|
||||
|
||||
def test_lt_operator(self):
|
||||
e = col("age") < lit(18)
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(age < 18)"
|
||||
|
||||
def test_le_operator(self):
|
||||
e = col("age") <= lit(18)
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(age <= 18)"
|
||||
|
||||
def test_gt_operator(self):
|
||||
e = col("age") > lit(18)
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(age > 18)"
|
||||
|
||||
def test_ge_operator(self):
|
||||
e = col("age") >= lit(18)
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(age >= 18)"
|
||||
|
||||
def test_and_operator(self):
|
||||
e = (col("age") > lit(18)) & (col("status") == lit("active"))
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "((age > 18) AND (status = 'active'))"
|
||||
|
||||
def test_or_operator(self):
|
||||
e = (col("a") == lit(1)) | (col("b") == lit(2))
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "((a = 1) OR (b = 2))"
|
||||
|
||||
def test_invert_operator(self):
|
||||
e = ~(col("active") == lit(True))
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "NOT (active = true)"
|
||||
|
||||
def test_add_operator(self):
|
||||
e = col("x") + lit(1)
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(x + 1)"
|
||||
|
||||
def test_sub_operator(self):
|
||||
e = col("x") - lit(1)
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(x - 1)"
|
||||
|
||||
def test_mul_operator(self):
|
||||
e = col("price") * lit(1.1)
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(price * 1.1)"
|
||||
|
||||
def test_div_operator(self):
|
||||
e = col("total") / lit(2)
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(total / 2)"
|
||||
|
||||
def test_radd(self):
|
||||
e = lit(1) + col("x")
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(1 + x)"
|
||||
|
||||
def test_rmul(self):
|
||||
e = lit(2) * col("x")
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(2 * x)"
|
||||
|
||||
def test_coerce_plain_int(self):
|
||||
# Operators should auto-wrap plain Python values via lit()
|
||||
e = col("age") > 18
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(age > 18)"
|
||||
|
||||
def test_coerce_plain_str(self):
|
||||
e = col("name") == "alice"
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(name = 'alice')"
|
||||
|
||||
|
||||
class TestExprStringMethods:
|
||||
def test_lower(self):
|
||||
e = col("name").lower()
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "lower(name)"
|
||||
|
||||
def test_upper(self):
|
||||
e = col("name").upper()
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "upper(name)"
|
||||
|
||||
def test_contains(self):
|
||||
e = col("text").contains(lit("hello"))
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "contains(text, 'hello')"
|
||||
|
||||
def test_contains_with_str_coerce(self):
|
||||
e = col("text").contains("hello")
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "contains(text, 'hello')"
|
||||
|
||||
def test_chained_lower_eq(self):
|
||||
e = col("name").lower() == lit("alice")
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(lower(name) = 'alice')"
|
||||
|
||||
|
||||
class TestExprCast:
|
||||
def test_cast_string(self):
|
||||
e = col("id").cast("string")
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "CAST(id AS VARCHAR)"
|
||||
|
||||
def test_cast_int32(self):
|
||||
e = col("score").cast("int32")
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "CAST(score AS INTEGER)"
|
||||
|
||||
def test_cast_float64(self):
|
||||
e = col("val").cast("float64")
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "CAST(val AS DOUBLE)"
|
||||
|
||||
def test_cast_pyarrow_type(self):
|
||||
e = col("score").cast(pa.int32())
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "CAST(score AS INTEGER)"
|
||||
|
||||
def test_cast_pyarrow_float64(self):
|
||||
e = col("val").cast(pa.float64())
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "CAST(val AS DOUBLE)"
|
||||
|
||||
def test_cast_pyarrow_string(self):
|
||||
e = col("id").cast(pa.string())
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "CAST(id AS VARCHAR)"
|
||||
|
||||
def test_cast_pyarrow_and_string_equivalent(self):
|
||||
# pa.int32() and "int32" should produce equivalent SQL
|
||||
sql_str = col("x").cast("int32").to_sql()
|
||||
sql_pa = col("x").cast(pa.int32()).to_sql()
|
||||
assert sql_str == sql_pa
|
||||
|
||||
|
||||
class TestExprNamedMethods:
|
||||
def test_eq_method(self):
|
||||
e = col("x").eq(lit(1))
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(x = 1)"
|
||||
|
||||
def test_gt_method(self):
|
||||
e = col("x").gt(lit(0))
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(x > 0)"
|
||||
|
||||
def test_and_method(self):
|
||||
e = col("x").gt(lit(0)).and_(col("y").lt(lit(10)))
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "((x > 0) AND (y < 10))"
|
||||
|
||||
def test_or_method(self):
|
||||
e = col("x").eq(lit(1)).or_(col("x").eq(lit(2)))
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "((x = 1) OR (x = 2))"
|
||||
|
||||
|
||||
class TestExprRepr:
|
||||
def test_repr(self):
|
||||
e = col("age") > lit(18)
|
||||
assert repr(e) == "Expr((age > 18))"
|
||||
|
||||
def test_to_sql(self):
|
||||
e = col("age") > 18
|
||||
assert e.to_sql() == "(age > 18)"
|
||||
|
||||
def test_unhashable(self):
|
||||
e = col("x")
|
||||
with pytest.raises(TypeError):
|
||||
{e: 1}
|
||||
|
||||
|
||||
# ── integration tests: end-to-end query against a real table ─────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_table(tmp_path):
|
||||
db = lancedb.connect(str(tmp_path))
|
||||
data = pa.table(
|
||||
{
|
||||
"id": [1, 2, 3, 4, 5],
|
||||
"name": ["Alice", "Bob", "Charlie", "alice", "BOB"],
|
||||
"age": [25, 17, 30, 22, 15],
|
||||
"score": [1.5, 2.0, 3.5, 4.0, 0.5],
|
||||
}
|
||||
)
|
||||
return db.create_table("test", data)
|
||||
|
||||
|
||||
class TestExprFilter:
|
||||
def test_simple_gt_filter(self, simple_table):
|
||||
result = simple_table.search().where(col("age") > lit(20)).to_arrow()
|
||||
assert result.num_rows == 3 # ages 25, 30, 22
|
||||
|
||||
def test_compound_and_filter(self, simple_table):
|
||||
result = (
|
||||
simple_table.search()
|
||||
.where((col("age") > lit(18)) & (col("score") > lit(2.0)))
|
||||
.to_arrow()
|
||||
)
|
||||
assert result.num_rows == 2 # (30, 3.5) and (22, 4.0)
|
||||
|
||||
def test_string_equality_filter(self, simple_table):
|
||||
result = simple_table.search().where(col("name") == lit("Bob")).to_arrow()
|
||||
assert result.num_rows == 1
|
||||
|
||||
def test_or_filter(self, simple_table):
|
||||
result = (
|
||||
simple_table.search()
|
||||
.where((col("age") < lit(18)) | (col("age") > lit(28)))
|
||||
.to_arrow()
|
||||
)
|
||||
assert result.num_rows == 3 # ages 17, 30, 15
|
||||
|
||||
def test_coercion_no_lit(self, simple_table):
|
||||
# Python values should be auto-coerced
|
||||
result = simple_table.search().where(col("age") > 20).to_arrow()
|
||||
assert result.num_rows == 3
|
||||
|
||||
def test_string_sql_still_works(self, simple_table):
|
||||
# Backwards compatibility: plain strings still accepted
|
||||
result = simple_table.search().where("age > 20").to_arrow()
|
||||
assert result.num_rows == 3
|
||||
|
||||
|
||||
class TestExprProjection:
|
||||
def test_select_with_expr(self, simple_table):
|
||||
result = (
|
||||
simple_table.search()
|
||||
.select({"double_score": col("score") * lit(2)})
|
||||
.to_arrow()
|
||||
)
|
||||
assert "double_score" in result.schema.names
|
||||
|
||||
def test_select_mixed_str_and_expr(self, simple_table):
|
||||
result = (
|
||||
simple_table.search()
|
||||
.select({"id": "id", "double_score": col("score") * lit(2)})
|
||||
.to_arrow()
|
||||
)
|
||||
assert "id" in result.schema.names
|
||||
assert "double_score" in result.schema.names
|
||||
|
||||
def test_select_list_of_columns(self, simple_table):
|
||||
# Plain list of str still works
|
||||
result = simple_table.search().select(["id", "name"]).to_arrow()
|
||||
assert result.schema.names == ["id", "name"]
|
||||
|
||||
|
||||
# ── column name edge cases ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestColNaming:
|
||||
"""Unit tests verifying that col() preserves identifiers exactly.
|
||||
|
||||
Identifiers that need quoting (camelCase, spaces, leading digits, unicode)
|
||||
are wrapped in backticks to match the lance SQL parser's dialect.
|
||||
"""
|
||||
|
||||
def test_camel_case_preserved_in_sql(self):
|
||||
# camelCase is quoted with backticks so the case round-trips correctly.
|
||||
assert col("firstName").to_sql() == "`firstName`"
|
||||
|
||||
def test_camel_case_in_expression(self):
|
||||
assert (col("firstName") > lit(18)).to_sql() == "(`firstName` > 18)"
|
||||
|
||||
def test_space_in_name_quoted(self):
|
||||
assert col("first name").to_sql() == "`first name`"
|
||||
|
||||
def test_space_in_expression(self):
|
||||
assert (col("first name") == lit("A")).to_sql() == "(`first name` = 'A')"
|
||||
|
||||
def test_leading_digit_quoted(self):
|
||||
assert col("2fast").to_sql() == "`2fast`"
|
||||
|
||||
def test_unicode_quoted(self):
|
||||
assert col("名前").to_sql() == "`名前`"
|
||||
|
||||
def test_snake_case_unquoted(self):
|
||||
# Plain snake_case needs no quoting.
|
||||
assert col("first_name").to_sql() == "first_name"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def special_col_table(tmp_path):
|
||||
db = lancedb.connect(str(tmp_path))
|
||||
data = pa.table(
|
||||
{
|
||||
"firstName": ["Alice", "Bob", "Charlie"],
|
||||
"first name": ["A", "B", "C"],
|
||||
"score": [10, 20, 30],
|
||||
}
|
||||
)
|
||||
return db.create_table("special", data)
|
||||
|
||||
|
||||
class TestColNamingIntegration:
|
||||
def test_camel_case_filter(self, special_col_table):
|
||||
result = (
|
||||
special_col_table.search()
|
||||
.where(col("firstName") == lit("Alice"))
|
||||
.to_arrow()
|
||||
)
|
||||
assert result.num_rows == 1
|
||||
assert result["firstName"][0].as_py() == "Alice"
|
||||
|
||||
def test_space_in_col_filter(self, special_col_table):
|
||||
result = (
|
||||
special_col_table.search().where(col("first name") == lit("B")).to_arrow()
|
||||
)
|
||||
assert result.num_rows == 1
|
||||
|
||||
def test_camel_case_projection(self, special_col_table):
|
||||
result = (
|
||||
special_col_table.search()
|
||||
.select({"upper_name": col("firstName").upper()})
|
||||
.to_arrow()
|
||||
)
|
||||
assert "upper_name" in result.schema.names
|
||||
assert sorted(result["upper_name"].to_pylist()) == ["ALICE", "BOB", "CHARLIE"]
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.27.2-beta.1"
|
||||
version = "0.27.2"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
|
||||
@@ -27,7 +27,17 @@ use arrow_schema::DataType;
|
||||
use datafusion_expr::{Expr, ScalarUDF, expr_fn::cast};
|
||||
use datafusion_functions::string::expr_fn as string_expr_fn;
|
||||
|
||||
pub use datafusion_expr::{col, lit};
|
||||
pub use datafusion_expr::lit;
|
||||
|
||||
/// Create a column reference expression, preserving the name exactly as given.
|
||||
///
|
||||
/// Unlike DataFusion's built-in [`col`][datafusion_expr::col], this function
|
||||
/// does **not** normalise the identifier to lower-case, so
|
||||
/// `col("firstName")` correctly references a field named `firstName`.
|
||||
pub fn col(name: impl Into<String>) -> DfExpr {
|
||||
use datafusion_common::Column;
|
||||
DfExpr::Column(Column::new_unqualified(name))
|
||||
}
|
||||
|
||||
pub use datafusion_expr::Expr as DfExpr;
|
||||
|
||||
|
||||
@@ -2,11 +2,37 @@
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use datafusion_expr::Expr;
|
||||
use datafusion_sql::unparser;
|
||||
use datafusion_sql::unparser::{self, dialect::Dialect};
|
||||
|
||||
/// Unparser dialect that matches the quoting style expected by the Lance SQL
|
||||
/// parser. Lance uses backtick (`` ` ``) as the only delimited-identifier
|
||||
/// quote character, so we must produce `` `firstName` `` rather than
|
||||
/// `"firstName"` for identifiers that require quoting.
|
||||
///
|
||||
/// We quote an identifier when it:
|
||||
/// * is a SQL reserved word, OR
|
||||
/// * contains characters outside `[a-zA-Z0-9_]`, OR
|
||||
/// * starts with a digit, OR
|
||||
/// * contains upper-case letters (unquoted identifiers are normalised to
|
||||
/// lower-case by the SQL parser, which would break case-sensitive schemas).
|
||||
struct LanceSqlDialect;
|
||||
|
||||
impl Dialect for LanceSqlDialect {
|
||||
fn identifier_quote_style(&self, identifier: &str) -> Option<char> {
|
||||
let needs_quote = identifier.chars().any(|c| c.is_ascii_uppercase())
|
||||
|| !identifier
|
||||
.chars()
|
||||
.enumerate()
|
||||
.all(|(i, c)| c == '_' || c.is_ascii_alphabetic() || (i > 0 && c.is_ascii_digit()));
|
||||
if needs_quote { Some('`') } else { None }
|
||||
}
|
||||
}
|
||||
|
||||
pub fn expr_to_sql_string(expr: &Expr) -> crate::Result<String> {
|
||||
let ast = unparser::expr_to_sql(expr).map_err(|e| crate::Error::InvalidInput {
|
||||
message: format!("failed to serialize expression to SQL: {}", e),
|
||||
})?;
|
||||
let ast = unparser::Unparser::new(&LanceSqlDialect)
|
||||
.expr_to_sql(expr)
|
||||
.map_err(|e| crate::Error::InvalidInput {
|
||||
message: format!("failed to serialize expression to SQL: {}", e),
|
||||
})?;
|
||||
Ok(ast.to_string())
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ use std::sync::Arc;
|
||||
use std::{future::Future, time::Duration};
|
||||
|
||||
use arrow::compute::concat_batches;
|
||||
use arrow_array::{Array, Float16Array, Float32Array, Float64Array, make_array};
|
||||
use arrow_array::{Array, Float16Array, Float32Array, Float64Array, RecordBatch, make_array};
|
||||
use arrow_schema::{DataType, SchemaRef};
|
||||
use datafusion_expr::Expr;
|
||||
use datafusion_physical_plan::ExecutionPlan;
|
||||
@@ -17,15 +17,17 @@ use lance_datafusion::exec::execute_plan;
|
||||
use lance_index::scalar::FullTextSearchQuery;
|
||||
use lance_index::scalar::inverted::SCORE_COL;
|
||||
use lance_index::vector::DIST_COL;
|
||||
use lance_io::stream::RecordBatchStreamAdapter;
|
||||
|
||||
use crate::DistanceType;
|
||||
use crate::error::{Error, Result};
|
||||
use crate::rerankers::rrf::RRFReranker;
|
||||
use crate::rerankers::{NormalizeMethod, Reranker, check_reranker_result};
|
||||
use crate::table::BaseTable;
|
||||
use crate::utils::TimeoutStream;
|
||||
use crate::{arrow::SendableRecordBatchStream, table::AnyQuery};
|
||||
use crate::utils::{MaxBatchLengthStream, TimeoutStream};
|
||||
use crate::{
|
||||
arrow::{SendableRecordBatchStream, SimpleRecordBatchStream},
|
||||
table::AnyQuery,
|
||||
};
|
||||
|
||||
mod hybrid;
|
||||
|
||||
@@ -604,6 +606,14 @@ impl Default for QueryExecutionOptions {
|
||||
}
|
||||
}
|
||||
|
||||
impl QueryExecutionOptions {
|
||||
fn without_output_batch_length_limit(&self) -> Self {
|
||||
let mut options = self.clone();
|
||||
options.max_batch_length = 0;
|
||||
options
|
||||
}
|
||||
}
|
||||
|
||||
/// A trait for a query object that can be executed to get results
|
||||
///
|
||||
/// There are various kinds of queries but they all return results
|
||||
@@ -1180,6 +1190,8 @@ impl VectorQuery {
|
||||
&self,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let max_batch_length = options.max_batch_length as usize;
|
||||
let internal_options = options.without_output_batch_length_limit();
|
||||
// clone query and specify we want to include row IDs, which can be needed for reranking
|
||||
let mut fts_query = Query::new(self.parent.clone());
|
||||
fts_query.request = self.request.base.clone();
|
||||
@@ -1189,8 +1201,8 @@ impl VectorQuery {
|
||||
|
||||
vector_query.request.base.full_text_search = None;
|
||||
let (fts_results, vec_results) = try_join!(
|
||||
fts_query.execute_with_options(options.clone()),
|
||||
vector_query.inner_execute_with_options(options)
|
||||
fts_query.execute_with_options(internal_options.clone()),
|
||||
vector_query.inner_execute_with_options(internal_options)
|
||||
)?;
|
||||
|
||||
let (fts_results, vec_results) = try_join!(
|
||||
@@ -1245,9 +1257,7 @@ impl VectorQuery {
|
||||
results = results.drop_column(ROW_ID)?;
|
||||
}
|
||||
|
||||
Ok(SendableRecordBatchStream::from(
|
||||
RecordBatchStreamAdapter::new(results.schema(), stream::iter([Ok(results)])),
|
||||
))
|
||||
Ok(single_batch_stream(results, max_batch_length))
|
||||
}
|
||||
|
||||
async fn inner_execute_with_options(
|
||||
@@ -1256,6 +1266,7 @@ impl VectorQuery {
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let plan = self.create_plan(options.clone()).await?;
|
||||
let inner = execute_plan(plan, Default::default())?;
|
||||
let inner = MaxBatchLengthStream::new_boxed(inner, options.max_batch_length as usize);
|
||||
let inner = if let Some(timeout) = options.timeout {
|
||||
TimeoutStream::new_boxed(inner, timeout)
|
||||
} else {
|
||||
@@ -1265,6 +1276,25 @@ impl VectorQuery {
|
||||
}
|
||||
}
|
||||
|
||||
fn single_batch_stream(batch: RecordBatch, max_batch_length: usize) -> SendableRecordBatchStream {
|
||||
let schema = batch.schema();
|
||||
if max_batch_length == 0 || batch.num_rows() <= max_batch_length {
|
||||
return Box::pin(SimpleRecordBatchStream::new(
|
||||
stream::iter([Ok(batch)]),
|
||||
schema,
|
||||
));
|
||||
}
|
||||
|
||||
let mut batches = Vec::with_capacity(batch.num_rows().div_ceil(max_batch_length));
|
||||
let mut offset = 0;
|
||||
while offset < batch.num_rows() {
|
||||
let length = (batch.num_rows() - offset).min(max_batch_length);
|
||||
batches.push(Ok(batch.slice(offset, length)));
|
||||
offset += length;
|
||||
}
|
||||
Box::pin(SimpleRecordBatchStream::new(stream::iter(batches), schema))
|
||||
}
|
||||
|
||||
impl ExecutableQuery for VectorQuery {
|
||||
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
let query = AnyQuery::VectorQuery(self.request.clone());
|
||||
@@ -1753,6 +1783,50 @@ mod tests {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
async fn make_large_vector_table(tmp_dir: &tempfile::TempDir, rows: usize) -> Table {
|
||||
let dataset_path = tmp_dir.path().join("large_test.lance");
|
||||
let uri = dataset_path.to_str().unwrap();
|
||||
|
||||
let schema = Arc::new(ArrowSchema::new(vec![
|
||||
ArrowField::new("id", DataType::Utf8, false),
|
||||
ArrowField::new(
|
||||
"vector",
|
||||
DataType::FixedSizeList(
|
||||
Arc::new(ArrowField::new("item", DataType::Float32, true)),
|
||||
4,
|
||||
),
|
||||
false,
|
||||
),
|
||||
]));
|
||||
|
||||
let ids = StringArray::from_iter_values((0..rows).map(|i| format!("row-{i}")));
|
||||
let vectors = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
(0..rows).map(|i| Some(vec![Some(i as f32), Some(1.0), Some(2.0), Some(3.0)])),
|
||||
4,
|
||||
);
|
||||
let batch =
|
||||
RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(vectors)]).unwrap();
|
||||
|
||||
let conn = connect(uri).execute().await.unwrap();
|
||||
conn.create_table("my_table", vec![batch])
|
||||
.execute()
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
async fn assert_stream_batches_at_most(
|
||||
mut results: SendableRecordBatchStream,
|
||||
max_batch_length: usize,
|
||||
) {
|
||||
let mut saw_batch = false;
|
||||
while let Some(batch) = results.next().await {
|
||||
let batch = batch.unwrap();
|
||||
saw_batch = true;
|
||||
assert!(batch.num_rows() <= max_batch_length);
|
||||
}
|
||||
assert!(saw_batch);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_with_options() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
@@ -1772,6 +1846,83 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_query_execute_with_options_respects_max_batch_length() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let table = make_large_vector_table(&tmp_dir, 10_000).await;
|
||||
|
||||
let results = table
|
||||
.query()
|
||||
.nearest_to(vec![0.0, 1.0, 2.0, 3.0])
|
||||
.unwrap()
|
||||
.limit(10_000)
|
||||
.execute_with_options(QueryExecutionOptions {
|
||||
max_batch_length: 100,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_stream_batches_at_most(results, 100).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hybrid_query_execute_with_options_respects_max_batch_length() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let dataset_path = tmp_dir.path();
|
||||
let conn = connect(dataset_path.to_str().unwrap())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let dims = 2;
|
||||
let rows = 512;
|
||||
let schema = Arc::new(ArrowSchema::new(vec![
|
||||
ArrowField::new("text", DataType::Utf8, false),
|
||||
ArrowField::new(
|
||||
"vector",
|
||||
DataType::FixedSizeList(
|
||||
Arc::new(ArrowField::new("item", DataType::Float32, true)),
|
||||
dims,
|
||||
),
|
||||
false,
|
||||
),
|
||||
]));
|
||||
|
||||
let text = StringArray::from_iter_values((0..rows).map(|_| "match"));
|
||||
let vectors = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
(0..rows).map(|i| Some(vec![Some(i as f32), Some(0.0)])),
|
||||
dims,
|
||||
);
|
||||
let record_batch =
|
||||
RecordBatch::try_new(schema.clone(), vec![Arc::new(text), Arc::new(vectors)]).unwrap();
|
||||
let table = conn
|
||||
.create_table("my_table", record_batch)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
table
|
||||
.create_index(&["text"], crate::index::Index::FTS(Default::default()))
|
||||
.replace(true)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = table
|
||||
.query()
|
||||
.full_text_search(FullTextSearchQuery::new("match".to_string()))
|
||||
.limit(rows)
|
||||
.nearest_to(&[0.0, 0.0])
|
||||
.unwrap()
|
||||
.execute_with_options(QueryExecutionOptions {
|
||||
max_batch_length: 100,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert_stream_batches_at_most(results, 100).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_analyze_plan() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
|
||||
@@ -9,7 +9,7 @@ use crate::expr::expr_to_sql_string;
|
||||
use crate::query::{
|
||||
DEFAULT_TOP_K, QueryExecutionOptions, QueryFilter, QueryRequest, Select, VectorQueryRequest,
|
||||
};
|
||||
use crate::utils::{TimeoutStream, default_vector_column};
|
||||
use crate::utils::{MaxBatchLengthStream, TimeoutStream, default_vector_column};
|
||||
use arrow::array::{AsArray, FixedSizeListBuilder, Float32Builder};
|
||||
use arrow::datatypes::{Float32Type, UInt8Type};
|
||||
use arrow_array::Array;
|
||||
@@ -66,6 +66,7 @@ async fn execute_generic_query(
|
||||
) -> Result<DatasetRecordBatchStream> {
|
||||
let plan = create_plan(table, query, options.clone()).await?;
|
||||
let inner = execute_plan(plan, Default::default())?;
|
||||
let inner = MaxBatchLengthStream::new_boxed(inner, options.max_batch_length as usize);
|
||||
let inner = if let Some(timeout) = options.timeout {
|
||||
TimeoutStream::new_boxed(inner, timeout)
|
||||
} else {
|
||||
@@ -200,7 +201,9 @@ pub async fn create_plan(
|
||||
scanner.with_row_id();
|
||||
}
|
||||
|
||||
scanner.batch_size(options.max_batch_length as usize);
|
||||
if options.max_batch_length > 0 {
|
||||
scanner.batch_size(options.max_batch_length as usize);
|
||||
}
|
||||
|
||||
if query.base.fast_search {
|
||||
scanner.fast_search();
|
||||
|
||||
@@ -335,6 +335,85 @@ impl Stream for TimeoutStream {
|
||||
}
|
||||
}
|
||||
|
||||
/// A `Stream` wrapper that slices oversized batches to enforce a maximum batch length.
|
||||
pub struct MaxBatchLengthStream {
|
||||
inner: SendableRecordBatchStream,
|
||||
max_batch_length: Option<usize>,
|
||||
buffered_batch: Option<RecordBatch>,
|
||||
buffered_offset: usize,
|
||||
}
|
||||
|
||||
impl MaxBatchLengthStream {
|
||||
pub fn new(inner: SendableRecordBatchStream, max_batch_length: usize) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
max_batch_length: (max_batch_length > 0).then_some(max_batch_length),
|
||||
buffered_batch: None,
|
||||
buffered_offset: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_boxed(
|
||||
inner: SendableRecordBatchStream,
|
||||
max_batch_length: usize,
|
||||
) -> SendableRecordBatchStream {
|
||||
if max_batch_length == 0 {
|
||||
inner
|
||||
} else {
|
||||
Box::pin(Self::new(inner, max_batch_length))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RecordBatchStream for MaxBatchLengthStream {
|
||||
fn schema(&self) -> SchemaRef {
|
||||
self.inner.schema()
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for MaxBatchLengthStream {
|
||||
type Item = DataFusionResult<RecordBatch>;
|
||||
|
||||
fn poll_next(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Option<Self::Item>> {
|
||||
loop {
|
||||
let Some(max_batch_length) = self.max_batch_length else {
|
||||
return Pin::new(&mut self.inner).poll_next(cx);
|
||||
};
|
||||
|
||||
if let Some(batch) = self.buffered_batch.clone() {
|
||||
if self.buffered_offset < batch.num_rows() {
|
||||
let remaining = batch.num_rows() - self.buffered_offset;
|
||||
let length = remaining.min(max_batch_length);
|
||||
let sliced = batch.slice(self.buffered_offset, length);
|
||||
self.buffered_offset += length;
|
||||
if self.buffered_offset >= batch.num_rows() {
|
||||
self.buffered_batch = None;
|
||||
self.buffered_offset = 0;
|
||||
}
|
||||
return std::task::Poll::Ready(Some(Ok(sliced)));
|
||||
}
|
||||
|
||||
self.buffered_batch = None;
|
||||
self.buffered_offset = 0;
|
||||
}
|
||||
|
||||
match Pin::new(&mut self.inner).poll_next(cx) {
|
||||
std::task::Poll::Ready(Some(Ok(batch))) => {
|
||||
if batch.num_rows() <= max_batch_length {
|
||||
return std::task::Poll::Ready(Some(Ok(batch)));
|
||||
}
|
||||
self.buffered_batch = Some(batch);
|
||||
self.buffered_offset = 0;
|
||||
}
|
||||
other => return other,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use arrow_array::Int32Array;
|
||||
@@ -470,7 +549,7 @@ mod tests {
|
||||
assert_eq!(string_to_datatype(string), Some(expected));
|
||||
}
|
||||
|
||||
fn sample_batch() -> RecordBatch {
|
||||
fn sample_batch(num_rows: i32) -> RecordBatch {
|
||||
let schema = Arc::new(Schema::new(vec![Field::new(
|
||||
"col1",
|
||||
DataType::Int32,
|
||||
@@ -478,14 +557,14 @@ mod tests {
|
||||
)]));
|
||||
RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
vec![Arc::new(Int32Array::from_iter_values(0..num_rows))],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout_stream() {
|
||||
let batch = sample_batch();
|
||||
let batch = sample_batch(3);
|
||||
let schema = batch.schema();
|
||||
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
|
||||
|
||||
@@ -515,7 +594,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout_stream_zero_duration() {
|
||||
let batch = sample_batch();
|
||||
let batch = sample_batch(3);
|
||||
let schema = batch.schema();
|
||||
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
|
||||
|
||||
@@ -534,7 +613,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout_stream_completes_normally() {
|
||||
let batch = sample_batch();
|
||||
let batch = sample_batch(3);
|
||||
let schema = batch.schema();
|
||||
let mock_stream = stream::iter(vec![Ok(batch.clone()), Ok(batch.clone())]);
|
||||
|
||||
@@ -552,4 +631,35 @@ mod tests {
|
||||
// Stream should be empty now
|
||||
assert!(timeout_stream.next().await.is_none());
|
||||
}
|
||||
|
||||
async fn collect_batch_sizes(
|
||||
stream: SendableRecordBatchStream,
|
||||
max_batch_length: usize,
|
||||
) -> Vec<usize> {
|
||||
let mut sliced_stream = MaxBatchLengthStream::new(stream, max_batch_length);
|
||||
sliced_stream
|
||||
.by_ref()
|
||||
.map(|batch| batch.unwrap().num_rows())
|
||||
.collect::<Vec<_>>()
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_max_batch_length_stream_behaviors() {
|
||||
let schema = sample_batch(7).schema();
|
||||
let mock_stream = stream::iter(vec![Ok(sample_batch(2)), Ok(sample_batch(7))]);
|
||||
|
||||
let sendable_stream: SendableRecordBatchStream =
|
||||
Box::pin(RecordBatchStreamAdapter::new(schema.clone(), mock_stream));
|
||||
assert_eq!(
|
||||
collect_batch_sizes(sendable_stream, 3).await,
|
||||
vec![2, 3, 3, 1]
|
||||
);
|
||||
|
||||
let sendable_stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new(
|
||||
schema,
|
||||
stream::iter(vec![Ok(sample_batch(2)), Ok(sample_batch(7))]),
|
||||
));
|
||||
assert_eq!(collect_batch_sizes(sendable_stream, 0).await, vec![2, 7]);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user