mirror of
https://github.com/lancedb/lancedb.git
synced 2026-04-07 16:30:41 +00:00
Compare commits
1 Commits
main
...
justin/oss
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c8d654db37 |
165
Cargo.lock
generated
165
Cargo.lock
generated
@@ -290,6 +290,34 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "arrow-flight"
|
||||
version = "57.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "58c5b083668e6230eae3eab2fc4b5fb989974c845d0aa538dde61a4327c78675"
|
||||
dependencies = [
|
||||
"arrow-arith",
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
"arrow-cast",
|
||||
"arrow-data",
|
||||
"arrow-ipc",
|
||||
"arrow-ord",
|
||||
"arrow-row",
|
||||
"arrow-schema",
|
||||
"arrow-select",
|
||||
"arrow-string",
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"futures",
|
||||
"once_cell",
|
||||
"paste",
|
||||
"prost",
|
||||
"prost-types",
|
||||
"tonic",
|
||||
"tonic-prost",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "arrow-ipc"
|
||||
version = "57.3.0"
|
||||
@@ -1069,7 +1097,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"axum-core",
|
||||
"axum-core 0.4.5",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 1.4.0",
|
||||
@@ -1078,7 +1106,7 @@ dependencies = [
|
||||
"hyper 1.8.1",
|
||||
"hyper-util",
|
||||
"itoa",
|
||||
"matchit",
|
||||
"matchit 0.7.3",
|
||||
"memchr",
|
||||
"mime",
|
||||
"percent-encoding",
|
||||
@@ -1096,6 +1124,31 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum"
|
||||
version = "0.8.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8"
|
||||
dependencies = [
|
||||
"axum-core 0.5.6",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 1.4.0",
|
||||
"http-body 1.0.1",
|
||||
"http-body-util",
|
||||
"itoa",
|
||||
"matchit 0.8.4",
|
||||
"memchr",
|
||||
"mime",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"serde_core",
|
||||
"sync_wrapper",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-core"
|
||||
version = "0.4.5"
|
||||
@@ -1117,6 +1170,24 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-core"
|
||||
version = "0.5.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"http 1.4.0",
|
||||
"http-body 1.0.1",
|
||||
"http-body-util",
|
||||
"mime",
|
||||
"pin-project-lite",
|
||||
"sync_wrapper",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backoff"
|
||||
version = "0.4.0"
|
||||
@@ -1294,7 +1365,7 @@ version = "3.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "519bd3116aeeb42d5372c29d982d16d0170d3d4a5ed85fc7dd91642ffff3c67c"
|
||||
dependencies = [
|
||||
"darling 0.20.11",
|
||||
"darling 0.23.0",
|
||||
"ident_case",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
@@ -2682,7 +2753,7 @@ dependencies = [
|
||||
"libc",
|
||||
"option-ext",
|
||||
"redox_users",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2876,7 +2947,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3719,6 +3790,19 @@ dependencies = [
|
||||
"webpki-roots 1.0.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-timeout"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0"
|
||||
dependencies = [
|
||||
"hyper 1.8.1",
|
||||
"hyper-util",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-util"
|
||||
version = "0.1.20"
|
||||
@@ -4048,7 +4132,7 @@ dependencies = [
|
||||
"portable-atomic",
|
||||
"portable-atomic-util",
|
||||
"serde_core",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4539,7 +4623,7 @@ dependencies = [
|
||||
"arrow-ipc",
|
||||
"arrow-schema",
|
||||
"async-trait",
|
||||
"axum",
|
||||
"axum 0.7.9",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"futures",
|
||||
@@ -4638,6 +4722,7 @@ dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-cast",
|
||||
"arrow-data",
|
||||
"arrow-flight",
|
||||
"arrow-ipc",
|
||||
"arrow-ord",
|
||||
"arrow-schema",
|
||||
@@ -4691,6 +4776,7 @@ dependencies = [
|
||||
"pin-project",
|
||||
"polars",
|
||||
"polars-arrow",
|
||||
"prost",
|
||||
"rand 0.9.2",
|
||||
"random_word 0.4.3",
|
||||
"regex",
|
||||
@@ -4705,6 +4791,8 @@ dependencies = [
|
||||
"test-log",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tonic",
|
||||
"url",
|
||||
"uuid",
|
||||
"walkdir",
|
||||
@@ -5009,6 +5097,12 @@ version = "0.7.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
|
||||
|
||||
[[package]]
|
||||
name = "matchit"
|
||||
version = "0.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3"
|
||||
|
||||
[[package]]
|
||||
name = "matrixmultiply"
|
||||
version = "0.3.10"
|
||||
@@ -5322,7 +5416,7 @@ version = "0.50.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
|
||||
dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6302,7 +6396,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7"
|
||||
dependencies = [
|
||||
"heck 0.5.0",
|
||||
"itertools 0.11.0",
|
||||
"itertools 0.14.0",
|
||||
"log",
|
||||
"multimap",
|
||||
"petgraph",
|
||||
@@ -6321,7 +6415,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"itertools 0.11.0",
|
||||
"itertools 0.14.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.117",
|
||||
@@ -6527,7 +6621,7 @@ dependencies = [
|
||||
"once_cell",
|
||||
"socket2 0.6.3",
|
||||
"tracing",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.60.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7069,7 +7163,7 @@ dependencies = [
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys 0.12.1",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -8088,7 +8182,7 @@ dependencies = [
|
||||
"getrandom 0.4.2",
|
||||
"once_cell",
|
||||
"rustix 1.1.4",
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -8370,6 +8464,46 @@ dependencies = [
|
||||
"winnow",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tonic"
|
||||
version = "0.14.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fec7c61a0695dc1887c1b53952990f3ad2e3a31453e1f49f10e75424943a93ec"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"axum 0.8.8",
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"h2 0.4.13",
|
||||
"http 1.4.0",
|
||||
"http-body 1.0.1",
|
||||
"http-body-util",
|
||||
"hyper 1.8.1",
|
||||
"hyper-timeout",
|
||||
"hyper-util",
|
||||
"percent-encoding",
|
||||
"pin-project",
|
||||
"socket2 0.6.3",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tonic-prost"
|
||||
version = "0.14.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a55376a0bbaa4975a3f10d009ad763d8f4108f067c7c2e74f3001fb49778d309"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"prost",
|
||||
"tonic",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower"
|
||||
version = "0.5.3"
|
||||
@@ -8378,9 +8512,12 @@ checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"indexmap 2.13.0",
|
||||
"pin-project-lite",
|
||||
"slab",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
@@ -8893,7 +9030,7 @@ version = "0.1.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
|
||||
dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -39,6 +39,8 @@ arrow-ord = "57.2"
|
||||
arrow-schema = "57.2"
|
||||
arrow-select = "57.2"
|
||||
arrow-cast = "57.2"
|
||||
arrow-flight = { version = "57.2", features = ["flight-sql-experimental"] }
|
||||
tonic = "0.14"
|
||||
async-trait = "0"
|
||||
datafusion = { version = "52.1", default-features = false }
|
||||
datafusion-catalog = "52.1"
|
||||
|
||||
@@ -88,6 +88,10 @@ candle-transformers = { version = "0.9.1", optional = true }
|
||||
candle-nn = { version = "0.9.1", optional = true }
|
||||
tokenizers = { version = "0.19.1", optional = true }
|
||||
semver = { workspace = true }
|
||||
# For flight feature (Arrow Flight SQL server)
|
||||
arrow-flight = { workspace = true, optional = true }
|
||||
tonic = { workspace = true, optional = true }
|
||||
prost = { version = "0.14", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = "1"
|
||||
@@ -104,6 +108,7 @@ datafusion.workspace = true
|
||||
http-body = "1" # Matching reqwest
|
||||
rstest = "0.23.0"
|
||||
test-log = "0.2"
|
||||
tokio-stream = "0.1"
|
||||
|
||||
|
||||
[features]
|
||||
@@ -131,6 +136,7 @@ sentence-transformers = [
|
||||
"dep:candle-nn",
|
||||
"dep:tokenizers",
|
||||
]
|
||||
flight = ["dep:arrow-flight", "dep:tonic", "dep:prost"]
|
||||
|
||||
[[example]]
|
||||
name = "openai"
|
||||
@@ -157,5 +163,9 @@ name = "ivf_pq"
|
||||
name = "hybrid_search"
|
||||
required-features = ["sentence-transformers"]
|
||||
|
||||
[[example]]
|
||||
name = "flight_sql"
|
||||
required-features = ["flight"]
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
all-features = true
|
||||
|
||||
99
rust/lancedb/examples/flight_sql.rs
Normal file
99
rust/lancedb/examples/flight_sql.rs
Normal file
@@ -0,0 +1,99 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Example: LanceDB Arrow Flight SQL Server
|
||||
//!
|
||||
//! This example demonstrates how to:
|
||||
//! 1. Create a LanceDB database with sample data
|
||||
//! 2. Start an Arrow Flight SQL server
|
||||
//! 3. Connect with a Flight SQL client
|
||||
//! 4. Run SQL queries including vector_search and fts table functions
|
||||
//!
|
||||
//! Run with: `cargo run --features flight --example flight_sql`
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, RecordBatch, StringArray};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use lance_arrow::FixedSizeListArrayExt;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Initialize logging if env_logger is available
|
||||
let _ = std::env::var("RUST_LOG").ok();
|
||||
|
||||
// 1. Create an in-memory LanceDB database
|
||||
let db = lancedb::connect("memory://flight_sql_demo")
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
// 2. Create a table with text and vector data
|
||||
let dim = 4i32;
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Int32, false),
|
||||
Field::new("text", DataType::Utf8, false),
|
||||
Field::new(
|
||||
"vector",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
|
||||
true,
|
||||
),
|
||||
]));
|
||||
|
||||
let ids = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
let texts = StringArray::from(vec![
|
||||
"the quick brown fox jumps over the lazy dog",
|
||||
"a fast red fox leaps across the sleeping hound",
|
||||
"machine learning models process natural language",
|
||||
"neural networks learn from training data",
|
||||
"the brown dog chases the red fox through the forest",
|
||||
"deep learning algorithms improve with more data",
|
||||
"a lazy cat sleeps on the warm windowsill",
|
||||
"vector databases enable fast similarity search",
|
||||
]);
|
||||
let flat_values = Float32Array::from(vec![
|
||||
1.0, 0.0, 0.0, 0.0, // fox-like
|
||||
0.9, 0.1, 0.0, 0.0, // similar to fox
|
||||
0.0, 1.0, 0.0, 0.0, // ML-like
|
||||
0.0, 0.9, 0.1, 0.0, // similar to ML
|
||||
0.7, 0.3, 0.0, 0.0, // fox+dog mix
|
||||
0.0, 0.8, 0.2, 0.0, // ML-like
|
||||
0.1, 0.0, 0.0, 0.9, // cat-like
|
||||
0.0, 0.5, 0.5, 0.0, // tech-like
|
||||
]);
|
||||
let vector_array = FixedSizeListArray::try_new_from_values(flat_values, dim)?;
|
||||
|
||||
let batch = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(ids), Arc::new(texts), Arc::new(vector_array)],
|
||||
)?;
|
||||
|
||||
let table = db.create_table("documents", batch).execute().await?;
|
||||
|
||||
// 3. Create indices
|
||||
println!("Creating FTS index on 'text' column...");
|
||||
table
|
||||
.create_index(&["text"], lancedb::index::Index::FTS(Default::default()))
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
println!("Database ready with {} rows", table.count_rows(None).await?);
|
||||
|
||||
// 4. Start Flight SQL server
|
||||
let addr = "0.0.0.0:50051".parse()?;
|
||||
println!("Starting Arrow Flight SQL server on {}...", addr);
|
||||
println!();
|
||||
println!("Connect with any ADBC Flight SQL client:");
|
||||
println!(" URI: grpc://localhost:50051");
|
||||
println!();
|
||||
println!("Example SQL queries:");
|
||||
println!(" SELECT * FROM documents LIMIT 5;");
|
||||
println!(" SELECT * FROM vector_search('documents', '[1.0, 0.0, 0.0, 0.0]', 3);");
|
||||
println!(
|
||||
" SELECT * FROM fts('documents', '{{\"match\": {{\"column\": \"text\", \"terms\": \"fox\"}}}}');"
|
||||
);
|
||||
println!();
|
||||
|
||||
lancedb::flight::serve(db, addr).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
606
rust/lancedb/src/flight.rs
Normal file
606
rust/lancedb/src/flight.rs
Normal file
@@ -0,0 +1,606 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Arrow Flight SQL server for LanceDB.
|
||||
//!
|
||||
//! This module provides an Arrow Flight SQL server that exposes a LanceDB
|
||||
//! [`Connection`] over the Flight SQL protocol. Any ADBC, ODBC (via bridge),
|
||||
//! or JDBC Flight SQL client can connect and run SQL queries — including
|
||||
//! LanceDB's search table functions (`vector_search`, `fts`, `hybrid_search`).
|
||||
//!
|
||||
//! # Quick Start
|
||||
//!
|
||||
//! ```no_run
|
||||
//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
|
||||
//! let db = lancedb::connect("data/my-db").execute().await?;
|
||||
//! let addr = "0.0.0.0:50051".parse()?;
|
||||
//! lancedb::flight::serve(db, addr).await?;
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::{ArrayRef, RecordBatch, StringArray};
|
||||
use arrow_flight::encode::FlightDataEncoderBuilder;
|
||||
use arrow_flight::error::FlightError;
|
||||
use arrow_flight::flight_service_server::FlightServiceServer;
|
||||
use arrow_flight::sql::server::FlightSqlService;
|
||||
use arrow_flight::sql::{
|
||||
Any, CommandGetCatalogs, CommandGetDbSchemas, CommandGetTableTypes, CommandGetTables,
|
||||
CommandStatementQuery, SqlInfo, TicketStatementQuery,
|
||||
};
|
||||
use arrow_flight::{FlightDescriptor, FlightEndpoint, FlightInfo, Ticket};
|
||||
use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef};
|
||||
use datafusion::prelude::SessionContext;
|
||||
use datafusion_catalog::TableProvider;
|
||||
use datafusion_common::{DataFusionError, Result as DataFusionResult};
|
||||
use futures::StreamExt;
|
||||
use futures::stream;
|
||||
use log;
|
||||
use prost::Message;
|
||||
use tonic::transport::Server;
|
||||
use tonic::{Request, Response, Status};
|
||||
|
||||
use crate::connection::Connection;
|
||||
use crate::table::datafusion::BaseTableAdapter;
|
||||
use crate::table::datafusion::udtf::fts::FtsTableFunction;
|
||||
use crate::table::datafusion::udtf::hybrid_search::HybridSearchTableFunction;
|
||||
use crate::table::datafusion::udtf::vector_search::VectorSearchTableFunction;
|
||||
use crate::table::datafusion::udtf::{SearchQuery, TableResolver};
|
||||
|
||||
/// Start an Arrow Flight SQL server exposing the given LanceDB connection.
|
||||
///
|
||||
/// This is a convenience function that creates a server and starts listening.
|
||||
/// It blocks until the server is shut down.
|
||||
pub async fn serve(connection: Connection, addr: SocketAddr) -> crate::Result<()> {
|
||||
let service = LanceFlightSqlService::try_new(connection).await?;
|
||||
let flight_svc = FlightServiceServer::new(service);
|
||||
|
||||
Server::builder()
|
||||
.add_service(flight_svc)
|
||||
.serve(addr)
|
||||
.await
|
||||
.map_err(|e| crate::Error::Runtime {
|
||||
message: format!("Flight SQL server error: {}", e),
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// A table resolver that looks up tables from a pre-built HashMap.
|
||||
#[derive(Debug)]
|
||||
struct ConnectionTableResolver {
|
||||
tables: HashMap<String, Arc<BaseTableAdapter>>,
|
||||
}
|
||||
|
||||
impl TableResolver for ConnectionTableResolver {
|
||||
fn resolve_table(
|
||||
&self,
|
||||
name: &str,
|
||||
search: Option<SearchQuery>,
|
||||
) -> DataFusionResult<Arc<dyn TableProvider>> {
|
||||
let adapter = self
|
||||
.tables
|
||||
.get(name)
|
||||
.ok_or_else(|| DataFusionError::Plan(format!("Table '{}' not found", name)))?;
|
||||
|
||||
match search {
|
||||
None => Ok(adapter.clone() as Arc<dyn TableProvider>),
|
||||
Some(SearchQuery::Fts(fts)) => Ok(Arc::new(adapter.with_fts_query(fts))),
|
||||
Some(SearchQuery::Vector(vq)) => Ok(Arc::new(adapter.with_vector_query(vq))),
|
||||
Some(SearchQuery::Hybrid { fts, vector }) => {
|
||||
Ok(Arc::new(adapter.with_hybrid_query(fts, vector)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Arrow Flight SQL service backed by a LanceDB connection.
|
||||
struct LanceFlightSqlService {
|
||||
/// Kept for future use (e.g., refreshing table list).
|
||||
_connection: Connection,
|
||||
/// Pre-built table adapters (refreshed on creation)
|
||||
tables: HashMap<String, Arc<BaseTableAdapter>>,
|
||||
}
|
||||
|
||||
impl LanceFlightSqlService {
|
||||
async fn try_new(connection: Connection) -> crate::Result<Self> {
|
||||
let table_names = connection.table_names().execute().await?;
|
||||
let mut tables = HashMap::new();
|
||||
|
||||
for name in &table_names {
|
||||
let table = connection.open_table(name).execute().await?;
|
||||
let adapter = BaseTableAdapter::try_new(table.base_table().clone()).await?;
|
||||
tables.insert(name.clone(), Arc::new(adapter));
|
||||
}
|
||||
|
||||
log::info!(
|
||||
"LanceDB Flight SQL service initialized with {} tables",
|
||||
tables.len()
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
_connection: connection,
|
||||
tables,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a DataFusion SessionContext with all tables and UDTFs registered.
|
||||
fn create_session_context(&self) -> SessionContext {
|
||||
let ctx = SessionContext::new();
|
||||
|
||||
// Register all tables
|
||||
for (name, adapter) in &self.tables {
|
||||
if let Err(e) = ctx.register_table(name, adapter.clone() as Arc<dyn TableProvider>) {
|
||||
log::warn!("Failed to register table '{}': {}", name, e);
|
||||
}
|
||||
}
|
||||
|
||||
// Create resolver for UDTFs
|
||||
let resolver = Arc::new(ConnectionTableResolver {
|
||||
tables: self.tables.clone(),
|
||||
});
|
||||
|
||||
// Register search UDTFs
|
||||
ctx.register_udtf("fts", Arc::new(FtsTableFunction::new(resolver.clone())));
|
||||
ctx.register_udtf(
|
||||
"vector_search",
|
||||
Arc::new(VectorSearchTableFunction::new(resolver.clone())),
|
||||
);
|
||||
ctx.register_udtf(
|
||||
"hybrid_search",
|
||||
Arc::new(HybridSearchTableFunction::new(resolver)),
|
||||
);
|
||||
|
||||
ctx
|
||||
}
|
||||
|
||||
/// Execute a SQL query and return the results as a stream of FlightData.
|
||||
async fn execute_sql(
|
||||
&self,
|
||||
sql: &str,
|
||||
) -> Result<
|
||||
Pin<Box<dyn futures::Stream<Item = Result<arrow_flight::FlightData, Status>> + Send>>,
|
||||
Status,
|
||||
> {
|
||||
let ctx = self.create_session_context();
|
||||
|
||||
let df = ctx
|
||||
.sql(sql)
|
||||
.await
|
||||
.map_err(|e| Status::internal(format!("SQL planning error: {}", e)))?;
|
||||
|
||||
let schema: SchemaRef = df.schema().inner().clone();
|
||||
let stream = df
|
||||
.execute_stream()
|
||||
.await
|
||||
.map_err(|e| Status::internal(format!("SQL execution error: {}", e)))?;
|
||||
|
||||
// Use FlightDataEncoderBuilder to properly encode batches with schema
|
||||
let batch_stream = stream.map(|r| r.map_err(|e| FlightError::ExternalError(Box::new(e))));
|
||||
let flight_data_stream = FlightDataEncoderBuilder::new()
|
||||
.with_schema(schema)
|
||||
.build(batch_stream)
|
||||
.map(|result| result.map_err(|e| Status::internal(format!("Encoding error: {}", e))));
|
||||
|
||||
Ok(Box::pin(flight_data_stream))
|
||||
}
|
||||
|
||||
/// Encode a single RecordBatch into a FlightData stream (schema + data).
|
||||
fn batch_to_flight_stream(
|
||||
batch: RecordBatch,
|
||||
) -> Pin<Box<dyn futures::Stream<Item = Result<arrow_flight::FlightData, Status>> + Send>> {
|
||||
let schema = batch.schema();
|
||||
let stream = FlightDataEncoderBuilder::new()
|
||||
.with_schema(schema)
|
||||
.build(stream::once(async move { Ok(batch) }))
|
||||
.map(|result| result.map_err(|e| Status::internal(format!("Encoding error: {}", e))));
|
||||
Box::pin(stream)
|
||||
}
|
||||
|
||||
/// Get the schema for a SQL query without executing it.
|
||||
async fn get_sql_schema(&self, sql: &str) -> Result<ArrowSchema, Status> {
|
||||
let ctx = self.create_session_context();
|
||||
let df = ctx
|
||||
.sql(sql)
|
||||
.await
|
||||
.map_err(|e| Status::internal(format!("SQL planning error: {}", e)))?;
|
||||
Ok(df.schema().inner().as_ref().clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl FlightSqlService for LanceFlightSqlService {
|
||||
type FlightService = LanceFlightSqlService;
|
||||
|
||||
/// Handle SQL query: return FlightInfo with schema and ticket.
|
||||
async fn get_flight_info_statement(
|
||||
&self,
|
||||
query: CommandStatementQuery,
|
||||
request: Request<FlightDescriptor>,
|
||||
) -> Result<Response<FlightInfo>, Status> {
|
||||
let sql = query.query;
|
||||
log::info!("get_flight_info_statement: {}", sql);
|
||||
|
||||
let schema = self.get_sql_schema(&sql).await?;
|
||||
|
||||
// Encode the query as an Any-wrapped TicketStatementQuery for do_get_statement
|
||||
let ticket = TicketStatementQuery {
|
||||
statement_handle: sql.into_bytes().into(),
|
||||
};
|
||||
let any_msg = Any::pack(&ticket)
|
||||
.map_err(|e| Status::internal(format!("Ticket encoding error: {}", e)))?;
|
||||
let mut ticket_bytes = Vec::new();
|
||||
any_msg
|
||||
.encode(&mut ticket_bytes)
|
||||
.map_err(|e| Status::internal(format!("Ticket encoding error: {}", e)))?;
|
||||
|
||||
let endpoint = FlightEndpoint::new().with_ticket(Ticket::new(ticket_bytes));
|
||||
let flight_info = FlightInfo::new()
|
||||
.try_with_schema(&schema)
|
||||
.map_err(|e| Status::internal(format!("Schema error: {}", e)))?
|
||||
.with_endpoint(endpoint)
|
||||
.with_descriptor(request.into_inner());
|
||||
|
||||
Ok(Response::new(flight_info))
|
||||
}
|
||||
|
||||
/// Execute a SQL query and stream results.
|
||||
async fn do_get_statement(
|
||||
&self,
|
||||
ticket: TicketStatementQuery,
|
||||
_request: Request<Ticket>,
|
||||
) -> Result<
|
||||
Response<<Self as arrow_flight::flight_service_server::FlightService>::DoGetStream>,
|
||||
Status,
|
||||
> {
|
||||
let sql = String::from_utf8(ticket.statement_handle.to_vec())
|
||||
.map_err(|e| Status::internal(format!("Invalid ticket: {}", e)))?;
|
||||
log::info!("do_get_statement: {}", sql);
|
||||
|
||||
let stream = self.execute_sql(&sql).await?;
|
||||
Ok(Response::new(stream))
|
||||
}
|
||||
|
||||
/// List tables in the database.
|
||||
async fn get_flight_info_tables(
|
||||
&self,
|
||||
_query: CommandGetTables,
|
||||
request: Request<FlightDescriptor>,
|
||||
) -> Result<Response<FlightInfo>, Status> {
|
||||
let schema = ArrowSchema::new(vec![
|
||||
Field::new("catalog_name", DataType::Utf8, true),
|
||||
Field::new("db_schema_name", DataType::Utf8, true),
|
||||
Field::new("table_name", DataType::Utf8, false),
|
||||
Field::new("table_type", DataType::Utf8, false),
|
||||
]);
|
||||
|
||||
let cmd = CommandGetTables::default();
|
||||
let any_msg =
|
||||
Any::pack(&cmd).map_err(|e| Status::internal(format!("Encoding error: {}", e)))?;
|
||||
let mut ticket_bytes = Vec::new();
|
||||
any_msg
|
||||
.encode(&mut ticket_bytes)
|
||||
.map_err(|e| Status::internal(format!("Encoding error: {}", e)))?;
|
||||
|
||||
let endpoint = FlightEndpoint::new().with_ticket(Ticket::new(ticket_bytes));
|
||||
let flight_info = FlightInfo::new()
|
||||
.try_with_schema(&schema)
|
||||
.map_err(|e| Status::internal(format!("Schema error: {}", e)))?
|
||||
.with_endpoint(endpoint)
|
||||
.with_descriptor(request.into_inner());
|
||||
|
||||
Ok(Response::new(flight_info))
|
||||
}
|
||||
|
||||
async fn do_get_tables(
|
||||
&self,
|
||||
_query: CommandGetTables,
|
||||
_request: Request<Ticket>,
|
||||
) -> Result<
|
||||
Response<<Self as arrow_flight::flight_service_server::FlightService>::DoGetStream>,
|
||||
Status,
|
||||
> {
|
||||
let table_names: Vec<&str> = self.tables.keys().map(|s| s.as_str()).collect();
|
||||
let num_tables = table_names.len();
|
||||
|
||||
let schema = Arc::new(ArrowSchema::new(vec![
|
||||
Field::new("catalog_name", DataType::Utf8, true),
|
||||
Field::new("db_schema_name", DataType::Utf8, true),
|
||||
Field::new("table_name", DataType::Utf8, false),
|
||||
Field::new("table_type", DataType::Utf8, false),
|
||||
]));
|
||||
|
||||
let catalog_names: ArrayRef =
|
||||
Arc::new(StringArray::from(vec![Some("lancedb"); num_tables]));
|
||||
let schema_names: ArrayRef = Arc::new(StringArray::from(vec![Some("default"); num_tables]));
|
||||
let table_name_array: ArrayRef = Arc::new(StringArray::from(table_names));
|
||||
let table_types: ArrayRef = Arc::new(StringArray::from(vec!["TABLE"; num_tables]));
|
||||
|
||||
let batch = RecordBatch::try_new(
|
||||
schema,
|
||||
vec![catalog_names, schema_names, table_name_array, table_types],
|
||||
)
|
||||
.map_err(|e| Status::internal(format!("RecordBatch error: {}", e)))?;
|
||||
|
||||
Ok(Response::new(Self::batch_to_flight_stream(batch)))
|
||||
}
|
||||
|
||||
/// List table types.
|
||||
async fn get_flight_info_table_types(
|
||||
&self,
|
||||
_query: CommandGetTableTypes,
|
||||
request: Request<FlightDescriptor>,
|
||||
) -> Result<Response<FlightInfo>, Status> {
|
||||
let schema = ArrowSchema::new(vec![Field::new("table_type", DataType::Utf8, false)]);
|
||||
|
||||
let cmd = CommandGetTableTypes::default();
|
||||
let any_msg =
|
||||
Any::pack(&cmd).map_err(|e| Status::internal(format!("Encoding error: {}", e)))?;
|
||||
let mut ticket_bytes = Vec::new();
|
||||
any_msg
|
||||
.encode(&mut ticket_bytes)
|
||||
.map_err(|e| Status::internal(format!("Encoding error: {}", e)))?;
|
||||
|
||||
let endpoint = FlightEndpoint::new().with_ticket(Ticket::new(ticket_bytes));
|
||||
let flight_info = FlightInfo::new()
|
||||
.try_with_schema(&schema)
|
||||
.map_err(|e| Status::internal(format!("Schema error: {}", e)))?
|
||||
.with_endpoint(endpoint)
|
||||
.with_descriptor(request.into_inner());
|
||||
|
||||
Ok(Response::new(flight_info))
|
||||
}
|
||||
|
||||
async fn do_get_table_types(
|
||||
&self,
|
||||
_query: CommandGetTableTypes,
|
||||
_request: Request<Ticket>,
|
||||
) -> Result<
|
||||
Response<<Self as arrow_flight::flight_service_server::FlightService>::DoGetStream>,
|
||||
Status,
|
||||
> {
|
||||
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
|
||||
"table_type",
|
||||
DataType::Utf8,
|
||||
false,
|
||||
)]));
|
||||
let table_types: ArrayRef = Arc::new(StringArray::from(vec!["TABLE"]));
|
||||
let batch = RecordBatch::try_new(schema, vec![table_types])
|
||||
.map_err(|e| Status::internal(format!("RecordBatch error: {}", e)))?;
|
||||
|
||||
Ok(Response::new(Self::batch_to_flight_stream(batch)))
|
||||
}
|
||||
|
||||
/// List catalogs.
|
||||
async fn get_flight_info_catalogs(
|
||||
&self,
|
||||
_query: CommandGetCatalogs,
|
||||
request: Request<FlightDescriptor>,
|
||||
) -> Result<Response<FlightInfo>, Status> {
|
||||
let schema = ArrowSchema::new(vec![Field::new("catalog_name", DataType::Utf8, false)]);
|
||||
|
||||
let cmd = CommandGetCatalogs::default();
|
||||
let any_msg =
|
||||
Any::pack(&cmd).map_err(|e| Status::internal(format!("Encoding error: {}", e)))?;
|
||||
let mut ticket_bytes = Vec::new();
|
||||
any_msg
|
||||
.encode(&mut ticket_bytes)
|
||||
.map_err(|e| Status::internal(format!("Encoding error: {}", e)))?;
|
||||
|
||||
let endpoint = FlightEndpoint::new().with_ticket(Ticket::new(ticket_bytes));
|
||||
let flight_info = FlightInfo::new()
|
||||
.try_with_schema(&schema)
|
||||
.map_err(|e| Status::internal(format!("Schema error: {}", e)))?
|
||||
.with_endpoint(endpoint)
|
||||
.with_descriptor(request.into_inner());
|
||||
|
||||
Ok(Response::new(flight_info))
|
||||
}
|
||||
|
||||
async fn do_get_catalogs(
|
||||
&self,
|
||||
_query: CommandGetCatalogs,
|
||||
_request: Request<Ticket>,
|
||||
) -> Result<
|
||||
Response<<Self as arrow_flight::flight_service_server::FlightService>::DoGetStream>,
|
||||
Status,
|
||||
> {
|
||||
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
|
||||
"catalog_name",
|
||||
DataType::Utf8,
|
||||
false,
|
||||
)]));
|
||||
let catalogs: ArrayRef = Arc::new(StringArray::from(vec!["lancedb"]));
|
||||
let batch = RecordBatch::try_new(schema, vec![catalogs])
|
||||
.map_err(|e| Status::internal(format!("RecordBatch error: {}", e)))?;
|
||||
|
||||
Ok(Response::new(Self::batch_to_flight_stream(batch)))
|
||||
}
|
||||
|
||||
/// List schemas.
|
||||
async fn get_flight_info_schemas(
|
||||
&self,
|
||||
_query: CommandGetDbSchemas,
|
||||
request: Request<FlightDescriptor>,
|
||||
) -> Result<Response<FlightInfo>, Status> {
|
||||
let schema = ArrowSchema::new(vec![
|
||||
Field::new("catalog_name", DataType::Utf8, true),
|
||||
Field::new("db_schema_name", DataType::Utf8, false),
|
||||
]);
|
||||
|
||||
let cmd = CommandGetDbSchemas::default();
|
||||
let any_msg =
|
||||
Any::pack(&cmd).map_err(|e| Status::internal(format!("Encoding error: {}", e)))?;
|
||||
let mut ticket_bytes = Vec::new();
|
||||
any_msg
|
||||
.encode(&mut ticket_bytes)
|
||||
.map_err(|e| Status::internal(format!("Encoding error: {}", e)))?;
|
||||
|
||||
let endpoint = FlightEndpoint::new().with_ticket(Ticket::new(ticket_bytes));
|
||||
let flight_info = FlightInfo::new()
|
||||
.try_with_schema(&schema)
|
||||
.map_err(|e| Status::internal(format!("Schema error: {}", e)))?
|
||||
.with_endpoint(endpoint)
|
||||
.with_descriptor(request.into_inner());
|
||||
|
||||
Ok(Response::new(flight_info))
|
||||
}
|
||||
|
||||
async fn do_get_schemas(
|
||||
&self,
|
||||
_query: CommandGetDbSchemas,
|
||||
_request: Request<Ticket>,
|
||||
) -> Result<
|
||||
Response<<Self as arrow_flight::flight_service_server::FlightService>::DoGetStream>,
|
||||
Status,
|
||||
> {
|
||||
let schema = Arc::new(ArrowSchema::new(vec![
|
||||
Field::new("catalog_name", DataType::Utf8, true),
|
||||
Field::new("db_schema_name", DataType::Utf8, false),
|
||||
]));
|
||||
let catalogs: ArrayRef = Arc::new(StringArray::from(vec![Some("lancedb")]));
|
||||
let schemas: ArrayRef = Arc::new(StringArray::from(vec!["default"]));
|
||||
let batch = RecordBatch::try_new(schema, vec![catalogs, schemas])
|
||||
.map_err(|e| Status::internal(format!("RecordBatch error: {}", e)))?;
|
||||
|
||||
Ok(Response::new(Self::batch_to_flight_stream(batch)))
|
||||
}
|
||||
|
||||
async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use arrow_array::{Float32Array, Int32Array};
|
||||
use arrow_flight::sql::client::FlightSqlServiceClient;
|
||||
use futures::TryStreamExt;
|
||||
use lance_arrow::FixedSizeListArrayExt;
|
||||
use std::time::Duration;
|
||||
|
||||
async fn create_test_db() -> Connection {
|
||||
let db = crate::connect("memory://flight_test")
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let dim = 4i32;
|
||||
let schema = Arc::new(ArrowSchema::new(vec![
|
||||
Field::new("id", DataType::Int32, false),
|
||||
Field::new("text", DataType::Utf8, false),
|
||||
Field::new(
|
||||
"vector",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
|
||||
true,
|
||||
),
|
||||
]));
|
||||
|
||||
let ids = Int32Array::from(vec![1, 2, 3]);
|
||||
let texts = StringArray::from(vec!["hello world", "foo bar", "baz qux"]);
|
||||
let flat_values = Float32Array::from(vec![
|
||||
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
|
||||
]);
|
||||
let vectors =
|
||||
arrow_array::FixedSizeListArray::try_new_from_values(flat_values, dim).unwrap();
|
||||
|
||||
let batch = RecordBatch::try_new(
|
||||
schema,
|
||||
vec![Arc::new(ids), Arc::new(texts), Arc::new(vectors)],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let table = db
|
||||
.create_table("test_table", batch)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Create FTS index
|
||||
table
|
||||
.create_index(&["text"], crate::index::Index::FTS(Default::default()))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
db
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_flight_sql_basic_query() {
|
||||
let db = create_test_db().await;
|
||||
let service = LanceFlightSqlService::try_new(db).await.unwrap();
|
||||
let flight_svc = FlightServiceServer::new(service);
|
||||
|
||||
// Start server on a random port
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let server_handle = tokio::spawn(async move {
|
||||
Server::builder()
|
||||
.add_service(flight_svc)
|
||||
.serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
// Give server a moment to start
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
|
||||
// Connect client
|
||||
let channel = tonic::transport::Channel::from_shared(format!("http://{}", addr))
|
||||
.unwrap()
|
||||
.connect()
|
||||
.await
|
||||
.unwrap();
|
||||
let mut client = FlightSqlServiceClient::new(channel);
|
||||
|
||||
// Execute SQL query
|
||||
let flight_info = client
|
||||
.execute("SELECT id, text FROM test_table".to_string(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Fetch results using the FlightSql client's do_get
|
||||
let ticket = flight_info.endpoint[0].ticket.as_ref().unwrap().clone();
|
||||
let flight_stream = client.do_get(ticket).await.unwrap();
|
||||
|
||||
let batches: Vec<RecordBatch> = flight_stream.try_collect().await.unwrap();
|
||||
|
||||
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
|
||||
assert_eq!(total_rows, 3, "Should return all 3 rows");
|
||||
|
||||
// Verify schema
|
||||
if let Some(batch) = batches.first() {
|
||||
assert!(batch.schema().column_with_name("id").is_some());
|
||||
assert!(batch.schema().column_with_name("text").is_some());
|
||||
}
|
||||
|
||||
server_handle.abort();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_flight_sql_table_listing() {
|
||||
let db = create_test_db().await;
|
||||
let service = LanceFlightSqlService::try_new(db).await.unwrap();
|
||||
|
||||
assert!(service.tables.contains_key("test_table"));
|
||||
assert_eq!(service.tables.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_flight_sql_session_context() {
|
||||
let db = create_test_db().await;
|
||||
let service = LanceFlightSqlService::try_new(db).await.unwrap();
|
||||
let ctx = service.create_session_context();
|
||||
|
||||
// Test that we can execute a simple query
|
||||
let df = ctx.sql("SELECT * FROM test_table LIMIT 1").await.unwrap();
|
||||
let results = df.collect().await.unwrap();
|
||||
assert_eq!(results[0].num_rows(), 1);
|
||||
}
|
||||
}
|
||||
@@ -170,6 +170,8 @@ pub mod dataloader;
|
||||
pub mod embeddings;
|
||||
pub mod error;
|
||||
pub mod expr;
|
||||
#[cfg(feature = "flight")]
|
||||
pub mod flight;
|
||||
pub mod index;
|
||||
pub mod io;
|
||||
pub mod ipc;
|
||||
|
||||
@@ -26,9 +26,10 @@ use lance::dataset::{WriteMode, WriteParams};
|
||||
|
||||
use super::{AnyQuery, BaseTable};
|
||||
use crate::{
|
||||
Result,
|
||||
query::{QueryExecutionOptions, QueryFilter, QueryRequest, Select},
|
||||
DistanceType, Result,
|
||||
query::{QueryExecutionOptions, QueryFilter, QueryRequest, Select, VectorQueryRequest},
|
||||
};
|
||||
use arrow_array::Array;
|
||||
use arrow_schema::{DataType, Field};
|
||||
use lance_index::scalar::FullTextSearchQuery;
|
||||
|
||||
@@ -141,11 +142,31 @@ impl ExecutionPlan for MetadataEraserExec {
|
||||
}
|
||||
}
|
||||
|
||||
/// Parameters for a vector search query, used by vector_search and hybrid_search UDTFs.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VectorSearchParams {
|
||||
/// The query vector to search for
|
||||
pub query_vector: Arc<dyn Array>,
|
||||
/// The column to search on (None for auto-detection)
|
||||
pub column: Option<String>,
|
||||
/// Number of results to return
|
||||
pub top_k: usize,
|
||||
/// Distance metric to use
|
||||
pub distance_type: Option<DistanceType>,
|
||||
/// Number of IVF partitions to search
|
||||
pub nprobes: Option<usize>,
|
||||
/// HNSW search parameter
|
||||
pub ef: Option<usize>,
|
||||
/// Refine factor for improving recall
|
||||
pub refine_factor: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct BaseTableAdapter {
|
||||
table: Arc<dyn BaseTable>,
|
||||
schema: Arc<ArrowSchema>,
|
||||
fts_query: Option<FullTextSearchQuery>,
|
||||
vector_query: Option<VectorSearchParams>,
|
||||
}
|
||||
|
||||
impl BaseTableAdapter {
|
||||
@@ -161,6 +182,7 @@ impl BaseTableAdapter {
|
||||
table,
|
||||
schema: Arc::new(schema),
|
||||
fts_query: None,
|
||||
vector_query: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -176,6 +198,49 @@ impl BaseTableAdapter {
|
||||
table: self.table.clone(),
|
||||
schema,
|
||||
fts_query: Some(fts_query),
|
||||
vector_query: self.vector_query.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new adapter with a vector search query applied.
|
||||
pub fn with_vector_query(&self, vector_query: VectorSearchParams) -> Self {
|
||||
// Add _distance column to the schema
|
||||
let distance_field = Field::new("_distance", DataType::Float32, true);
|
||||
let mut fields = self.schema.fields().to_vec();
|
||||
fields.push(Arc::new(distance_field));
|
||||
let schema = Arc::new(ArrowSchema::new(fields));
|
||||
|
||||
Self {
|
||||
table: self.table.clone(),
|
||||
schema,
|
||||
fts_query: self.fts_query.clone(),
|
||||
vector_query: Some(vector_query),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new adapter with both FTS and vector search queries (hybrid search).
|
||||
///
|
||||
/// Uses vector search as the primary retrieval method, with FTS applied as a
|
||||
/// pre-filter to restrict the candidate set. Both `_distance` and `_score`
|
||||
/// columns are added to results.
|
||||
pub fn with_hybrid_query(
|
||||
&self,
|
||||
fts_query: FullTextSearchQuery,
|
||||
vector_query: VectorSearchParams,
|
||||
) -> Self {
|
||||
// Add _distance column (vector search is primary)
|
||||
let mut fields = self.schema.fields().to_vec();
|
||||
fields.push(Arc::new(Field::new("_distance", DataType::Float32, true)));
|
||||
let schema = Arc::new(ArrowSchema::new(fields));
|
||||
|
||||
// Store FTS as a filter concept, but vector search drives the query.
|
||||
// The FTS query is applied via the base QueryRequest's full_text_search
|
||||
// field, which acts as a pre-filter for vector search.
|
||||
Self {
|
||||
table: self.table.clone(),
|
||||
schema,
|
||||
fts_query: Some(fts_query),
|
||||
vector_query: Some(vector_query),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -201,11 +266,20 @@ impl TableProvider for BaseTableAdapter {
|
||||
filters: &[Expr],
|
||||
limit: Option<usize>,
|
||||
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
|
||||
// For FTS queries, disable auto-projection of _score to match DataFusion expectations
|
||||
let disable_scoring = self.fts_query.is_some() && projection.is_some();
|
||||
let has_scoring = self.fts_query.is_some() || self.vector_query.is_some();
|
||||
let disable_scoring = has_scoring && projection.is_some();
|
||||
|
||||
let mut query = QueryRequest {
|
||||
full_text_search: self.fts_query.clone(),
|
||||
// When doing vector search, FTS cannot be combined in the same scanner
|
||||
// (Lance doesn't support both nearest + full_text_search simultaneously).
|
||||
// FTS is only set when there's no vector query.
|
||||
let fts_for_query = if self.vector_query.is_some() {
|
||||
None
|
||||
} else {
|
||||
self.fts_query.clone()
|
||||
};
|
||||
|
||||
let mut base_query = QueryRequest {
|
||||
full_text_search: fts_for_query,
|
||||
disable_scoring_autoprojection: disable_scoring,
|
||||
..Default::default()
|
||||
};
|
||||
@@ -215,20 +289,20 @@ impl TableProvider for BaseTableAdapter {
|
||||
.iter()
|
||||
.map(|i| self.schema.field(*i).name().clone())
|
||||
.collect();
|
||||
query.select = Select::Columns(field_names);
|
||||
base_query.select = Select::Columns(field_names);
|
||||
}
|
||||
if !filters.is_empty() {
|
||||
let first = filters.first().unwrap().clone();
|
||||
let filter = filters[1..]
|
||||
.iter()
|
||||
.fold(first, |acc, expr| acc.and(expr.clone()));
|
||||
query.filter = Some(QueryFilter::Datafusion(filter));
|
||||
base_query.filter = Some(QueryFilter::Datafusion(filter));
|
||||
}
|
||||
if let Some(limit) = limit {
|
||||
query.limit = Some(limit);
|
||||
} else {
|
||||
// Need to override the default of 10
|
||||
query.limit = None;
|
||||
base_query.limit = Some(limit);
|
||||
} else if self.vector_query.is_none() {
|
||||
// Need to override the default of 10 for non-vector queries
|
||||
base_query.limit = None;
|
||||
}
|
||||
|
||||
let options = QueryExecutionOptions {
|
||||
@@ -236,9 +310,33 @@ impl TableProvider for BaseTableAdapter {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Build the appropriate query type
|
||||
let any_query = if let Some(ref vq) = self.vector_query {
|
||||
let vector_query = VectorQueryRequest {
|
||||
base: base_query,
|
||||
column: vq.column.clone(),
|
||||
query_vector: vec![vq.query_vector.clone()],
|
||||
minimum_nprobes: vq.nprobes.unwrap_or(20),
|
||||
maximum_nprobes: vq.nprobes,
|
||||
ef: vq.ef,
|
||||
refine_factor: vq.refine_factor,
|
||||
distance_type: vq.distance_type,
|
||||
use_index: true,
|
||||
..Default::default()
|
||||
};
|
||||
// For vector queries, use top_k as the limit if no explicit limit set
|
||||
let mut vq_req = vector_query;
|
||||
if limit.is_none() {
|
||||
vq_req.base.limit = Some(vq.top_k);
|
||||
}
|
||||
AnyQuery::VectorQuery(vq_req)
|
||||
} else {
|
||||
AnyQuery::Query(base_query)
|
||||
};
|
||||
|
||||
let plan = self
|
||||
.table
|
||||
.create_plan(&AnyQuery::Query(query), options)
|
||||
.create_plan(&any_query, options)
|
||||
.map_err(|err| DataFusionError::External(err.into()))
|
||||
.await?;
|
||||
Ok(Arc::new(MetadataEraserExec::new(plan)))
|
||||
|
||||
@@ -2,5 +2,75 @@
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! User-Defined Table Functions (UDTFs) for DataFusion integration
|
||||
//!
|
||||
//! This module provides SQL table functions for LanceDB search capabilities:
|
||||
//! - `fts(table_name, query_json)` — full-text search
|
||||
//! - `vector_search(table_name, query_vector_json, top_k)` — vector similarity search
|
||||
//! - `hybrid_search(table_name, query_vector_json, fts_query_json, top_k)` — combined search
|
||||
|
||||
pub mod fts;
|
||||
pub mod hybrid_search;
|
||||
pub mod vector_search;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_catalog::TableProvider;
|
||||
use datafusion_common::{DataFusionError, Result as DataFusionResult, ScalarValue};
|
||||
use datafusion_expr::Expr;
|
||||
use lance_index::scalar::FullTextSearchQuery;
|
||||
|
||||
use super::VectorSearchParams;
|
||||
|
||||
/// Describes the type of search to apply when resolving a table.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SearchQuery {
|
||||
/// Full-text search only
|
||||
Fts(FullTextSearchQuery),
|
||||
/// Vector similarity search only
|
||||
Vector(VectorSearchParams),
|
||||
/// Hybrid search combining FTS and vector search
|
||||
Hybrid {
|
||||
fts: FullTextSearchQuery,
|
||||
vector: VectorSearchParams,
|
||||
},
|
||||
}
|
||||
|
||||
/// Trait for resolving table names to TableProvider instances, optionally with a search query.
|
||||
pub trait TableResolver: std::fmt::Debug + Send + Sync {
|
||||
/// Resolve a table name to a TableProvider, optionally applying a search query.
|
||||
fn resolve_table(
|
||||
&self,
|
||||
name: &str,
|
||||
search: Option<SearchQuery>,
|
||||
) -> DataFusionResult<Arc<dyn TableProvider>>;
|
||||
}
|
||||
|
||||
/// Extract a string literal from a DataFusion expression.
|
||||
pub(crate) fn extract_string_literal(expr: &Expr, param_name: &str) -> DataFusionResult<String> {
|
||||
match expr {
|
||||
Expr::Literal(ScalarValue::Utf8(Some(s)), _) => Ok(s.clone()),
|
||||
Expr::Literal(ScalarValue::LargeUtf8(Some(s)), _) => Ok(s.clone()),
|
||||
_ => Err(DataFusionError::Plan(format!(
|
||||
"Parameter '{}' must be a string literal, got: {:?}",
|
||||
param_name, expr
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract an integer literal from a DataFusion expression.
|
||||
pub(crate) fn extract_int_literal(expr: &Expr, param_name: &str) -> DataFusionResult<usize> {
|
||||
match expr {
|
||||
Expr::Literal(ScalarValue::Int8(Some(v)), _) => Ok(*v as usize),
|
||||
Expr::Literal(ScalarValue::Int16(Some(v)), _) => Ok(*v as usize),
|
||||
Expr::Literal(ScalarValue::Int32(Some(v)), _) => Ok(*v as usize),
|
||||
Expr::Literal(ScalarValue::Int64(Some(v)), _) => Ok(*v as usize),
|
||||
Expr::Literal(ScalarValue::UInt8(Some(v)), _) => Ok(*v as usize),
|
||||
Expr::Literal(ScalarValue::UInt16(Some(v)), _) => Ok(*v as usize),
|
||||
Expr::Literal(ScalarValue::UInt32(Some(v)), _) => Ok(*v as usize),
|
||||
Expr::Literal(ScalarValue::UInt64(Some(v)), _) => Ok(*v as usize),
|
||||
_ => Err(DataFusionError::Plan(format!(
|
||||
"Parameter '{}' must be an integer literal, got: {:?}",
|
||||
param_name, expr
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,29 +1,23 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! User-Defined Table Functions (UDTFs) for LanceDB
|
||||
//! Full-Text Search (FTS) table function for DataFusion SQL integration.
|
||||
//!
|
||||
//! This module provides table-level UDTFs that integrate with DataFusion's SQL engine.
|
||||
//! Usage: `SELECT * FROM fts('table_name', '{"match": {"column": "text", "terms": "query"}}')`
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion::catalog::TableFunctionImpl;
|
||||
use datafusion_catalog::TableProvider;
|
||||
use datafusion_common::{DataFusionError, Result as DataFusionResult, ScalarValue, plan_err};
|
||||
use datafusion_common::{DataFusionError, Result as DataFusionResult, plan_err};
|
||||
use datafusion_expr::Expr;
|
||||
use lance_index::scalar::FullTextSearchQuery;
|
||||
|
||||
/// Trait for resolving table names to TableProvider instances.
|
||||
pub trait TableResolver: std::fmt::Debug + Send + Sync {
|
||||
/// Resolve a table name to a TableProvider, optionally with an FTS query applied.
|
||||
fn resolve_table(
|
||||
&self,
|
||||
name: &str,
|
||||
fts_query: Option<FullTextSearchQuery>,
|
||||
) -> DataFusionResult<Arc<dyn TableProvider>>;
|
||||
}
|
||||
use super::{SearchQuery, TableResolver, extract_string_literal};
|
||||
|
||||
/// Full-Text Search table function that operates on LanceDB tables
|
||||
/// Full-Text Search table function that operates on LanceDB tables.
|
||||
///
|
||||
/// Accepts 2 parameters: `fts(table_name, fts_query_json)`
|
||||
#[derive(Debug)]
|
||||
pub struct FtsTableFunction {
|
||||
resolver: Arc<dyn TableResolver>,
|
||||
@@ -45,20 +39,8 @@ impl TableFunctionImpl for FtsTableFunction {
|
||||
let query_json = extract_string_literal(&exprs[1], "fts_query")?;
|
||||
let fts_query = parse_fts_query(&query_json)?;
|
||||
|
||||
// Resolver returns a ready-to-use TableProvider with FTS applied
|
||||
self.resolver.resolve_table(&table_name, Some(fts_query))
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_string_literal(expr: &Expr, param_name: &str) -> DataFusionResult<String> {
|
||||
match expr {
|
||||
Expr::Literal(ScalarValue::Utf8(Some(s)), _) => Ok(s.clone()),
|
||||
Expr::Literal(ScalarValue::LargeUtf8(Some(s)), _) => Ok(s.clone()),
|
||||
_ => plan_err!(
|
||||
"Parameter '{}' must be a string literal, got: {:?}",
|
||||
param_name,
|
||||
expr
|
||||
),
|
||||
self.resolver
|
||||
.resolve_table(&table_name, Some(SearchQuery::Fts(fts_query)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,6 +73,7 @@ pub fn from_json(json: &str) -> crate::Result<lance_index::scalar::inverted::que
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::{SearchQuery, TableResolver};
|
||||
use super::*;
|
||||
use crate::{
|
||||
Connection, Table,
|
||||
@@ -100,6 +83,7 @@ mod tests {
|
||||
use arrow_array::{Int32Array, RecordBatch, StringArray};
|
||||
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
|
||||
use datafusion::prelude::SessionContext;
|
||||
use datafusion_common::DataFusionError;
|
||||
|
||||
/// Resolver that looks up tables in a HashMap
|
||||
#[derive(Debug)]
|
||||
@@ -123,7 +107,7 @@ mod tests {
|
||||
fn resolve_table(
|
||||
&self,
|
||||
name: &str,
|
||||
fts_query: Option<FullTextSearchQuery>,
|
||||
search: Option<SearchQuery>,
|
||||
) -> DataFusionResult<Arc<dyn TableProvider>> {
|
||||
let table_provider = self
|
||||
.tables
|
||||
@@ -131,12 +115,10 @@ mod tests {
|
||||
.cloned()
|
||||
.ok_or_else(|| DataFusionError::Plan(format!("Table '{}' not found", name)))?;
|
||||
|
||||
// If no FTS query, return as-is
|
||||
let Some(fts_query) = fts_query else {
|
||||
let Some(search) = search else {
|
||||
return Ok(table_provider);
|
||||
};
|
||||
|
||||
// Downcast to BaseTableAdapter and apply FTS query
|
||||
let base_adapter = table_provider
|
||||
.as_any()
|
||||
.downcast_ref::<BaseTableAdapter>()
|
||||
@@ -146,7 +128,15 @@ mod tests {
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Arc::new(base_adapter.with_fts_query(fts_query)))
|
||||
match search {
|
||||
SearchQuery::Fts(fts_query) => Ok(Arc::new(base_adapter.with_fts_query(fts_query))),
|
||||
SearchQuery::Vector(vector_query) => {
|
||||
Ok(Arc::new(base_adapter.with_vector_query(vector_query)))
|
||||
}
|
||||
SearchQuery::Hybrid { fts, vector } => {
|
||||
Ok(Arc::new(base_adapter.with_hybrid_query(fts, vector)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
257
rust/lancedb/src/table/datafusion/udtf/hybrid_search.rs
Normal file
257
rust/lancedb/src/table/datafusion/udtf/hybrid_search.rs
Normal file
@@ -0,0 +1,257 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Hybrid search table function for DataFusion SQL integration.
|
||||
//!
|
||||
//! Combines vector similarity search with full-text search:
|
||||
//! ```sql
|
||||
//! SELECT * FROM hybrid_search(
|
||||
//! 'my_table',
|
||||
//! '[0.1, 0.2, 0.3]',
|
||||
//! '{"match": {"column": "text", "terms": "search query"}}',
|
||||
//! 10
|
||||
//! )
|
||||
//! ```
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::Array;
|
||||
use datafusion::catalog::TableFunctionImpl;
|
||||
use datafusion_catalog::TableProvider;
|
||||
use datafusion_common::{DataFusionError, Result as DataFusionResult, plan_err};
|
||||
use datafusion_expr::Expr;
|
||||
use lance_index::scalar::FullTextSearchQuery;
|
||||
|
||||
use super::fts::from_json as fts_from_json;
|
||||
use super::{SearchQuery, TableResolver, extract_int_literal, extract_string_literal};
|
||||
use crate::table::datafusion::VectorSearchParams;
|
||||
|
||||
/// Default number of results for hybrid search when top_k is not specified.
|
||||
const DEFAULT_TOP_K: usize = 10;
|
||||
|
||||
/// Hybrid search table function combining vector and full-text search.
|
||||
///
|
||||
/// Accepts 3-4 parameters: `hybrid_search(table_name, query_vector_json, fts_query_json [, top_k])`
|
||||
///
|
||||
/// - `table_name`: Name of the table to search
|
||||
/// - `query_vector_json`: JSON array of float values, e.g. `'[0.1, 0.2, 0.3]'`
|
||||
/// - `fts_query_json`: FTS query as JSON, e.g. `'{"match": {"column": "text", "terms": "query"}}'`
|
||||
/// - `top_k` (optional): Number of results to return (default: 10)
|
||||
#[derive(Debug)]
|
||||
pub struct HybridSearchTableFunction {
|
||||
resolver: Arc<dyn TableResolver>,
|
||||
}
|
||||
|
||||
impl HybridSearchTableFunction {
|
||||
pub fn new(resolver: Arc<dyn TableResolver>) -> Self {
|
||||
Self { resolver }
|
||||
}
|
||||
}
|
||||
|
||||
impl TableFunctionImpl for HybridSearchTableFunction {
|
||||
fn call(&self, exprs: &[Expr]) -> DataFusionResult<Arc<dyn TableProvider>> {
|
||||
if exprs.len() < 3 || exprs.len() > 4 {
|
||||
return plan_err!(
|
||||
"hybrid_search() requires 3-4 parameters: hybrid_search(table_name, query_vector_json, fts_query_json [, top_k])"
|
||||
);
|
||||
}
|
||||
|
||||
let table_name = extract_string_literal(&exprs[0], "table_name")?;
|
||||
let vector_json = extract_string_literal(&exprs[1], "query_vector_json")?;
|
||||
let fts_json = extract_string_literal(&exprs[2], "fts_query_json")?;
|
||||
|
||||
let top_k = if exprs.len() == 4 {
|
||||
extract_int_literal(&exprs[3], "top_k")?
|
||||
} else {
|
||||
DEFAULT_TOP_K
|
||||
};
|
||||
|
||||
let query_vector = parse_vector_json(&vector_json)?;
|
||||
let fts_query = parse_fts_query(&fts_json)?;
|
||||
|
||||
let vector_params = VectorSearchParams {
|
||||
query_vector: query_vector as Arc<dyn Array>,
|
||||
column: None,
|
||||
top_k,
|
||||
distance_type: None,
|
||||
nprobes: None,
|
||||
ef: None,
|
||||
refine_factor: None,
|
||||
};
|
||||
|
||||
self.resolver.resolve_table(
|
||||
&table_name,
|
||||
Some(SearchQuery::Hybrid {
|
||||
fts: fts_query,
|
||||
vector: vector_params,
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_vector_json(json: &str) -> DataFusionResult<Arc<arrow_array::Float32Array>> {
|
||||
super::vector_search::parse_vector_json(json)
|
||||
}
|
||||
|
||||
fn parse_fts_query(json: &str) -> DataFusionResult<FullTextSearchQuery> {
|
||||
let query = fts_from_json(json).map_err(|e| {
|
||||
DataFusionError::Plan(format!(
|
||||
"Invalid FTS query JSON: {}. Expected format: {{\"match\": {{\"column\": \"text\", \"terms\": \"query\"}} }}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
Ok(FullTextSearchQuery::new_query(query))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::fts::to_json;
|
||||
use super::super::{SearchQuery, TableResolver};
|
||||
use super::*;
|
||||
use crate::{index::Index, table::datafusion::BaseTableAdapter};
|
||||
use arrow_array::FixedSizeListArray;
|
||||
use arrow_array::{Float32Array, Int32Array, RecordBatch, StringArray};
|
||||
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
|
||||
use datafusion::prelude::SessionContext;
|
||||
#[allow(unused_imports)]
|
||||
use lance_arrow::FixedSizeListArrayExt;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct HashMapTableResolver {
|
||||
tables: std::collections::HashMap<String, Arc<dyn TableProvider>>,
|
||||
}
|
||||
|
||||
impl HashMapTableResolver {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
tables: std::collections::HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn register(&mut self, name: String, table: Arc<dyn TableProvider>) {
|
||||
self.tables.insert(name, table);
|
||||
}
|
||||
}
|
||||
|
||||
impl TableResolver for HashMapTableResolver {
|
||||
fn resolve_table(
|
||||
&self,
|
||||
name: &str,
|
||||
search: Option<SearchQuery>,
|
||||
) -> DataFusionResult<Arc<dyn TableProvider>> {
|
||||
let table_provider = self
|
||||
.tables
|
||||
.get(name)
|
||||
.cloned()
|
||||
.ok_or_else(|| DataFusionError::Plan(format!("Table '{}' not found", name)))?;
|
||||
|
||||
let Some(search) = search else {
|
||||
return Ok(table_provider);
|
||||
};
|
||||
|
||||
let base_adapter = table_provider
|
||||
.as_any()
|
||||
.downcast_ref::<BaseTableAdapter>()
|
||||
.ok_or_else(|| {
|
||||
DataFusionError::Internal(
|
||||
"Expected BaseTableAdapter but got different type".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
match search {
|
||||
SearchQuery::Fts(fts_query) => Ok(Arc::new(base_adapter.with_fts_query(fts_query))),
|
||||
SearchQuery::Vector(vector_query) => {
|
||||
Ok(Arc::new(base_adapter.with_vector_query(vector_query)))
|
||||
}
|
||||
SearchQuery::Hybrid { fts, vector } => {
|
||||
Ok(Arc::new(base_adapter.with_hybrid_query(fts, vector)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hybrid_search_udtf() {
|
||||
let dim = 4i32;
|
||||
let schema = Arc::new(ArrowSchema::new(vec![
|
||||
Field::new("id", DataType::Int32, false),
|
||||
Field::new("text", DataType::Utf8, false),
|
||||
Field::new(
|
||||
"vector",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
|
||||
true,
|
||||
),
|
||||
]));
|
||||
|
||||
let ids = Int32Array::from(vec![1, 2, 3, 4, 5]);
|
||||
let texts = StringArray::from(vec![
|
||||
"the quick brown fox",
|
||||
"jumps over the lazy dog",
|
||||
"a quick red fox runs",
|
||||
"the dog sleeps all day",
|
||||
"a brown fox and a quick dog",
|
||||
]);
|
||||
let flat_values = Float32Array::from(vec![
|
||||
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.9, 0.1, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.5,
|
||||
0.5, 0.0, 0.0,
|
||||
]);
|
||||
let vector_array = FixedSizeListArray::try_new_from_values(flat_values, dim).unwrap();
|
||||
|
||||
let batch = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(ids), Arc::new(texts), Arc::new(vector_array)],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let db = crate::connect("memory://test_hybrid")
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
let table = db.create_table("docs", batch).execute().await.unwrap();
|
||||
|
||||
// Create FTS index on text column
|
||||
table
|
||||
.create_index(&["text"], Index::FTS(Default::default()))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let ctx = SessionContext::new();
|
||||
let mut resolver = HashMapTableResolver::new();
|
||||
let adapter = BaseTableAdapter::try_new(table.base_table().clone())
|
||||
.await
|
||||
.unwrap();
|
||||
resolver.register("docs".to_string(), Arc::new(adapter));
|
||||
|
||||
let resolver = Arc::new(resolver);
|
||||
ctx.register_udtf(
|
||||
"hybrid_search",
|
||||
Arc::new(HybridSearchTableFunction::new(resolver.clone())),
|
||||
);
|
||||
|
||||
// Run hybrid search: vector close to [1,0,0,0] AND FTS for "fox"
|
||||
use lance_index::scalar::inverted::query::*;
|
||||
let fts_query_struct = FtsQuery::Match(
|
||||
MatchQuery::new("fox".to_string()).with_column(Some("text".to_string())),
|
||||
);
|
||||
let fts_json = to_json(&fts_query_struct).unwrap();
|
||||
|
||||
let query = format!(
|
||||
"SELECT * FROM hybrid_search('docs', '[1.0, 0.0, 0.0, 0.0]', '{}', 5)",
|
||||
fts_json
|
||||
);
|
||||
|
||||
let df = ctx.sql(&query).await.unwrap();
|
||||
let results = df.collect().await.unwrap();
|
||||
|
||||
assert!(!results.is_empty());
|
||||
let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
|
||||
assert!(total_rows > 0, "Should have at least one result");
|
||||
|
||||
// Check schema has the expected columns
|
||||
let result_schema = results[0].schema();
|
||||
assert!(result_schema.column_with_name("id").is_some());
|
||||
assert!(result_schema.column_with_name("text").is_some());
|
||||
assert!(result_schema.column_with_name("vector").is_some());
|
||||
}
|
||||
}
|
||||
278
rust/lancedb/src/table/datafusion/udtf/vector_search.rs
Normal file
278
rust/lancedb/src/table/datafusion/udtf/vector_search.rs
Normal file
@@ -0,0 +1,278 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! Vector search table function for DataFusion SQL integration.
|
||||
//!
|
||||
//! Enables vector similarity search via SQL:
|
||||
//! ```sql
|
||||
//! SELECT * FROM vector_search('my_table', '[0.1, 0.2, 0.3, ...]', 10)
|
||||
//! ```
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::{Array, Float32Array};
|
||||
use datafusion::catalog::TableFunctionImpl;
|
||||
use datafusion_catalog::TableProvider;
|
||||
use datafusion_common::{DataFusionError, Result as DataFusionResult, plan_err};
|
||||
use datafusion_expr::Expr;
|
||||
|
||||
use super::{SearchQuery, TableResolver, extract_int_literal, extract_string_literal};
|
||||
use crate::table::datafusion::VectorSearchParams;
|
||||
|
||||
/// Default number of results for vector search when top_k is not specified.
|
||||
const DEFAULT_TOP_K: usize = 10;
|
||||
|
||||
/// Vector search table function for LanceDB tables.
|
||||
///
|
||||
/// Accepts 2-3 parameters: `vector_search(table_name, query_vector_json [, top_k])`
|
||||
///
|
||||
/// - `table_name`: Name of the table to search
|
||||
/// - `query_vector_json`: JSON array of float values, e.g. `'[0.1, 0.2, 0.3]'`
|
||||
/// - `top_k` (optional): Number of results to return (default: 10)
|
||||
#[derive(Debug)]
|
||||
pub struct VectorSearchTableFunction {
|
||||
resolver: Arc<dyn TableResolver>,
|
||||
}
|
||||
|
||||
impl VectorSearchTableFunction {
|
||||
pub fn new(resolver: Arc<dyn TableResolver>) -> Self {
|
||||
Self { resolver }
|
||||
}
|
||||
}
|
||||
|
||||
impl TableFunctionImpl for VectorSearchTableFunction {
|
||||
fn call(&self, exprs: &[Expr]) -> DataFusionResult<Arc<dyn TableProvider>> {
|
||||
if exprs.len() < 2 || exprs.len() > 3 {
|
||||
return plan_err!(
|
||||
"vector_search() requires 2-3 parameters: vector_search(table_name, query_vector_json [, top_k])"
|
||||
);
|
||||
}
|
||||
|
||||
let table_name = extract_string_literal(&exprs[0], "table_name")?;
|
||||
let vector_json = extract_string_literal(&exprs[1], "query_vector_json")?;
|
||||
|
||||
let top_k = if exprs.len() == 3 {
|
||||
extract_int_literal(&exprs[2], "top_k")?
|
||||
} else {
|
||||
DEFAULT_TOP_K
|
||||
};
|
||||
|
||||
let query_vector = parse_vector_json(&vector_json)?;
|
||||
|
||||
let params = VectorSearchParams {
|
||||
query_vector: query_vector as Arc<dyn Array>,
|
||||
column: None,
|
||||
top_k,
|
||||
distance_type: None,
|
||||
nprobes: None,
|
||||
ef: None,
|
||||
refine_factor: None,
|
||||
};
|
||||
|
||||
self.resolver
|
||||
.resolve_table(&table_name, Some(SearchQuery::Vector(params)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a JSON array of floats into an Arrow Float32Array for vector search.
|
||||
///
|
||||
/// Input format: `"[0.1, 0.2, 0.3, ...]"`
|
||||
///
|
||||
/// Returns a Float32Array whose length equals the vector dimension.
|
||||
/// This is the format expected by LanceDB's vector search internals.
|
||||
pub(crate) fn parse_vector_json(json: &str) -> DataFusionResult<Arc<Float32Array>> {
|
||||
let values: Vec<f32> = serde_json::from_str(json).map_err(|e| {
|
||||
DataFusionError::Plan(format!(
|
||||
"Invalid vector JSON: {}. Expected format: [0.1, 0.2, 0.3, ...]",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
if values.is_empty() {
|
||||
return Err(DataFusionError::Plan(
|
||||
"Vector must not be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Arc::new(Float32Array::from(values)))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::super::{SearchQuery, TableResolver};
|
||||
use super::*;
|
||||
use crate::table::datafusion::BaseTableAdapter;
|
||||
use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, RecordBatch};
|
||||
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
|
||||
use datafusion::prelude::SessionContext;
|
||||
#[allow(unused_imports)]
|
||||
use lance_arrow::FixedSizeListArrayExt;
|
||||
|
||||
/// Resolver that looks up tables in a HashMap and applies search queries
|
||||
#[derive(Debug)]
|
||||
struct HashMapTableResolver {
|
||||
tables: std::collections::HashMap<String, Arc<dyn TableProvider>>,
|
||||
}
|
||||
|
||||
impl HashMapTableResolver {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
tables: std::collections::HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn register(&mut self, name: String, table: Arc<dyn TableProvider>) {
|
||||
self.tables.insert(name, table);
|
||||
}
|
||||
}
|
||||
|
||||
impl TableResolver for HashMapTableResolver {
|
||||
fn resolve_table(
|
||||
&self,
|
||||
name: &str,
|
||||
search: Option<SearchQuery>,
|
||||
) -> DataFusionResult<Arc<dyn TableProvider>> {
|
||||
let table_provider = self
|
||||
.tables
|
||||
.get(name)
|
||||
.cloned()
|
||||
.ok_or_else(|| DataFusionError::Plan(format!("Table '{}' not found", name)))?;
|
||||
|
||||
let Some(search) = search else {
|
||||
return Ok(table_provider);
|
||||
};
|
||||
|
||||
let base_adapter = table_provider
|
||||
.as_any()
|
||||
.downcast_ref::<BaseTableAdapter>()
|
||||
.ok_or_else(|| {
|
||||
DataFusionError::Internal(
|
||||
"Expected BaseTableAdapter but got different type".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
match search {
|
||||
SearchQuery::Fts(fts_query) => Ok(Arc::new(base_adapter.with_fts_query(fts_query))),
|
||||
SearchQuery::Vector(vector_query) => {
|
||||
Ok(Arc::new(base_adapter.with_vector_query(vector_query)))
|
||||
}
|
||||
SearchQuery::Hybrid { fts, vector } => {
|
||||
Ok(Arc::new(base_adapter.with_hybrid_query(fts, vector)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn make_test_data() -> (Arc<ArrowSchema>, RecordBatch) {
|
||||
let dim = 4;
|
||||
let schema = Arc::new(ArrowSchema::new(vec![
|
||||
Field::new("id", DataType::Int32, false),
|
||||
Field::new(
|
||||
"vector",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
|
||||
true,
|
||||
),
|
||||
]));
|
||||
|
||||
let ids = Int32Array::from(vec![1, 2, 3, 4, 5]);
|
||||
// Create vectors: [1,0,0,0], [0,1,0,0], [0,0,1,0], [0,0,0,1], [1,1,0,0]
|
||||
let flat_values = Float32Array::from(vec![
|
||||
1.0, 0.0, 0.0, 0.0, // vec 1
|
||||
0.0, 1.0, 0.0, 0.0, // vec 2
|
||||
0.0, 0.0, 1.0, 0.0, // vec 3
|
||||
0.0, 0.0, 0.0, 1.0, // vec 4
|
||||
1.0, 1.0, 0.0, 0.0, // vec 5
|
||||
]);
|
||||
let vector_array = FixedSizeListArray::try_new_from_values(flat_values, dim).unwrap();
|
||||
|
||||
let batch =
|
||||
RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(vector_array)])
|
||||
.unwrap();
|
||||
|
||||
(schema, batch)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_udtf() {
|
||||
let (_schema, batch) = make_test_data();
|
||||
|
||||
let db = crate::connect("memory://test_vec").execute().await.unwrap();
|
||||
let table = db.create_table("vectors", batch).execute().await.unwrap();
|
||||
|
||||
// No index needed — vector search works with brute-force scan on small tables
|
||||
|
||||
// Setup DataFusion context
|
||||
let ctx = SessionContext::new();
|
||||
let mut resolver = HashMapTableResolver::new();
|
||||
let adapter = BaseTableAdapter::try_new(table.base_table().clone())
|
||||
.await
|
||||
.unwrap();
|
||||
resolver.register("vectors".to_string(), Arc::new(adapter));
|
||||
|
||||
let udtf = VectorSearchTableFunction::new(Arc::new(resolver));
|
||||
ctx.register_udtf("vector_search", Arc::new(udtf));
|
||||
|
||||
// Search for vectors close to [1, 0, 0, 0]
|
||||
let query = "SELECT * FROM vector_search('vectors', '[1.0, 0.0, 0.0, 0.0]', 3)";
|
||||
let df = ctx.sql(query).await.unwrap();
|
||||
let results = df.collect().await.unwrap();
|
||||
|
||||
assert!(!results.is_empty());
|
||||
let batch = &results[0];
|
||||
|
||||
// Should have id, vector, _distance columns
|
||||
assert!(batch.schema().column_with_name("id").is_some());
|
||||
assert!(batch.schema().column_with_name("vector").is_some());
|
||||
assert!(
|
||||
batch.schema().column_with_name("_distance").is_some(),
|
||||
"_distance column should be present"
|
||||
);
|
||||
|
||||
// Should return at most 3 results
|
||||
let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
|
||||
assert!(total_rows <= 3);
|
||||
assert!(total_rows > 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_default_top_k() {
|
||||
let (_, batch) = make_test_data();
|
||||
|
||||
let db = crate::connect("memory://test_vec_default")
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
let table = db.create_table("vectors", batch).execute().await.unwrap();
|
||||
|
||||
let ctx = SessionContext::new();
|
||||
let mut resolver = HashMapTableResolver::new();
|
||||
let adapter = BaseTableAdapter::try_new(table.base_table().clone())
|
||||
.await
|
||||
.unwrap();
|
||||
resolver.register("vectors".to_string(), Arc::new(adapter));
|
||||
|
||||
let udtf = VectorSearchTableFunction::new(Arc::new(resolver));
|
||||
ctx.register_udtf("vector_search", Arc::new(udtf));
|
||||
|
||||
// No top_k parameter — should default to 10
|
||||
let query = "SELECT * FROM vector_search('vectors', '[1.0, 0.0, 0.0, 0.0]')";
|
||||
let df = ctx.sql(query).await.unwrap();
|
||||
let results = df.collect().await.unwrap();
|
||||
|
||||
let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
|
||||
// We only have 5 rows, so we should get all 5 back
|
||||
assert_eq!(total_rows, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_vector_json() {
|
||||
let result = parse_vector_json("[1.0, 2.0, 3.0]").unwrap();
|
||||
assert_eq!(result.len(), 3); // 3-dimensional vector
|
||||
|
||||
// Empty vector should fail
|
||||
assert!(parse_vector_json("[]").is_err());
|
||||
|
||||
// Invalid JSON should fail
|
||||
assert!(parse_vector_json("not json").is_err());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user