From c8d654db37039533861fff2434881816c170578b Mon Sep 17 00:00:00 2001 From: Justin Miller Date: Tue, 7 Apr 2026 01:48:35 -0700 Subject: [PATCH] feat: add ADBC/Flight SQL interface with search table functions (OSS-738) Add Arrow Flight SQL server and SQL table functions for vector/hybrid search, enabling standard SQL workbench tools to connect to LanceDB via ADBC. Phase 1 - Search Table Functions: - vector_search(table, vector_json, top_k) UDTF for vector similarity search - hybrid_search(table, vector_json, fts_json, top_k) UDTF combining both - Generalized TableResolver with SearchQuery enum (Fts/Vector/Hybrid) - Extended BaseTableAdapter with VectorSearchParams support Phase 2 - Arrow Flight SQL Server: - FlightSqlService implementation with SQL query execution - Catalog introspection (tables, schemas, catalogs, table types) - All three UDTFs auto-registered in session context - Feature-gated behind "flight" flag (arrow-flight + tonic + prost) Phase 3 - Example & Tests: - flight_sql example: creates sample DB, starts Flight SQL server - Tests for vector_search and hybrid_search UDTFs - Flight SQL integration tests (basic query, table listing, session) Co-Authored-By: Claude Opus 4.6 (1M context) --- Cargo.lock | 165 ++++- Cargo.toml | 2 + rust/lancedb/Cargo.toml | 10 + rust/lancedb/examples/flight_sql.rs | 99 +++ rust/lancedb/src/flight.rs | 606 ++++++++++++++++++ rust/lancedb/src/lib.rs | 2 + rust/lancedb/src/table/datafusion.rs | 124 +++- rust/lancedb/src/table/datafusion/udtf.rs | 70 ++ rust/lancedb/src/table/datafusion/udtf/fts.rs | 54 +- .../table/datafusion/udtf/hybrid_search.rs | 257 ++++++++ .../table/datafusion/udtf/vector_search.rs | 278 ++++++++ 11 files changed, 1608 insertions(+), 59 deletions(-) create mode 100644 rust/lancedb/examples/flight_sql.rs create mode 100644 rust/lancedb/src/flight.rs create mode 100644 rust/lancedb/src/table/datafusion/udtf/hybrid_search.rs create mode 100644 rust/lancedb/src/table/datafusion/udtf/vector_search.rs diff --git a/Cargo.lock b/Cargo.lock index 8ed07c153..14a6689a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -290,6 +290,34 @@ dependencies = [ "num-traits", ] +[[package]] +name = "arrow-flight" +version = "57.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "58c5b083668e6230eae3eab2fc4b5fb989974c845d0aa538dde61a4327c78675" +dependencies = [ + "arrow-arith", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-ipc", + "arrow-ord", + "arrow-row", + "arrow-schema", + "arrow-select", + "arrow-string", + "base64 0.22.1", + "bytes", + "futures", + "once_cell", + "paste", + "prost", + "prost-types", + "tonic", + "tonic-prost", +] + [[package]] name = "arrow-ipc" version = "57.3.0" @@ -1069,7 +1097,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" dependencies = [ "async-trait", - "axum-core", + "axum-core 0.4.5", "bytes", "futures-util", "http 1.4.0", @@ -1078,7 +1106,7 @@ dependencies = [ "hyper 1.8.1", "hyper-util", "itoa", - "matchit", + "matchit 0.7.3", "memchr", "mime", "percent-encoding", @@ -1096,6 +1124,31 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" +dependencies = [ + "axum-core 0.5.6", + "bytes", + "futures-util", + "http 1.4.0", + "http-body 1.0.1", + "http-body-util", + "itoa", + "matchit 0.8.4", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "sync_wrapper", + "tower", + "tower-layer", + "tower-service", +] + [[package]] name = "axum-core" version = "0.4.5" @@ -1117,6 +1170,24 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-core" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" +dependencies = [ + "bytes", + "futures-core", + "http 1.4.0", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", +] + [[package]] name = "backoff" version = "0.4.0" @@ -1294,7 +1365,7 @@ version = "3.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "519bd3116aeeb42d5372c29d982d16d0170d3d4a5ed85fc7dd91642ffff3c67c" dependencies = [ - "darling 0.20.11", + "darling 0.23.0", "ident_case", "prettyplease", "proc-macro2", @@ -2682,7 +2753,7 @@ dependencies = [ "libc", "option-ext", "redox_users", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -2876,7 +2947,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -3719,6 +3790,19 @@ dependencies = [ "webpki-roots 1.0.6", ] +[[package]] +name = "hyper-timeout" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0" +dependencies = [ + "hyper 1.8.1", + "hyper-util", + "pin-project-lite", + "tokio", + "tower-service", +] + [[package]] name = "hyper-util" version = "0.1.20" @@ -4048,7 +4132,7 @@ dependencies = [ "portable-atomic", "portable-atomic-util", "serde_core", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -4539,7 +4623,7 @@ dependencies = [ "arrow-ipc", "arrow-schema", "async-trait", - "axum", + "axum 0.7.9", "bytes", "chrono", "futures", @@ -4638,6 +4722,7 @@ dependencies = [ "arrow-array", "arrow-cast", "arrow-data", + "arrow-flight", "arrow-ipc", "arrow-ord", "arrow-schema", @@ -4691,6 +4776,7 @@ dependencies = [ "pin-project", "polars", "polars-arrow", + "prost", "rand 0.9.2", "random_word 0.4.3", "regex", @@ -4705,6 +4791,8 @@ dependencies = [ "test-log", "tokenizers", "tokio", + "tokio-stream", + "tonic", "url", "uuid", "walkdir", @@ -5009,6 +5097,12 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "matrixmultiply" version = "0.3.10" @@ -5322,7 +5416,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -6302,7 +6396,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7" dependencies = [ "heck 0.5.0", - "itertools 0.11.0", + "itertools 0.14.0", "log", "multimap", "petgraph", @@ -6321,7 +6415,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" dependencies = [ "anyhow", - "itertools 0.11.0", + "itertools 0.14.0", "proc-macro2", "quote", "syn 2.0.117", @@ -6527,7 +6621,7 @@ dependencies = [ "once_cell", "socket2 0.6.3", "tracing", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] @@ -7069,7 +7163,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.12.1", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -8088,7 +8182,7 @@ dependencies = [ "getrandom 0.4.2", "once_cell", "rustix 1.1.4", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -8370,6 +8464,46 @@ dependencies = [ "winnow", ] +[[package]] +name = "tonic" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fec7c61a0695dc1887c1b53952990f3ad2e3a31453e1f49f10e75424943a93ec" +dependencies = [ + "async-trait", + "axum 0.8.8", + "base64 0.22.1", + "bytes", + "h2 0.4.13", + "http 1.4.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.8.1", + "hyper-timeout", + "hyper-util", + "percent-encoding", + "pin-project", + "socket2 0.6.3", + "sync_wrapper", + "tokio", + "tokio-stream", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tonic-prost" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a55376a0bbaa4975a3f10d009ad763d8f4108f067c7c2e74f3001fb49778d309" +dependencies = [ + "bytes", + "prost", + "tonic", +] + [[package]] name = "tower" version = "0.5.3" @@ -8378,9 +8512,12 @@ checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" dependencies = [ "futures-core", "futures-util", + "indexmap 2.13.0", "pin-project-lite", + "slab", "sync_wrapper", "tokio", + "tokio-util", "tower-layer", "tower-service", "tracing", @@ -8893,7 +9030,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 9bb9ab8a4..bc41287e2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,8 @@ arrow-ord = "57.2" arrow-schema = "57.2" arrow-select = "57.2" arrow-cast = "57.2" +arrow-flight = { version = "57.2", features = ["flight-sql-experimental"] } +tonic = "0.14" async-trait = "0" datafusion = { version = "52.1", default-features = false } datafusion-catalog = "52.1" diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index 6b6f41104..346dc96a2 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -88,6 +88,10 @@ candle-transformers = { version = "0.9.1", optional = true } candle-nn = { version = "0.9.1", optional = true } tokenizers = { version = "0.19.1", optional = true } semver = { workspace = true } +# For flight feature (Arrow Flight SQL server) +arrow-flight = { workspace = true, optional = true } +tonic = { workspace = true, optional = true } +prost = { version = "0.14", optional = true } [dev-dependencies] anyhow = "1" @@ -104,6 +108,7 @@ datafusion.workspace = true http-body = "1" # Matching reqwest rstest = "0.23.0" test-log = "0.2" +tokio-stream = "0.1" [features] @@ -131,6 +136,7 @@ sentence-transformers = [ "dep:candle-nn", "dep:tokenizers", ] +flight = ["dep:arrow-flight", "dep:tonic", "dep:prost"] [[example]] name = "openai" @@ -157,5 +163,9 @@ name = "ivf_pq" name = "hybrid_search" required-features = ["sentence-transformers"] +[[example]] +name = "flight_sql" +required-features = ["flight"] + [package.metadata.docs.rs] all-features = true diff --git a/rust/lancedb/examples/flight_sql.rs b/rust/lancedb/examples/flight_sql.rs new file mode 100644 index 000000000..559b0be55 --- /dev/null +++ b/rust/lancedb/examples/flight_sql.rs @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +//! Example: LanceDB Arrow Flight SQL Server +//! +//! This example demonstrates how to: +//! 1. Create a LanceDB database with sample data +//! 2. Start an Arrow Flight SQL server +//! 3. Connect with a Flight SQL client +//! 4. Run SQL queries including vector_search and fts table functions +//! +//! Run with: `cargo run --features flight --example flight_sql` + +use std::sync::Arc; + +use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, RecordBatch, StringArray}; +use arrow_schema::{DataType, Field, Schema}; +use lance_arrow::FixedSizeListArrayExt; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Initialize logging if env_logger is available + let _ = std::env::var("RUST_LOG").ok(); + + // 1. Create an in-memory LanceDB database + let db = lancedb::connect("memory://flight_sql_demo") + .execute() + .await?; + + // 2. Create a table with text and vector data + let dim = 4i32; + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("text", DataType::Utf8, false), + Field::new( + "vector", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim), + true, + ), + ])); + + let ids = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]); + let texts = StringArray::from(vec![ + "the quick brown fox jumps over the lazy dog", + "a fast red fox leaps across the sleeping hound", + "machine learning models process natural language", + "neural networks learn from training data", + "the brown dog chases the red fox through the forest", + "deep learning algorithms improve with more data", + "a lazy cat sleeps on the warm windowsill", + "vector databases enable fast similarity search", + ]); + let flat_values = Float32Array::from(vec![ + 1.0, 0.0, 0.0, 0.0, // fox-like + 0.9, 0.1, 0.0, 0.0, // similar to fox + 0.0, 1.0, 0.0, 0.0, // ML-like + 0.0, 0.9, 0.1, 0.0, // similar to ML + 0.7, 0.3, 0.0, 0.0, // fox+dog mix + 0.0, 0.8, 0.2, 0.0, // ML-like + 0.1, 0.0, 0.0, 0.9, // cat-like + 0.0, 0.5, 0.5, 0.0, // tech-like + ]); + let vector_array = FixedSizeListArray::try_new_from_values(flat_values, dim)?; + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(ids), Arc::new(texts), Arc::new(vector_array)], + )?; + + let table = db.create_table("documents", batch).execute().await?; + + // 3. Create indices + println!("Creating FTS index on 'text' column..."); + table + .create_index(&["text"], lancedb::index::Index::FTS(Default::default())) + .execute() + .await?; + + println!("Database ready with {} rows", table.count_rows(None).await?); + + // 4. Start Flight SQL server + let addr = "0.0.0.0:50051".parse()?; + println!("Starting Arrow Flight SQL server on {}...", addr); + println!(); + println!("Connect with any ADBC Flight SQL client:"); + println!(" URI: grpc://localhost:50051"); + println!(); + println!("Example SQL queries:"); + println!(" SELECT * FROM documents LIMIT 5;"); + println!(" SELECT * FROM vector_search('documents', '[1.0, 0.0, 0.0, 0.0]', 3);"); + println!( + " SELECT * FROM fts('documents', '{{\"match\": {{\"column\": \"text\", \"terms\": \"fox\"}}}}');" + ); + println!(); + + lancedb::flight::serve(db, addr).await?; + + Ok(()) +} diff --git a/rust/lancedb/src/flight.rs b/rust/lancedb/src/flight.rs new file mode 100644 index 000000000..ac2647daa --- /dev/null +++ b/rust/lancedb/src/flight.rs @@ -0,0 +1,606 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +//! Arrow Flight SQL server for LanceDB. +//! +//! This module provides an Arrow Flight SQL server that exposes a LanceDB +//! [`Connection`] over the Flight SQL protocol. Any ADBC, ODBC (via bridge), +//! or JDBC Flight SQL client can connect and run SQL queries — including +//! LanceDB's search table functions (`vector_search`, `fts`, `hybrid_search`). +//! +//! # Quick Start +//! +//! ```no_run +//! # async fn example() -> Result<(), Box> { +//! let db = lancedb::connect("data/my-db").execute().await?; +//! let addr = "0.0.0.0:50051".parse()?; +//! lancedb::flight::serve(db, addr).await?; +//! # Ok(()) +//! # } +//! ``` + +use std::collections::HashMap; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::Arc; + +use arrow_array::{ArrayRef, RecordBatch, StringArray}; +use arrow_flight::encode::FlightDataEncoderBuilder; +use arrow_flight::error::FlightError; +use arrow_flight::flight_service_server::FlightServiceServer; +use arrow_flight::sql::server::FlightSqlService; +use arrow_flight::sql::{ + Any, CommandGetCatalogs, CommandGetDbSchemas, CommandGetTableTypes, CommandGetTables, + CommandStatementQuery, SqlInfo, TicketStatementQuery, +}; +use arrow_flight::{FlightDescriptor, FlightEndpoint, FlightInfo, Ticket}; +use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef}; +use datafusion::prelude::SessionContext; +use datafusion_catalog::TableProvider; +use datafusion_common::{DataFusionError, Result as DataFusionResult}; +use futures::StreamExt; +use futures::stream; +use log; +use prost::Message; +use tonic::transport::Server; +use tonic::{Request, Response, Status}; + +use crate::connection::Connection; +use crate::table::datafusion::BaseTableAdapter; +use crate::table::datafusion::udtf::fts::FtsTableFunction; +use crate::table::datafusion::udtf::hybrid_search::HybridSearchTableFunction; +use crate::table::datafusion::udtf::vector_search::VectorSearchTableFunction; +use crate::table::datafusion::udtf::{SearchQuery, TableResolver}; + +/// Start an Arrow Flight SQL server exposing the given LanceDB connection. +/// +/// This is a convenience function that creates a server and starts listening. +/// It blocks until the server is shut down. +pub async fn serve(connection: Connection, addr: SocketAddr) -> crate::Result<()> { + let service = LanceFlightSqlService::try_new(connection).await?; + let flight_svc = FlightServiceServer::new(service); + + Server::builder() + .add_service(flight_svc) + .serve(addr) + .await + .map_err(|e| crate::Error::Runtime { + message: format!("Flight SQL server error: {}", e), + })?; + + Ok(()) +} + +/// A table resolver that looks up tables from a pre-built HashMap. +#[derive(Debug)] +struct ConnectionTableResolver { + tables: HashMap>, +} + +impl TableResolver for ConnectionTableResolver { + fn resolve_table( + &self, + name: &str, + search: Option, + ) -> DataFusionResult> { + let adapter = self + .tables + .get(name) + .ok_or_else(|| DataFusionError::Plan(format!("Table '{}' not found", name)))?; + + match search { + None => Ok(adapter.clone() as Arc), + Some(SearchQuery::Fts(fts)) => Ok(Arc::new(adapter.with_fts_query(fts))), + Some(SearchQuery::Vector(vq)) => Ok(Arc::new(adapter.with_vector_query(vq))), + Some(SearchQuery::Hybrid { fts, vector }) => { + Ok(Arc::new(adapter.with_hybrid_query(fts, vector))) + } + } + } +} + +/// Arrow Flight SQL service backed by a LanceDB connection. +struct LanceFlightSqlService { + /// Kept for future use (e.g., refreshing table list). + _connection: Connection, + /// Pre-built table adapters (refreshed on creation) + tables: HashMap>, +} + +impl LanceFlightSqlService { + async fn try_new(connection: Connection) -> crate::Result { + let table_names = connection.table_names().execute().await?; + let mut tables = HashMap::new(); + + for name in &table_names { + let table = connection.open_table(name).execute().await?; + let adapter = BaseTableAdapter::try_new(table.base_table().clone()).await?; + tables.insert(name.clone(), Arc::new(adapter)); + } + + log::info!( + "LanceDB Flight SQL service initialized with {} tables", + tables.len() + ); + + Ok(Self { + _connection: connection, + tables, + }) + } + + /// Create a DataFusion SessionContext with all tables and UDTFs registered. + fn create_session_context(&self) -> SessionContext { + let ctx = SessionContext::new(); + + // Register all tables + for (name, adapter) in &self.tables { + if let Err(e) = ctx.register_table(name, adapter.clone() as Arc) { + log::warn!("Failed to register table '{}': {}", name, e); + } + } + + // Create resolver for UDTFs + let resolver = Arc::new(ConnectionTableResolver { + tables: self.tables.clone(), + }); + + // Register search UDTFs + ctx.register_udtf("fts", Arc::new(FtsTableFunction::new(resolver.clone()))); + ctx.register_udtf( + "vector_search", + Arc::new(VectorSearchTableFunction::new(resolver.clone())), + ); + ctx.register_udtf( + "hybrid_search", + Arc::new(HybridSearchTableFunction::new(resolver)), + ); + + ctx + } + + /// Execute a SQL query and return the results as a stream of FlightData. + async fn execute_sql( + &self, + sql: &str, + ) -> Result< + Pin> + Send>>, + Status, + > { + let ctx = self.create_session_context(); + + let df = ctx + .sql(sql) + .await + .map_err(|e| Status::internal(format!("SQL planning error: {}", e)))?; + + let schema: SchemaRef = df.schema().inner().clone(); + let stream = df + .execute_stream() + .await + .map_err(|e| Status::internal(format!("SQL execution error: {}", e)))?; + + // Use FlightDataEncoderBuilder to properly encode batches with schema + let batch_stream = stream.map(|r| r.map_err(|e| FlightError::ExternalError(Box::new(e)))); + let flight_data_stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .build(batch_stream) + .map(|result| result.map_err(|e| Status::internal(format!("Encoding error: {}", e)))); + + Ok(Box::pin(flight_data_stream)) + } + + /// Encode a single RecordBatch into a FlightData stream (schema + data). + fn batch_to_flight_stream( + batch: RecordBatch, + ) -> Pin> + Send>> { + let schema = batch.schema(); + let stream = FlightDataEncoderBuilder::new() + .with_schema(schema) + .build(stream::once(async move { Ok(batch) })) + .map(|result| result.map_err(|e| Status::internal(format!("Encoding error: {}", e)))); + Box::pin(stream) + } + + /// Get the schema for a SQL query without executing it. + async fn get_sql_schema(&self, sql: &str) -> Result { + let ctx = self.create_session_context(); + let df = ctx + .sql(sql) + .await + .map_err(|e| Status::internal(format!("SQL planning error: {}", e)))?; + Ok(df.schema().inner().as_ref().clone()) + } +} + +#[tonic::async_trait] +impl FlightSqlService for LanceFlightSqlService { + type FlightService = LanceFlightSqlService; + + /// Handle SQL query: return FlightInfo with schema and ticket. + async fn get_flight_info_statement( + &self, + query: CommandStatementQuery, + request: Request, + ) -> Result, Status> { + let sql = query.query; + log::info!("get_flight_info_statement: {}", sql); + + let schema = self.get_sql_schema(&sql).await?; + + // Encode the query as an Any-wrapped TicketStatementQuery for do_get_statement + let ticket = TicketStatementQuery { + statement_handle: sql.into_bytes().into(), + }; + let any_msg = Any::pack(&ticket) + .map_err(|e| Status::internal(format!("Ticket encoding error: {}", e)))?; + let mut ticket_bytes = Vec::new(); + any_msg + .encode(&mut ticket_bytes) + .map_err(|e| Status::internal(format!("Ticket encoding error: {}", e)))?; + + let endpoint = FlightEndpoint::new().with_ticket(Ticket::new(ticket_bytes)); + let flight_info = FlightInfo::new() + .try_with_schema(&schema) + .map_err(|e| Status::internal(format!("Schema error: {}", e)))? + .with_endpoint(endpoint) + .with_descriptor(request.into_inner()); + + Ok(Response::new(flight_info)) + } + + /// Execute a SQL query and stream results. + async fn do_get_statement( + &self, + ticket: TicketStatementQuery, + _request: Request, + ) -> Result< + Response<::DoGetStream>, + Status, + > { + let sql = String::from_utf8(ticket.statement_handle.to_vec()) + .map_err(|e| Status::internal(format!("Invalid ticket: {}", e)))?; + log::info!("do_get_statement: {}", sql); + + let stream = self.execute_sql(&sql).await?; + Ok(Response::new(stream)) + } + + /// List tables in the database. + async fn get_flight_info_tables( + &self, + _query: CommandGetTables, + request: Request, + ) -> Result, Status> { + let schema = ArrowSchema::new(vec![ + Field::new("catalog_name", DataType::Utf8, true), + Field::new("db_schema_name", DataType::Utf8, true), + Field::new("table_name", DataType::Utf8, false), + Field::new("table_type", DataType::Utf8, false), + ]); + + let cmd = CommandGetTables::default(); + let any_msg = + Any::pack(&cmd).map_err(|e| Status::internal(format!("Encoding error: {}", e)))?; + let mut ticket_bytes = Vec::new(); + any_msg + .encode(&mut ticket_bytes) + .map_err(|e| Status::internal(format!("Encoding error: {}", e)))?; + + let endpoint = FlightEndpoint::new().with_ticket(Ticket::new(ticket_bytes)); + let flight_info = FlightInfo::new() + .try_with_schema(&schema) + .map_err(|e| Status::internal(format!("Schema error: {}", e)))? + .with_endpoint(endpoint) + .with_descriptor(request.into_inner()); + + Ok(Response::new(flight_info)) + } + + async fn do_get_tables( + &self, + _query: CommandGetTables, + _request: Request, + ) -> Result< + Response<::DoGetStream>, + Status, + > { + let table_names: Vec<&str> = self.tables.keys().map(|s| s.as_str()).collect(); + let num_tables = table_names.len(); + + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("catalog_name", DataType::Utf8, true), + Field::new("db_schema_name", DataType::Utf8, true), + Field::new("table_name", DataType::Utf8, false), + Field::new("table_type", DataType::Utf8, false), + ])); + + let catalog_names: ArrayRef = + Arc::new(StringArray::from(vec![Some("lancedb"); num_tables])); + let schema_names: ArrayRef = Arc::new(StringArray::from(vec![Some("default"); num_tables])); + let table_name_array: ArrayRef = Arc::new(StringArray::from(table_names)); + let table_types: ArrayRef = Arc::new(StringArray::from(vec!["TABLE"; num_tables])); + + let batch = RecordBatch::try_new( + schema, + vec![catalog_names, schema_names, table_name_array, table_types], + ) + .map_err(|e| Status::internal(format!("RecordBatch error: {}", e)))?; + + Ok(Response::new(Self::batch_to_flight_stream(batch))) + } + + /// List table types. + async fn get_flight_info_table_types( + &self, + _query: CommandGetTableTypes, + request: Request, + ) -> Result, Status> { + let schema = ArrowSchema::new(vec![Field::new("table_type", DataType::Utf8, false)]); + + let cmd = CommandGetTableTypes::default(); + let any_msg = + Any::pack(&cmd).map_err(|e| Status::internal(format!("Encoding error: {}", e)))?; + let mut ticket_bytes = Vec::new(); + any_msg + .encode(&mut ticket_bytes) + .map_err(|e| Status::internal(format!("Encoding error: {}", e)))?; + + let endpoint = FlightEndpoint::new().with_ticket(Ticket::new(ticket_bytes)); + let flight_info = FlightInfo::new() + .try_with_schema(&schema) + .map_err(|e| Status::internal(format!("Schema error: {}", e)))? + .with_endpoint(endpoint) + .with_descriptor(request.into_inner()); + + Ok(Response::new(flight_info)) + } + + async fn do_get_table_types( + &self, + _query: CommandGetTableTypes, + _request: Request, + ) -> Result< + Response<::DoGetStream>, + Status, + > { + let schema = Arc::new(ArrowSchema::new(vec![Field::new( + "table_type", + DataType::Utf8, + false, + )])); + let table_types: ArrayRef = Arc::new(StringArray::from(vec!["TABLE"])); + let batch = RecordBatch::try_new(schema, vec![table_types]) + .map_err(|e| Status::internal(format!("RecordBatch error: {}", e)))?; + + Ok(Response::new(Self::batch_to_flight_stream(batch))) + } + + /// List catalogs. + async fn get_flight_info_catalogs( + &self, + _query: CommandGetCatalogs, + request: Request, + ) -> Result, Status> { + let schema = ArrowSchema::new(vec![Field::new("catalog_name", DataType::Utf8, false)]); + + let cmd = CommandGetCatalogs::default(); + let any_msg = + Any::pack(&cmd).map_err(|e| Status::internal(format!("Encoding error: {}", e)))?; + let mut ticket_bytes = Vec::new(); + any_msg + .encode(&mut ticket_bytes) + .map_err(|e| Status::internal(format!("Encoding error: {}", e)))?; + + let endpoint = FlightEndpoint::new().with_ticket(Ticket::new(ticket_bytes)); + let flight_info = FlightInfo::new() + .try_with_schema(&schema) + .map_err(|e| Status::internal(format!("Schema error: {}", e)))? + .with_endpoint(endpoint) + .with_descriptor(request.into_inner()); + + Ok(Response::new(flight_info)) + } + + async fn do_get_catalogs( + &self, + _query: CommandGetCatalogs, + _request: Request, + ) -> Result< + Response<::DoGetStream>, + Status, + > { + let schema = Arc::new(ArrowSchema::new(vec![Field::new( + "catalog_name", + DataType::Utf8, + false, + )])); + let catalogs: ArrayRef = Arc::new(StringArray::from(vec!["lancedb"])); + let batch = RecordBatch::try_new(schema, vec![catalogs]) + .map_err(|e| Status::internal(format!("RecordBatch error: {}", e)))?; + + Ok(Response::new(Self::batch_to_flight_stream(batch))) + } + + /// List schemas. + async fn get_flight_info_schemas( + &self, + _query: CommandGetDbSchemas, + request: Request, + ) -> Result, Status> { + let schema = ArrowSchema::new(vec![ + Field::new("catalog_name", DataType::Utf8, true), + Field::new("db_schema_name", DataType::Utf8, false), + ]); + + let cmd = CommandGetDbSchemas::default(); + let any_msg = + Any::pack(&cmd).map_err(|e| Status::internal(format!("Encoding error: {}", e)))?; + let mut ticket_bytes = Vec::new(); + any_msg + .encode(&mut ticket_bytes) + .map_err(|e| Status::internal(format!("Encoding error: {}", e)))?; + + let endpoint = FlightEndpoint::new().with_ticket(Ticket::new(ticket_bytes)); + let flight_info = FlightInfo::new() + .try_with_schema(&schema) + .map_err(|e| Status::internal(format!("Schema error: {}", e)))? + .with_endpoint(endpoint) + .with_descriptor(request.into_inner()); + + Ok(Response::new(flight_info)) + } + + async fn do_get_schemas( + &self, + _query: CommandGetDbSchemas, + _request: Request, + ) -> Result< + Response<::DoGetStream>, + Status, + > { + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("catalog_name", DataType::Utf8, true), + Field::new("db_schema_name", DataType::Utf8, false), + ])); + let catalogs: ArrayRef = Arc::new(StringArray::from(vec![Some("lancedb")])); + let schemas: ArrayRef = Arc::new(StringArray::from(vec!["default"])); + let batch = RecordBatch::try_new(schema, vec![catalogs, schemas]) + .map_err(|e| Status::internal(format!("RecordBatch error: {}", e)))?; + + Ok(Response::new(Self::batch_to_flight_stream(batch))) + } + + async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {} +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{Float32Array, Int32Array}; + use arrow_flight::sql::client::FlightSqlServiceClient; + use futures::TryStreamExt; + use lance_arrow::FixedSizeListArrayExt; + use std::time::Duration; + + async fn create_test_db() -> Connection { + let db = crate::connect("memory://flight_test") + .execute() + .await + .unwrap(); + + let dim = 4i32; + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("text", DataType::Utf8, false), + Field::new( + "vector", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim), + true, + ), + ])); + + let ids = Int32Array::from(vec![1, 2, 3]); + let texts = StringArray::from(vec!["hello world", "foo bar", "baz qux"]); + let flat_values = Float32Array::from(vec![ + 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, + ]); + let vectors = + arrow_array::FixedSizeListArray::try_new_from_values(flat_values, dim).unwrap(); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(ids), Arc::new(texts), Arc::new(vectors)], + ) + .unwrap(); + + let table = db + .create_table("test_table", batch) + .execute() + .await + .unwrap(); + + // Create FTS index + table + .create_index(&["text"], crate::index::Index::FTS(Default::default())) + .execute() + .await + .unwrap(); + + db + } + + #[tokio::test] + async fn test_flight_sql_basic_query() { + let db = create_test_db().await; + let service = LanceFlightSqlService::try_new(db).await.unwrap(); + let flight_svc = FlightServiceServer::new(service); + + // Start server on a random port + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let server_handle = tokio::spawn(async move { + Server::builder() + .add_service(flight_svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + }); + + // Give server a moment to start + tokio::time::sleep(Duration::from_millis(100)).await; + + // Connect client + let channel = tonic::transport::Channel::from_shared(format!("http://{}", addr)) + .unwrap() + .connect() + .await + .unwrap(); + let mut client = FlightSqlServiceClient::new(channel); + + // Execute SQL query + let flight_info = client + .execute("SELECT id, text FROM test_table".to_string(), None) + .await + .unwrap(); + + // Fetch results using the FlightSql client's do_get + let ticket = flight_info.endpoint[0].ticket.as_ref().unwrap().clone(); + let flight_stream = client.do_get(ticket).await.unwrap(); + + let batches: Vec = flight_stream.try_collect().await.unwrap(); + + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 3, "Should return all 3 rows"); + + // Verify schema + if let Some(batch) = batches.first() { + assert!(batch.schema().column_with_name("id").is_some()); + assert!(batch.schema().column_with_name("text").is_some()); + } + + server_handle.abort(); + } + + #[tokio::test] + async fn test_flight_sql_table_listing() { + let db = create_test_db().await; + let service = LanceFlightSqlService::try_new(db).await.unwrap(); + + assert!(service.tables.contains_key("test_table")); + assert_eq!(service.tables.len(), 1); + } + + #[tokio::test] + async fn test_flight_sql_session_context() { + let db = create_test_db().await; + let service = LanceFlightSqlService::try_new(db).await.unwrap(); + let ctx = service.create_session_context(); + + // Test that we can execute a simple query + let df = ctx.sql("SELECT * FROM test_table LIMIT 1").await.unwrap(); + let results = df.collect().await.unwrap(); + assert_eq!(results[0].num_rows(), 1); + } +} diff --git a/rust/lancedb/src/lib.rs b/rust/lancedb/src/lib.rs index 1ae7c48e2..25211b307 100644 --- a/rust/lancedb/src/lib.rs +++ b/rust/lancedb/src/lib.rs @@ -170,6 +170,8 @@ pub mod dataloader; pub mod embeddings; pub mod error; pub mod expr; +#[cfg(feature = "flight")] +pub mod flight; pub mod index; pub mod io; pub mod ipc; diff --git a/rust/lancedb/src/table/datafusion.rs b/rust/lancedb/src/table/datafusion.rs index bd93dd05d..37b720174 100644 --- a/rust/lancedb/src/table/datafusion.rs +++ b/rust/lancedb/src/table/datafusion.rs @@ -26,9 +26,10 @@ use lance::dataset::{WriteMode, WriteParams}; use super::{AnyQuery, BaseTable}; use crate::{ - Result, - query::{QueryExecutionOptions, QueryFilter, QueryRequest, Select}, + DistanceType, Result, + query::{QueryExecutionOptions, QueryFilter, QueryRequest, Select, VectorQueryRequest}, }; +use arrow_array::Array; use arrow_schema::{DataType, Field}; use lance_index::scalar::FullTextSearchQuery; @@ -141,11 +142,31 @@ impl ExecutionPlan for MetadataEraserExec { } } +/// Parameters for a vector search query, used by vector_search and hybrid_search UDTFs. +#[derive(Debug, Clone)] +pub struct VectorSearchParams { + /// The query vector to search for + pub query_vector: Arc, + /// The column to search on (None for auto-detection) + pub column: Option, + /// Number of results to return + pub top_k: usize, + /// Distance metric to use + pub distance_type: Option, + /// Number of IVF partitions to search + pub nprobes: Option, + /// HNSW search parameter + pub ef: Option, + /// Refine factor for improving recall + pub refine_factor: Option, +} + #[derive(Debug)] pub struct BaseTableAdapter { table: Arc, schema: Arc, fts_query: Option, + vector_query: Option, } impl BaseTableAdapter { @@ -161,6 +182,7 @@ impl BaseTableAdapter { table, schema: Arc::new(schema), fts_query: None, + vector_query: None, }) } @@ -176,6 +198,49 @@ impl BaseTableAdapter { table: self.table.clone(), schema, fts_query: Some(fts_query), + vector_query: self.vector_query.clone(), + } + } + + /// Create a new adapter with a vector search query applied. + pub fn with_vector_query(&self, vector_query: VectorSearchParams) -> Self { + // Add _distance column to the schema + let distance_field = Field::new("_distance", DataType::Float32, true); + let mut fields = self.schema.fields().to_vec(); + fields.push(Arc::new(distance_field)); + let schema = Arc::new(ArrowSchema::new(fields)); + + Self { + table: self.table.clone(), + schema, + fts_query: self.fts_query.clone(), + vector_query: Some(vector_query), + } + } + + /// Create a new adapter with both FTS and vector search queries (hybrid search). + /// + /// Uses vector search as the primary retrieval method, with FTS applied as a + /// pre-filter to restrict the candidate set. Both `_distance` and `_score` + /// columns are added to results. + pub fn with_hybrid_query( + &self, + fts_query: FullTextSearchQuery, + vector_query: VectorSearchParams, + ) -> Self { + // Add _distance column (vector search is primary) + let mut fields = self.schema.fields().to_vec(); + fields.push(Arc::new(Field::new("_distance", DataType::Float32, true))); + let schema = Arc::new(ArrowSchema::new(fields)); + + // Store FTS as a filter concept, but vector search drives the query. + // The FTS query is applied via the base QueryRequest's full_text_search + // field, which acts as a pre-filter for vector search. + Self { + table: self.table.clone(), + schema, + fts_query: Some(fts_query), + vector_query: Some(vector_query), } } } @@ -201,11 +266,20 @@ impl TableProvider for BaseTableAdapter { filters: &[Expr], limit: Option, ) -> DataFusionResult> { - // For FTS queries, disable auto-projection of _score to match DataFusion expectations - let disable_scoring = self.fts_query.is_some() && projection.is_some(); + let has_scoring = self.fts_query.is_some() || self.vector_query.is_some(); + let disable_scoring = has_scoring && projection.is_some(); - let mut query = QueryRequest { - full_text_search: self.fts_query.clone(), + // When doing vector search, FTS cannot be combined in the same scanner + // (Lance doesn't support both nearest + full_text_search simultaneously). + // FTS is only set when there's no vector query. + let fts_for_query = if self.vector_query.is_some() { + None + } else { + self.fts_query.clone() + }; + + let mut base_query = QueryRequest { + full_text_search: fts_for_query, disable_scoring_autoprojection: disable_scoring, ..Default::default() }; @@ -215,20 +289,20 @@ impl TableProvider for BaseTableAdapter { .iter() .map(|i| self.schema.field(*i).name().clone()) .collect(); - query.select = Select::Columns(field_names); + base_query.select = Select::Columns(field_names); } if !filters.is_empty() { let first = filters.first().unwrap().clone(); let filter = filters[1..] .iter() .fold(first, |acc, expr| acc.and(expr.clone())); - query.filter = Some(QueryFilter::Datafusion(filter)); + base_query.filter = Some(QueryFilter::Datafusion(filter)); } if let Some(limit) = limit { - query.limit = Some(limit); - } else { - // Need to override the default of 10 - query.limit = None; + base_query.limit = Some(limit); + } else if self.vector_query.is_none() { + // Need to override the default of 10 for non-vector queries + base_query.limit = None; } let options = QueryExecutionOptions { @@ -236,9 +310,33 @@ impl TableProvider for BaseTableAdapter { ..Default::default() }; + // Build the appropriate query type + let any_query = if let Some(ref vq) = self.vector_query { + let vector_query = VectorQueryRequest { + base: base_query, + column: vq.column.clone(), + query_vector: vec![vq.query_vector.clone()], + minimum_nprobes: vq.nprobes.unwrap_or(20), + maximum_nprobes: vq.nprobes, + ef: vq.ef, + refine_factor: vq.refine_factor, + distance_type: vq.distance_type, + use_index: true, + ..Default::default() + }; + // For vector queries, use top_k as the limit if no explicit limit set + let mut vq_req = vector_query; + if limit.is_none() { + vq_req.base.limit = Some(vq.top_k); + } + AnyQuery::VectorQuery(vq_req) + } else { + AnyQuery::Query(base_query) + }; + let plan = self .table - .create_plan(&AnyQuery::Query(query), options) + .create_plan(&any_query, options) .map_err(|err| DataFusionError::External(err.into())) .await?; Ok(Arc::new(MetadataEraserExec::new(plan))) diff --git a/rust/lancedb/src/table/datafusion/udtf.rs b/rust/lancedb/src/table/datafusion/udtf.rs index 11f0cc61b..477182e65 100644 --- a/rust/lancedb/src/table/datafusion/udtf.rs +++ b/rust/lancedb/src/table/datafusion/udtf.rs @@ -2,5 +2,75 @@ // SPDX-FileCopyrightText: Copyright The LanceDB Authors //! User-Defined Table Functions (UDTFs) for DataFusion integration +//! +//! This module provides SQL table functions for LanceDB search capabilities: +//! - `fts(table_name, query_json)` — full-text search +//! - `vector_search(table_name, query_vector_json, top_k)` — vector similarity search +//! - `hybrid_search(table_name, query_vector_json, fts_query_json, top_k)` — combined search pub mod fts; +pub mod hybrid_search; +pub mod vector_search; + +use std::sync::Arc; + +use datafusion_catalog::TableProvider; +use datafusion_common::{DataFusionError, Result as DataFusionResult, ScalarValue}; +use datafusion_expr::Expr; +use lance_index::scalar::FullTextSearchQuery; + +use super::VectorSearchParams; + +/// Describes the type of search to apply when resolving a table. +#[derive(Debug, Clone)] +pub enum SearchQuery { + /// Full-text search only + Fts(FullTextSearchQuery), + /// Vector similarity search only + Vector(VectorSearchParams), + /// Hybrid search combining FTS and vector search + Hybrid { + fts: FullTextSearchQuery, + vector: VectorSearchParams, + }, +} + +/// Trait for resolving table names to TableProvider instances, optionally with a search query. +pub trait TableResolver: std::fmt::Debug + Send + Sync { + /// Resolve a table name to a TableProvider, optionally applying a search query. + fn resolve_table( + &self, + name: &str, + search: Option, + ) -> DataFusionResult>; +} + +/// Extract a string literal from a DataFusion expression. +pub(crate) fn extract_string_literal(expr: &Expr, param_name: &str) -> DataFusionResult { + match expr { + Expr::Literal(ScalarValue::Utf8(Some(s)), _) => Ok(s.clone()), + Expr::Literal(ScalarValue::LargeUtf8(Some(s)), _) => Ok(s.clone()), + _ => Err(DataFusionError::Plan(format!( + "Parameter '{}' must be a string literal, got: {:?}", + param_name, expr + ))), + } +} + +/// Extract an integer literal from a DataFusion expression. +pub(crate) fn extract_int_literal(expr: &Expr, param_name: &str) -> DataFusionResult { + match expr { + Expr::Literal(ScalarValue::Int8(Some(v)), _) => Ok(*v as usize), + Expr::Literal(ScalarValue::Int16(Some(v)), _) => Ok(*v as usize), + Expr::Literal(ScalarValue::Int32(Some(v)), _) => Ok(*v as usize), + Expr::Literal(ScalarValue::Int64(Some(v)), _) => Ok(*v as usize), + Expr::Literal(ScalarValue::UInt8(Some(v)), _) => Ok(*v as usize), + Expr::Literal(ScalarValue::UInt16(Some(v)), _) => Ok(*v as usize), + Expr::Literal(ScalarValue::UInt32(Some(v)), _) => Ok(*v as usize), + Expr::Literal(ScalarValue::UInt64(Some(v)), _) => Ok(*v as usize), + _ => Err(DataFusionError::Plan(format!( + "Parameter '{}' must be an integer literal, got: {:?}", + param_name, expr + ))), + } +} diff --git a/rust/lancedb/src/table/datafusion/udtf/fts.rs b/rust/lancedb/src/table/datafusion/udtf/fts.rs index 5b50ddfa3..a5cebcd0e 100644 --- a/rust/lancedb/src/table/datafusion/udtf/fts.rs +++ b/rust/lancedb/src/table/datafusion/udtf/fts.rs @@ -1,29 +1,23 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors -//! User-Defined Table Functions (UDTFs) for LanceDB +//! Full-Text Search (FTS) table function for DataFusion SQL integration. //! -//! This module provides table-level UDTFs that integrate with DataFusion's SQL engine. +//! Usage: `SELECT * FROM fts('table_name', '{"match": {"column": "text", "terms": "query"}}')` use std::sync::Arc; use datafusion::catalog::TableFunctionImpl; use datafusion_catalog::TableProvider; -use datafusion_common::{DataFusionError, Result as DataFusionResult, ScalarValue, plan_err}; +use datafusion_common::{DataFusionError, Result as DataFusionResult, plan_err}; use datafusion_expr::Expr; use lance_index::scalar::FullTextSearchQuery; -/// Trait for resolving table names to TableProvider instances. -pub trait TableResolver: std::fmt::Debug + Send + Sync { - /// Resolve a table name to a TableProvider, optionally with an FTS query applied. - fn resolve_table( - &self, - name: &str, - fts_query: Option, - ) -> DataFusionResult>; -} +use super::{SearchQuery, TableResolver, extract_string_literal}; -/// Full-Text Search table function that operates on LanceDB tables +/// Full-Text Search table function that operates on LanceDB tables. +/// +/// Accepts 2 parameters: `fts(table_name, fts_query_json)` #[derive(Debug)] pub struct FtsTableFunction { resolver: Arc, @@ -45,20 +39,8 @@ impl TableFunctionImpl for FtsTableFunction { let query_json = extract_string_literal(&exprs[1], "fts_query")?; let fts_query = parse_fts_query(&query_json)?; - // Resolver returns a ready-to-use TableProvider with FTS applied - self.resolver.resolve_table(&table_name, Some(fts_query)) - } -} - -fn extract_string_literal(expr: &Expr, param_name: &str) -> DataFusionResult { - match expr { - Expr::Literal(ScalarValue::Utf8(Some(s)), _) => Ok(s.clone()), - Expr::Literal(ScalarValue::LargeUtf8(Some(s)), _) => Ok(s.clone()), - _ => plan_err!( - "Parameter '{}' must be a string literal, got: {:?}", - param_name, - expr - ), + self.resolver + .resolve_table(&table_name, Some(SearchQuery::Fts(fts_query))) } } @@ -91,6 +73,7 @@ pub fn from_json(json: &str) -> crate::Result, + search: Option, ) -> DataFusionResult> { let table_provider = self .tables @@ -131,12 +115,10 @@ mod tests { .cloned() .ok_or_else(|| DataFusionError::Plan(format!("Table '{}' not found", name)))?; - // If no FTS query, return as-is - let Some(fts_query) = fts_query else { + let Some(search) = search else { return Ok(table_provider); }; - // Downcast to BaseTableAdapter and apply FTS query let base_adapter = table_provider .as_any() .downcast_ref::() @@ -146,7 +128,15 @@ mod tests { ) })?; - Ok(Arc::new(base_adapter.with_fts_query(fts_query))) + match search { + SearchQuery::Fts(fts_query) => Ok(Arc::new(base_adapter.with_fts_query(fts_query))), + SearchQuery::Vector(vector_query) => { + Ok(Arc::new(base_adapter.with_vector_query(vector_query))) + } + SearchQuery::Hybrid { fts, vector } => { + Ok(Arc::new(base_adapter.with_hybrid_query(fts, vector))) + } + } } } diff --git a/rust/lancedb/src/table/datafusion/udtf/hybrid_search.rs b/rust/lancedb/src/table/datafusion/udtf/hybrid_search.rs new file mode 100644 index 000000000..6421adf5c --- /dev/null +++ b/rust/lancedb/src/table/datafusion/udtf/hybrid_search.rs @@ -0,0 +1,257 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +//! Hybrid search table function for DataFusion SQL integration. +//! +//! Combines vector similarity search with full-text search: +//! ```sql +//! SELECT * FROM hybrid_search( +//! 'my_table', +//! '[0.1, 0.2, 0.3]', +//! '{"match": {"column": "text", "terms": "search query"}}', +//! 10 +//! ) +//! ``` + +use std::sync::Arc; + +use arrow_array::Array; +use datafusion::catalog::TableFunctionImpl; +use datafusion_catalog::TableProvider; +use datafusion_common::{DataFusionError, Result as DataFusionResult, plan_err}; +use datafusion_expr::Expr; +use lance_index::scalar::FullTextSearchQuery; + +use super::fts::from_json as fts_from_json; +use super::{SearchQuery, TableResolver, extract_int_literal, extract_string_literal}; +use crate::table::datafusion::VectorSearchParams; + +/// Default number of results for hybrid search when top_k is not specified. +const DEFAULT_TOP_K: usize = 10; + +/// Hybrid search table function combining vector and full-text search. +/// +/// Accepts 3-4 parameters: `hybrid_search(table_name, query_vector_json, fts_query_json [, top_k])` +/// +/// - `table_name`: Name of the table to search +/// - `query_vector_json`: JSON array of float values, e.g. `'[0.1, 0.2, 0.3]'` +/// - `fts_query_json`: FTS query as JSON, e.g. `'{"match": {"column": "text", "terms": "query"}}'` +/// - `top_k` (optional): Number of results to return (default: 10) +#[derive(Debug)] +pub struct HybridSearchTableFunction { + resolver: Arc, +} + +impl HybridSearchTableFunction { + pub fn new(resolver: Arc) -> Self { + Self { resolver } + } +} + +impl TableFunctionImpl for HybridSearchTableFunction { + fn call(&self, exprs: &[Expr]) -> DataFusionResult> { + if exprs.len() < 3 || exprs.len() > 4 { + return plan_err!( + "hybrid_search() requires 3-4 parameters: hybrid_search(table_name, query_vector_json, fts_query_json [, top_k])" + ); + } + + let table_name = extract_string_literal(&exprs[0], "table_name")?; + let vector_json = extract_string_literal(&exprs[1], "query_vector_json")?; + let fts_json = extract_string_literal(&exprs[2], "fts_query_json")?; + + let top_k = if exprs.len() == 4 { + extract_int_literal(&exprs[3], "top_k")? + } else { + DEFAULT_TOP_K + }; + + let query_vector = parse_vector_json(&vector_json)?; + let fts_query = parse_fts_query(&fts_json)?; + + let vector_params = VectorSearchParams { + query_vector: query_vector as Arc, + column: None, + top_k, + distance_type: None, + nprobes: None, + ef: None, + refine_factor: None, + }; + + self.resolver.resolve_table( + &table_name, + Some(SearchQuery::Hybrid { + fts: fts_query, + vector: vector_params, + }), + ) + } +} + +fn parse_vector_json(json: &str) -> DataFusionResult> { + super::vector_search::parse_vector_json(json) +} + +fn parse_fts_query(json: &str) -> DataFusionResult { + let query = fts_from_json(json).map_err(|e| { + DataFusionError::Plan(format!( + "Invalid FTS query JSON: {}. Expected format: {{\"match\": {{\"column\": \"text\", \"terms\": \"query\"}} }}", + e + )) + })?; + Ok(FullTextSearchQuery::new_query(query)) +} + +#[cfg(test)] +mod tests { + use super::super::fts::to_json; + use super::super::{SearchQuery, TableResolver}; + use super::*; + use crate::{index::Index, table::datafusion::BaseTableAdapter}; + use arrow_array::FixedSizeListArray; + use arrow_array::{Float32Array, Int32Array, RecordBatch, StringArray}; + use arrow_schema::{DataType, Field, Schema as ArrowSchema}; + use datafusion::prelude::SessionContext; + #[allow(unused_imports)] + use lance_arrow::FixedSizeListArrayExt; + + #[derive(Debug)] + struct HashMapTableResolver { + tables: std::collections::HashMap>, + } + + impl HashMapTableResolver { + fn new() -> Self { + Self { + tables: std::collections::HashMap::new(), + } + } + + fn register(&mut self, name: String, table: Arc) { + self.tables.insert(name, table); + } + } + + impl TableResolver for HashMapTableResolver { + fn resolve_table( + &self, + name: &str, + search: Option, + ) -> DataFusionResult> { + let table_provider = self + .tables + .get(name) + .cloned() + .ok_or_else(|| DataFusionError::Plan(format!("Table '{}' not found", name)))?; + + let Some(search) = search else { + return Ok(table_provider); + }; + + let base_adapter = table_provider + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "Expected BaseTableAdapter but got different type".to_string(), + ) + })?; + + match search { + SearchQuery::Fts(fts_query) => Ok(Arc::new(base_adapter.with_fts_query(fts_query))), + SearchQuery::Vector(vector_query) => { + Ok(Arc::new(base_adapter.with_vector_query(vector_query))) + } + SearchQuery::Hybrid { fts, vector } => { + Ok(Arc::new(base_adapter.with_hybrid_query(fts, vector))) + } + } + } + } + + #[tokio::test] + async fn test_hybrid_search_udtf() { + let dim = 4i32; + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("text", DataType::Utf8, false), + Field::new( + "vector", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim), + true, + ), + ])); + + let ids = Int32Array::from(vec![1, 2, 3, 4, 5]); + let texts = StringArray::from(vec![ + "the quick brown fox", + "jumps over the lazy dog", + "a quick red fox runs", + "the dog sleeps all day", + "a brown fox and a quick dog", + ]); + let flat_values = Float32Array::from(vec![ + 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.9, 0.1, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.5, + 0.5, 0.0, 0.0, + ]); + let vector_array = FixedSizeListArray::try_new_from_values(flat_values, dim).unwrap(); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(ids), Arc::new(texts), Arc::new(vector_array)], + ) + .unwrap(); + + let db = crate::connect("memory://test_hybrid") + .execute() + .await + .unwrap(); + let table = db.create_table("docs", batch).execute().await.unwrap(); + + // Create FTS index on text column + table + .create_index(&["text"], Index::FTS(Default::default())) + .execute() + .await + .unwrap(); + + let ctx = SessionContext::new(); + let mut resolver = HashMapTableResolver::new(); + let adapter = BaseTableAdapter::try_new(table.base_table().clone()) + .await + .unwrap(); + resolver.register("docs".to_string(), Arc::new(adapter)); + + let resolver = Arc::new(resolver); + ctx.register_udtf( + "hybrid_search", + Arc::new(HybridSearchTableFunction::new(resolver.clone())), + ); + + // Run hybrid search: vector close to [1,0,0,0] AND FTS for "fox" + use lance_index::scalar::inverted::query::*; + let fts_query_struct = FtsQuery::Match( + MatchQuery::new("fox".to_string()).with_column(Some("text".to_string())), + ); + let fts_json = to_json(&fts_query_struct).unwrap(); + + let query = format!( + "SELECT * FROM hybrid_search('docs', '[1.0, 0.0, 0.0, 0.0]', '{}', 5)", + fts_json + ); + + let df = ctx.sql(&query).await.unwrap(); + let results = df.collect().await.unwrap(); + + assert!(!results.is_empty()); + let total_rows: usize = results.iter().map(|b| b.num_rows()).sum(); + assert!(total_rows > 0, "Should have at least one result"); + + // Check schema has the expected columns + let result_schema = results[0].schema(); + assert!(result_schema.column_with_name("id").is_some()); + assert!(result_schema.column_with_name("text").is_some()); + assert!(result_schema.column_with_name("vector").is_some()); + } +} diff --git a/rust/lancedb/src/table/datafusion/udtf/vector_search.rs b/rust/lancedb/src/table/datafusion/udtf/vector_search.rs new file mode 100644 index 000000000..7e8c84c72 --- /dev/null +++ b/rust/lancedb/src/table/datafusion/udtf/vector_search.rs @@ -0,0 +1,278 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +//! Vector search table function for DataFusion SQL integration. +//! +//! Enables vector similarity search via SQL: +//! ```sql +//! SELECT * FROM vector_search('my_table', '[0.1, 0.2, 0.3, ...]', 10) +//! ``` + +use std::sync::Arc; + +use arrow_array::{Array, Float32Array}; +use datafusion::catalog::TableFunctionImpl; +use datafusion_catalog::TableProvider; +use datafusion_common::{DataFusionError, Result as DataFusionResult, plan_err}; +use datafusion_expr::Expr; + +use super::{SearchQuery, TableResolver, extract_int_literal, extract_string_literal}; +use crate::table::datafusion::VectorSearchParams; + +/// Default number of results for vector search when top_k is not specified. +const DEFAULT_TOP_K: usize = 10; + +/// Vector search table function for LanceDB tables. +/// +/// Accepts 2-3 parameters: `vector_search(table_name, query_vector_json [, top_k])` +/// +/// - `table_name`: Name of the table to search +/// - `query_vector_json`: JSON array of float values, e.g. `'[0.1, 0.2, 0.3]'` +/// - `top_k` (optional): Number of results to return (default: 10) +#[derive(Debug)] +pub struct VectorSearchTableFunction { + resolver: Arc, +} + +impl VectorSearchTableFunction { + pub fn new(resolver: Arc) -> Self { + Self { resolver } + } +} + +impl TableFunctionImpl for VectorSearchTableFunction { + fn call(&self, exprs: &[Expr]) -> DataFusionResult> { + if exprs.len() < 2 || exprs.len() > 3 { + return plan_err!( + "vector_search() requires 2-3 parameters: vector_search(table_name, query_vector_json [, top_k])" + ); + } + + let table_name = extract_string_literal(&exprs[0], "table_name")?; + let vector_json = extract_string_literal(&exprs[1], "query_vector_json")?; + + let top_k = if exprs.len() == 3 { + extract_int_literal(&exprs[2], "top_k")? + } else { + DEFAULT_TOP_K + }; + + let query_vector = parse_vector_json(&vector_json)?; + + let params = VectorSearchParams { + query_vector: query_vector as Arc, + column: None, + top_k, + distance_type: None, + nprobes: None, + ef: None, + refine_factor: None, + }; + + self.resolver + .resolve_table(&table_name, Some(SearchQuery::Vector(params))) + } +} + +/// Parse a JSON array of floats into an Arrow Float32Array for vector search. +/// +/// Input format: `"[0.1, 0.2, 0.3, ...]"` +/// +/// Returns a Float32Array whose length equals the vector dimension. +/// This is the format expected by LanceDB's vector search internals. +pub(crate) fn parse_vector_json(json: &str) -> DataFusionResult> { + let values: Vec = serde_json::from_str(json).map_err(|e| { + DataFusionError::Plan(format!( + "Invalid vector JSON: {}. Expected format: [0.1, 0.2, 0.3, ...]", + e + )) + })?; + + if values.is_empty() { + return Err(DataFusionError::Plan( + "Vector must not be empty".to_string(), + )); + } + + Ok(Arc::new(Float32Array::from(values))) +} + +#[cfg(test)] +mod tests { + use super::super::{SearchQuery, TableResolver}; + use super::*; + use crate::table::datafusion::BaseTableAdapter; + use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, RecordBatch}; + use arrow_schema::{DataType, Field, Schema as ArrowSchema}; + use datafusion::prelude::SessionContext; + #[allow(unused_imports)] + use lance_arrow::FixedSizeListArrayExt; + + /// Resolver that looks up tables in a HashMap and applies search queries + #[derive(Debug)] + struct HashMapTableResolver { + tables: std::collections::HashMap>, + } + + impl HashMapTableResolver { + fn new() -> Self { + Self { + tables: std::collections::HashMap::new(), + } + } + + fn register(&mut self, name: String, table: Arc) { + self.tables.insert(name, table); + } + } + + impl TableResolver for HashMapTableResolver { + fn resolve_table( + &self, + name: &str, + search: Option, + ) -> DataFusionResult> { + let table_provider = self + .tables + .get(name) + .cloned() + .ok_or_else(|| DataFusionError::Plan(format!("Table '{}' not found", name)))?; + + let Some(search) = search else { + return Ok(table_provider); + }; + + let base_adapter = table_provider + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "Expected BaseTableAdapter but got different type".to_string(), + ) + })?; + + match search { + SearchQuery::Fts(fts_query) => Ok(Arc::new(base_adapter.with_fts_query(fts_query))), + SearchQuery::Vector(vector_query) => { + Ok(Arc::new(base_adapter.with_vector_query(vector_query))) + } + SearchQuery::Hybrid { fts, vector } => { + Ok(Arc::new(base_adapter.with_hybrid_query(fts, vector))) + } + } + } + } + + fn make_test_data() -> (Arc, RecordBatch) { + let dim = 4; + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "vector", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim), + true, + ), + ])); + + let ids = Int32Array::from(vec![1, 2, 3, 4, 5]); + // Create vectors: [1,0,0,0], [0,1,0,0], [0,0,1,0], [0,0,0,1], [1,1,0,0] + let flat_values = Float32Array::from(vec![ + 1.0, 0.0, 0.0, 0.0, // vec 1 + 0.0, 1.0, 0.0, 0.0, // vec 2 + 0.0, 0.0, 1.0, 0.0, // vec 3 + 0.0, 0.0, 0.0, 1.0, // vec 4 + 1.0, 1.0, 0.0, 0.0, // vec 5 + ]); + let vector_array = FixedSizeListArray::try_new_from_values(flat_values, dim).unwrap(); + + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(vector_array)]) + .unwrap(); + + (schema, batch) + } + + #[tokio::test] + async fn test_vector_search_udtf() { + let (_schema, batch) = make_test_data(); + + let db = crate::connect("memory://test_vec").execute().await.unwrap(); + let table = db.create_table("vectors", batch).execute().await.unwrap(); + + // No index needed — vector search works with brute-force scan on small tables + + // Setup DataFusion context + let ctx = SessionContext::new(); + let mut resolver = HashMapTableResolver::new(); + let adapter = BaseTableAdapter::try_new(table.base_table().clone()) + .await + .unwrap(); + resolver.register("vectors".to_string(), Arc::new(adapter)); + + let udtf = VectorSearchTableFunction::new(Arc::new(resolver)); + ctx.register_udtf("vector_search", Arc::new(udtf)); + + // Search for vectors close to [1, 0, 0, 0] + let query = "SELECT * FROM vector_search('vectors', '[1.0, 0.0, 0.0, 0.0]', 3)"; + let df = ctx.sql(query).await.unwrap(); + let results = df.collect().await.unwrap(); + + assert!(!results.is_empty()); + let batch = &results[0]; + + // Should have id, vector, _distance columns + assert!(batch.schema().column_with_name("id").is_some()); + assert!(batch.schema().column_with_name("vector").is_some()); + assert!( + batch.schema().column_with_name("_distance").is_some(), + "_distance column should be present" + ); + + // Should return at most 3 results + let total_rows: usize = results.iter().map(|b| b.num_rows()).sum(); + assert!(total_rows <= 3); + assert!(total_rows > 0); + } + + #[tokio::test] + async fn test_vector_search_default_top_k() { + let (_, batch) = make_test_data(); + + let db = crate::connect("memory://test_vec_default") + .execute() + .await + .unwrap(); + let table = db.create_table("vectors", batch).execute().await.unwrap(); + + let ctx = SessionContext::new(); + let mut resolver = HashMapTableResolver::new(); + let adapter = BaseTableAdapter::try_new(table.base_table().clone()) + .await + .unwrap(); + resolver.register("vectors".to_string(), Arc::new(adapter)); + + let udtf = VectorSearchTableFunction::new(Arc::new(resolver)); + ctx.register_udtf("vector_search", Arc::new(udtf)); + + // No top_k parameter — should default to 10 + let query = "SELECT * FROM vector_search('vectors', '[1.0, 0.0, 0.0, 0.0]')"; + let df = ctx.sql(query).await.unwrap(); + let results = df.collect().await.unwrap(); + + let total_rows: usize = results.iter().map(|b| b.num_rows()).sum(); + // We only have 5 rows, so we should get all 5 back + assert_eq!(total_rows, 5); + } + + #[test] + fn test_parse_vector_json() { + let result = parse_vector_json("[1.0, 2.0, 3.0]").unwrap(); + assert_eq!(result.len(), 3); // 3-dimensional vector + + // Empty vector should fail + assert!(parse_vector_json("[]").is_err()); + + // Invalid JSON should fail + assert!(parse_vector_json("not json").is_err()); + } +}