Compare commits

...

1 Commits

Author SHA1 Message Date
Justin Miller
c8d654db37 feat: add ADBC/Flight SQL interface with search table functions (OSS-738)
Add Arrow Flight SQL server and SQL table functions for vector/hybrid search,
enabling standard SQL workbench tools to connect to LanceDB via ADBC.

Phase 1 - Search Table Functions:
- vector_search(table, vector_json, top_k) UDTF for vector similarity search
- hybrid_search(table, vector_json, fts_json, top_k) UDTF combining both
- Generalized TableResolver with SearchQuery enum (Fts/Vector/Hybrid)
- Extended BaseTableAdapter with VectorSearchParams support

Phase 2 - Arrow Flight SQL Server:
- FlightSqlService implementation with SQL query execution
- Catalog introspection (tables, schemas, catalogs, table types)
- All three UDTFs auto-registered in session context
- Feature-gated behind "flight" flag (arrow-flight + tonic + prost)

Phase 3 - Example & Tests:
- flight_sql example: creates sample DB, starts Flight SQL server
- Tests for vector_search and hybrid_search UDTFs
- Flight SQL integration tests (basic query, table listing, session)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-07 01:48:35 -07:00
11 changed files with 1608 additions and 59 deletions

165
Cargo.lock generated
View File

@@ -290,6 +290,34 @@ dependencies = [
"num-traits",
]
[[package]]
name = "arrow-flight"
version = "57.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "58c5b083668e6230eae3eab2fc4b5fb989974c845d0aa538dde61a4327c78675"
dependencies = [
"arrow-arith",
"arrow-array",
"arrow-buffer",
"arrow-cast",
"arrow-data",
"arrow-ipc",
"arrow-ord",
"arrow-row",
"arrow-schema",
"arrow-select",
"arrow-string",
"base64 0.22.1",
"bytes",
"futures",
"once_cell",
"paste",
"prost",
"prost-types",
"tonic",
"tonic-prost",
]
[[package]]
name = "arrow-ipc"
version = "57.3.0"
@@ -1069,7 +1097,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f"
dependencies = [
"async-trait",
"axum-core",
"axum-core 0.4.5",
"bytes",
"futures-util",
"http 1.4.0",
@@ -1078,7 +1106,7 @@ dependencies = [
"hyper 1.8.1",
"hyper-util",
"itoa",
"matchit",
"matchit 0.7.3",
"memchr",
"mime",
"percent-encoding",
@@ -1096,6 +1124,31 @@ dependencies = [
"tracing",
]
[[package]]
name = "axum"
version = "0.8.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8"
dependencies = [
"axum-core 0.5.6",
"bytes",
"futures-util",
"http 1.4.0",
"http-body 1.0.1",
"http-body-util",
"itoa",
"matchit 0.8.4",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"serde_core",
"sync_wrapper",
"tower",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum-core"
version = "0.4.5"
@@ -1117,6 +1170,24 @@ dependencies = [
"tracing",
]
[[package]]
name = "axum-core"
version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1"
dependencies = [
"bytes",
"futures-core",
"http 1.4.0",
"http-body 1.0.1",
"http-body-util",
"mime",
"pin-project-lite",
"sync_wrapper",
"tower-layer",
"tower-service",
]
[[package]]
name = "backoff"
version = "0.4.0"
@@ -1294,7 +1365,7 @@ version = "3.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "519bd3116aeeb42d5372c29d982d16d0170d3d4a5ed85fc7dd91642ffff3c67c"
dependencies = [
"darling 0.20.11",
"darling 0.23.0",
"ident_case",
"prettyplease",
"proc-macro2",
@@ -2682,7 +2753,7 @@ dependencies = [
"libc",
"option-ext",
"redox_users",
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -2876,7 +2947,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [
"libc",
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -3719,6 +3790,19 @@ dependencies = [
"webpki-roots 1.0.6",
]
[[package]]
name = "hyper-timeout"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0"
dependencies = [
"hyper 1.8.1",
"hyper-util",
"pin-project-lite",
"tokio",
"tower-service",
]
[[package]]
name = "hyper-util"
version = "0.1.20"
@@ -4048,7 +4132,7 @@ dependencies = [
"portable-atomic",
"portable-atomic-util",
"serde_core",
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -4539,7 +4623,7 @@ dependencies = [
"arrow-ipc",
"arrow-schema",
"async-trait",
"axum",
"axum 0.7.9",
"bytes",
"chrono",
"futures",
@@ -4638,6 +4722,7 @@ dependencies = [
"arrow-array",
"arrow-cast",
"arrow-data",
"arrow-flight",
"arrow-ipc",
"arrow-ord",
"arrow-schema",
@@ -4691,6 +4776,7 @@ dependencies = [
"pin-project",
"polars",
"polars-arrow",
"prost",
"rand 0.9.2",
"random_word 0.4.3",
"regex",
@@ -4705,6 +4791,8 @@ dependencies = [
"test-log",
"tokenizers",
"tokio",
"tokio-stream",
"tonic",
"url",
"uuid",
"walkdir",
@@ -5009,6 +5097,12 @@ version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
[[package]]
name = "matchit"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3"
[[package]]
name = "matrixmultiply"
version = "0.3.10"
@@ -5322,7 +5416,7 @@ version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -6302,7 +6396,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7"
dependencies = [
"heck 0.5.0",
"itertools 0.11.0",
"itertools 0.14.0",
"log",
"multimap",
"petgraph",
@@ -6321,7 +6415,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b"
dependencies = [
"anyhow",
"itertools 0.11.0",
"itertools 0.14.0",
"proc-macro2",
"quote",
"syn 2.0.117",
@@ -6527,7 +6621,7 @@ dependencies = [
"once_cell",
"socket2 0.6.3",
"tracing",
"windows-sys 0.59.0",
"windows-sys 0.60.2",
]
[[package]]
@@ -7069,7 +7163,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys 0.12.1",
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -8088,7 +8182,7 @@ dependencies = [
"getrandom 0.4.2",
"once_cell",
"rustix 1.1.4",
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -8370,6 +8464,46 @@ dependencies = [
"winnow",
]
[[package]]
name = "tonic"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fec7c61a0695dc1887c1b53952990f3ad2e3a31453e1f49f10e75424943a93ec"
dependencies = [
"async-trait",
"axum 0.8.8",
"base64 0.22.1",
"bytes",
"h2 0.4.13",
"http 1.4.0",
"http-body 1.0.1",
"http-body-util",
"hyper 1.8.1",
"hyper-timeout",
"hyper-util",
"percent-encoding",
"pin-project",
"socket2 0.6.3",
"sync_wrapper",
"tokio",
"tokio-stream",
"tower",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tonic-prost"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a55376a0bbaa4975a3f10d009ad763d8f4108f067c7c2e74f3001fb49778d309"
dependencies = [
"bytes",
"prost",
"tonic",
]
[[package]]
name = "tower"
version = "0.5.3"
@@ -8378,9 +8512,12 @@ checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4"
dependencies = [
"futures-core",
"futures-util",
"indexmap 2.13.0",
"pin-project-lite",
"slab",
"sync_wrapper",
"tokio",
"tokio-util",
"tower-layer",
"tower-service",
"tracing",
@@ -8893,7 +9030,7 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
dependencies = [
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]

View File

@@ -39,6 +39,8 @@ arrow-ord = "57.2"
arrow-schema = "57.2"
arrow-select = "57.2"
arrow-cast = "57.2"
arrow-flight = { version = "57.2", features = ["flight-sql-experimental"] }
tonic = "0.14"
async-trait = "0"
datafusion = { version = "52.1", default-features = false }
datafusion-catalog = "52.1"

View File

@@ -88,6 +88,10 @@ candle-transformers = { version = "0.9.1", optional = true }
candle-nn = { version = "0.9.1", optional = true }
tokenizers = { version = "0.19.1", optional = true }
semver = { workspace = true }
# For flight feature (Arrow Flight SQL server)
arrow-flight = { workspace = true, optional = true }
tonic = { workspace = true, optional = true }
prost = { version = "0.14", optional = true }
[dev-dependencies]
anyhow = "1"
@@ -104,6 +108,7 @@ datafusion.workspace = true
http-body = "1" # Matching reqwest
rstest = "0.23.0"
test-log = "0.2"
tokio-stream = "0.1"
[features]
@@ -131,6 +136,7 @@ sentence-transformers = [
"dep:candle-nn",
"dep:tokenizers",
]
flight = ["dep:arrow-flight", "dep:tonic", "dep:prost"]
[[example]]
name = "openai"
@@ -157,5 +163,9 @@ name = "ivf_pq"
name = "hybrid_search"
required-features = ["sentence-transformers"]
[[example]]
name = "flight_sql"
required-features = ["flight"]
[package.metadata.docs.rs]
all-features = true

View File

@@ -0,0 +1,99 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! Example: LanceDB Arrow Flight SQL Server
//!
//! This example demonstrates how to:
//! 1. Create a LanceDB database with sample data
//! 2. Start an Arrow Flight SQL server
//! 3. Connect with a Flight SQL client
//! 4. Run SQL queries including vector_search and fts table functions
//!
//! Run with: `cargo run --features flight --example flight_sql`
use std::sync::Arc;
use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use lance_arrow::FixedSizeListArrayExt;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Initialize logging if env_logger is available
let _ = std::env::var("RUST_LOG").ok();
// 1. Create an in-memory LanceDB database
let db = lancedb::connect("memory://flight_sql_demo")
.execute()
.await?;
// 2. Create a table with text and vector data
let dim = 4i32;
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("text", DataType::Utf8, false),
Field::new(
"vector",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
true,
),
]));
let ids = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
let texts = StringArray::from(vec![
"the quick brown fox jumps over the lazy dog",
"a fast red fox leaps across the sleeping hound",
"machine learning models process natural language",
"neural networks learn from training data",
"the brown dog chases the red fox through the forest",
"deep learning algorithms improve with more data",
"a lazy cat sleeps on the warm windowsill",
"vector databases enable fast similarity search",
]);
let flat_values = Float32Array::from(vec![
1.0, 0.0, 0.0, 0.0, // fox-like
0.9, 0.1, 0.0, 0.0, // similar to fox
0.0, 1.0, 0.0, 0.0, // ML-like
0.0, 0.9, 0.1, 0.0, // similar to ML
0.7, 0.3, 0.0, 0.0, // fox+dog mix
0.0, 0.8, 0.2, 0.0, // ML-like
0.1, 0.0, 0.0, 0.9, // cat-like
0.0, 0.5, 0.5, 0.0, // tech-like
]);
let vector_array = FixedSizeListArray::try_new_from_values(flat_values, dim)?;
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(ids), Arc::new(texts), Arc::new(vector_array)],
)?;
let table = db.create_table("documents", batch).execute().await?;
// 3. Create indices
println!("Creating FTS index on 'text' column...");
table
.create_index(&["text"], lancedb::index::Index::FTS(Default::default()))
.execute()
.await?;
println!("Database ready with {} rows", table.count_rows(None).await?);
// 4. Start Flight SQL server
let addr = "0.0.0.0:50051".parse()?;
println!("Starting Arrow Flight SQL server on {}...", addr);
println!();
println!("Connect with any ADBC Flight SQL client:");
println!(" URI: grpc://localhost:50051");
println!();
println!("Example SQL queries:");
println!(" SELECT * FROM documents LIMIT 5;");
println!(" SELECT * FROM vector_search('documents', '[1.0, 0.0, 0.0, 0.0]', 3);");
println!(
" SELECT * FROM fts('documents', '{{\"match\": {{\"column\": \"text\", \"terms\": \"fox\"}}}}');"
);
println!();
lancedb::flight::serve(db, addr).await?;
Ok(())
}

606
rust/lancedb/src/flight.rs Normal file
View File

@@ -0,0 +1,606 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! Arrow Flight SQL server for LanceDB.
//!
//! This module provides an Arrow Flight SQL server that exposes a LanceDB
//! [`Connection`] over the Flight SQL protocol. Any ADBC, ODBC (via bridge),
//! or JDBC Flight SQL client can connect and run SQL queries — including
//! LanceDB's search table functions (`vector_search`, `fts`, `hybrid_search`).
//!
//! # Quick Start
//!
//! ```no_run
//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
//! let db = lancedb::connect("data/my-db").execute().await?;
//! let addr = "0.0.0.0:50051".parse()?;
//! lancedb::flight::serve(db, addr).await?;
//! # Ok(())
//! # }
//! ```
use std::collections::HashMap;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use arrow_array::{ArrayRef, RecordBatch, StringArray};
use arrow_flight::encode::FlightDataEncoderBuilder;
use arrow_flight::error::FlightError;
use arrow_flight::flight_service_server::FlightServiceServer;
use arrow_flight::sql::server::FlightSqlService;
use arrow_flight::sql::{
Any, CommandGetCatalogs, CommandGetDbSchemas, CommandGetTableTypes, CommandGetTables,
CommandStatementQuery, SqlInfo, TicketStatementQuery,
};
use arrow_flight::{FlightDescriptor, FlightEndpoint, FlightInfo, Ticket};
use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef};
use datafusion::prelude::SessionContext;
use datafusion_catalog::TableProvider;
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use futures::StreamExt;
use futures::stream;
use log;
use prost::Message;
use tonic::transport::Server;
use tonic::{Request, Response, Status};
use crate::connection::Connection;
use crate::table::datafusion::BaseTableAdapter;
use crate::table::datafusion::udtf::fts::FtsTableFunction;
use crate::table::datafusion::udtf::hybrid_search::HybridSearchTableFunction;
use crate::table::datafusion::udtf::vector_search::VectorSearchTableFunction;
use crate::table::datafusion::udtf::{SearchQuery, TableResolver};
/// Start an Arrow Flight SQL server exposing the given LanceDB connection.
///
/// This is a convenience function that creates a server and starts listening.
/// It blocks until the server is shut down.
pub async fn serve(connection: Connection, addr: SocketAddr) -> crate::Result<()> {
let service = LanceFlightSqlService::try_new(connection).await?;
let flight_svc = FlightServiceServer::new(service);
Server::builder()
.add_service(flight_svc)
.serve(addr)
.await
.map_err(|e| crate::Error::Runtime {
message: format!("Flight SQL server error: {}", e),
})?;
Ok(())
}
/// A table resolver that looks up tables from a pre-built HashMap.
#[derive(Debug)]
struct ConnectionTableResolver {
tables: HashMap<String, Arc<BaseTableAdapter>>,
}
impl TableResolver for ConnectionTableResolver {
fn resolve_table(
&self,
name: &str,
search: Option<SearchQuery>,
) -> DataFusionResult<Arc<dyn TableProvider>> {
let adapter = self
.tables
.get(name)
.ok_or_else(|| DataFusionError::Plan(format!("Table '{}' not found", name)))?;
match search {
None => Ok(adapter.clone() as Arc<dyn TableProvider>),
Some(SearchQuery::Fts(fts)) => Ok(Arc::new(adapter.with_fts_query(fts))),
Some(SearchQuery::Vector(vq)) => Ok(Arc::new(adapter.with_vector_query(vq))),
Some(SearchQuery::Hybrid { fts, vector }) => {
Ok(Arc::new(adapter.with_hybrid_query(fts, vector)))
}
}
}
}
/// Arrow Flight SQL service backed by a LanceDB connection.
struct LanceFlightSqlService {
/// Kept for future use (e.g., refreshing table list).
_connection: Connection,
/// Pre-built table adapters (refreshed on creation)
tables: HashMap<String, Arc<BaseTableAdapter>>,
}
impl LanceFlightSqlService {
async fn try_new(connection: Connection) -> crate::Result<Self> {
let table_names = connection.table_names().execute().await?;
let mut tables = HashMap::new();
for name in &table_names {
let table = connection.open_table(name).execute().await?;
let adapter = BaseTableAdapter::try_new(table.base_table().clone()).await?;
tables.insert(name.clone(), Arc::new(adapter));
}
log::info!(
"LanceDB Flight SQL service initialized with {} tables",
tables.len()
);
Ok(Self {
_connection: connection,
tables,
})
}
/// Create a DataFusion SessionContext with all tables and UDTFs registered.
fn create_session_context(&self) -> SessionContext {
let ctx = SessionContext::new();
// Register all tables
for (name, adapter) in &self.tables {
if let Err(e) = ctx.register_table(name, adapter.clone() as Arc<dyn TableProvider>) {
log::warn!("Failed to register table '{}': {}", name, e);
}
}
// Create resolver for UDTFs
let resolver = Arc::new(ConnectionTableResolver {
tables: self.tables.clone(),
});
// Register search UDTFs
ctx.register_udtf("fts", Arc::new(FtsTableFunction::new(resolver.clone())));
ctx.register_udtf(
"vector_search",
Arc::new(VectorSearchTableFunction::new(resolver.clone())),
);
ctx.register_udtf(
"hybrid_search",
Arc::new(HybridSearchTableFunction::new(resolver)),
);
ctx
}
/// Execute a SQL query and return the results as a stream of FlightData.
async fn execute_sql(
&self,
sql: &str,
) -> Result<
Pin<Box<dyn futures::Stream<Item = Result<arrow_flight::FlightData, Status>> + Send>>,
Status,
> {
let ctx = self.create_session_context();
let df = ctx
.sql(sql)
.await
.map_err(|e| Status::internal(format!("SQL planning error: {}", e)))?;
let schema: SchemaRef = df.schema().inner().clone();
let stream = df
.execute_stream()
.await
.map_err(|e| Status::internal(format!("SQL execution error: {}", e)))?;
// Use FlightDataEncoderBuilder to properly encode batches with schema
let batch_stream = stream.map(|r| r.map_err(|e| FlightError::ExternalError(Box::new(e))));
let flight_data_stream = FlightDataEncoderBuilder::new()
.with_schema(schema)
.build(batch_stream)
.map(|result| result.map_err(|e| Status::internal(format!("Encoding error: {}", e))));
Ok(Box::pin(flight_data_stream))
}
/// Encode a single RecordBatch into a FlightData stream (schema + data).
fn batch_to_flight_stream(
batch: RecordBatch,
) -> Pin<Box<dyn futures::Stream<Item = Result<arrow_flight::FlightData, Status>> + Send>> {
let schema = batch.schema();
let stream = FlightDataEncoderBuilder::new()
.with_schema(schema)
.build(stream::once(async move { Ok(batch) }))
.map(|result| result.map_err(|e| Status::internal(format!("Encoding error: {}", e))));
Box::pin(stream)
}
/// Get the schema for a SQL query without executing it.
async fn get_sql_schema(&self, sql: &str) -> Result<ArrowSchema, Status> {
let ctx = self.create_session_context();
let df = ctx
.sql(sql)
.await
.map_err(|e| Status::internal(format!("SQL planning error: {}", e)))?;
Ok(df.schema().inner().as_ref().clone())
}
}
#[tonic::async_trait]
impl FlightSqlService for LanceFlightSqlService {
type FlightService = LanceFlightSqlService;
/// Handle SQL query: return FlightInfo with schema and ticket.
async fn get_flight_info_statement(
&self,
query: CommandStatementQuery,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
let sql = query.query;
log::info!("get_flight_info_statement: {}", sql);
let schema = self.get_sql_schema(&sql).await?;
// Encode the query as an Any-wrapped TicketStatementQuery for do_get_statement
let ticket = TicketStatementQuery {
statement_handle: sql.into_bytes().into(),
};
let any_msg = Any::pack(&ticket)
.map_err(|e| Status::internal(format!("Ticket encoding error: {}", e)))?;
let mut ticket_bytes = Vec::new();
any_msg
.encode(&mut ticket_bytes)
.map_err(|e| Status::internal(format!("Ticket encoding error: {}", e)))?;
let endpoint = FlightEndpoint::new().with_ticket(Ticket::new(ticket_bytes));
let flight_info = FlightInfo::new()
.try_with_schema(&schema)
.map_err(|e| Status::internal(format!("Schema error: {}", e)))?
.with_endpoint(endpoint)
.with_descriptor(request.into_inner());
Ok(Response::new(flight_info))
}
/// Execute a SQL query and stream results.
async fn do_get_statement(
&self,
ticket: TicketStatementQuery,
_request: Request<Ticket>,
) -> Result<
Response<<Self as arrow_flight::flight_service_server::FlightService>::DoGetStream>,
Status,
> {
let sql = String::from_utf8(ticket.statement_handle.to_vec())
.map_err(|e| Status::internal(format!("Invalid ticket: {}", e)))?;
log::info!("do_get_statement: {}", sql);
let stream = self.execute_sql(&sql).await?;
Ok(Response::new(stream))
}
/// List tables in the database.
async fn get_flight_info_tables(
&self,
_query: CommandGetTables,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
let schema = ArrowSchema::new(vec![
Field::new("catalog_name", DataType::Utf8, true),
Field::new("db_schema_name", DataType::Utf8, true),
Field::new("table_name", DataType::Utf8, false),
Field::new("table_type", DataType::Utf8, false),
]);
let cmd = CommandGetTables::default();
let any_msg =
Any::pack(&cmd).map_err(|e| Status::internal(format!("Encoding error: {}", e)))?;
let mut ticket_bytes = Vec::new();
any_msg
.encode(&mut ticket_bytes)
.map_err(|e| Status::internal(format!("Encoding error: {}", e)))?;
let endpoint = FlightEndpoint::new().with_ticket(Ticket::new(ticket_bytes));
let flight_info = FlightInfo::new()
.try_with_schema(&schema)
.map_err(|e| Status::internal(format!("Schema error: {}", e)))?
.with_endpoint(endpoint)
.with_descriptor(request.into_inner());
Ok(Response::new(flight_info))
}
async fn do_get_tables(
&self,
_query: CommandGetTables,
_request: Request<Ticket>,
) -> Result<
Response<<Self as arrow_flight::flight_service_server::FlightService>::DoGetStream>,
Status,
> {
let table_names: Vec<&str> = self.tables.keys().map(|s| s.as_str()).collect();
let num_tables = table_names.len();
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("catalog_name", DataType::Utf8, true),
Field::new("db_schema_name", DataType::Utf8, true),
Field::new("table_name", DataType::Utf8, false),
Field::new("table_type", DataType::Utf8, false),
]));
let catalog_names: ArrayRef =
Arc::new(StringArray::from(vec![Some("lancedb"); num_tables]));
let schema_names: ArrayRef = Arc::new(StringArray::from(vec![Some("default"); num_tables]));
let table_name_array: ArrayRef = Arc::new(StringArray::from(table_names));
let table_types: ArrayRef = Arc::new(StringArray::from(vec!["TABLE"; num_tables]));
let batch = RecordBatch::try_new(
schema,
vec![catalog_names, schema_names, table_name_array, table_types],
)
.map_err(|e| Status::internal(format!("RecordBatch error: {}", e)))?;
Ok(Response::new(Self::batch_to_flight_stream(batch)))
}
/// List table types.
async fn get_flight_info_table_types(
&self,
_query: CommandGetTableTypes,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
let schema = ArrowSchema::new(vec![Field::new("table_type", DataType::Utf8, false)]);
let cmd = CommandGetTableTypes::default();
let any_msg =
Any::pack(&cmd).map_err(|e| Status::internal(format!("Encoding error: {}", e)))?;
let mut ticket_bytes = Vec::new();
any_msg
.encode(&mut ticket_bytes)
.map_err(|e| Status::internal(format!("Encoding error: {}", e)))?;
let endpoint = FlightEndpoint::new().with_ticket(Ticket::new(ticket_bytes));
let flight_info = FlightInfo::new()
.try_with_schema(&schema)
.map_err(|e| Status::internal(format!("Schema error: {}", e)))?
.with_endpoint(endpoint)
.with_descriptor(request.into_inner());
Ok(Response::new(flight_info))
}
async fn do_get_table_types(
&self,
_query: CommandGetTableTypes,
_request: Request<Ticket>,
) -> Result<
Response<<Self as arrow_flight::flight_service_server::FlightService>::DoGetStream>,
Status,
> {
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"table_type",
DataType::Utf8,
false,
)]));
let table_types: ArrayRef = Arc::new(StringArray::from(vec!["TABLE"]));
let batch = RecordBatch::try_new(schema, vec![table_types])
.map_err(|e| Status::internal(format!("RecordBatch error: {}", e)))?;
Ok(Response::new(Self::batch_to_flight_stream(batch)))
}
/// List catalogs.
async fn get_flight_info_catalogs(
&self,
_query: CommandGetCatalogs,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
let schema = ArrowSchema::new(vec![Field::new("catalog_name", DataType::Utf8, false)]);
let cmd = CommandGetCatalogs::default();
let any_msg =
Any::pack(&cmd).map_err(|e| Status::internal(format!("Encoding error: {}", e)))?;
let mut ticket_bytes = Vec::new();
any_msg
.encode(&mut ticket_bytes)
.map_err(|e| Status::internal(format!("Encoding error: {}", e)))?;
let endpoint = FlightEndpoint::new().with_ticket(Ticket::new(ticket_bytes));
let flight_info = FlightInfo::new()
.try_with_schema(&schema)
.map_err(|e| Status::internal(format!("Schema error: {}", e)))?
.with_endpoint(endpoint)
.with_descriptor(request.into_inner());
Ok(Response::new(flight_info))
}
async fn do_get_catalogs(
&self,
_query: CommandGetCatalogs,
_request: Request<Ticket>,
) -> Result<
Response<<Self as arrow_flight::flight_service_server::FlightService>::DoGetStream>,
Status,
> {
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"catalog_name",
DataType::Utf8,
false,
)]));
let catalogs: ArrayRef = Arc::new(StringArray::from(vec!["lancedb"]));
let batch = RecordBatch::try_new(schema, vec![catalogs])
.map_err(|e| Status::internal(format!("RecordBatch error: {}", e)))?;
Ok(Response::new(Self::batch_to_flight_stream(batch)))
}
/// List schemas.
async fn get_flight_info_schemas(
&self,
_query: CommandGetDbSchemas,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
let schema = ArrowSchema::new(vec![
Field::new("catalog_name", DataType::Utf8, true),
Field::new("db_schema_name", DataType::Utf8, false),
]);
let cmd = CommandGetDbSchemas::default();
let any_msg =
Any::pack(&cmd).map_err(|e| Status::internal(format!("Encoding error: {}", e)))?;
let mut ticket_bytes = Vec::new();
any_msg
.encode(&mut ticket_bytes)
.map_err(|e| Status::internal(format!("Encoding error: {}", e)))?;
let endpoint = FlightEndpoint::new().with_ticket(Ticket::new(ticket_bytes));
let flight_info = FlightInfo::new()
.try_with_schema(&schema)
.map_err(|e| Status::internal(format!("Schema error: {}", e)))?
.with_endpoint(endpoint)
.with_descriptor(request.into_inner());
Ok(Response::new(flight_info))
}
async fn do_get_schemas(
&self,
_query: CommandGetDbSchemas,
_request: Request<Ticket>,
) -> Result<
Response<<Self as arrow_flight::flight_service_server::FlightService>::DoGetStream>,
Status,
> {
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("catalog_name", DataType::Utf8, true),
Field::new("db_schema_name", DataType::Utf8, false),
]));
let catalogs: ArrayRef = Arc::new(StringArray::from(vec![Some("lancedb")]));
let schemas: ArrayRef = Arc::new(StringArray::from(vec!["default"]));
let batch = RecordBatch::try_new(schema, vec![catalogs, schemas])
.map_err(|e| Status::internal(format!("RecordBatch error: {}", e)))?;
Ok(Response::new(Self::batch_to_flight_stream(batch)))
}
async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{Float32Array, Int32Array};
use arrow_flight::sql::client::FlightSqlServiceClient;
use futures::TryStreamExt;
use lance_arrow::FixedSizeListArrayExt;
use std::time::Duration;
async fn create_test_db() -> Connection {
let db = crate::connect("memory://flight_test")
.execute()
.await
.unwrap();
let dim = 4i32;
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("text", DataType::Utf8, false),
Field::new(
"vector",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
true,
),
]));
let ids = Int32Array::from(vec![1, 2, 3]);
let texts = StringArray::from(vec!["hello world", "foo bar", "baz qux"]);
let flat_values = Float32Array::from(vec![
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
]);
let vectors =
arrow_array::FixedSizeListArray::try_new_from_values(flat_values, dim).unwrap();
let batch = RecordBatch::try_new(
schema,
vec![Arc::new(ids), Arc::new(texts), Arc::new(vectors)],
)
.unwrap();
let table = db
.create_table("test_table", batch)
.execute()
.await
.unwrap();
// Create FTS index
table
.create_index(&["text"], crate::index::Index::FTS(Default::default()))
.execute()
.await
.unwrap();
db
}
#[tokio::test]
async fn test_flight_sql_basic_query() {
let db = create_test_db().await;
let service = LanceFlightSqlService::try_new(db).await.unwrap();
let flight_svc = FlightServiceServer::new(service);
// Start server on a random port
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_handle = tokio::spawn(async move {
Server::builder()
.add_service(flight_svc)
.serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener))
.await
.unwrap();
});
// Give server a moment to start
tokio::time::sleep(Duration::from_millis(100)).await;
// Connect client
let channel = tonic::transport::Channel::from_shared(format!("http://{}", addr))
.unwrap()
.connect()
.await
.unwrap();
let mut client = FlightSqlServiceClient::new(channel);
// Execute SQL query
let flight_info = client
.execute("SELECT id, text FROM test_table".to_string(), None)
.await
.unwrap();
// Fetch results using the FlightSql client's do_get
let ticket = flight_info.endpoint[0].ticket.as_ref().unwrap().clone();
let flight_stream = client.do_get(ticket).await.unwrap();
let batches: Vec<RecordBatch> = flight_stream.try_collect().await.unwrap();
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 3, "Should return all 3 rows");
// Verify schema
if let Some(batch) = batches.first() {
assert!(batch.schema().column_with_name("id").is_some());
assert!(batch.schema().column_with_name("text").is_some());
}
server_handle.abort();
}
#[tokio::test]
async fn test_flight_sql_table_listing() {
let db = create_test_db().await;
let service = LanceFlightSqlService::try_new(db).await.unwrap();
assert!(service.tables.contains_key("test_table"));
assert_eq!(service.tables.len(), 1);
}
#[tokio::test]
async fn test_flight_sql_session_context() {
let db = create_test_db().await;
let service = LanceFlightSqlService::try_new(db).await.unwrap();
let ctx = service.create_session_context();
// Test that we can execute a simple query
let df = ctx.sql("SELECT * FROM test_table LIMIT 1").await.unwrap();
let results = df.collect().await.unwrap();
assert_eq!(results[0].num_rows(), 1);
}
}

View File

@@ -170,6 +170,8 @@ pub mod dataloader;
pub mod embeddings;
pub mod error;
pub mod expr;
#[cfg(feature = "flight")]
pub mod flight;
pub mod index;
pub mod io;
pub mod ipc;

View File

@@ -26,9 +26,10 @@ use lance::dataset::{WriteMode, WriteParams};
use super::{AnyQuery, BaseTable};
use crate::{
Result,
query::{QueryExecutionOptions, QueryFilter, QueryRequest, Select},
DistanceType, Result,
query::{QueryExecutionOptions, QueryFilter, QueryRequest, Select, VectorQueryRequest},
};
use arrow_array::Array;
use arrow_schema::{DataType, Field};
use lance_index::scalar::FullTextSearchQuery;
@@ -141,11 +142,31 @@ impl ExecutionPlan for MetadataEraserExec {
}
}
/// Parameters for a vector search query, used by vector_search and hybrid_search UDTFs.
#[derive(Debug, Clone)]
pub struct VectorSearchParams {
/// The query vector to search for
pub query_vector: Arc<dyn Array>,
/// The column to search on (None for auto-detection)
pub column: Option<String>,
/// Number of results to return
pub top_k: usize,
/// Distance metric to use
pub distance_type: Option<DistanceType>,
/// Number of IVF partitions to search
pub nprobes: Option<usize>,
/// HNSW search parameter
pub ef: Option<usize>,
/// Refine factor for improving recall
pub refine_factor: Option<u32>,
}
#[derive(Debug)]
pub struct BaseTableAdapter {
table: Arc<dyn BaseTable>,
schema: Arc<ArrowSchema>,
fts_query: Option<FullTextSearchQuery>,
vector_query: Option<VectorSearchParams>,
}
impl BaseTableAdapter {
@@ -161,6 +182,7 @@ impl BaseTableAdapter {
table,
schema: Arc::new(schema),
fts_query: None,
vector_query: None,
})
}
@@ -176,6 +198,49 @@ impl BaseTableAdapter {
table: self.table.clone(),
schema,
fts_query: Some(fts_query),
vector_query: self.vector_query.clone(),
}
}
/// Create a new adapter with a vector search query applied.
pub fn with_vector_query(&self, vector_query: VectorSearchParams) -> Self {
// Add _distance column to the schema
let distance_field = Field::new("_distance", DataType::Float32, true);
let mut fields = self.schema.fields().to_vec();
fields.push(Arc::new(distance_field));
let schema = Arc::new(ArrowSchema::new(fields));
Self {
table: self.table.clone(),
schema,
fts_query: self.fts_query.clone(),
vector_query: Some(vector_query),
}
}
/// Create a new adapter with both FTS and vector search queries (hybrid search).
///
/// Uses vector search as the primary retrieval method, with FTS applied as a
/// pre-filter to restrict the candidate set. Both `_distance` and `_score`
/// columns are added to results.
pub fn with_hybrid_query(
&self,
fts_query: FullTextSearchQuery,
vector_query: VectorSearchParams,
) -> Self {
// Add _distance column (vector search is primary)
let mut fields = self.schema.fields().to_vec();
fields.push(Arc::new(Field::new("_distance", DataType::Float32, true)));
let schema = Arc::new(ArrowSchema::new(fields));
// Store FTS as a filter concept, but vector search drives the query.
// The FTS query is applied via the base QueryRequest's full_text_search
// field, which acts as a pre-filter for vector search.
Self {
table: self.table.clone(),
schema,
fts_query: Some(fts_query),
vector_query: Some(vector_query),
}
}
}
@@ -201,11 +266,20 @@ impl TableProvider for BaseTableAdapter {
filters: &[Expr],
limit: Option<usize>,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
// For FTS queries, disable auto-projection of _score to match DataFusion expectations
let disable_scoring = self.fts_query.is_some() && projection.is_some();
let has_scoring = self.fts_query.is_some() || self.vector_query.is_some();
let disable_scoring = has_scoring && projection.is_some();
let mut query = QueryRequest {
full_text_search: self.fts_query.clone(),
// When doing vector search, FTS cannot be combined in the same scanner
// (Lance doesn't support both nearest + full_text_search simultaneously).
// FTS is only set when there's no vector query.
let fts_for_query = if self.vector_query.is_some() {
None
} else {
self.fts_query.clone()
};
let mut base_query = QueryRequest {
full_text_search: fts_for_query,
disable_scoring_autoprojection: disable_scoring,
..Default::default()
};
@@ -215,20 +289,20 @@ impl TableProvider for BaseTableAdapter {
.iter()
.map(|i| self.schema.field(*i).name().clone())
.collect();
query.select = Select::Columns(field_names);
base_query.select = Select::Columns(field_names);
}
if !filters.is_empty() {
let first = filters.first().unwrap().clone();
let filter = filters[1..]
.iter()
.fold(first, |acc, expr| acc.and(expr.clone()));
query.filter = Some(QueryFilter::Datafusion(filter));
base_query.filter = Some(QueryFilter::Datafusion(filter));
}
if let Some(limit) = limit {
query.limit = Some(limit);
} else {
// Need to override the default of 10
query.limit = None;
base_query.limit = Some(limit);
} else if self.vector_query.is_none() {
// Need to override the default of 10 for non-vector queries
base_query.limit = None;
}
let options = QueryExecutionOptions {
@@ -236,9 +310,33 @@ impl TableProvider for BaseTableAdapter {
..Default::default()
};
// Build the appropriate query type
let any_query = if let Some(ref vq) = self.vector_query {
let vector_query = VectorQueryRequest {
base: base_query,
column: vq.column.clone(),
query_vector: vec![vq.query_vector.clone()],
minimum_nprobes: vq.nprobes.unwrap_or(20),
maximum_nprobes: vq.nprobes,
ef: vq.ef,
refine_factor: vq.refine_factor,
distance_type: vq.distance_type,
use_index: true,
..Default::default()
};
// For vector queries, use top_k as the limit if no explicit limit set
let mut vq_req = vector_query;
if limit.is_none() {
vq_req.base.limit = Some(vq.top_k);
}
AnyQuery::VectorQuery(vq_req)
} else {
AnyQuery::Query(base_query)
};
let plan = self
.table
.create_plan(&AnyQuery::Query(query), options)
.create_plan(&any_query, options)
.map_err(|err| DataFusionError::External(err.into()))
.await?;
Ok(Arc::new(MetadataEraserExec::new(plan)))

View File

@@ -2,5 +2,75 @@
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! User-Defined Table Functions (UDTFs) for DataFusion integration
//!
//! This module provides SQL table functions for LanceDB search capabilities:
//! - `fts(table_name, query_json)` — full-text search
//! - `vector_search(table_name, query_vector_json, top_k)` — vector similarity search
//! - `hybrid_search(table_name, query_vector_json, fts_query_json, top_k)` — combined search
pub mod fts;
pub mod hybrid_search;
pub mod vector_search;
use std::sync::Arc;
use datafusion_catalog::TableProvider;
use datafusion_common::{DataFusionError, Result as DataFusionResult, ScalarValue};
use datafusion_expr::Expr;
use lance_index::scalar::FullTextSearchQuery;
use super::VectorSearchParams;
/// Describes the type of search to apply when resolving a table.
#[derive(Debug, Clone)]
pub enum SearchQuery {
/// Full-text search only
Fts(FullTextSearchQuery),
/// Vector similarity search only
Vector(VectorSearchParams),
/// Hybrid search combining FTS and vector search
Hybrid {
fts: FullTextSearchQuery,
vector: VectorSearchParams,
},
}
/// Trait for resolving table names to TableProvider instances, optionally with a search query.
pub trait TableResolver: std::fmt::Debug + Send + Sync {
/// Resolve a table name to a TableProvider, optionally applying a search query.
fn resolve_table(
&self,
name: &str,
search: Option<SearchQuery>,
) -> DataFusionResult<Arc<dyn TableProvider>>;
}
/// Extract a string literal from a DataFusion expression.
pub(crate) fn extract_string_literal(expr: &Expr, param_name: &str) -> DataFusionResult<String> {
match expr {
Expr::Literal(ScalarValue::Utf8(Some(s)), _) => Ok(s.clone()),
Expr::Literal(ScalarValue::LargeUtf8(Some(s)), _) => Ok(s.clone()),
_ => Err(DataFusionError::Plan(format!(
"Parameter '{}' must be a string literal, got: {:?}",
param_name, expr
))),
}
}
/// Extract an integer literal from a DataFusion expression.
pub(crate) fn extract_int_literal(expr: &Expr, param_name: &str) -> DataFusionResult<usize> {
match expr {
Expr::Literal(ScalarValue::Int8(Some(v)), _) => Ok(*v as usize),
Expr::Literal(ScalarValue::Int16(Some(v)), _) => Ok(*v as usize),
Expr::Literal(ScalarValue::Int32(Some(v)), _) => Ok(*v as usize),
Expr::Literal(ScalarValue::Int64(Some(v)), _) => Ok(*v as usize),
Expr::Literal(ScalarValue::UInt8(Some(v)), _) => Ok(*v as usize),
Expr::Literal(ScalarValue::UInt16(Some(v)), _) => Ok(*v as usize),
Expr::Literal(ScalarValue::UInt32(Some(v)), _) => Ok(*v as usize),
Expr::Literal(ScalarValue::UInt64(Some(v)), _) => Ok(*v as usize),
_ => Err(DataFusionError::Plan(format!(
"Parameter '{}' must be an integer literal, got: {:?}",
param_name, expr
))),
}
}

View File

@@ -1,29 +1,23 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! User-Defined Table Functions (UDTFs) for LanceDB
//! Full-Text Search (FTS) table function for DataFusion SQL integration.
//!
//! This module provides table-level UDTFs that integrate with DataFusion's SQL engine.
//! Usage: `SELECT * FROM fts('table_name', '{"match": {"column": "text", "terms": "query"}}')`
use std::sync::Arc;
use datafusion::catalog::TableFunctionImpl;
use datafusion_catalog::TableProvider;
use datafusion_common::{DataFusionError, Result as DataFusionResult, ScalarValue, plan_err};
use datafusion_common::{DataFusionError, Result as DataFusionResult, plan_err};
use datafusion_expr::Expr;
use lance_index::scalar::FullTextSearchQuery;
/// Trait for resolving table names to TableProvider instances.
pub trait TableResolver: std::fmt::Debug + Send + Sync {
/// Resolve a table name to a TableProvider, optionally with an FTS query applied.
fn resolve_table(
&self,
name: &str,
fts_query: Option<FullTextSearchQuery>,
) -> DataFusionResult<Arc<dyn TableProvider>>;
}
use super::{SearchQuery, TableResolver, extract_string_literal};
/// Full-Text Search table function that operates on LanceDB tables
/// Full-Text Search table function that operates on LanceDB tables.
///
/// Accepts 2 parameters: `fts(table_name, fts_query_json)`
#[derive(Debug)]
pub struct FtsTableFunction {
resolver: Arc<dyn TableResolver>,
@@ -45,20 +39,8 @@ impl TableFunctionImpl for FtsTableFunction {
let query_json = extract_string_literal(&exprs[1], "fts_query")?;
let fts_query = parse_fts_query(&query_json)?;
// Resolver returns a ready-to-use TableProvider with FTS applied
self.resolver.resolve_table(&table_name, Some(fts_query))
}
}
fn extract_string_literal(expr: &Expr, param_name: &str) -> DataFusionResult<String> {
match expr {
Expr::Literal(ScalarValue::Utf8(Some(s)), _) => Ok(s.clone()),
Expr::Literal(ScalarValue::LargeUtf8(Some(s)), _) => Ok(s.clone()),
_ => plan_err!(
"Parameter '{}' must be a string literal, got: {:?}",
param_name,
expr
),
self.resolver
.resolve_table(&table_name, Some(SearchQuery::Fts(fts_query)))
}
}
@@ -91,6 +73,7 @@ pub fn from_json(json: &str) -> crate::Result<lance_index::scalar::inverted::que
#[cfg(test)]
mod tests {
use super::super::{SearchQuery, TableResolver};
use super::*;
use crate::{
Connection, Table,
@@ -100,6 +83,7 @@ mod tests {
use arrow_array::{Int32Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use datafusion::prelude::SessionContext;
use datafusion_common::DataFusionError;
/// Resolver that looks up tables in a HashMap
#[derive(Debug)]
@@ -123,7 +107,7 @@ mod tests {
fn resolve_table(
&self,
name: &str,
fts_query: Option<FullTextSearchQuery>,
search: Option<SearchQuery>,
) -> DataFusionResult<Arc<dyn TableProvider>> {
let table_provider = self
.tables
@@ -131,12 +115,10 @@ mod tests {
.cloned()
.ok_or_else(|| DataFusionError::Plan(format!("Table '{}' not found", name)))?;
// If no FTS query, return as-is
let Some(fts_query) = fts_query else {
let Some(search) = search else {
return Ok(table_provider);
};
// Downcast to BaseTableAdapter and apply FTS query
let base_adapter = table_provider
.as_any()
.downcast_ref::<BaseTableAdapter>()
@@ -146,7 +128,15 @@ mod tests {
)
})?;
Ok(Arc::new(base_adapter.with_fts_query(fts_query)))
match search {
SearchQuery::Fts(fts_query) => Ok(Arc::new(base_adapter.with_fts_query(fts_query))),
SearchQuery::Vector(vector_query) => {
Ok(Arc::new(base_adapter.with_vector_query(vector_query)))
}
SearchQuery::Hybrid { fts, vector } => {
Ok(Arc::new(base_adapter.with_hybrid_query(fts, vector)))
}
}
}
}

View File

@@ -0,0 +1,257 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! Hybrid search table function for DataFusion SQL integration.
//!
//! Combines vector similarity search with full-text search:
//! ```sql
//! SELECT * FROM hybrid_search(
//! 'my_table',
//! '[0.1, 0.2, 0.3]',
//! '{"match": {"column": "text", "terms": "search query"}}',
//! 10
//! )
//! ```
use std::sync::Arc;
use arrow_array::Array;
use datafusion::catalog::TableFunctionImpl;
use datafusion_catalog::TableProvider;
use datafusion_common::{DataFusionError, Result as DataFusionResult, plan_err};
use datafusion_expr::Expr;
use lance_index::scalar::FullTextSearchQuery;
use super::fts::from_json as fts_from_json;
use super::{SearchQuery, TableResolver, extract_int_literal, extract_string_literal};
use crate::table::datafusion::VectorSearchParams;
/// Default number of results for hybrid search when top_k is not specified.
const DEFAULT_TOP_K: usize = 10;
/// Hybrid search table function combining vector and full-text search.
///
/// Accepts 3-4 parameters: `hybrid_search(table_name, query_vector_json, fts_query_json [, top_k])`
///
/// - `table_name`: Name of the table to search
/// - `query_vector_json`: JSON array of float values, e.g. `'[0.1, 0.2, 0.3]'`
/// - `fts_query_json`: FTS query as JSON, e.g. `'{"match": {"column": "text", "terms": "query"}}'`
/// - `top_k` (optional): Number of results to return (default: 10)
#[derive(Debug)]
pub struct HybridSearchTableFunction {
resolver: Arc<dyn TableResolver>,
}
impl HybridSearchTableFunction {
pub fn new(resolver: Arc<dyn TableResolver>) -> Self {
Self { resolver }
}
}
impl TableFunctionImpl for HybridSearchTableFunction {
fn call(&self, exprs: &[Expr]) -> DataFusionResult<Arc<dyn TableProvider>> {
if exprs.len() < 3 || exprs.len() > 4 {
return plan_err!(
"hybrid_search() requires 3-4 parameters: hybrid_search(table_name, query_vector_json, fts_query_json [, top_k])"
);
}
let table_name = extract_string_literal(&exprs[0], "table_name")?;
let vector_json = extract_string_literal(&exprs[1], "query_vector_json")?;
let fts_json = extract_string_literal(&exprs[2], "fts_query_json")?;
let top_k = if exprs.len() == 4 {
extract_int_literal(&exprs[3], "top_k")?
} else {
DEFAULT_TOP_K
};
let query_vector = parse_vector_json(&vector_json)?;
let fts_query = parse_fts_query(&fts_json)?;
let vector_params = VectorSearchParams {
query_vector: query_vector as Arc<dyn Array>,
column: None,
top_k,
distance_type: None,
nprobes: None,
ef: None,
refine_factor: None,
};
self.resolver.resolve_table(
&table_name,
Some(SearchQuery::Hybrid {
fts: fts_query,
vector: vector_params,
}),
)
}
}
fn parse_vector_json(json: &str) -> DataFusionResult<Arc<arrow_array::Float32Array>> {
super::vector_search::parse_vector_json(json)
}
fn parse_fts_query(json: &str) -> DataFusionResult<FullTextSearchQuery> {
let query = fts_from_json(json).map_err(|e| {
DataFusionError::Plan(format!(
"Invalid FTS query JSON: {}. Expected format: {{\"match\": {{\"column\": \"text\", \"terms\": \"query\"}} }}",
e
))
})?;
Ok(FullTextSearchQuery::new_query(query))
}
#[cfg(test)]
mod tests {
use super::super::fts::to_json;
use super::super::{SearchQuery, TableResolver};
use super::*;
use crate::{index::Index, table::datafusion::BaseTableAdapter};
use arrow_array::FixedSizeListArray;
use arrow_array::{Float32Array, Int32Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use datafusion::prelude::SessionContext;
#[allow(unused_imports)]
use lance_arrow::FixedSizeListArrayExt;
#[derive(Debug)]
struct HashMapTableResolver {
tables: std::collections::HashMap<String, Arc<dyn TableProvider>>,
}
impl HashMapTableResolver {
fn new() -> Self {
Self {
tables: std::collections::HashMap::new(),
}
}
fn register(&mut self, name: String, table: Arc<dyn TableProvider>) {
self.tables.insert(name, table);
}
}
impl TableResolver for HashMapTableResolver {
fn resolve_table(
&self,
name: &str,
search: Option<SearchQuery>,
) -> DataFusionResult<Arc<dyn TableProvider>> {
let table_provider = self
.tables
.get(name)
.cloned()
.ok_or_else(|| DataFusionError::Plan(format!("Table '{}' not found", name)))?;
let Some(search) = search else {
return Ok(table_provider);
};
let base_adapter = table_provider
.as_any()
.downcast_ref::<BaseTableAdapter>()
.ok_or_else(|| {
DataFusionError::Internal(
"Expected BaseTableAdapter but got different type".to_string(),
)
})?;
match search {
SearchQuery::Fts(fts_query) => Ok(Arc::new(base_adapter.with_fts_query(fts_query))),
SearchQuery::Vector(vector_query) => {
Ok(Arc::new(base_adapter.with_vector_query(vector_query)))
}
SearchQuery::Hybrid { fts, vector } => {
Ok(Arc::new(base_adapter.with_hybrid_query(fts, vector)))
}
}
}
}
#[tokio::test]
async fn test_hybrid_search_udtf() {
let dim = 4i32;
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("text", DataType::Utf8, false),
Field::new(
"vector",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
true,
),
]));
let ids = Int32Array::from(vec![1, 2, 3, 4, 5]);
let texts = StringArray::from(vec![
"the quick brown fox",
"jumps over the lazy dog",
"a quick red fox runs",
"the dog sleeps all day",
"a brown fox and a quick dog",
]);
let flat_values = Float32Array::from(vec![
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.9, 0.1, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.5,
0.5, 0.0, 0.0,
]);
let vector_array = FixedSizeListArray::try_new_from_values(flat_values, dim).unwrap();
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(ids), Arc::new(texts), Arc::new(vector_array)],
)
.unwrap();
let db = crate::connect("memory://test_hybrid")
.execute()
.await
.unwrap();
let table = db.create_table("docs", batch).execute().await.unwrap();
// Create FTS index on text column
table
.create_index(&["text"], Index::FTS(Default::default()))
.execute()
.await
.unwrap();
let ctx = SessionContext::new();
let mut resolver = HashMapTableResolver::new();
let adapter = BaseTableAdapter::try_new(table.base_table().clone())
.await
.unwrap();
resolver.register("docs".to_string(), Arc::new(adapter));
let resolver = Arc::new(resolver);
ctx.register_udtf(
"hybrid_search",
Arc::new(HybridSearchTableFunction::new(resolver.clone())),
);
// Run hybrid search: vector close to [1,0,0,0] AND FTS for "fox"
use lance_index::scalar::inverted::query::*;
let fts_query_struct = FtsQuery::Match(
MatchQuery::new("fox".to_string()).with_column(Some("text".to_string())),
);
let fts_json = to_json(&fts_query_struct).unwrap();
let query = format!(
"SELECT * FROM hybrid_search('docs', '[1.0, 0.0, 0.0, 0.0]', '{}', 5)",
fts_json
);
let df = ctx.sql(&query).await.unwrap();
let results = df.collect().await.unwrap();
assert!(!results.is_empty());
let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
assert!(total_rows > 0, "Should have at least one result");
// Check schema has the expected columns
let result_schema = results[0].schema();
assert!(result_schema.column_with_name("id").is_some());
assert!(result_schema.column_with_name("text").is_some());
assert!(result_schema.column_with_name("vector").is_some());
}
}

View File

@@ -0,0 +1,278 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
//! Vector search table function for DataFusion SQL integration.
//!
//! Enables vector similarity search via SQL:
//! ```sql
//! SELECT * FROM vector_search('my_table', '[0.1, 0.2, 0.3, ...]', 10)
//! ```
use std::sync::Arc;
use arrow_array::{Array, Float32Array};
use datafusion::catalog::TableFunctionImpl;
use datafusion_catalog::TableProvider;
use datafusion_common::{DataFusionError, Result as DataFusionResult, plan_err};
use datafusion_expr::Expr;
use super::{SearchQuery, TableResolver, extract_int_literal, extract_string_literal};
use crate::table::datafusion::VectorSearchParams;
/// Default number of results for vector search when top_k is not specified.
const DEFAULT_TOP_K: usize = 10;
/// Vector search table function for LanceDB tables.
///
/// Accepts 2-3 parameters: `vector_search(table_name, query_vector_json [, top_k])`
///
/// - `table_name`: Name of the table to search
/// - `query_vector_json`: JSON array of float values, e.g. `'[0.1, 0.2, 0.3]'`
/// - `top_k` (optional): Number of results to return (default: 10)
#[derive(Debug)]
pub struct VectorSearchTableFunction {
resolver: Arc<dyn TableResolver>,
}
impl VectorSearchTableFunction {
pub fn new(resolver: Arc<dyn TableResolver>) -> Self {
Self { resolver }
}
}
impl TableFunctionImpl for VectorSearchTableFunction {
fn call(&self, exprs: &[Expr]) -> DataFusionResult<Arc<dyn TableProvider>> {
if exprs.len() < 2 || exprs.len() > 3 {
return plan_err!(
"vector_search() requires 2-3 parameters: vector_search(table_name, query_vector_json [, top_k])"
);
}
let table_name = extract_string_literal(&exprs[0], "table_name")?;
let vector_json = extract_string_literal(&exprs[1], "query_vector_json")?;
let top_k = if exprs.len() == 3 {
extract_int_literal(&exprs[2], "top_k")?
} else {
DEFAULT_TOP_K
};
let query_vector = parse_vector_json(&vector_json)?;
let params = VectorSearchParams {
query_vector: query_vector as Arc<dyn Array>,
column: None,
top_k,
distance_type: None,
nprobes: None,
ef: None,
refine_factor: None,
};
self.resolver
.resolve_table(&table_name, Some(SearchQuery::Vector(params)))
}
}
/// Parse a JSON array of floats into an Arrow Float32Array for vector search.
///
/// Input format: `"[0.1, 0.2, 0.3, ...]"`
///
/// Returns a Float32Array whose length equals the vector dimension.
/// This is the format expected by LanceDB's vector search internals.
pub(crate) fn parse_vector_json(json: &str) -> DataFusionResult<Arc<Float32Array>> {
let values: Vec<f32> = serde_json::from_str(json).map_err(|e| {
DataFusionError::Plan(format!(
"Invalid vector JSON: {}. Expected format: [0.1, 0.2, 0.3, ...]",
e
))
})?;
if values.is_empty() {
return Err(DataFusionError::Plan(
"Vector must not be empty".to_string(),
));
}
Ok(Arc::new(Float32Array::from(values)))
}
#[cfg(test)]
mod tests {
use super::super::{SearchQuery, TableResolver};
use super::*;
use crate::table::datafusion::BaseTableAdapter;
use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use datafusion::prelude::SessionContext;
#[allow(unused_imports)]
use lance_arrow::FixedSizeListArrayExt;
/// Resolver that looks up tables in a HashMap and applies search queries
#[derive(Debug)]
struct HashMapTableResolver {
tables: std::collections::HashMap<String, Arc<dyn TableProvider>>,
}
impl HashMapTableResolver {
fn new() -> Self {
Self {
tables: std::collections::HashMap::new(),
}
}
fn register(&mut self, name: String, table: Arc<dyn TableProvider>) {
self.tables.insert(name, table);
}
}
impl TableResolver for HashMapTableResolver {
fn resolve_table(
&self,
name: &str,
search: Option<SearchQuery>,
) -> DataFusionResult<Arc<dyn TableProvider>> {
let table_provider = self
.tables
.get(name)
.cloned()
.ok_or_else(|| DataFusionError::Plan(format!("Table '{}' not found", name)))?;
let Some(search) = search else {
return Ok(table_provider);
};
let base_adapter = table_provider
.as_any()
.downcast_ref::<BaseTableAdapter>()
.ok_or_else(|| {
DataFusionError::Internal(
"Expected BaseTableAdapter but got different type".to_string(),
)
})?;
match search {
SearchQuery::Fts(fts_query) => Ok(Arc::new(base_adapter.with_fts_query(fts_query))),
SearchQuery::Vector(vector_query) => {
Ok(Arc::new(base_adapter.with_vector_query(vector_query)))
}
SearchQuery::Hybrid { fts, vector } => {
Ok(Arc::new(base_adapter.with_hybrid_query(fts, vector)))
}
}
}
}
fn make_test_data() -> (Arc<ArrowSchema>, RecordBatch) {
let dim = 4;
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new(
"vector",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim),
true,
),
]));
let ids = Int32Array::from(vec![1, 2, 3, 4, 5]);
// Create vectors: [1,0,0,0], [0,1,0,0], [0,0,1,0], [0,0,0,1], [1,1,0,0]
let flat_values = Float32Array::from(vec![
1.0, 0.0, 0.0, 0.0, // vec 1
0.0, 1.0, 0.0, 0.0, // vec 2
0.0, 0.0, 1.0, 0.0, // vec 3
0.0, 0.0, 0.0, 1.0, // vec 4
1.0, 1.0, 0.0, 0.0, // vec 5
]);
let vector_array = FixedSizeListArray::try_new_from_values(flat_values, dim).unwrap();
let batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(vector_array)])
.unwrap();
(schema, batch)
}
#[tokio::test]
async fn test_vector_search_udtf() {
let (_schema, batch) = make_test_data();
let db = crate::connect("memory://test_vec").execute().await.unwrap();
let table = db.create_table("vectors", batch).execute().await.unwrap();
// No index needed — vector search works with brute-force scan on small tables
// Setup DataFusion context
let ctx = SessionContext::new();
let mut resolver = HashMapTableResolver::new();
let adapter = BaseTableAdapter::try_new(table.base_table().clone())
.await
.unwrap();
resolver.register("vectors".to_string(), Arc::new(adapter));
let udtf = VectorSearchTableFunction::new(Arc::new(resolver));
ctx.register_udtf("vector_search", Arc::new(udtf));
// Search for vectors close to [1, 0, 0, 0]
let query = "SELECT * FROM vector_search('vectors', '[1.0, 0.0, 0.0, 0.0]', 3)";
let df = ctx.sql(query).await.unwrap();
let results = df.collect().await.unwrap();
assert!(!results.is_empty());
let batch = &results[0];
// Should have id, vector, _distance columns
assert!(batch.schema().column_with_name("id").is_some());
assert!(batch.schema().column_with_name("vector").is_some());
assert!(
batch.schema().column_with_name("_distance").is_some(),
"_distance column should be present"
);
// Should return at most 3 results
let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
assert!(total_rows <= 3);
assert!(total_rows > 0);
}
#[tokio::test]
async fn test_vector_search_default_top_k() {
let (_, batch) = make_test_data();
let db = crate::connect("memory://test_vec_default")
.execute()
.await
.unwrap();
let table = db.create_table("vectors", batch).execute().await.unwrap();
let ctx = SessionContext::new();
let mut resolver = HashMapTableResolver::new();
let adapter = BaseTableAdapter::try_new(table.base_table().clone())
.await
.unwrap();
resolver.register("vectors".to_string(), Arc::new(adapter));
let udtf = VectorSearchTableFunction::new(Arc::new(resolver));
ctx.register_udtf("vector_search", Arc::new(udtf));
// No top_k parameter — should default to 10
let query = "SELECT * FROM vector_search('vectors', '[1.0, 0.0, 0.0, 0.0]')";
let df = ctx.sql(query).await.unwrap();
let results = df.collect().await.unwrap();
let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
// We only have 5 rows, so we should get all 5 back
assert_eq!(total_rows, 5);
}
#[test]
fn test_parse_vector_json() {
let result = parse_vector_json("[1.0, 2.0, 3.0]").unwrap();
assert_eq!(result.len(), 3); // 3-dimensional vector
// Empty vector should fail
assert!(parse_vector_json("[]").is_err());
// Invalid JSON should fail
assert!(parse_vector_json("not json").is_err());
}
}