mirror of
https://github.com/lancedb/lancedb.git
synced 2026-04-07 16:30:41 +00:00
Compare commits
4 Commits
justin/oss
...
feature/wa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b0d3fadfc0 | ||
|
|
fd9dd390fc | ||
|
|
931f19b737 | ||
|
|
cde0814bbc |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.28.0-beta.1"
|
||||
current_version = "0.28.0-beta.0"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
235
Cargo.lock
generated
235
Cargo.lock
generated
@@ -290,34 +290,6 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "arrow-flight"
|
||||
version = "57.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "58c5b083668e6230eae3eab2fc4b5fb989974c845d0aa538dde61a4327c78675"
|
||||
dependencies = [
|
||||
"arrow-arith",
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
"arrow-cast",
|
||||
"arrow-data",
|
||||
"arrow-ipc",
|
||||
"arrow-ord",
|
||||
"arrow-row",
|
||||
"arrow-schema",
|
||||
"arrow-select",
|
||||
"arrow-string",
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"futures",
|
||||
"once_cell",
|
||||
"paste",
|
||||
"prost",
|
||||
"prost-types",
|
||||
"tonic",
|
||||
"tonic-prost",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "arrow-ipc"
|
||||
version = "57.3.0"
|
||||
@@ -1097,7 +1069,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"axum-core 0.4.5",
|
||||
"axum-core",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 1.4.0",
|
||||
@@ -1106,7 +1078,7 @@ dependencies = [
|
||||
"hyper 1.8.1",
|
||||
"hyper-util",
|
||||
"itoa",
|
||||
"matchit 0.7.3",
|
||||
"matchit",
|
||||
"memchr",
|
||||
"mime",
|
||||
"percent-encoding",
|
||||
@@ -1124,31 +1096,6 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum"
|
||||
version = "0.8.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8"
|
||||
dependencies = [
|
||||
"axum-core 0.5.6",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 1.4.0",
|
||||
"http-body 1.0.1",
|
||||
"http-body-util",
|
||||
"itoa",
|
||||
"matchit 0.8.4",
|
||||
"memchr",
|
||||
"mime",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"serde_core",
|
||||
"sync_wrapper",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-core"
|
||||
version = "0.4.5"
|
||||
@@ -1170,24 +1117,6 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-core"
|
||||
version = "0.5.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"http 1.4.0",
|
||||
"http-body 1.0.1",
|
||||
"http-body-util",
|
||||
"mime",
|
||||
"pin-project-lite",
|
||||
"sync_wrapper",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backoff"
|
||||
version = "0.4.0"
|
||||
@@ -1365,7 +1294,7 @@ version = "3.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "519bd3116aeeb42d5372c29d982d16d0170d3d4a5ed85fc7dd91642ffff3c67c"
|
||||
dependencies = [
|
||||
"darling 0.23.0",
|
||||
"darling 0.20.11",
|
||||
"ident_case",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
@@ -2753,7 +2682,7 @@ dependencies = [
|
||||
"libc",
|
||||
"option-ext",
|
||||
"redox_users",
|
||||
"windows-sys 0.61.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2947,7 +2876,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.61.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3143,8 +3072,8 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
|
||||
|
||||
[[package]]
|
||||
name = "fsst"
|
||||
version = "5.0.0-beta.5"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945"
|
||||
version = "5.0.0-beta.4"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"rand 0.9.2",
|
||||
@@ -3790,19 +3719,6 @@ dependencies = [
|
||||
"webpki-roots 1.0.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-timeout"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0"
|
||||
dependencies = [
|
||||
"hyper 1.8.1",
|
||||
"hyper-util",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-util"
|
||||
version = "0.1.20"
|
||||
@@ -4132,7 +4048,7 @@ dependencies = [
|
||||
"portable-atomic",
|
||||
"portable-atomic-util",
|
||||
"serde_core",
|
||||
"windows-sys 0.61.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4218,8 +4134,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance"
|
||||
version = "5.0.0-beta.5"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945"
|
||||
version = "5.0.0-beta.4"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -4285,8 +4201,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-arrow"
|
||||
version = "5.0.0-beta.5"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945"
|
||||
version = "5.0.0-beta.4"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4306,8 +4222,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-bitpacking"
|
||||
version = "5.0.0-beta.5"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945"
|
||||
version = "5.0.0-beta.4"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac"
|
||||
dependencies = [
|
||||
"arrayref",
|
||||
"paste",
|
||||
@@ -4316,8 +4232,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-core"
|
||||
version = "5.0.0-beta.5"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945"
|
||||
version = "5.0.0-beta.4"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4354,8 +4270,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-datafusion"
|
||||
version = "5.0.0-beta.5"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945"
|
||||
version = "5.0.0-beta.4"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4385,8 +4301,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-datagen"
|
||||
version = "5.0.0-beta.5"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945"
|
||||
version = "5.0.0-beta.4"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4404,8 +4320,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-encoding"
|
||||
version = "5.0.0-beta.5"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945"
|
||||
version = "5.0.0-beta.4"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac"
|
||||
dependencies = [
|
||||
"arrow-arith",
|
||||
"arrow-array",
|
||||
@@ -4442,8 +4358,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-file"
|
||||
version = "5.0.0-beta.5"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945"
|
||||
version = "5.0.0-beta.4"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac"
|
||||
dependencies = [
|
||||
"arrow-arith",
|
||||
"arrow-array",
|
||||
@@ -4475,8 +4391,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-index"
|
||||
version = "5.0.0-beta.5"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945"
|
||||
version = "5.0.0-beta.4"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -4540,8 +4456,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-io"
|
||||
version = "5.0.0-beta.5"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945"
|
||||
version = "5.0.0-beta.4"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -4585,8 +4501,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-linalg"
|
||||
version = "5.0.0-beta.5"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945"
|
||||
version = "5.0.0-beta.4"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4602,8 +4518,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-namespace"
|
||||
version = "5.0.0-beta.5"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945"
|
||||
version = "5.0.0-beta.4"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"async-trait",
|
||||
@@ -4616,14 +4532,14 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-namespace-impls"
|
||||
version = "5.0.0-beta.5"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945"
|
||||
version = "5.0.0-beta.4"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-ipc",
|
||||
"arrow-schema",
|
||||
"async-trait",
|
||||
"axum 0.7.9",
|
||||
"axum",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"futures",
|
||||
@@ -4662,8 +4578,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-table"
|
||||
version = "5.0.0-beta.5"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945"
|
||||
version = "5.0.0-beta.4"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4702,8 +4618,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-testing"
|
||||
version = "5.0.0-beta.5"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.5#d630106da5a238b3adfb8c5dea3b3921f3519945"
|
||||
version = "5.0.0-beta.4"
|
||||
source = "git+https://github.com/lance-format/lance.git?tag=v5.0.0-beta.4#d9068e76a301df9e21d7282419f24f61a11375ac"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-schema",
|
||||
@@ -4714,7 +4630,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb"
|
||||
version = "0.28.0-beta.1"
|
||||
version = "0.28.0-beta.0"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"anyhow",
|
||||
@@ -4722,7 +4638,6 @@ dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-cast",
|
||||
"arrow-data",
|
||||
"arrow-flight",
|
||||
"arrow-ipc",
|
||||
"arrow-ord",
|
||||
"arrow-schema",
|
||||
@@ -4776,7 +4691,6 @@ dependencies = [
|
||||
"pin-project",
|
||||
"polars",
|
||||
"polars-arrow",
|
||||
"prost",
|
||||
"rand 0.9.2",
|
||||
"random_word 0.4.3",
|
||||
"regex",
|
||||
@@ -4791,8 +4705,6 @@ dependencies = [
|
||||
"test-log",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tonic",
|
||||
"url",
|
||||
"uuid",
|
||||
"walkdir",
|
||||
@@ -4800,7 +4712,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-nodejs"
|
||||
version = "0.28.0-beta.1"
|
||||
version = "0.28.0-beta.0"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4822,7 +4734,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lancedb-python"
|
||||
version = "0.31.0-beta.1"
|
||||
version = "0.31.0-beta.0"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"async-trait",
|
||||
@@ -5097,12 +5009,6 @@ version = "0.7.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
|
||||
|
||||
[[package]]
|
||||
name = "matchit"
|
||||
version = "0.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3"
|
||||
|
||||
[[package]]
|
||||
name = "matrixmultiply"
|
||||
version = "0.3.10"
|
||||
@@ -5416,7 +5322,7 @@ version = "0.50.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
|
||||
dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6396,7 +6302,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7"
|
||||
dependencies = [
|
||||
"heck 0.5.0",
|
||||
"itertools 0.14.0",
|
||||
"itertools 0.11.0",
|
||||
"log",
|
||||
"multimap",
|
||||
"petgraph",
|
||||
@@ -6415,7 +6321,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"itertools 0.14.0",
|
||||
"itertools 0.11.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
@@ -6621,7 +6527,7 @@ dependencies = [
|
||||
"once_cell",
|
||||
"socket2 0.6.3",
|
||||
"tracing",
|
||||
"windows-sys 0.60.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7163,7 +7069,7 @@ dependencies = [
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys 0.12.1",
|
||||
"windows-sys 0.61.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -8182,7 +8088,7 @@ dependencies = [
|
||||
"getrandom 0.4.2",
|
||||
"once_cell",
|
||||
"rustix 1.1.4",
|
||||
"windows-sys 0.61.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -8464,46 +8370,6 @@ dependencies = [
|
||||
"winnow",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tonic"
|
||||
version = "0.14.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fec7c61a0695dc1887c1b53952990f3ad2e3a31453e1f49f10e75424943a93ec"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"axum 0.8.8",
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"h2 0.4.13",
|
||||
"http 1.4.0",
|
||||
"http-body 1.0.1",
|
||||
"http-body-util",
|
||||
"hyper 1.8.1",
|
||||
"hyper-timeout",
|
||||
"hyper-util",
|
||||
"percent-encoding",
|
||||
"pin-project",
|
||||
"socket2 0.6.3",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tonic-prost"
|
||||
version = "0.14.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a55376a0bbaa4975a3f10d009ad763d8f4108f067c7c2e74f3001fb49778d309"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"prost",
|
||||
"tonic",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower"
|
||||
version = "0.5.3"
|
||||
@@ -8512,12 +8378,9 @@ checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"indexmap 2.13.0",
|
||||
"pin-project-lite",
|
||||
"slab",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
@@ -9030,7 +8893,7 @@ version = "0.1.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
|
||||
dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
30
Cargo.toml
30
Cargo.toml
@@ -15,20 +15,20 @@ categories = ["database-implementations"]
|
||||
rust-version = "1.91.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=5.0.0-beta.5", default-features = false, "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-core = { "version" = "=5.0.0-beta.5", "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datagen = { "version" = "=5.0.0-beta.5", "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-file = { "version" = "=5.0.0-beta.5", "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-io = { "version" = "=5.0.0-beta.5", default-features = false, "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-index = { "version" = "=5.0.0-beta.5", "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-linalg = { "version" = "=5.0.0-beta.5", "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace = { "version" = "=5.0.0-beta.5", "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace-impls = { "version" = "=5.0.0-beta.5", default-features = false, "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-table = { "version" = "=5.0.0-beta.5", "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-testing = { "version" = "=5.0.0-beta.5", "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datafusion = { "version" = "=5.0.0-beta.5", "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-encoding = { "version" = "=5.0.0-beta.5", "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-arrow = { "version" = "=5.0.0-beta.5", "tag" = "v5.0.0-beta.5", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance = { "version" = "=5.0.0-beta.4", default-features = false, "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-core = { "version" = "=5.0.0-beta.4", "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datagen = { "version" = "=5.0.0-beta.4", "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-file = { "version" = "=5.0.0-beta.4", "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-io = { "version" = "=5.0.0-beta.4", default-features = false, "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-index = { "version" = "=5.0.0-beta.4", "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-linalg = { "version" = "=5.0.0-beta.4", "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace = { "version" = "=5.0.0-beta.4", "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-namespace-impls = { "version" = "=5.0.0-beta.4", default-features = false, "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-table = { "version" = "=5.0.0-beta.4", "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-testing = { "version" = "=5.0.0-beta.4", "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-datafusion = { "version" = "=5.0.0-beta.4", "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-encoding = { "version" = "=5.0.0-beta.4", "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
lance-arrow = { "version" = "=5.0.0-beta.4", "tag" = "v5.0.0-beta.4", "git" = "https://github.com/lance-format/lance.git" }
|
||||
ahash = "0.8"
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "57.2", optional = false }
|
||||
@@ -39,8 +39,6 @@ arrow-ord = "57.2"
|
||||
arrow-schema = "57.2"
|
||||
arrow-select = "57.2"
|
||||
arrow-cast = "57.2"
|
||||
arrow-flight = { version = "57.2", features = ["flight-sql-experimental"] }
|
||||
tonic = "0.14"
|
||||
async-trait = "0"
|
||||
datafusion = { version = "52.1", default-features = false }
|
||||
datafusion-catalog = "52.1"
|
||||
|
||||
@@ -14,7 +14,7 @@ Add the following dependency to your `pom.xml`:
|
||||
<dependency>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-core</artifactId>
|
||||
<version>0.28.0-beta.1</version>
|
||||
<version>0.28.0-beta.0</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
|
||||
@@ -53,18 +53,3 @@ optional tlsConfig: TlsConfig;
|
||||
```ts
|
||||
optional userAgent: string;
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### userId?
|
||||
|
||||
```ts
|
||||
optional userId: string;
|
||||
```
|
||||
|
||||
User identifier for tracking purposes.
|
||||
|
||||
This is sent as the `x-lancedb-user-id` header in requests to LanceDB Cloud/Enterprise.
|
||||
It can be set directly, or via the `LANCEDB_USER_ID` environment variable.
|
||||
Alternatively, set `LANCEDB_USER_ID_ENV_KEY` to specify another environment
|
||||
variable that contains the user ID value.
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.28.0-beta.1</version>
|
||||
<version>0.28.0-beta.0</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.28.0-beta.1</version>
|
||||
<version>0.28.0-beta.0</version>
|
||||
<packaging>pom</packaging>
|
||||
<name>${project.artifactId}</name>
|
||||
<description>LanceDB Java SDK Parent POM</description>
|
||||
@@ -28,7 +28,7 @@
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<arrow.version>15.0.0</arrow.version>
|
||||
<lance-core.version>5.0.0-beta.5</lance-core.version>
|
||||
<lance-core.version>5.0.0-beta.4</lance-core.version>
|
||||
<spotless.skip>false</spotless.skip>
|
||||
<spotless.version>2.30.0</spotless.version>
|
||||
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.28.0-beta.1"
|
||||
version = "0.28.0-beta.0"
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.28.0-beta.1",
|
||||
"version": "0.28.0-beta.0",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.28.0-beta.1",
|
||||
"version": "0.28.0-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||
"version": "0.28.0-beta.1",
|
||||
"version": "0.28.0-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.28.0-beta.1",
|
||||
"version": "0.28.0-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||
"version": "0.28.0-beta.1",
|
||||
"version": "0.28.0-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||
"version": "0.28.0-beta.1",
|
||||
"version": "0.28.0-beta.0",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.28.0-beta.1",
|
||||
"version": "0.28.0-beta.0",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
4
nodejs/package-lock.json
generated
4
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.28.0-beta.1",
|
||||
"version": "0.28.0-beta.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.28.0-beta.1",
|
||||
"version": "0.28.0-beta.0",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
"ann"
|
||||
],
|
||||
"private": false,
|
||||
"version": "0.28.0-beta.1",
|
||||
"version": "0.28.0-beta.0",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
|
||||
@@ -92,13 +92,6 @@ pub struct ClientConfig {
|
||||
pub extra_headers: Option<HashMap<String, String>>,
|
||||
pub id_delimiter: Option<String>,
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
/// User identifier for tracking purposes.
|
||||
///
|
||||
/// This is sent as the `x-lancedb-user-id` header in requests to LanceDB Cloud/Enterprise.
|
||||
/// It can be set directly, or via the `LANCEDB_USER_ID` environment variable.
|
||||
/// Alternatively, set `LANCEDB_USER_ID_ENV_KEY` to specify another environment
|
||||
/// variable that contains the user ID value.
|
||||
pub user_id: Option<String>,
|
||||
}
|
||||
|
||||
impl From<TimeoutConfig> for lancedb::remote::TimeoutConfig {
|
||||
@@ -152,7 +145,6 @@ impl From<ClientConfig> for lancedb::remote::ClientConfig {
|
||||
id_delimiter: config.id_delimiter,
|
||||
tls_config: config.tls_config.map(Into::into),
|
||||
header_provider: None, // the header provider is set separately later
|
||||
user_id: config.user_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.31.0-beta.1"
|
||||
current_version = "0.31.0-beta.0"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.31.0-beta.1"
|
||||
version = "0.31.0-beta.0"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
|
||||
@@ -45,7 +45,7 @@ repository = "https://github.com/lancedb/lancedb"
|
||||
|
||||
[project.optional-dependencies]
|
||||
pylance = [
|
||||
"pylance>=5.0.0b5",
|
||||
"pylance>=5.0.0b3",
|
||||
]
|
||||
tests = [
|
||||
"aiohttp>=3.9.0",
|
||||
@@ -59,7 +59,7 @@ tests = [
|
||||
"polars>=0.19, <=1.3.0",
|
||||
"tantivy>=0.20.0",
|
||||
"pyarrow-stubs>=16.0",
|
||||
"pylance>=5.0.0b5",
|
||||
"pylance>=5.0.0b3",
|
||||
"requests>=2.31.0",
|
||||
"datafusion>=52,<53",
|
||||
]
|
||||
|
||||
@@ -151,9 +151,6 @@ class Connection(object):
|
||||
async def drop_all_tables(
|
||||
self, namespace_path: Optional[List[str]] = None
|
||||
) -> None: ...
|
||||
async def namespace_client_config(
|
||||
self,
|
||||
) -> Dict[str, Any]: ...
|
||||
|
||||
class Table:
|
||||
def name(self) -> str: ...
|
||||
|
||||
@@ -23,13 +23,11 @@ from lancedb.embeddings.registry import EmbeddingFunctionRegistry
|
||||
from lancedb.common import data_to_reader, sanitize_uri, validate_schema
|
||||
from lancedb.background_loop import LOOP
|
||||
from lance_namespace import (
|
||||
LanceNamespace,
|
||||
ListNamespacesResponse,
|
||||
CreateNamespaceResponse,
|
||||
DropNamespaceResponse,
|
||||
DescribeNamespaceResponse,
|
||||
ListTablesResponse,
|
||||
connect as namespace_connect,
|
||||
)
|
||||
|
||||
from . import __version__
|
||||
@@ -509,26 +507,6 @@ class DBConnection(EnforceOverrides):
|
||||
def uri(self) -> str:
|
||||
return self._uri
|
||||
|
||||
def namespace_client(self) -> LanceNamespace:
|
||||
"""Get the equivalent namespace client for this connection.
|
||||
|
||||
For native storage connections, this returns a DirectoryNamespace
|
||||
pointing to the same root with the same storage options.
|
||||
|
||||
For namespace connections, this returns the backing namespace client.
|
||||
|
||||
For enterprise (remote) connections, this returns a RestNamespace
|
||||
with the same URI and authentication headers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceNamespace
|
||||
The namespace client for this connection.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"namespace_client is not supported for this connection type"
|
||||
)
|
||||
|
||||
|
||||
class LanceDBConnection(DBConnection):
|
||||
"""
|
||||
@@ -1066,20 +1044,6 @@ class LanceDBConnection(DBConnection):
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def namespace_client(self) -> LanceNamespace:
|
||||
"""Get the equivalent namespace client for this connection.
|
||||
|
||||
Returns a DirectoryNamespace pointing to the same root with the
|
||||
same storage options.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceNamespace
|
||||
The namespace client for this connection.
|
||||
"""
|
||||
return LOOP.run(self._conn.namespace_client())
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.15.1",
|
||||
removed_in="0.17",
|
||||
@@ -1752,25 +1716,6 @@ class AsyncConnection(object):
|
||||
namespace_path = []
|
||||
await self._inner.drop_all_tables(namespace_path=namespace_path)
|
||||
|
||||
async def namespace_client(self) -> LanceNamespace:
|
||||
"""Get the equivalent namespace client for this connection.
|
||||
|
||||
For native storage connections, this returns a DirectoryNamespace
|
||||
pointing to the same root with the same storage options.
|
||||
|
||||
For namespace connections, this returns the backing namespace client.
|
||||
|
||||
For enterprise (remote) connections, this returns a RestNamespace
|
||||
with the same URI and authentication headers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceNamespace
|
||||
The namespace client for this connection.
|
||||
"""
|
||||
config = await self._inner.namespace_client_config()
|
||||
return namespace_connect(config["impl"], config["properties"])
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.15.1",
|
||||
removed_in="0.17",
|
||||
|
||||
@@ -890,20 +890,6 @@ class LanceNamespaceDBConnection(DBConnection):
|
||||
pushdown_operations=self._pushdown_operations,
|
||||
)
|
||||
|
||||
@override
|
||||
def namespace_client(self) -> LanceNamespace:
|
||||
"""Get the namespace client for this connection.
|
||||
|
||||
For namespace connections, this returns the backing namespace client
|
||||
that was provided during construction.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceNamespace
|
||||
The namespace client for this connection.
|
||||
"""
|
||||
return self._namespace_client
|
||||
|
||||
|
||||
class AsyncLanceNamespaceDBConnection:
|
||||
"""
|
||||
@@ -1401,19 +1387,6 @@ class AsyncLanceNamespaceDBConnection:
|
||||
page_token=response.page_token,
|
||||
)
|
||||
|
||||
async def namespace_client(self) -> LanceNamespace:
|
||||
"""Get the namespace client for this connection.
|
||||
|
||||
For namespace connections, this returns the backing namespace client
|
||||
that was provided during construction.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceNamespace
|
||||
The namespace client for this connection.
|
||||
"""
|
||||
return self._namespace_client
|
||||
|
||||
|
||||
def connect_namespace(
|
||||
namespace_client_impl: str,
|
||||
|
||||
@@ -10,7 +10,6 @@ import sys
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import date, datetime
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -315,19 +314,6 @@ def _pydantic_type_to_arrow_type(tp: Any, field: FieldInfo) -> pa.DataType:
|
||||
return pa.list_(pa.list_(tp.value_arrow_type(), tp.dim()))
|
||||
# For regular Vector
|
||||
return pa.list_(tp.value_arrow_type(), tp.dim())
|
||||
if _safe_issubclass(tp, Enum):
|
||||
# Map Enum to the Arrow type of its value.
|
||||
# For string-valued enums, use dictionary encoding for efficiency.
|
||||
# For integer enums, use the native type.
|
||||
# Fall back to utf8 for mixed-type or empty enums.
|
||||
value_types = {type(m.value) for m in tp}
|
||||
if len(value_types) == 1:
|
||||
value_type = value_types.pop()
|
||||
if value_type is str:
|
||||
# Use dictionary encoding for string enums
|
||||
return pa.dictionary(pa.int32(), pa.utf8())
|
||||
return _py_type_to_arrow_type(value_type, field)
|
||||
return pa.utf8()
|
||||
return _py_type_to_arrow_type(tp, field)
|
||||
|
||||
|
||||
|
||||
@@ -145,33 +145,6 @@ class TlsConfig:
|
||||
|
||||
@dataclass
|
||||
class ClientConfig:
|
||||
"""Configuration for the LanceDB Cloud HTTP client.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
user_agent: str
|
||||
User agent string sent with requests.
|
||||
retry_config: RetryConfig
|
||||
Configuration for retrying failed requests.
|
||||
timeout_config: Optional[TimeoutConfig]
|
||||
Configuration for request timeouts.
|
||||
extra_headers: Optional[dict]
|
||||
Additional headers to include in requests.
|
||||
id_delimiter: Optional[str]
|
||||
The delimiter to use when constructing object identifiers.
|
||||
tls_config: Optional[TlsConfig]
|
||||
TLS/mTLS configuration for secure connections.
|
||||
header_provider: Optional[HeaderProvider]
|
||||
Provider for dynamic headers to be added to each request.
|
||||
user_id: Optional[str]
|
||||
User identifier for tracking purposes. This is sent as the
|
||||
`x-lancedb-user-id` header in requests to LanceDB Cloud/Enterprise.
|
||||
|
||||
This can also be set via the `LANCEDB_USER_ID` environment variable.
|
||||
Alternatively, set `LANCEDB_USER_ID_ENV_KEY` to specify another
|
||||
environment variable that contains the user ID value.
|
||||
"""
|
||||
|
||||
user_agent: str = f"LanceDB-Python-Client/{__version__}"
|
||||
retry_config: RetryConfig = field(default_factory=RetryConfig)
|
||||
timeout_config: Optional[TimeoutConfig] = field(default_factory=TimeoutConfig)
|
||||
@@ -179,7 +152,6 @@ class ClientConfig:
|
||||
id_delimiter: Optional[str] = None
|
||||
tls_config: Optional[TlsConfig] = None
|
||||
header_provider: Optional["HeaderProvider"] = None
|
||||
user_id: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.retry_config, dict):
|
||||
|
||||
@@ -24,7 +24,6 @@ from ..common import DATA
|
||||
from ..db import DBConnection, LOOP
|
||||
from ..embeddings import EmbeddingFunctionConfig
|
||||
from lance_namespace import (
|
||||
LanceNamespace,
|
||||
CreateNamespaceResponse,
|
||||
DescribeNamespaceResponse,
|
||||
DropNamespaceResponse,
|
||||
@@ -571,19 +570,6 @@ class RemoteDBConnection(DBConnection):
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def namespace_client(self) -> LanceNamespace:
|
||||
"""Get the equivalent namespace client for this connection.
|
||||
|
||||
Returns a RestNamespace with the same URI and authentication headers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceNamespace
|
||||
The namespace client for this connection.
|
||||
"""
|
||||
return LOOP.run(self._conn.namespace_client())
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection to the database."""
|
||||
self._conn.close()
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
|
||||
import re
|
||||
import sys
|
||||
from datetime import timedelta
|
||||
import os
|
||||
|
||||
@@ -1049,59 +1048,3 @@ def test_clone_table_deep_clone_fails(tmp_path):
|
||||
source_uri = os.path.join(tmp_path, "source.lance")
|
||||
with pytest.raises(Exception, match="Deep clone is not yet implemented"):
|
||||
db.clone_table("cloned", source_uri, is_shallow=False)
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Namespace client issues")
|
||||
def test_namespace_client_native_storage(tmp_path):
|
||||
"""Test namespace_client() returns DirectoryNamespace for native storage."""
|
||||
from lance.namespace import DirectoryNamespace
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
ns_client = db.namespace_client()
|
||||
|
||||
assert isinstance(ns_client, DirectoryNamespace)
|
||||
assert str(tmp_path) in ns_client.namespace_id()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Namespace client issues")
|
||||
def test_namespace_client_with_storage_options(tmp_path):
|
||||
"""Test namespace_client() preserves storage options."""
|
||||
from lance.namespace import DirectoryNamespace
|
||||
|
||||
storage_options = {"timeout": "10s"}
|
||||
db = lancedb.connect(tmp_path, storage_options=storage_options)
|
||||
ns_client = db.namespace_client()
|
||||
|
||||
assert isinstance(ns_client, DirectoryNamespace)
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Namespace client issues")
|
||||
def test_namespace_client_operations(tmp_path):
|
||||
"""Test that namespace_client() returns a functional namespace client."""
|
||||
db = lancedb.connect(tmp_path)
|
||||
ns_client = db.namespace_client()
|
||||
|
||||
# Create a table through the main db connection
|
||||
data = [{"id": 1, "text": "hello", "vector": [1.0, 2.0]}]
|
||||
db.create_table("test_table", data=data)
|
||||
|
||||
# Verify the namespace client can see the table
|
||||
from lance_namespace import ListTablesRequest
|
||||
|
||||
# id=[] means root namespace
|
||||
response = ns_client.list_tables(ListTablesRequest(id=[]))
|
||||
# Tables can be strings or objects with name attribute
|
||||
table_names = [t.name if hasattr(t, "name") else t for t in response.tables]
|
||||
assert "test_table" in table_names
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Namespace client issues")
|
||||
def test_namespace_client_namespace_connection(tmp_path):
|
||||
"""Test namespace_client() returns the backing client for namespace connections."""
|
||||
from lance.namespace import DirectoryNamespace
|
||||
|
||||
db = lancedb.connect_namespace("dir", {"root": str(tmp_path)})
|
||||
ns_client = db.namespace_client()
|
||||
|
||||
assert isinstance(ns_client, DirectoryNamespace)
|
||||
assert str(tmp_path) in ns_client.namespace_id()
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
import json
|
||||
from datetime import date, datetime
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import pyarrow as pa
|
||||
@@ -674,29 +673,3 @@ async def test_aliases_in_lance_model_async(mem_db_async):
|
||||
assert hasattr(model, "name")
|
||||
assert hasattr(model, "distance")
|
||||
assert model.distance < 0.01
|
||||
|
||||
|
||||
def test_enum_types():
|
||||
"""Enum fields should map to the Arrow type of their value (issue #1846)."""
|
||||
|
||||
class StrStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
DONE = "done"
|
||||
|
||||
class IntPriority(int, Enum):
|
||||
LOW = 1
|
||||
MEDIUM = 2
|
||||
HIGH = 3
|
||||
|
||||
class TestModel(pydantic.BaseModel):
|
||||
status: StrStatus
|
||||
priority: IntPriority
|
||||
opt_status: Optional[StrStatus] = None
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
|
||||
assert schema.field("status").type == pa.dictionary(pa.int32(), pa.utf8())
|
||||
assert schema.field("priority").type == pa.int64()
|
||||
assert schema.field("opt_status").type == pa.dictionary(pa.int32(), pa.utf8())
|
||||
assert schema.field("opt_status").nullable
|
||||
|
||||
@@ -474,25 +474,6 @@ impl Connection {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the configuration for constructing an equivalent namespace client.
|
||||
/// Returns a dict with:
|
||||
/// - "impl": "dir" for DirectoryNamespace, "rest" for RestNamespace
|
||||
/// - "properties": configuration properties for the namespace
|
||||
#[pyo3(signature = ())]
|
||||
pub fn namespace_client_config(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
let py = self_.py();
|
||||
future_into_py(py, async move {
|
||||
let (impl_type, properties) = inner.namespace_client_config().await.infer_error()?;
|
||||
Python::attach(|py| -> PyResult<Py<PyDict>> {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("impl", impl_type)?;
|
||||
dict.set_item("properties", properties)?;
|
||||
Ok(dict.unbind())
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
@@ -547,7 +528,6 @@ pub struct PyClientConfig {
|
||||
id_delimiter: Option<String>,
|
||||
tls_config: Option<PyClientTlsConfig>,
|
||||
header_provider: Option<Py<PyAny>>,
|
||||
user_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(FromPyObject)]
|
||||
@@ -632,7 +612,6 @@ impl From<PyClientConfig> for lancedb::remote::ClientConfig {
|
||||
id_delimiter: value.id_delimiter,
|
||||
tls_config: value.tls_config.map(Into::into),
|
||||
header_provider,
|
||||
user_id: value.user_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.28.0-beta.1"
|
||||
version = "0.28.0-beta.0"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
@@ -88,10 +88,6 @@ candle-transformers = { version = "0.9.1", optional = true }
|
||||
candle-nn = { version = "0.9.1", optional = true }
|
||||
tokenizers = { version = "0.19.1", optional = true }
|
||||
semver = { workspace = true }
|
||||
# For flight feature (Arrow Flight SQL server)
|
||||
arrow-flight = { workspace = true, optional = true }
|
||||
tonic = { workspace = true, optional = true }
|
||||
prost = { version = "0.14", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = "1"
|
||||
@@ -108,7 +104,6 @@ datafusion.workspace = true
|
||||
http-body = "1" # Matching reqwest
|
||||
rstest = "0.23.0"
|
||||
test-log = "0.2"
|
||||
tokio-stream = "0.1"
|
||||
|
||||
|
||||
[features]
|
||||
@@ -136,7 +131,6 @@ sentence-transformers = [
|
||||
"dep:candle-nn",
|
||||
"dep:tokenizers",
|
||||
]
|
||||
flight = ["dep:arrow-flight", "dep:tonic", "dep:prost"]
|
||||
|
||||
[[example]]
|
||||
name = "openai"
|
||||
@@ -163,9 +157,5 @@ name = "ivf_pq"
|
||||
name = "hybrid_search"
|
||||
required-features = ["sentence-transformers"]
|
||||
|
||||
[[example]]
|
||||
name = "flight_sql"
|
||||
required-features = ["flight"]
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
all-features = true
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Example: LanceDB Arrow Flight SQL Server
|
||||
//!
|
||||
//! This example demonstrates how to:
|
||||
//! 1. Create a LanceDB database with sample data
|
||||
//! 2. Start an Arrow Flight SQL server
|
||||
//! 3. Connect with a Flight SQL client
|
||||
//! 4. Run SQL queries including vector_search and fts table functions
|
||||
//!
|
||||
//! Run with: `cargo run --features flight --example flight_sql`
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, RecordBatch, StringArray};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use lance_arrow::FixedSizeListArrayExt;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Initialize logging if env_logger is available
|
||||
let _ = std::env::var("RUST_LOG").ok();
|
||||
|
||||
// 1. Create an in-memory LanceDB database
|
||||
let db = lancedb::connect("memory://flight_sql_demo")
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
// 2. Create a table with text and vector data
|
||||
let dim = 4i32;
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Int32, false),
|
||||
Field::new("text", DataType::Utf8, false),
|
||||
Field::new(
|
||||
"vector",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
|
||||
true,
|
||||
),
|
||||
]));
|
||||
|
||||
let ids = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
let texts = StringArray::from(vec![
|
||||
"the quick brown fox jumps over the lazy dog",
|
||||
"a fast red fox leaps across the sleeping hound",
|
||||
"machine learning models process natural language",
|
||||
"neural networks learn from training data",
|
||||
"the brown dog chases the red fox through the forest",
|
||||
"deep learning algorithms improve with more data",
|
||||
"a lazy cat sleeps on the warm windowsill",
|
||||
"vector databases enable fast similarity search",
|
||||
]);
|
||||
let flat_values = Float32Array::from(vec![
|
||||
1.0, 0.0, 0.0, 0.0, // fox-like
|
||||
0.9, 0.1, 0.0, 0.0, // similar to fox
|
||||
0.0, 1.0, 0.0, 0.0, // ML-like
|
||||
0.0, 0.9, 0.1, 0.0, // similar to ML
|
||||
0.7, 0.3, 0.0, 0.0, // fox+dog mix
|
||||
0.0, 0.8, 0.2, 0.0, // ML-like
|
||||
0.1, 0.0, 0.0, 0.9, // cat-like
|
||||
0.0, 0.5, 0.5, 0.0, // tech-like
|
||||
]);
|
||||
let vector_array = FixedSizeListArray::try_new_from_values(flat_values, dim)?;
|
||||
|
||||
let batch = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(ids), Arc::new(texts), Arc::new(vector_array)],
|
||||
)?;
|
||||
|
||||
let table = db.create_table("documents", batch).execute().await?;
|
||||
|
||||
// 3. Create indices
|
||||
println!("Creating FTS index on 'text' column...");
|
||||
table
|
||||
.create_index(&["text"], lancedb::index::Index::FTS(Default::default()))
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
println!("Database ready with {} rows", table.count_rows(None).await?);
|
||||
|
||||
// 4. Start Flight SQL server
|
||||
let addr = "0.0.0.0:50051".parse()?;
|
||||
println!("Starting Arrow Flight SQL server on {}...", addr);
|
||||
println!();
|
||||
println!("Connect with any ADBC Flight SQL client:");
|
||||
println!(" URI: grpc://localhost:50051");
|
||||
println!();
|
||||
println!("Example SQL queries:");
|
||||
println!(" SELECT * FROM documents LIMIT 5;");
|
||||
println!(" SELECT * FROM vector_search('documents', '[1.0, 0.0, 0.0, 0.0]', 3);");
|
||||
println!(
|
||||
" SELECT * FROM fts('documents', '{{\"match\": {{\"column\": \"text\", \"terms\": \"fox\"}}}}');"
|
||||
);
|
||||
println!();
|
||||
|
||||
lancedb::flight::serve(db, addr).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -30,7 +30,10 @@ use crate::error::{Error, Result};
|
||||
#[cfg(feature = "remote")]
|
||||
use crate::remote::{
|
||||
client::ClientConfig,
|
||||
db::{OPT_REMOTE_API_KEY, OPT_REMOTE_HOST_OVERRIDE, OPT_REMOTE_REGION},
|
||||
db::{
|
||||
OPT_REMOTE_API_KEY, OPT_REMOTE_HOST_OVERRIDE, OPT_REMOTE_REGION,
|
||||
OPT_REMOTE_WAL_HOST_OVERRIDE,
|
||||
},
|
||||
};
|
||||
use lance::io::ObjectStoreParams;
|
||||
pub use lance_encoding::version::LanceFileVersion;
|
||||
@@ -541,16 +544,6 @@ impl Connection {
|
||||
self.internal.namespace_client().await
|
||||
}
|
||||
|
||||
/// Get the configuration for constructing an equivalent namespace client.
|
||||
/// Returns (impl_type, properties) where:
|
||||
/// - impl_type: "dir" for DirectoryNamespace, "rest" for RestNamespace
|
||||
/// - properties: configuration properties for the namespace
|
||||
pub async fn namespace_client_config(
|
||||
&self,
|
||||
) -> Result<(String, std::collections::HashMap<String, String>)> {
|
||||
self.internal.namespace_client_config().await
|
||||
}
|
||||
|
||||
/// List tables with pagination support
|
||||
pub async fn list_tables(&self, request: ListTablesRequest) -> Result<ListTablesResponse> {
|
||||
self.internal.list_tables(request).await
|
||||
@@ -676,6 +669,24 @@ impl ConnectBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the WAL host override for routing merge_insert requests
|
||||
/// to a separate WAL/ingest service.
|
||||
///
|
||||
/// This option is only used when connecting to LanceDB Cloud (db:// URIs)
|
||||
/// and will be ignored for other URIs.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `wal_host_override` - The WAL host override to use for the connection
|
||||
#[cfg(feature = "remote")]
|
||||
pub fn wal_host_override(mut self, wal_host_override: &str) -> Self {
|
||||
self.request.options.insert(
|
||||
OPT_REMOTE_WAL_HOST_OVERRIDE.to_string(),
|
||||
wal_host_override.to_string(),
|
||||
);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the database specific options
|
||||
///
|
||||
/// See [crate::database::listing::ListingDatabaseOptions] for the options available for
|
||||
@@ -829,6 +840,7 @@ impl ConnectBuilder {
|
||||
&api_key,
|
||||
®ion,
|
||||
options.host_override,
|
||||
options.wal_host_override,
|
||||
self.request.client_config,
|
||||
storage_options.into(),
|
||||
)?);
|
||||
|
||||
@@ -265,13 +265,4 @@ pub trait Database:
|
||||
/// For ListingDatabase, it is the equivalent DirectoryNamespace.
|
||||
/// For RemoteDatabase, it is the equivalent RestNamespace.
|
||||
async fn namespace_client(&self) -> Result<Arc<dyn LanceNamespace>>;
|
||||
|
||||
/// Get the configuration for constructing an equivalent namespace client.
|
||||
/// Returns (impl_type, properties) where:
|
||||
/// - impl_type: "dir" for DirectoryNamespace, "rest" for RestNamespace
|
||||
/// - properties: configuration properties for the namespace
|
||||
///
|
||||
/// This is useful for Python bindings where we want to return a Python
|
||||
/// namespace object rather than a Rust trait object.
|
||||
async fn namespace_client_config(&self) -> Result<(String, HashMap<String, String>)>;
|
||||
}
|
||||
|
||||
@@ -1099,15 +1099,6 @@ impl Database for ListingDatabase {
|
||||
})?;
|
||||
Ok(Arc::new(namespace) as Arc<dyn lance_namespace::LanceNamespace>)
|
||||
}
|
||||
|
||||
async fn namespace_client_config(&self) -> Result<(String, HashMap<String, String>)> {
|
||||
let mut properties = HashMap::new();
|
||||
properties.insert("root".to_string(), self.uri.clone());
|
||||
for (key, value) in &self.storage_options {
|
||||
properties.insert(format!("storage.{}", key), value.clone());
|
||||
}
|
||||
Ok(("dir".to_string(), properties))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -45,10 +45,6 @@ pub struct LanceNamespaceDatabase {
|
||||
uri: String,
|
||||
// Operations to push down to the namespace server
|
||||
pushdown_operations: HashSet<PushdownOperation>,
|
||||
// Namespace implementation type (e.g., "dir", "rest")
|
||||
ns_impl: String,
|
||||
// Namespace properties used to construct the namespace client
|
||||
ns_properties: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl LanceNamespaceDatabase {
|
||||
@@ -78,8 +74,6 @@ impl LanceNamespaceDatabase {
|
||||
session,
|
||||
uri: format!("namespace://{}", ns_impl),
|
||||
pushdown_operations,
|
||||
ns_impl: ns_impl.to_string(),
|
||||
ns_properties,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -351,10 +345,6 @@ impl Database for LanceNamespaceDatabase {
|
||||
async fn namespace_client(&self) -> Result<Arc<dyn LanceNamespace>> {
|
||||
Ok(self.namespace.clone())
|
||||
}
|
||||
|
||||
async fn namespace_client_config(&self) -> Result<(String, HashMap<String, String>)> {
|
||||
Ok((self.ns_impl.clone(), self.ns_properties.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1,606 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Arrow Flight SQL server for LanceDB.
|
||||
//!
|
||||
//! This module provides an Arrow Flight SQL server that exposes a LanceDB
|
||||
//! [`Connection`] over the Flight SQL protocol. Any ADBC, ODBC (via bridge),
|
||||
//! or JDBC Flight SQL client can connect and run SQL queries — including
|
||||
//! LanceDB's search table functions (`vector_search`, `fts`, `hybrid_search`).
|
||||
//!
|
||||
//! # Quick Start
|
||||
//!
|
||||
//! ```no_run
|
||||
//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
|
||||
//! let db = lancedb::connect("data/my-db").execute().await?;
|
||||
//! let addr = "0.0.0.0:50051".parse()?;
|
||||
//! lancedb::flight::serve(db, addr).await?;
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::{ArrayRef, RecordBatch, StringArray};
|
||||
use arrow_flight::encode::FlightDataEncoderBuilder;
|
||||
use arrow_flight::error::FlightError;
|
||||
use arrow_flight::flight_service_server::FlightServiceServer;
|
||||
use arrow_flight::sql::server::FlightSqlService;
|
||||
use arrow_flight::sql::{
|
||||
Any, CommandGetCatalogs, CommandGetDbSchemas, CommandGetTableTypes, CommandGetTables,
|
||||
CommandStatementQuery, SqlInfo, TicketStatementQuery,
|
||||
};
|
||||
use arrow_flight::{FlightDescriptor, FlightEndpoint, FlightInfo, Ticket};
|
||||
use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef};
|
||||
use datafusion::prelude::SessionContext;
|
||||
use datafusion_catalog::TableProvider;
|
||||
use datafusion_common::{DataFusionError, Result as DataFusionResult};
|
||||
use futures::StreamExt;
|
||||
use futures::stream;
|
||||
use log;
|
||||
use prost::Message;
|
||||
use tonic::transport::Server;
|
||||
use tonic::{Request, Response, Status};
|
||||
|
||||
use crate::connection::Connection;
|
||||
use crate::table::datafusion::BaseTableAdapter;
|
||||
use crate::table::datafusion::udtf::fts::FtsTableFunction;
|
||||
use crate::table::datafusion::udtf::hybrid_search::HybridSearchTableFunction;
|
||||
use crate::table::datafusion::udtf::vector_search::VectorSearchTableFunction;
|
||||
use crate::table::datafusion::udtf::{SearchQuery, TableResolver};
|
||||
|
||||
/// Start an Arrow Flight SQL server exposing the given LanceDB connection.
|
||||
///
|
||||
/// This is a convenience function that creates a server and starts listening.
|
||||
/// It blocks until the server is shut down.
|
||||
pub async fn serve(connection: Connection, addr: SocketAddr) -> crate::Result<()> {
|
||||
let service = LanceFlightSqlService::try_new(connection).await?;
|
||||
let flight_svc = FlightServiceServer::new(service);
|
||||
|
||||
Server::builder()
|
||||
.add_service(flight_svc)
|
||||
.serve(addr)
|
||||
.await
|
||||
.map_err(|e| crate::Error::Runtime {
|
||||
message: format!("Flight SQL server error: {}", e),
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// A table resolver that looks up tables from a pre-built HashMap.
|
||||
#[derive(Debug)]
|
||||
struct ConnectionTableResolver {
|
||||
tables: HashMap<String, Arc<BaseTableAdapter>>,
|
||||
}
|
||||
|
||||
impl TableResolver for ConnectionTableResolver {
|
||||
fn resolve_table(
|
||||
&self,
|
||||
name: &str,
|
||||
search: Option<SearchQuery>,
|
||||
) -> DataFusionResult<Arc<dyn TableProvider>> {
|
||||
let adapter = self
|
||||
.tables
|
||||
.get(name)
|
||||
.ok_or_else(|| DataFusionError::Plan(format!("Table '{}' not found", name)))?;
|
||||
|
||||
match search {
|
||||
None => Ok(adapter.clone() as Arc<dyn TableProvider>),
|
||||
Some(SearchQuery::Fts(fts)) => Ok(Arc::new(adapter.with_fts_query(fts))),
|
||||
Some(SearchQuery::Vector(vq)) => Ok(Arc::new(adapter.with_vector_query(vq))),
|
||||
Some(SearchQuery::Hybrid { fts, vector }) => {
|
||||
Ok(Arc::new(adapter.with_hybrid_query(fts, vector)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Arrow Flight SQL service backed by a LanceDB connection.
|
||||
struct LanceFlightSqlService {
|
||||
/// Kept for future use (e.g., refreshing table list).
|
||||
_connection: Connection,
|
||||
/// Pre-built table adapters (refreshed on creation)
|
||||
tables: HashMap<String, Arc<BaseTableAdapter>>,
|
||||
}
|
||||
|
||||
impl LanceFlightSqlService {
|
||||
async fn try_new(connection: Connection) -> crate::Result<Self> {
|
||||
let table_names = connection.table_names().execute().await?;
|
||||
let mut tables = HashMap::new();
|
||||
|
||||
for name in &table_names {
|
||||
let table = connection.open_table(name).execute().await?;
|
||||
let adapter = BaseTableAdapter::try_new(table.base_table().clone()).await?;
|
||||
tables.insert(name.clone(), Arc::new(adapter));
|
||||
}
|
||||
|
||||
log::info!(
|
||||
"LanceDB Flight SQL service initialized with {} tables",
|
||||
tables.len()
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
_connection: connection,
|
||||
tables,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a DataFusion SessionContext with all tables and UDTFs registered.
|
||||
fn create_session_context(&self) -> SessionContext {
|
||||
let ctx = SessionContext::new();
|
||||
|
||||
// Register all tables
|
||||
for (name, adapter) in &self.tables {
|
||||
if let Err(e) = ctx.register_table(name, adapter.clone() as Arc<dyn TableProvider>) {
|
||||
log::warn!("Failed to register table '{}': {}", name, e);
|
||||
}
|
||||
}
|
||||
|
||||
// Create resolver for UDTFs
|
||||
let resolver = Arc::new(ConnectionTableResolver {
|
||||
tables: self.tables.clone(),
|
||||
});
|
||||
|
||||
// Register search UDTFs
|
||||
ctx.register_udtf("fts", Arc::new(FtsTableFunction::new(resolver.clone())));
|
||||
ctx.register_udtf(
|
||||
"vector_search",
|
||||
Arc::new(VectorSearchTableFunction::new(resolver.clone())),
|
||||
);
|
||||
ctx.register_udtf(
|
||||
"hybrid_search",
|
||||
Arc::new(HybridSearchTableFunction::new(resolver)),
|
||||
);
|
||||
|
||||
ctx
|
||||
}
|
||||
|
||||
/// Execute a SQL query and return the results as a stream of FlightData.
|
||||
async fn execute_sql(
|
||||
&self,
|
||||
sql: &str,
|
||||
) -> Result<
|
||||
Pin<Box<dyn futures::Stream<Item = Result<arrow_flight::FlightData, Status>> + Send>>,
|
||||
Status,
|
||||
> {
|
||||
let ctx = self.create_session_context();
|
||||
|
||||
let df = ctx
|
||||
.sql(sql)
|
||||
.await
|
||||
.map_err(|e| Status::internal(format!("SQL planning error: {}", e)))?;
|
||||
|
||||
let schema: SchemaRef = df.schema().inner().clone();
|
||||
let stream = df
|
||||
.execute_stream()
|
||||
.await
|
||||
.map_err(|e| Status::internal(format!("SQL execution error: {}", e)))?;
|
||||
|
||||
// Use FlightDataEncoderBuilder to properly encode batches with schema
|
||||
let batch_stream = stream.map(|r| r.map_err(|e| FlightError::ExternalError(Box::new(e))));
|
||||
let flight_data_stream = FlightDataEncoderBuilder::new()
|
||||
.with_schema(schema)
|
||||
.build(batch_stream)
|
||||
.map(|result| result.map_err(|e| Status::internal(format!("Encoding error: {}", e))));
|
||||
|
||||
Ok(Box::pin(flight_data_stream))
|
||||
}
|
||||
|
||||
/// Encode a single RecordBatch into a FlightData stream (schema + data).
|
||||
fn batch_to_flight_stream(
|
||||
batch: RecordBatch,
|
||||
) -> Pin<Box<dyn futures::Stream<Item = Result<arrow_flight::FlightData, Status>> + Send>> {
|
||||
let schema = batch.schema();
|
||||
let stream = FlightDataEncoderBuilder::new()
|
||||
.with_schema(schema)
|
||||
.build(stream::once(async move { Ok(batch) }))
|
||||
.map(|result| result.map_err(|e| Status::internal(format!("Encoding error: {}", e))));
|
||||
Box::pin(stream)
|
||||
}
|
||||
|
||||
/// Get the schema for a SQL query without executing it.
|
||||
async fn get_sql_schema(&self, sql: &str) -> Result<ArrowSchema, Status> {
|
||||
let ctx = self.create_session_context();
|
||||
let df = ctx
|
||||
.sql(sql)
|
||||
.await
|
||||
.map_err(|e| Status::internal(format!("SQL planning error: {}", e)))?;
|
||||
Ok(df.schema().inner().as_ref().clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl FlightSqlService for LanceFlightSqlService {
|
||||
type FlightService = LanceFlightSqlService;
|
||||
|
||||
/// Handle SQL query: return FlightInfo with schema and ticket.
|
||||
async fn get_flight_info_statement(
|
||||
&self,
|
||||
query: CommandStatementQuery,
|
||||
request: Request<FlightDescriptor>,
|
||||
) -> Result<Response<FlightInfo>, Status> {
|
||||
let sql = query.query;
|
||||
log::info!("get_flight_info_statement: {}", sql);
|
||||
|
||||
let schema = self.get_sql_schema(&sql).await?;
|
||||
|
||||
// Encode the query as an Any-wrapped TicketStatementQuery for do_get_statement
|
||||
let ticket = TicketStatementQuery {
|
||||
statement_handle: sql.into_bytes().into(),
|
||||
};
|
||||
let any_msg = Any::pack(&ticket)
|
||||
.map_err(|e| Status::internal(format!("Ticket encoding error: {}", e)))?;
|
||||
let mut ticket_bytes = Vec::new();
|
||||
any_msg
|
||||
.encode(&mut ticket_bytes)
|
||||
.map_err(|e| Status::internal(format!("Ticket encoding error: {}", e)))?;
|
||||
|
||||
let endpoint = FlightEndpoint::new().with_ticket(Ticket::new(ticket_bytes));
|
||||
let flight_info = FlightInfo::new()
|
||||
.try_with_schema(&schema)
|
||||
.map_err(|e| Status::internal(format!("Schema error: {}", e)))?
|
||||
.with_endpoint(endpoint)
|
||||
.with_descriptor(request.into_inner());
|
||||
|
||||
Ok(Response::new(flight_info))
|
||||
}
|
||||
|
||||
/// Execute a SQL query and stream results.
|
||||
async fn do_get_statement(
|
||||
&self,
|
||||
ticket: TicketStatementQuery,
|
||||
_request: Request<Ticket>,
|
||||
) -> Result<
|
||||
Response<<Self as arrow_flight::flight_service_server::FlightService>::DoGetStream>,
|
||||
Status,
|
||||
> {
|
||||
let sql = String::from_utf8(ticket.statement_handle.to_vec())
|
||||
.map_err(|e| Status::internal(format!("Invalid ticket: {}", e)))?;
|
||||
log::info!("do_get_statement: {}", sql);
|
||||
|
||||
let stream = self.execute_sql(&sql).await?;
|
||||
Ok(Response::new(stream))
|
||||
}
|
||||
|
||||
/// List tables in the database.
|
||||
async fn get_flight_info_tables(
|
||||
&self,
|
||||
_query: CommandGetTables,
|
||||
request: Request<FlightDescriptor>,
|
||||
) -> Result<Response<FlightInfo>, Status> {
|
||||
let schema = ArrowSchema::new(vec![
|
||||
Field::new("catalog_name", DataType::Utf8, true),
|
||||
Field::new("db_schema_name", DataType::Utf8, true),
|
||||
Field::new("table_name", DataType::Utf8, false),
|
||||
Field::new("table_type", DataType::Utf8, false),
|
||||
]);
|
||||
|
||||
let cmd = CommandGetTables::default();
|
||||
let any_msg =
|
||||
Any::pack(&cmd).map_err(|e| Status::internal(format!("Encoding error: {}", e)))?;
|
||||
let mut ticket_bytes = Vec::new();
|
||||
any_msg
|
||||
.encode(&mut ticket_bytes)
|
||||
.map_err(|e| Status::internal(format!("Encoding error: {}", e)))?;
|
||||
|
||||
let endpoint = FlightEndpoint::new().with_ticket(Ticket::new(ticket_bytes));
|
||||
let flight_info = FlightInfo::new()
|
||||
.try_with_schema(&schema)
|
||||
.map_err(|e| Status::internal(format!("Schema error: {}", e)))?
|
||||
.with_endpoint(endpoint)
|
||||
.with_descriptor(request.into_inner());
|
||||
|
||||
Ok(Response::new(flight_info))
|
||||
}
|
||||
|
||||
async fn do_get_tables(
|
||||
&self,
|
||||
_query: CommandGetTables,
|
||||
_request: Request<Ticket>,
|
||||
) -> Result<
|
||||
Response<<Self as arrow_flight::flight_service_server::FlightService>::DoGetStream>,
|
||||
Status,
|
||||
> {
|
||||
let table_names: Vec<&str> = self.tables.keys().map(|s| s.as_str()).collect();
|
||||
let num_tables = table_names.len();
|
||||
|
||||
let schema = Arc::new(ArrowSchema::new(vec![
|
||||
Field::new("catalog_name", DataType::Utf8, true),
|
||||
Field::new("db_schema_name", DataType::Utf8, true),
|
||||
Field::new("table_name", DataType::Utf8, false),
|
||||
Field::new("table_type", DataType::Utf8, false),
|
||||
]));
|
||||
|
||||
let catalog_names: ArrayRef =
|
||||
Arc::new(StringArray::from(vec![Some("lancedb"); num_tables]));
|
||||
let schema_names: ArrayRef = Arc::new(StringArray::from(vec![Some("default"); num_tables]));
|
||||
let table_name_array: ArrayRef = Arc::new(StringArray::from(table_names));
|
||||
let table_types: ArrayRef = Arc::new(StringArray::from(vec!["TABLE"; num_tables]));
|
||||
|
||||
let batch = RecordBatch::try_new(
|
||||
schema,
|
||||
vec![catalog_names, schema_names, table_name_array, table_types],
|
||||
)
|
||||
.map_err(|e| Status::internal(format!("RecordBatch error: {}", e)))?;
|
||||
|
||||
Ok(Response::new(Self::batch_to_flight_stream(batch)))
|
||||
}
|
||||
|
||||
/// List table types.
|
||||
async fn get_flight_info_table_types(
|
||||
&self,
|
||||
_query: CommandGetTableTypes,
|
||||
request: Request<FlightDescriptor>,
|
||||
) -> Result<Response<FlightInfo>, Status> {
|
||||
let schema = ArrowSchema::new(vec![Field::new("table_type", DataType::Utf8, false)]);
|
||||
|
||||
let cmd = CommandGetTableTypes::default();
|
||||
let any_msg =
|
||||
Any::pack(&cmd).map_err(|e| Status::internal(format!("Encoding error: {}", e)))?;
|
||||
let mut ticket_bytes = Vec::new();
|
||||
any_msg
|
||||
.encode(&mut ticket_bytes)
|
||||
.map_err(|e| Status::internal(format!("Encoding error: {}", e)))?;
|
||||
|
||||
let endpoint = FlightEndpoint::new().with_ticket(Ticket::new(ticket_bytes));
|
||||
let flight_info = FlightInfo::new()
|
||||
.try_with_schema(&schema)
|
||||
.map_err(|e| Status::internal(format!("Schema error: {}", e)))?
|
||||
.with_endpoint(endpoint)
|
||||
.with_descriptor(request.into_inner());
|
||||
|
||||
Ok(Response::new(flight_info))
|
||||
}
|
||||
|
||||
async fn do_get_table_types(
|
||||
&self,
|
||||
_query: CommandGetTableTypes,
|
||||
_request: Request<Ticket>,
|
||||
) -> Result<
|
||||
Response<<Self as arrow_flight::flight_service_server::FlightService>::DoGetStream>,
|
||||
Status,
|
||||
> {
|
||||
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
|
||||
"table_type",
|
||||
DataType::Utf8,
|
||||
false,
|
||||
)]));
|
||||
let table_types: ArrayRef = Arc::new(StringArray::from(vec!["TABLE"]));
|
||||
let batch = RecordBatch::try_new(schema, vec![table_types])
|
||||
.map_err(|e| Status::internal(format!("RecordBatch error: {}", e)))?;
|
||||
|
||||
Ok(Response::new(Self::batch_to_flight_stream(batch)))
|
||||
}
|
||||
|
||||
/// List catalogs.
|
||||
async fn get_flight_info_catalogs(
|
||||
&self,
|
||||
_query: CommandGetCatalogs,
|
||||
request: Request<FlightDescriptor>,
|
||||
) -> Result<Response<FlightInfo>, Status> {
|
||||
let schema = ArrowSchema::new(vec![Field::new("catalog_name", DataType::Utf8, false)]);
|
||||
|
||||
let cmd = CommandGetCatalogs::default();
|
||||
let any_msg =
|
||||
Any::pack(&cmd).map_err(|e| Status::internal(format!("Encoding error: {}", e)))?;
|
||||
let mut ticket_bytes = Vec::new();
|
||||
any_msg
|
||||
.encode(&mut ticket_bytes)
|
||||
.map_err(|e| Status::internal(format!("Encoding error: {}", e)))?;
|
||||
|
||||
let endpoint = FlightEndpoint::new().with_ticket(Ticket::new(ticket_bytes));
|
||||
let flight_info = FlightInfo::new()
|
||||
.try_with_schema(&schema)
|
||||
.map_err(|e| Status::internal(format!("Schema error: {}", e)))?
|
||||
.with_endpoint(endpoint)
|
||||
.with_descriptor(request.into_inner());
|
||||
|
||||
Ok(Response::new(flight_info))
|
||||
}
|
||||
|
||||
async fn do_get_catalogs(
|
||||
&self,
|
||||
_query: CommandGetCatalogs,
|
||||
_request: Request<Ticket>,
|
||||
) -> Result<
|
||||
Response<<Self as arrow_flight::flight_service_server::FlightService>::DoGetStream>,
|
||||
Status,
|
||||
> {
|
||||
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
|
||||
"catalog_name",
|
||||
DataType::Utf8,
|
||||
false,
|
||||
)]));
|
||||
let catalogs: ArrayRef = Arc::new(StringArray::from(vec!["lancedb"]));
|
||||
let batch = RecordBatch::try_new(schema, vec![catalogs])
|
||||
.map_err(|e| Status::internal(format!("RecordBatch error: {}", e)))?;
|
||||
|
||||
Ok(Response::new(Self::batch_to_flight_stream(batch)))
|
||||
}
|
||||
|
||||
/// List schemas.
|
||||
async fn get_flight_info_schemas(
|
||||
&self,
|
||||
_query: CommandGetDbSchemas,
|
||||
request: Request<FlightDescriptor>,
|
||||
) -> Result<Response<FlightInfo>, Status> {
|
||||
let schema = ArrowSchema::new(vec![
|
||||
Field::new("catalog_name", DataType::Utf8, true),
|
||||
Field::new("db_schema_name", DataType::Utf8, false),
|
||||
]);
|
||||
|
||||
let cmd = CommandGetDbSchemas::default();
|
||||
let any_msg =
|
||||
Any::pack(&cmd).map_err(|e| Status::internal(format!("Encoding error: {}", e)))?;
|
||||
let mut ticket_bytes = Vec::new();
|
||||
any_msg
|
||||
.encode(&mut ticket_bytes)
|
||||
.map_err(|e| Status::internal(format!("Encoding error: {}", e)))?;
|
||||
|
||||
let endpoint = FlightEndpoint::new().with_ticket(Ticket::new(ticket_bytes));
|
||||
let flight_info = FlightInfo::new()
|
||||
.try_with_schema(&schema)
|
||||
.map_err(|e| Status::internal(format!("Schema error: {}", e)))?
|
||||
.with_endpoint(endpoint)
|
||||
.with_descriptor(request.into_inner());
|
||||
|
||||
Ok(Response::new(flight_info))
|
||||
}
|
||||
|
||||
async fn do_get_schemas(
|
||||
&self,
|
||||
_query: CommandGetDbSchemas,
|
||||
_request: Request<Ticket>,
|
||||
) -> Result<
|
||||
Response<<Self as arrow_flight::flight_service_server::FlightService>::DoGetStream>,
|
||||
Status,
|
||||
> {
|
||||
let schema = Arc::new(ArrowSchema::new(vec![
|
||||
Field::new("catalog_name", DataType::Utf8, true),
|
||||
Field::new("db_schema_name", DataType::Utf8, false),
|
||||
]));
|
||||
let catalogs: ArrayRef = Arc::new(StringArray::from(vec![Some("lancedb")]));
|
||||
let schemas: ArrayRef = Arc::new(StringArray::from(vec!["default"]));
|
||||
let batch = RecordBatch::try_new(schema, vec![catalogs, schemas])
|
||||
.map_err(|e| Status::internal(format!("RecordBatch error: {}", e)))?;
|
||||
|
||||
Ok(Response::new(Self::batch_to_flight_stream(batch)))
|
||||
}
|
||||
|
||||
async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use arrow_array::{Float32Array, Int32Array};
|
||||
use arrow_flight::sql::client::FlightSqlServiceClient;
|
||||
use futures::TryStreamExt;
|
||||
use lance_arrow::FixedSizeListArrayExt;
|
||||
use std::time::Duration;
|
||||
|
||||
async fn create_test_db() -> Connection {
|
||||
let db = crate::connect("memory://flight_test")
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let dim = 4i32;
|
||||
let schema = Arc::new(ArrowSchema::new(vec![
|
||||
Field::new("id", DataType::Int32, false),
|
||||
Field::new("text", DataType::Utf8, false),
|
||||
Field::new(
|
||||
"vector",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
|
||||
true,
|
||||
),
|
||||
]));
|
||||
|
||||
let ids = Int32Array::from(vec![1, 2, 3]);
|
||||
let texts = StringArray::from(vec!["hello world", "foo bar", "baz qux"]);
|
||||
let flat_values = Float32Array::from(vec![
|
||||
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
|
||||
]);
|
||||
let vectors =
|
||||
arrow_array::FixedSizeListArray::try_new_from_values(flat_values, dim).unwrap();
|
||||
|
||||
let batch = RecordBatch::try_new(
|
||||
schema,
|
||||
vec![Arc::new(ids), Arc::new(texts), Arc::new(vectors)],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let table = db
|
||||
.create_table("test_table", batch)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Create FTS index
|
||||
table
|
||||
.create_index(&["text"], crate::index::Index::FTS(Default::default()))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
db
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_flight_sql_basic_query() {
|
||||
let db = create_test_db().await;
|
||||
let service = LanceFlightSqlService::try_new(db).await.unwrap();
|
||||
let flight_svc = FlightServiceServer::new(service);
|
||||
|
||||
// Start server on a random port
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let server_handle = tokio::spawn(async move {
|
||||
Server::builder()
|
||||
.add_service(flight_svc)
|
||||
.serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
// Give server a moment to start
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
|
||||
// Connect client
|
||||
let channel = tonic::transport::Channel::from_shared(format!("http://{}", addr))
|
||||
.unwrap()
|
||||
.connect()
|
||||
.await
|
||||
.unwrap();
|
||||
let mut client = FlightSqlServiceClient::new(channel);
|
||||
|
||||
// Execute SQL query
|
||||
let flight_info = client
|
||||
.execute("SELECT id, text FROM test_table".to_string(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Fetch results using the FlightSql client's do_get
|
||||
let ticket = flight_info.endpoint[0].ticket.as_ref().unwrap().clone();
|
||||
let flight_stream = client.do_get(ticket).await.unwrap();
|
||||
|
||||
let batches: Vec<RecordBatch> = flight_stream.try_collect().await.unwrap();
|
||||
|
||||
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
|
||||
assert_eq!(total_rows, 3, "Should return all 3 rows");
|
||||
|
||||
// Verify schema
|
||||
if let Some(batch) = batches.first() {
|
||||
assert!(batch.schema().column_with_name("id").is_some());
|
||||
assert!(batch.schema().column_with_name("text").is_some());
|
||||
}
|
||||
|
||||
server_handle.abort();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_flight_sql_table_listing() {
|
||||
let db = create_test_db().await;
|
||||
let service = LanceFlightSqlService::try_new(db).await.unwrap();
|
||||
|
||||
assert!(service.tables.contains_key("test_table"));
|
||||
assert_eq!(service.tables.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_flight_sql_session_context() {
|
||||
let db = create_test_db().await;
|
||||
let service = LanceFlightSqlService::try_new(db).await.unwrap();
|
||||
let ctx = service.create_session_context();
|
||||
|
||||
// Test that we can execute a simple query
|
||||
let df = ctx.sql("SELECT * FROM test_table LIMIT 1").await.unwrap();
|
||||
let results = df.collect().await.unwrap();
|
||||
assert_eq!(results[0].num_rows(), 1);
|
||||
}
|
||||
}
|
||||
@@ -170,8 +170,6 @@ pub mod dataloader;
|
||||
pub mod embeddings;
|
||||
pub mod error;
|
||||
pub mod expr;
|
||||
#[cfg(feature = "flight")]
|
||||
pub mod flight;
|
||||
pub mod index;
|
||||
pub mod io;
|
||||
pub mod ipc;
|
||||
|
||||
@@ -52,13 +52,6 @@ pub struct ClientConfig {
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
/// Provider for custom headers to be added to each request
|
||||
pub header_provider: Option<Arc<dyn HeaderProvider>>,
|
||||
/// User identifier for tracking purposes.
|
||||
///
|
||||
/// This is sent as the `x-lancedb-user-id` header in requests to LanceDB Cloud/Enterprise.
|
||||
/// It can be set directly, or via the `LANCEDB_USER_ID` environment variable.
|
||||
/// Alternatively, set `LANCEDB_USER_ID_ENV_KEY` to specify another environment
|
||||
/// variable that contains the user ID value.
|
||||
pub user_id: Option<String>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ClientConfig {
|
||||
@@ -74,7 +67,6 @@ impl std::fmt::Debug for ClientConfig {
|
||||
"header_provider",
|
||||
&self.header_provider.as_ref().map(|_| "Some(...)"),
|
||||
)
|
||||
.field("user_id", &self.user_id)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
@@ -89,41 +81,10 @@ impl Default for ClientConfig {
|
||||
id_delimiter: None,
|
||||
tls_config: None,
|
||||
header_provider: None,
|
||||
user_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientConfig {
|
||||
/// Resolve the user ID from the config or environment variables.
|
||||
///
|
||||
/// Resolution order:
|
||||
/// 1. If `user_id` is set in the config, use that value
|
||||
/// 2. If `LANCEDB_USER_ID` environment variable is set, use that value
|
||||
/// 3. If `LANCEDB_USER_ID_ENV_KEY` is set, read the env var it points to
|
||||
/// 4. Otherwise, return None
|
||||
pub fn resolve_user_id(&self) -> Option<String> {
|
||||
if self.user_id.is_some() {
|
||||
return self.user_id.clone();
|
||||
}
|
||||
|
||||
if let Ok(user_id) = std::env::var("LANCEDB_USER_ID")
|
||||
&& !user_id.is_empty()
|
||||
{
|
||||
return Some(user_id);
|
||||
}
|
||||
|
||||
if let Ok(env_key) = std::env::var("LANCEDB_USER_ID_ENV_KEY")
|
||||
&& let Ok(user_id) = std::env::var(&env_key)
|
||||
&& !user_id.is_empty()
|
||||
{
|
||||
return Some(user_id);
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// How to handle timeouts for HTTP requests.
|
||||
#[derive(Clone, Default, Debug)]
|
||||
pub struct TimeoutConfig {
|
||||
@@ -229,6 +190,7 @@ pub struct RetryConfig {
|
||||
pub struct RestfulLanceDbClient<S: HttpSend = Sender> {
|
||||
client: reqwest::Client,
|
||||
host: String,
|
||||
wal_host: String,
|
||||
pub(crate) retry_config: ResolvedRetryConfig,
|
||||
pub(crate) sender: S,
|
||||
pub(crate) id_delimiter: String,
|
||||
@@ -239,6 +201,7 @@ impl<S: HttpSend> std::fmt::Debug for RestfulLanceDbClient<S> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("RestfulLanceDbClient")
|
||||
.field("host", &self.host)
|
||||
.field("wal_host", &self.wal_host)
|
||||
.field("retry_config", &self.retry_config)
|
||||
.field("sender", &self.sender)
|
||||
.field("id_delimiter", &self.id_delimiter)
|
||||
@@ -324,6 +287,7 @@ impl RestfulLanceDbClient<Sender> {
|
||||
parsed_url: &ParsedDbUrl,
|
||||
region: &str,
|
||||
host_override: Option<String>,
|
||||
wal_host_override: Option<String>,
|
||||
default_headers: HeaderMap,
|
||||
client_config: ClientConfig,
|
||||
) -> Result<Self> {
|
||||
@@ -411,11 +375,16 @@ impl RestfulLanceDbClient<Sender> {
|
||||
Some(host_override) => host_override,
|
||||
None => format!("https://{}.{}.api.lancedb.com", parsed_url.db_name, region),
|
||||
};
|
||||
debug!("Created client for host: {}", host);
|
||||
let wal_host = match wal_host_override {
|
||||
Some(wal_host_override) => wal_host_override,
|
||||
None => format!("https://{}.{}.wal.lancedb.com", parsed_url.db_name, region),
|
||||
};
|
||||
debug!("Created client for host: {}, wal_host: {}", host, wal_host);
|
||||
let retry_config = client_config.retry_config.clone().try_into()?;
|
||||
Ok(Self {
|
||||
client,
|
||||
host,
|
||||
wal_host,
|
||||
retry_config,
|
||||
sender: Sender,
|
||||
id_delimiter: client_config
|
||||
@@ -503,15 +472,6 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(user_id) = config.resolve_user_id() {
|
||||
headers.insert(
|
||||
HeaderName::from_static("x-lancedb-user-id"),
|
||||
HeaderValue::from_str(&user_id).map_err(|_| Error::InvalidInput {
|
||||
message: format!("non-ascii user_id '{}' provided", user_id),
|
||||
})?,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
@@ -527,6 +487,12 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
self.add_id_delimiter_query_param(builder)
|
||||
}
|
||||
|
||||
pub fn post_wal(&self, uri: &str) -> RequestBuilder {
|
||||
let full_uri = format!("{}{}", self.wal_host, uri);
|
||||
let builder = self.client.post(full_uri);
|
||||
self.add_id_delimiter_query_param(builder)
|
||||
}
|
||||
|
||||
fn add_id_delimiter_query_param(&self, req: RequestBuilder) -> RequestBuilder {
|
||||
if self.id_delimiter != "$" {
|
||||
req.query(&[("delimiter", self.id_delimiter.clone())])
|
||||
@@ -839,6 +805,7 @@ pub mod test_utils {
|
||||
RestfulLanceDbClient {
|
||||
client: reqwest::Client::new(),
|
||||
host: "http://localhost".to_string(),
|
||||
wal_host: "http://localhost-wal".to_string(),
|
||||
retry_config: RetryConfig::default().try_into().unwrap(),
|
||||
sender: MockSender {
|
||||
f: Arc::new(wrapper),
|
||||
@@ -863,6 +830,7 @@ pub mod test_utils {
|
||||
RestfulLanceDbClient {
|
||||
client: reqwest::Client::new(),
|
||||
host: "http://localhost".to_string(),
|
||||
wal_host: "http://localhost-wal".to_string(),
|
||||
retry_config: config.retry_config.try_into().unwrap(),
|
||||
sender: MockSender {
|
||||
f: Arc::new(wrapper),
|
||||
@@ -1030,6 +998,7 @@ mod tests {
|
||||
let client = RestfulLanceDbClient {
|
||||
client: reqwest::Client::new(),
|
||||
host: "https://example.com".to_string(),
|
||||
wal_host: "https://example.com".to_string(),
|
||||
retry_config: RetryConfig::default().try_into().unwrap(),
|
||||
sender: Sender,
|
||||
id_delimiter: "+".to_string(),
|
||||
@@ -1065,6 +1034,7 @@ mod tests {
|
||||
let client = RestfulLanceDbClient {
|
||||
client: reqwest::Client::new(),
|
||||
host: "https://example.com".to_string(),
|
||||
wal_host: "https://example.com".to_string(),
|
||||
retry_config: RetryConfig::default().try_into().unwrap(),
|
||||
sender: Sender,
|
||||
id_delimiter: "+".to_string(),
|
||||
@@ -1102,6 +1072,7 @@ mod tests {
|
||||
let client = RestfulLanceDbClient {
|
||||
client: reqwest::Client::new(),
|
||||
host: "https://example.com".to_string(),
|
||||
wal_host: "https://example.com".to_string(),
|
||||
retry_config: RetryConfig::default().try_into().unwrap(),
|
||||
sender: Sender,
|
||||
id_delimiter: "+".to_string(),
|
||||
@@ -1120,91 +1091,4 @@ mod tests {
|
||||
_ => panic!("Expected Runtime error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_user_id_direct_value() {
|
||||
let config = ClientConfig {
|
||||
user_id: Some("direct-user-id".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(config.resolve_user_id(), Some("direct-user-id".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_user_id_none() {
|
||||
let config = ClientConfig::default();
|
||||
// Clear env vars that might be set from other tests
|
||||
// SAFETY: This is only called in tests
|
||||
unsafe {
|
||||
std::env::remove_var("LANCEDB_USER_ID");
|
||||
std::env::remove_var("LANCEDB_USER_ID_ENV_KEY");
|
||||
}
|
||||
assert_eq!(config.resolve_user_id(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_user_id_from_env() {
|
||||
// SAFETY: This is only called in tests
|
||||
unsafe {
|
||||
std::env::set_var("LANCEDB_USER_ID", "env-user-id");
|
||||
}
|
||||
let config = ClientConfig::default();
|
||||
assert_eq!(config.resolve_user_id(), Some("env-user-id".to_string()));
|
||||
// SAFETY: This is only called in tests
|
||||
unsafe {
|
||||
std::env::remove_var("LANCEDB_USER_ID");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_user_id_from_env_key() {
|
||||
// SAFETY: This is only called in tests
|
||||
unsafe {
|
||||
std::env::remove_var("LANCEDB_USER_ID");
|
||||
std::env::set_var("LANCEDB_USER_ID_ENV_KEY", "MY_CUSTOM_USER_ID");
|
||||
std::env::set_var("MY_CUSTOM_USER_ID", "custom-env-user-id");
|
||||
}
|
||||
let config = ClientConfig::default();
|
||||
assert_eq!(
|
||||
config.resolve_user_id(),
|
||||
Some("custom-env-user-id".to_string())
|
||||
);
|
||||
// SAFETY: This is only called in tests
|
||||
unsafe {
|
||||
std::env::remove_var("LANCEDB_USER_ID_ENV_KEY");
|
||||
std::env::remove_var("MY_CUSTOM_USER_ID");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_user_id_direct_takes_precedence() {
|
||||
// SAFETY: This is only called in tests
|
||||
unsafe {
|
||||
std::env::set_var("LANCEDB_USER_ID", "env-user-id");
|
||||
}
|
||||
let config = ClientConfig {
|
||||
user_id: Some("direct-user-id".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(config.resolve_user_id(), Some("direct-user-id".to_string()));
|
||||
// SAFETY: This is only called in tests
|
||||
unsafe {
|
||||
std::env::remove_var("LANCEDB_USER_ID");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_user_id_empty_env_ignored() {
|
||||
// SAFETY: This is only called in tests
|
||||
unsafe {
|
||||
std::env::set_var("LANCEDB_USER_ID", "");
|
||||
std::env::remove_var("LANCEDB_USER_ID_ENV_KEY");
|
||||
}
|
||||
let config = ClientConfig::default();
|
||||
assert_eq!(config.resolve_user_id(), None);
|
||||
// SAFETY: This is only called in tests
|
||||
unsafe {
|
||||
std::env::remove_var("LANCEDB_USER_ID");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,6 +82,7 @@ pub const OPT_REMOTE_PREFIX: &str = "remote_database_";
|
||||
pub const OPT_REMOTE_API_KEY: &str = "remote_database_api_key";
|
||||
pub const OPT_REMOTE_REGION: &str = "remote_database_region";
|
||||
pub const OPT_REMOTE_HOST_OVERRIDE: &str = "remote_database_host_override";
|
||||
pub const OPT_REMOTE_WAL_HOST_OVERRIDE: &str = "remote_database_wal_host_override";
|
||||
// TODO: add support for configuring client config via key/value options
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
@@ -95,6 +96,11 @@ pub struct RemoteDatabaseOptions {
|
||||
/// This is required when connecting to LanceDB Enterprise and should be
|
||||
/// provided if using an on-premises LanceDB Enterprise instance.
|
||||
pub host_override: Option<String>,
|
||||
/// The WAL host override
|
||||
///
|
||||
/// When set, merge_insert operations using WAL routing will be sent to
|
||||
/// this host instead of the auto-derived WAL host.
|
||||
pub wal_host_override: Option<String>,
|
||||
/// Storage options configure the storage layer (e.g. S3, GCS, Azure, etc.)
|
||||
///
|
||||
/// See available options at <https://lancedb.com/docs/storage/>
|
||||
@@ -113,6 +119,7 @@ impl RemoteDatabaseOptions {
|
||||
let api_key = map.get(OPT_REMOTE_API_KEY).cloned();
|
||||
let region = map.get(OPT_REMOTE_REGION).cloned();
|
||||
let host_override = map.get(OPT_REMOTE_HOST_OVERRIDE).cloned();
|
||||
let wal_host_override = map.get(OPT_REMOTE_WAL_HOST_OVERRIDE).cloned();
|
||||
let storage_options = map
|
||||
.iter()
|
||||
.filter(|(key, _)| !key.starts_with(OPT_REMOTE_PREFIX))
|
||||
@@ -122,6 +129,7 @@ impl RemoteDatabaseOptions {
|
||||
api_key,
|
||||
region,
|
||||
host_override,
|
||||
wal_host_override,
|
||||
storage_options,
|
||||
})
|
||||
}
|
||||
@@ -141,6 +149,12 @@ impl DatabaseOptions for RemoteDatabaseOptions {
|
||||
if let Some(host_override) = &self.host_override {
|
||||
map.insert(OPT_REMOTE_HOST_OVERRIDE.to_string(), host_override.clone());
|
||||
}
|
||||
if let Some(wal_host_override) = &self.wal_host_override {
|
||||
map.insert(
|
||||
OPT_REMOTE_WAL_HOST_OVERRIDE.to_string(),
|
||||
wal_host_override.clone(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,6 +199,19 @@ impl RemoteDatabaseOptionsBuilder {
|
||||
self.options.host_override = Some(host_override);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the WAL host override
|
||||
///
|
||||
/// When set, merge_insert operations using WAL routing will be sent to
|
||||
/// this host instead of the auto-derived WAL host.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `wal_host_override` - The WAL host override
|
||||
pub fn wal_host_override(mut self, wal_host_override: String) -> Self {
|
||||
self.options.wal_host_override = Some(wal_host_override);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -204,6 +231,7 @@ impl RemoteDatabase {
|
||||
api_key: &str,
|
||||
region: &str,
|
||||
host_override: Option<String>,
|
||||
wal_host_override: Option<String>,
|
||||
client_config: ClientConfig,
|
||||
options: RemoteOptions,
|
||||
) -> Result<Self> {
|
||||
@@ -231,6 +259,7 @@ impl RemoteDatabase {
|
||||
&parsed,
|
||||
region,
|
||||
host_override,
|
||||
wal_host_override,
|
||||
header_map,
|
||||
client_config.clone(),
|
||||
)?;
|
||||
@@ -777,32 +806,6 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
let namespace = builder.build();
|
||||
Ok(Arc::new(namespace) as Arc<dyn lance_namespace::LanceNamespace>)
|
||||
}
|
||||
|
||||
async fn namespace_client_config(&self) -> Result<(String, HashMap<String, String>)> {
|
||||
let mut properties = HashMap::new();
|
||||
properties.insert("uri".to_string(), self.client.host().to_string());
|
||||
properties.insert("delimiter".to_string(), self.client.id_delimiter.clone());
|
||||
for (key, value) in &self.namespace_headers {
|
||||
properties.insert(format!("header.{}", key), value.clone());
|
||||
}
|
||||
// Add TLS configuration if present
|
||||
if let Some(tls_config) = &self.tls_config {
|
||||
if let Some(cert_file) = &tls_config.cert_file {
|
||||
properties.insert("tls.cert_file".to_string(), cert_file.clone());
|
||||
}
|
||||
if let Some(key_file) = &tls_config.key_file {
|
||||
properties.insert("tls.key_file".to_string(), key_file.clone());
|
||||
}
|
||||
if let Some(ssl_ca_cert) = &tls_config.ssl_ca_cert {
|
||||
properties.insert("tls.ssl_ca_cert".to_string(), ssl_ca_cert.clone());
|
||||
}
|
||||
properties.insert(
|
||||
"tls.assert_hostname".to_string(),
|
||||
tls_config.assert_hostname.to_string(),
|
||||
);
|
||||
}
|
||||
Ok(("rest".to_string(), properties))
|
||||
}
|
||||
}
|
||||
|
||||
/// RemoteOptions contains a subset of StorageOptions that are compatible with Remote LanceDB connections
|
||||
|
||||
@@ -1610,13 +1610,17 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
self.check_mutable().await?;
|
||||
|
||||
let timeout = params.timeout;
|
||||
let use_wal = params.use_wal;
|
||||
|
||||
let query = MergeInsertRequest::try_from(params)?;
|
||||
let mut request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/merge_insert/", self.identifier))
|
||||
.query(&query)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE);
|
||||
let path = format!("/v1/table/{}/merge_insert/", self.identifier);
|
||||
let mut request = if use_wal {
|
||||
self.client.post_wal(&path)
|
||||
} else {
|
||||
self.client.post(&path)
|
||||
}
|
||||
.query(&query)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE);
|
||||
|
||||
if let Some(timeout) = timeout {
|
||||
// (If it doesn't fit into u64, it's not worth sending anyways.)
|
||||
@@ -2705,6 +2709,43 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_insert_use_wal() {
|
||||
let batch = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
let data: Box<dyn RecordBatchReader + Send> = Box::new(RecordBatchIterator::new(
|
||||
[Ok(batch.clone())],
|
||||
batch.schema(),
|
||||
));
|
||||
|
||||
let table = Table::new_with_handler("my_table", move |request| {
|
||||
if request.url().path() == "/v1/table/my_table/merge_insert/" {
|
||||
// Verify the request was sent to the WAL host
|
||||
assert_eq!(
|
||||
request.url().host_str().unwrap(),
|
||||
"localhost-wal",
|
||||
"merge_insert with use_wal should route to WAL host"
|
||||
);
|
||||
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"version": 1, "num_deleted_rows": 0, "num_inserted_rows": 3, "num_updated_rows": 0}"#)
|
||||
.unwrap()
|
||||
} else {
|
||||
panic!("Unexpected request path: {}", request.url().path());
|
||||
}
|
||||
});
|
||||
|
||||
let mut builder = table.merge_insert(&["some_col"]);
|
||||
builder.use_wal(true);
|
||||
let result = builder.execute(data).await.unwrap();
|
||||
|
||||
assert_eq!(result.num_inserted_rows, 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_insert_retries_on_409() {
|
||||
let batch = RecordBatch::try_new(
|
||||
|
||||
@@ -26,10 +26,9 @@ use lance::dataset::{WriteMode, WriteParams};
|
||||
|
||||
use super::{AnyQuery, BaseTable};
|
||||
use crate::{
|
||||
DistanceType, Result,
|
||||
query::{QueryExecutionOptions, QueryFilter, QueryRequest, Select, VectorQueryRequest},
|
||||
Result,
|
||||
query::{QueryExecutionOptions, QueryFilter, QueryRequest, Select},
|
||||
};
|
||||
use arrow_array::Array;
|
||||
use arrow_schema::{DataType, Field};
|
||||
use lance_index::scalar::FullTextSearchQuery;
|
||||
|
||||
@@ -142,31 +141,11 @@ impl ExecutionPlan for MetadataEraserExec {
|
||||
}
|
||||
}
|
||||
|
||||
/// Parameters for a vector search query, used by vector_search and hybrid_search UDTFs.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VectorSearchParams {
|
||||
/// The query vector to search for
|
||||
pub query_vector: Arc<dyn Array>,
|
||||
/// The column to search on (None for auto-detection)
|
||||
pub column: Option<String>,
|
||||
/// Number of results to return
|
||||
pub top_k: usize,
|
||||
/// Distance metric to use
|
||||
pub distance_type: Option<DistanceType>,
|
||||
/// Number of IVF partitions to search
|
||||
pub nprobes: Option<usize>,
|
||||
/// HNSW search parameter
|
||||
pub ef: Option<usize>,
|
||||
/// Refine factor for improving recall
|
||||
pub refine_factor: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct BaseTableAdapter {
|
||||
table: Arc<dyn BaseTable>,
|
||||
schema: Arc<ArrowSchema>,
|
||||
fts_query: Option<FullTextSearchQuery>,
|
||||
vector_query: Option<VectorSearchParams>,
|
||||
}
|
||||
|
||||
impl BaseTableAdapter {
|
||||
@@ -182,7 +161,6 @@ impl BaseTableAdapter {
|
||||
table,
|
||||
schema: Arc::new(schema),
|
||||
fts_query: None,
|
||||
vector_query: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -198,49 +176,6 @@ impl BaseTableAdapter {
|
||||
table: self.table.clone(),
|
||||
schema,
|
||||
fts_query: Some(fts_query),
|
||||
vector_query: self.vector_query.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new adapter with a vector search query applied.
|
||||
pub fn with_vector_query(&self, vector_query: VectorSearchParams) -> Self {
|
||||
// Add _distance column to the schema
|
||||
let distance_field = Field::new("_distance", DataType::Float32, true);
|
||||
let mut fields = self.schema.fields().to_vec();
|
||||
fields.push(Arc::new(distance_field));
|
||||
let schema = Arc::new(ArrowSchema::new(fields));
|
||||
|
||||
Self {
|
||||
table: self.table.clone(),
|
||||
schema,
|
||||
fts_query: self.fts_query.clone(),
|
||||
vector_query: Some(vector_query),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new adapter with both FTS and vector search queries (hybrid search).
|
||||
///
|
||||
/// Uses vector search as the primary retrieval method, with FTS applied as a
|
||||
/// pre-filter to restrict the candidate set. Both `_distance` and `_score`
|
||||
/// columns are added to results.
|
||||
pub fn with_hybrid_query(
|
||||
&self,
|
||||
fts_query: FullTextSearchQuery,
|
||||
vector_query: VectorSearchParams,
|
||||
) -> Self {
|
||||
// Add _distance column (vector search is primary)
|
||||
let mut fields = self.schema.fields().to_vec();
|
||||
fields.push(Arc::new(Field::new("_distance", DataType::Float32, true)));
|
||||
let schema = Arc::new(ArrowSchema::new(fields));
|
||||
|
||||
// Store FTS as a filter concept, but vector search drives the query.
|
||||
// The FTS query is applied via the base QueryRequest's full_text_search
|
||||
// field, which acts as a pre-filter for vector search.
|
||||
Self {
|
||||
table: self.table.clone(),
|
||||
schema,
|
||||
fts_query: Some(fts_query),
|
||||
vector_query: Some(vector_query),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -266,20 +201,11 @@ impl TableProvider for BaseTableAdapter {
|
||||
filters: &[Expr],
|
||||
limit: Option<usize>,
|
||||
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
|
||||
let has_scoring = self.fts_query.is_some() || self.vector_query.is_some();
|
||||
let disable_scoring = has_scoring && projection.is_some();
|
||||
// For FTS queries, disable auto-projection of _score to match DataFusion expectations
|
||||
let disable_scoring = self.fts_query.is_some() && projection.is_some();
|
||||
|
||||
// When doing vector search, FTS cannot be combined in the same scanner
|
||||
// (Lance doesn't support both nearest + full_text_search simultaneously).
|
||||
// FTS is only set when there's no vector query.
|
||||
let fts_for_query = if self.vector_query.is_some() {
|
||||
None
|
||||
} else {
|
||||
self.fts_query.clone()
|
||||
};
|
||||
|
||||
let mut base_query = QueryRequest {
|
||||
full_text_search: fts_for_query,
|
||||
let mut query = QueryRequest {
|
||||
full_text_search: self.fts_query.clone(),
|
||||
disable_scoring_autoprojection: disable_scoring,
|
||||
..Default::default()
|
||||
};
|
||||
@@ -289,20 +215,20 @@ impl TableProvider for BaseTableAdapter {
|
||||
.iter()
|
||||
.map(|i| self.schema.field(*i).name().clone())
|
||||
.collect();
|
||||
base_query.select = Select::Columns(field_names);
|
||||
query.select = Select::Columns(field_names);
|
||||
}
|
||||
if !filters.is_empty() {
|
||||
let first = filters.first().unwrap().clone();
|
||||
let filter = filters[1..]
|
||||
.iter()
|
||||
.fold(first, |acc, expr| acc.and(expr.clone()));
|
||||
base_query.filter = Some(QueryFilter::Datafusion(filter));
|
||||
query.filter = Some(QueryFilter::Datafusion(filter));
|
||||
}
|
||||
if let Some(limit) = limit {
|
||||
base_query.limit = Some(limit);
|
||||
} else if self.vector_query.is_none() {
|
||||
// Need to override the default of 10 for non-vector queries
|
||||
base_query.limit = None;
|
||||
query.limit = Some(limit);
|
||||
} else {
|
||||
// Need to override the default of 10
|
||||
query.limit = None;
|
||||
}
|
||||
|
||||
let options = QueryExecutionOptions {
|
||||
@@ -310,33 +236,9 @@ impl TableProvider for BaseTableAdapter {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Build the appropriate query type
|
||||
let any_query = if let Some(ref vq) = self.vector_query {
|
||||
let vector_query = VectorQueryRequest {
|
||||
base: base_query,
|
||||
column: vq.column.clone(),
|
||||
query_vector: vec![vq.query_vector.clone()],
|
||||
minimum_nprobes: vq.nprobes.unwrap_or(20),
|
||||
maximum_nprobes: vq.nprobes,
|
||||
ef: vq.ef,
|
||||
refine_factor: vq.refine_factor,
|
||||
distance_type: vq.distance_type,
|
||||
use_index: true,
|
||||
..Default::default()
|
||||
};
|
||||
// For vector queries, use top_k as the limit if no explicit limit set
|
||||
let mut vq_req = vector_query;
|
||||
if limit.is_none() {
|
||||
vq_req.base.limit = Some(vq.top_k);
|
||||
}
|
||||
AnyQuery::VectorQuery(vq_req)
|
||||
} else {
|
||||
AnyQuery::Query(base_query)
|
||||
};
|
||||
|
||||
let plan = self
|
||||
.table
|
||||
.create_plan(&any_query, options)
|
||||
.create_plan(&AnyQuery::Query(query), options)
|
||||
.map_err(|err| DataFusionError::External(err.into()))
|
||||
.await?;
|
||||
Ok(Arc::new(MetadataEraserExec::new(plan)))
|
||||
|
||||
@@ -2,75 +2,5 @@
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! User-Defined Table Functions (UDTFs) for DataFusion integration
|
||||
//!
|
||||
//! This module provides SQL table functions for LanceDB search capabilities:
|
||||
//! - `fts(table_name, query_json)` — full-text search
|
||||
//! - `vector_search(table_name, query_vector_json, top_k)` — vector similarity search
|
||||
//! - `hybrid_search(table_name, query_vector_json, fts_query_json, top_k)` — combined search
|
||||
|
||||
pub mod fts;
|
||||
pub mod hybrid_search;
|
||||
pub mod vector_search;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_catalog::TableProvider;
|
||||
use datafusion_common::{DataFusionError, Result as DataFusionResult, ScalarValue};
|
||||
use datafusion_expr::Expr;
|
||||
use lance_index::scalar::FullTextSearchQuery;
|
||||
|
||||
use super::VectorSearchParams;
|
||||
|
||||
/// Describes the type of search to apply when resolving a table.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SearchQuery {
|
||||
/// Full-text search only
|
||||
Fts(FullTextSearchQuery),
|
||||
/// Vector similarity search only
|
||||
Vector(VectorSearchParams),
|
||||
/// Hybrid search combining FTS and vector search
|
||||
Hybrid {
|
||||
fts: FullTextSearchQuery,
|
||||
vector: VectorSearchParams,
|
||||
},
|
||||
}
|
||||
|
||||
/// Trait for resolving table names to TableProvider instances, optionally with a search query.
|
||||
pub trait TableResolver: std::fmt::Debug + Send + Sync {
|
||||
/// Resolve a table name to a TableProvider, optionally applying a search query.
|
||||
fn resolve_table(
|
||||
&self,
|
||||
name: &str,
|
||||
search: Option<SearchQuery>,
|
||||
) -> DataFusionResult<Arc<dyn TableProvider>>;
|
||||
}
|
||||
|
||||
/// Extract a string literal from a DataFusion expression.
|
||||
pub(crate) fn extract_string_literal(expr: &Expr, param_name: &str) -> DataFusionResult<String> {
|
||||
match expr {
|
||||
Expr::Literal(ScalarValue::Utf8(Some(s)), _) => Ok(s.clone()),
|
||||
Expr::Literal(ScalarValue::LargeUtf8(Some(s)), _) => Ok(s.clone()),
|
||||
_ => Err(DataFusionError::Plan(format!(
|
||||
"Parameter '{}' must be a string literal, got: {:?}",
|
||||
param_name, expr
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract an integer literal from a DataFusion expression.
|
||||
pub(crate) fn extract_int_literal(expr: &Expr, param_name: &str) -> DataFusionResult<usize> {
|
||||
match expr {
|
||||
Expr::Literal(ScalarValue::Int8(Some(v)), _) => Ok(*v as usize),
|
||||
Expr::Literal(ScalarValue::Int16(Some(v)), _) => Ok(*v as usize),
|
||||
Expr::Literal(ScalarValue::Int32(Some(v)), _) => Ok(*v as usize),
|
||||
Expr::Literal(ScalarValue::Int64(Some(v)), _) => Ok(*v as usize),
|
||||
Expr::Literal(ScalarValue::UInt8(Some(v)), _) => Ok(*v as usize),
|
||||
Expr::Literal(ScalarValue::UInt16(Some(v)), _) => Ok(*v as usize),
|
||||
Expr::Literal(ScalarValue::UInt32(Some(v)), _) => Ok(*v as usize),
|
||||
Expr::Literal(ScalarValue::UInt64(Some(v)), _) => Ok(*v as usize),
|
||||
_ => Err(DataFusionError::Plan(format!(
|
||||
"Parameter '{}' must be an integer literal, got: {:?}",
|
||||
param_name, expr
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,23 +1,29 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Full-Text Search (FTS) table function for DataFusion SQL integration.
|
||||
//! User-Defined Table Functions (UDTFs) for LanceDB
|
||||
//!
|
||||
//! Usage: `SELECT * FROM fts('table_name', '{"match": {"column": "text", "terms": "query"}}')`
|
||||
//! This module provides table-level UDTFs that integrate with DataFusion's SQL engine.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion::catalog::TableFunctionImpl;
|
||||
use datafusion_catalog::TableProvider;
|
||||
use datafusion_common::{DataFusionError, Result as DataFusionResult, plan_err};
|
||||
use datafusion_common::{DataFusionError, Result as DataFusionResult, ScalarValue, plan_err};
|
||||
use datafusion_expr::Expr;
|
||||
use lance_index::scalar::FullTextSearchQuery;
|
||||
|
||||
use super::{SearchQuery, TableResolver, extract_string_literal};
|
||||
/// Trait for resolving table names to TableProvider instances.
|
||||
pub trait TableResolver: std::fmt::Debug + Send + Sync {
|
||||
/// Resolve a table name to a TableProvider, optionally with an FTS query applied.
|
||||
fn resolve_table(
|
||||
&self,
|
||||
name: &str,
|
||||
fts_query: Option<FullTextSearchQuery>,
|
||||
) -> DataFusionResult<Arc<dyn TableProvider>>;
|
||||
}
|
||||
|
||||
/// Full-Text Search table function that operates on LanceDB tables.
|
||||
///
|
||||
/// Accepts 2 parameters: `fts(table_name, fts_query_json)`
|
||||
/// Full-Text Search table function that operates on LanceDB tables
|
||||
#[derive(Debug)]
|
||||
pub struct FtsTableFunction {
|
||||
resolver: Arc<dyn TableResolver>,
|
||||
@@ -39,8 +45,20 @@ impl TableFunctionImpl for FtsTableFunction {
|
||||
let query_json = extract_string_literal(&exprs[1], "fts_query")?;
|
||||
let fts_query = parse_fts_query(&query_json)?;
|
||||
|
||||
self.resolver
|
||||
.resolve_table(&table_name, Some(SearchQuery::Fts(fts_query)))
|
||||
// Resolver returns a ready-to-use TableProvider with FTS applied
|
||||
self.resolver.resolve_table(&table_name, Some(fts_query))
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_string_literal(expr: &Expr, param_name: &str) -> DataFusionResult<String> {
|
||||
match expr {
|
||||
Expr::Literal(ScalarValue::Utf8(Some(s)), _) => Ok(s.clone()),
|
||||
Expr::Literal(ScalarValue::LargeUtf8(Some(s)), _) => Ok(s.clone()),
|
||||
_ => plan_err!(
|
||||
"Parameter '{}' must be a string literal, got: {:?}",
|
||||
param_name,
|
||||
expr
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,7 +91,6 @@ pub fn from_json(json: &str) -> crate::Result<lance_index::scalar::inverted::que
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::{SearchQuery, TableResolver};
|
||||
use super::*;
|
||||
use crate::{
|
||||
Connection, Table,
|
||||
@@ -83,7 +100,6 @@ mod tests {
|
||||
use arrow_array::{Int32Array, RecordBatch, StringArray};
|
||||
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
|
||||
use datafusion::prelude::SessionContext;
|
||||
use datafusion_common::DataFusionError;
|
||||
|
||||
/// Resolver that looks up tables in a HashMap
|
||||
#[derive(Debug)]
|
||||
@@ -107,7 +123,7 @@ mod tests {
|
||||
fn resolve_table(
|
||||
&self,
|
||||
name: &str,
|
||||
search: Option<SearchQuery>,
|
||||
fts_query: Option<FullTextSearchQuery>,
|
||||
) -> DataFusionResult<Arc<dyn TableProvider>> {
|
||||
let table_provider = self
|
||||
.tables
|
||||
@@ -115,10 +131,12 @@ mod tests {
|
||||
.cloned()
|
||||
.ok_or_else(|| DataFusionError::Plan(format!("Table '{}' not found", name)))?;
|
||||
|
||||
let Some(search) = search else {
|
||||
// If no FTS query, return as-is
|
||||
let Some(fts_query) = fts_query else {
|
||||
return Ok(table_provider);
|
||||
};
|
||||
|
||||
// Downcast to BaseTableAdapter and apply FTS query
|
||||
let base_adapter = table_provider
|
||||
.as_any()
|
||||
.downcast_ref::<BaseTableAdapter>()
|
||||
@@ -128,15 +146,7 @@ mod tests {
|
||||
)
|
||||
})?;
|
||||
|
||||
match search {
|
||||
SearchQuery::Fts(fts_query) => Ok(Arc::new(base_adapter.with_fts_query(fts_query))),
|
||||
SearchQuery::Vector(vector_query) => {
|
||||
Ok(Arc::new(base_adapter.with_vector_query(vector_query)))
|
||||
}
|
||||
SearchQuery::Hybrid { fts, vector } => {
|
||||
Ok(Arc::new(base_adapter.with_hybrid_query(fts, vector)))
|
||||
}
|
||||
}
|
||||
Ok(Arc::new(base_adapter.with_fts_query(fts_query)))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,257 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Hybrid search table function for DataFusion SQL integration.
|
||||
//!
|
||||
//! Combines vector similarity search with full-text search:
|
||||
//! ```sql
|
||||
//! SELECT * FROM hybrid_search(
|
||||
//! 'my_table',
|
||||
//! '[0.1, 0.2, 0.3]',
|
||||
//! '{"match": {"column": "text", "terms": "search query"}}',
|
||||
//! 10
|
||||
//! )
|
||||
//! ```
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::Array;
|
||||
use datafusion::catalog::TableFunctionImpl;
|
||||
use datafusion_catalog::TableProvider;
|
||||
use datafusion_common::{DataFusionError, Result as DataFusionResult, plan_err};
|
||||
use datafusion_expr::Expr;
|
||||
use lance_index::scalar::FullTextSearchQuery;
|
||||
|
||||
use super::fts::from_json as fts_from_json;
|
||||
use super::{SearchQuery, TableResolver, extract_int_literal, extract_string_literal};
|
||||
use crate::table::datafusion::VectorSearchParams;
|
||||
|
||||
/// Default number of results for hybrid search when top_k is not specified.
|
||||
const DEFAULT_TOP_K: usize = 10;
|
||||
|
||||
/// Hybrid search table function combining vector and full-text search.
|
||||
///
|
||||
/// Accepts 3-4 parameters: `hybrid_search(table_name, query_vector_json, fts_query_json [, top_k])`
|
||||
///
|
||||
/// - `table_name`: Name of the table to search
|
||||
/// - `query_vector_json`: JSON array of float values, e.g. `'[0.1, 0.2, 0.3]'`
|
||||
/// - `fts_query_json`: FTS query as JSON, e.g. `'{"match": {"column": "text", "terms": "query"}}'`
|
||||
/// - `top_k` (optional): Number of results to return (default: 10)
|
||||
#[derive(Debug)]
|
||||
pub struct HybridSearchTableFunction {
|
||||
resolver: Arc<dyn TableResolver>,
|
||||
}
|
||||
|
||||
impl HybridSearchTableFunction {
|
||||
pub fn new(resolver: Arc<dyn TableResolver>) -> Self {
|
||||
Self { resolver }
|
||||
}
|
||||
}
|
||||
|
||||
impl TableFunctionImpl for HybridSearchTableFunction {
|
||||
fn call(&self, exprs: &[Expr]) -> DataFusionResult<Arc<dyn TableProvider>> {
|
||||
if exprs.len() < 3 || exprs.len() > 4 {
|
||||
return plan_err!(
|
||||
"hybrid_search() requires 3-4 parameters: hybrid_search(table_name, query_vector_json, fts_query_json [, top_k])"
|
||||
);
|
||||
}
|
||||
|
||||
let table_name = extract_string_literal(&exprs[0], "table_name")?;
|
||||
let vector_json = extract_string_literal(&exprs[1], "query_vector_json")?;
|
||||
let fts_json = extract_string_literal(&exprs[2], "fts_query_json")?;
|
||||
|
||||
let top_k = if exprs.len() == 4 {
|
||||
extract_int_literal(&exprs[3], "top_k")?
|
||||
} else {
|
||||
DEFAULT_TOP_K
|
||||
};
|
||||
|
||||
let query_vector = parse_vector_json(&vector_json)?;
|
||||
let fts_query = parse_fts_query(&fts_json)?;
|
||||
|
||||
let vector_params = VectorSearchParams {
|
||||
query_vector: query_vector as Arc<dyn Array>,
|
||||
column: None,
|
||||
top_k,
|
||||
distance_type: None,
|
||||
nprobes: None,
|
||||
ef: None,
|
||||
refine_factor: None,
|
||||
};
|
||||
|
||||
self.resolver.resolve_table(
|
||||
&table_name,
|
||||
Some(SearchQuery::Hybrid {
|
||||
fts: fts_query,
|
||||
vector: vector_params,
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_vector_json(json: &str) -> DataFusionResult<Arc<arrow_array::Float32Array>> {
|
||||
super::vector_search::parse_vector_json(json)
|
||||
}
|
||||
|
||||
fn parse_fts_query(json: &str) -> DataFusionResult<FullTextSearchQuery> {
|
||||
let query = fts_from_json(json).map_err(|e| {
|
||||
DataFusionError::Plan(format!(
|
||||
"Invalid FTS query JSON: {}. Expected format: {{\"match\": {{\"column\": \"text\", \"terms\": \"query\"}} }}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
Ok(FullTextSearchQuery::new_query(query))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::fts::to_json;
|
||||
use super::super::{SearchQuery, TableResolver};
|
||||
use super::*;
|
||||
use crate::{index::Index, table::datafusion::BaseTableAdapter};
|
||||
use arrow_array::FixedSizeListArray;
|
||||
use arrow_array::{Float32Array, Int32Array, RecordBatch, StringArray};
|
||||
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
|
||||
use datafusion::prelude::SessionContext;
|
||||
#[allow(unused_imports)]
|
||||
use lance_arrow::FixedSizeListArrayExt;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct HashMapTableResolver {
|
||||
tables: std::collections::HashMap<String, Arc<dyn TableProvider>>,
|
||||
}
|
||||
|
||||
impl HashMapTableResolver {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
tables: std::collections::HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn register(&mut self, name: String, table: Arc<dyn TableProvider>) {
|
||||
self.tables.insert(name, table);
|
||||
}
|
||||
}
|
||||
|
||||
impl TableResolver for HashMapTableResolver {
|
||||
fn resolve_table(
|
||||
&self,
|
||||
name: &str,
|
||||
search: Option<SearchQuery>,
|
||||
) -> DataFusionResult<Arc<dyn TableProvider>> {
|
||||
let table_provider = self
|
||||
.tables
|
||||
.get(name)
|
||||
.cloned()
|
||||
.ok_or_else(|| DataFusionError::Plan(format!("Table '{}' not found", name)))?;
|
||||
|
||||
let Some(search) = search else {
|
||||
return Ok(table_provider);
|
||||
};
|
||||
|
||||
let base_adapter = table_provider
|
||||
.as_any()
|
||||
.downcast_ref::<BaseTableAdapter>()
|
||||
.ok_or_else(|| {
|
||||
DataFusionError::Internal(
|
||||
"Expected BaseTableAdapter but got different type".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
match search {
|
||||
SearchQuery::Fts(fts_query) => Ok(Arc::new(base_adapter.with_fts_query(fts_query))),
|
||||
SearchQuery::Vector(vector_query) => {
|
||||
Ok(Arc::new(base_adapter.with_vector_query(vector_query)))
|
||||
}
|
||||
SearchQuery::Hybrid { fts, vector } => {
|
||||
Ok(Arc::new(base_adapter.with_hybrid_query(fts, vector)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hybrid_search_udtf() {
|
||||
let dim = 4i32;
|
||||
let schema = Arc::new(ArrowSchema::new(vec![
|
||||
Field::new("id", DataType::Int32, false),
|
||||
Field::new("text", DataType::Utf8, false),
|
||||
Field::new(
|
||||
"vector",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
|
||||
true,
|
||||
),
|
||||
]));
|
||||
|
||||
let ids = Int32Array::from(vec![1, 2, 3, 4, 5]);
|
||||
let texts = StringArray::from(vec![
|
||||
"the quick brown fox",
|
||||
"jumps over the lazy dog",
|
||||
"a quick red fox runs",
|
||||
"the dog sleeps all day",
|
||||
"a brown fox and a quick dog",
|
||||
]);
|
||||
let flat_values = Float32Array::from(vec![
|
||||
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.9, 0.1, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.5,
|
||||
0.5, 0.0, 0.0,
|
||||
]);
|
||||
let vector_array = FixedSizeListArray::try_new_from_values(flat_values, dim).unwrap();
|
||||
|
||||
let batch = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(ids), Arc::new(texts), Arc::new(vector_array)],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let db = crate::connect("memory://test_hybrid")
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
let table = db.create_table("docs", batch).execute().await.unwrap();
|
||||
|
||||
// Create FTS index on text column
|
||||
table
|
||||
.create_index(&["text"], Index::FTS(Default::default()))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let ctx = SessionContext::new();
|
||||
let mut resolver = HashMapTableResolver::new();
|
||||
let adapter = BaseTableAdapter::try_new(table.base_table().clone())
|
||||
.await
|
||||
.unwrap();
|
||||
resolver.register("docs".to_string(), Arc::new(adapter));
|
||||
|
||||
let resolver = Arc::new(resolver);
|
||||
ctx.register_udtf(
|
||||
"hybrid_search",
|
||||
Arc::new(HybridSearchTableFunction::new(resolver.clone())),
|
||||
);
|
||||
|
||||
// Run hybrid search: vector close to [1,0,0,0] AND FTS for "fox"
|
||||
use lance_index::scalar::inverted::query::*;
|
||||
let fts_query_struct = FtsQuery::Match(
|
||||
MatchQuery::new("fox".to_string()).with_column(Some("text".to_string())),
|
||||
);
|
||||
let fts_json = to_json(&fts_query_struct).unwrap();
|
||||
|
||||
let query = format!(
|
||||
"SELECT * FROM hybrid_search('docs', '[1.0, 0.0, 0.0, 0.0]', '{}', 5)",
|
||||
fts_json
|
||||
);
|
||||
|
||||
let df = ctx.sql(&query).await.unwrap();
|
||||
let results = df.collect().await.unwrap();
|
||||
|
||||
assert!(!results.is_empty());
|
||||
let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
|
||||
assert!(total_rows > 0, "Should have at least one result");
|
||||
|
||||
// Check schema has the expected columns
|
||||
let result_schema = results[0].schema();
|
||||
assert!(result_schema.column_with_name("id").is_some());
|
||||
assert!(result_schema.column_with_name("text").is_some());
|
||||
assert!(result_schema.column_with_name("vector").is_some());
|
||||
}
|
||||
}
|
||||
@@ -1,278 +0,0 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Vector search table function for DataFusion SQL integration.
|
||||
//!
|
||||
//! Enables vector similarity search via SQL:
|
||||
//! ```sql
|
||||
//! SELECT * FROM vector_search('my_table', '[0.1, 0.2, 0.3, ...]', 10)
|
||||
//! ```
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::{Array, Float32Array};
|
||||
use datafusion::catalog::TableFunctionImpl;
|
||||
use datafusion_catalog::TableProvider;
|
||||
use datafusion_common::{DataFusionError, Result as DataFusionResult, plan_err};
|
||||
use datafusion_expr::Expr;
|
||||
|
||||
use super::{SearchQuery, TableResolver, extract_int_literal, extract_string_literal};
|
||||
use crate::table::datafusion::VectorSearchParams;
|
||||
|
||||
/// Default number of results for vector search when top_k is not specified.
|
||||
const DEFAULT_TOP_K: usize = 10;
|
||||
|
||||
/// Vector search table function for LanceDB tables.
|
||||
///
|
||||
/// Accepts 2-3 parameters: `vector_search(table_name, query_vector_json [, top_k])`
|
||||
///
|
||||
/// - `table_name`: Name of the table to search
|
||||
/// - `query_vector_json`: JSON array of float values, e.g. `'[0.1, 0.2, 0.3]'`
|
||||
/// - `top_k` (optional): Number of results to return (default: 10)
|
||||
#[derive(Debug)]
|
||||
pub struct VectorSearchTableFunction {
|
||||
resolver: Arc<dyn TableResolver>,
|
||||
}
|
||||
|
||||
impl VectorSearchTableFunction {
|
||||
pub fn new(resolver: Arc<dyn TableResolver>) -> Self {
|
||||
Self { resolver }
|
||||
}
|
||||
}
|
||||
|
||||
impl TableFunctionImpl for VectorSearchTableFunction {
|
||||
fn call(&self, exprs: &[Expr]) -> DataFusionResult<Arc<dyn TableProvider>> {
|
||||
if exprs.len() < 2 || exprs.len() > 3 {
|
||||
return plan_err!(
|
||||
"vector_search() requires 2-3 parameters: vector_search(table_name, query_vector_json [, top_k])"
|
||||
);
|
||||
}
|
||||
|
||||
let table_name = extract_string_literal(&exprs[0], "table_name")?;
|
||||
let vector_json = extract_string_literal(&exprs[1], "query_vector_json")?;
|
||||
|
||||
let top_k = if exprs.len() == 3 {
|
||||
extract_int_literal(&exprs[2], "top_k")?
|
||||
} else {
|
||||
DEFAULT_TOP_K
|
||||
};
|
||||
|
||||
let query_vector = parse_vector_json(&vector_json)?;
|
||||
|
||||
let params = VectorSearchParams {
|
||||
query_vector: query_vector as Arc<dyn Array>,
|
||||
column: None,
|
||||
top_k,
|
||||
distance_type: None,
|
||||
nprobes: None,
|
||||
ef: None,
|
||||
refine_factor: None,
|
||||
};
|
||||
|
||||
self.resolver
|
||||
.resolve_table(&table_name, Some(SearchQuery::Vector(params)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a JSON array of floats into an Arrow Float32Array for vector search.
|
||||
///
|
||||
/// Input format: `"[0.1, 0.2, 0.3, ...]"`
|
||||
///
|
||||
/// Returns a Float32Array whose length equals the vector dimension.
|
||||
/// This is the format expected by LanceDB's vector search internals.
|
||||
pub(crate) fn parse_vector_json(json: &str) -> DataFusionResult<Arc<Float32Array>> {
|
||||
let values: Vec<f32> = serde_json::from_str(json).map_err(|e| {
|
||||
DataFusionError::Plan(format!(
|
||||
"Invalid vector JSON: {}. Expected format: [0.1, 0.2, 0.3, ...]",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
if values.is_empty() {
|
||||
return Err(DataFusionError::Plan(
|
||||
"Vector must not be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Arc::new(Float32Array::from(values)))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::{SearchQuery, TableResolver};
|
||||
use super::*;
|
||||
use crate::table::datafusion::BaseTableAdapter;
|
||||
use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, RecordBatch};
|
||||
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
|
||||
use datafusion::prelude::SessionContext;
|
||||
#[allow(unused_imports)]
|
||||
use lance_arrow::FixedSizeListArrayExt;
|
||||
|
||||
/// Resolver that looks up tables in a HashMap and applies search queries
|
||||
#[derive(Debug)]
|
||||
struct HashMapTableResolver {
|
||||
tables: std::collections::HashMap<String, Arc<dyn TableProvider>>,
|
||||
}
|
||||
|
||||
impl HashMapTableResolver {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
tables: std::collections::HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn register(&mut self, name: String, table: Arc<dyn TableProvider>) {
|
||||
self.tables.insert(name, table);
|
||||
}
|
||||
}
|
||||
|
||||
impl TableResolver for HashMapTableResolver {
|
||||
fn resolve_table(
|
||||
&self,
|
||||
name: &str,
|
||||
search: Option<SearchQuery>,
|
||||
) -> DataFusionResult<Arc<dyn TableProvider>> {
|
||||
let table_provider = self
|
||||
.tables
|
||||
.get(name)
|
||||
.cloned()
|
||||
.ok_or_else(|| DataFusionError::Plan(format!("Table '{}' not found", name)))?;
|
||||
|
||||
let Some(search) = search else {
|
||||
return Ok(table_provider);
|
||||
};
|
||||
|
||||
let base_adapter = table_provider
|
||||
.as_any()
|
||||
.downcast_ref::<BaseTableAdapter>()
|
||||
.ok_or_else(|| {
|
||||
DataFusionError::Internal(
|
||||
"Expected BaseTableAdapter but got different type".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
match search {
|
||||
SearchQuery::Fts(fts_query) => Ok(Arc::new(base_adapter.with_fts_query(fts_query))),
|
||||
SearchQuery::Vector(vector_query) => {
|
||||
Ok(Arc::new(base_adapter.with_vector_query(vector_query)))
|
||||
}
|
||||
SearchQuery::Hybrid { fts, vector } => {
|
||||
Ok(Arc::new(base_adapter.with_hybrid_query(fts, vector)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn make_test_data() -> (Arc<ArrowSchema>, RecordBatch) {
|
||||
let dim = 4;
|
||||
let schema = Arc::new(ArrowSchema::new(vec![
|
||||
Field::new("id", DataType::Int32, false),
|
||||
Field::new(
|
||||
"vector",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
|
||||
true,
|
||||
),
|
||||
]));
|
||||
|
||||
let ids = Int32Array::from(vec![1, 2, 3, 4, 5]);
|
||||
// Create vectors: [1,0,0,0], [0,1,0,0], [0,0,1,0], [0,0,0,1], [1,1,0,0]
|
||||
let flat_values = Float32Array::from(vec![
|
||||
1.0, 0.0, 0.0, 0.0, // vec 1
|
||||
0.0, 1.0, 0.0, 0.0, // vec 2
|
||||
0.0, 0.0, 1.0, 0.0, // vec 3
|
||||
0.0, 0.0, 0.0, 1.0, // vec 4
|
||||
1.0, 1.0, 0.0, 0.0, // vec 5
|
||||
]);
|
||||
let vector_array = FixedSizeListArray::try_new_from_values(flat_values, dim).unwrap();
|
||||
|
||||
let batch =
|
||||
RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(vector_array)])
|
||||
.unwrap();
|
||||
|
||||
(schema, batch)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_udtf() {
|
||||
let (_schema, batch) = make_test_data();
|
||||
|
||||
let db = crate::connect("memory://test_vec").execute().await.unwrap();
|
||||
let table = db.create_table("vectors", batch).execute().await.unwrap();
|
||||
|
||||
// No index needed — vector search works with brute-force scan on small tables
|
||||
|
||||
// Setup DataFusion context
|
||||
let ctx = SessionContext::new();
|
||||
let mut resolver = HashMapTableResolver::new();
|
||||
let adapter = BaseTableAdapter::try_new(table.base_table().clone())
|
||||
.await
|
||||
.unwrap();
|
||||
resolver.register("vectors".to_string(), Arc::new(adapter));
|
||||
|
||||
let udtf = VectorSearchTableFunction::new(Arc::new(resolver));
|
||||
ctx.register_udtf("vector_search", Arc::new(udtf));
|
||||
|
||||
// Search for vectors close to [1, 0, 0, 0]
|
||||
let query = "SELECT * FROM vector_search('vectors', '[1.0, 0.0, 0.0, 0.0]', 3)";
|
||||
let df = ctx.sql(query).await.unwrap();
|
||||
let results = df.collect().await.unwrap();
|
||||
|
||||
assert!(!results.is_empty());
|
||||
let batch = &results[0];
|
||||
|
||||
// Should have id, vector, _distance columns
|
||||
assert!(batch.schema().column_with_name("id").is_some());
|
||||
assert!(batch.schema().column_with_name("vector").is_some());
|
||||
assert!(
|
||||
batch.schema().column_with_name("_distance").is_some(),
|
||||
"_distance column should be present"
|
||||
);
|
||||
|
||||
// Should return at most 3 results
|
||||
let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
|
||||
assert!(total_rows <= 3);
|
||||
assert!(total_rows > 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_default_top_k() {
|
||||
let (_, batch) = make_test_data();
|
||||
|
||||
let db = crate::connect("memory://test_vec_default")
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
let table = db.create_table("vectors", batch).execute().await.unwrap();
|
||||
|
||||
let ctx = SessionContext::new();
|
||||
let mut resolver = HashMapTableResolver::new();
|
||||
let adapter = BaseTableAdapter::try_new(table.base_table().clone())
|
||||
.await
|
||||
.unwrap();
|
||||
resolver.register("vectors".to_string(), Arc::new(adapter));
|
||||
|
||||
let udtf = VectorSearchTableFunction::new(Arc::new(resolver));
|
||||
ctx.register_udtf("vector_search", Arc::new(udtf));
|
||||
|
||||
// No top_k parameter — should default to 10
|
||||
let query = "SELECT * FROM vector_search('vectors', '[1.0, 0.0, 0.0, 0.0]')";
|
||||
let df = ctx.sql(query).await.unwrap();
|
||||
let results = df.collect().await.unwrap();
|
||||
|
||||
let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
|
||||
// We only have 5 rows, so we should get all 5 back
|
||||
assert_eq!(total_rows, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_vector_json() {
|
||||
let result = parse_vector_json("[1.0, 2.0, 3.0]").unwrap();
|
||||
assert_eq!(result.len(), 3); // 3-dimensional vector
|
||||
|
||||
// Empty vector should fail
|
||||
assert!(parse_vector_json("[]").is_err());
|
||||
|
||||
// Invalid JSON should fail
|
||||
assert!(parse_vector_json("not json").is_err());
|
||||
}
|
||||
}
|
||||
@@ -55,6 +55,7 @@ pub struct MergeInsertBuilder {
|
||||
pub(crate) when_not_matched_by_source_delete_filt: Option<String>,
|
||||
pub(crate) timeout: Option<Duration>,
|
||||
pub(crate) use_index: bool,
|
||||
pub(crate) use_wal: bool,
|
||||
}
|
||||
|
||||
impl MergeInsertBuilder {
|
||||
@@ -69,6 +70,7 @@ impl MergeInsertBuilder {
|
||||
when_not_matched_by_source_delete_filt: None,
|
||||
timeout: None,
|
||||
use_index: true,
|
||||
use_wal: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -148,6 +150,18 @@ impl MergeInsertBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Controls whether to route the merge insert operation through the WAL host.
|
||||
///
|
||||
/// When set to `true`, the operation will be sent to the WAL host instead of
|
||||
/// the main API host. The WAL host is auto-derived from the database connection
|
||||
/// or can be explicitly set via [`crate::connection::ConnectBuilder::wal_host_override`].
|
||||
///
|
||||
/// Defaults to `false`.
|
||||
pub fn use_wal(&mut self, use_wal: bool) -> &mut Self {
|
||||
self.use_wal = use_wal;
|
||||
self
|
||||
}
|
||||
|
||||
/// Executes the merge insert operation
|
||||
///
|
||||
/// Returns version and statistics about the merge operation including the number of rows
|
||||
|
||||
Reference in New Issue
Block a user