From 1778219ea97edfa78d5ea06aaceb3d10ac682070 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 27 Sep 2024 09:00:22 -0700 Subject: [PATCH] feat(rust): remote client `query` and `create_index` endpoints (#1663) Support for `query` and `create_index`. Closes [#2519](https://github.com/lancedb/sophon/issues/2519) --- Cargo.toml | 1 + rust/lancedb/Cargo.toml | 1 + rust/lancedb/src/lib.rs | 6 + rust/lancedb/src/remote.rs | 1 + rust/lancedb/src/remote/table.rs | 487 ++++++++++++++++++++++++++++--- rust/lancedb/src/table.rs | 118 ++------ rust/lancedb/src/utils.rs | 40 ++- 7 files changed, 532 insertions(+), 122 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3c563e7a..95a00f40 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,6 +38,7 @@ arrow-arith = "52.2" arrow-cast = "52.2" async-trait = "0" chrono = "0.4.35" +datafusion-common = "40.0" datafusion-physical-plan = "40.0" half = { "version" = "=2.4.1", default-features = false, features = [ "num-traits", diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index 08413dfc..2eee4227 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -19,6 +19,7 @@ arrow-ord = { workspace = true } arrow-cast = { workspace = true } arrow-ipc.workspace = true chrono = { workspace = true } +datafusion-common.workspace = true datafusion-physical-plan.workspace = true object_store = { workspace = true } snafu = { workspace = true } diff --git a/rust/lancedb/src/lib.rs b/rust/lancedb/src/lib.rs index c552c30b..52771fa5 100644 --- a/rust/lancedb/src/lib.rs +++ b/rust/lancedb/src/lib.rs @@ -254,6 +254,12 @@ pub enum DistanceType { Hamming, } +impl Default for DistanceType { + fn default() -> Self { + Self::L2 + } +} + impl From for LanceDistanceType { fn from(value: DistanceType) -> Self { match value { diff --git a/rust/lancedb/src/remote.rs b/rust/lancedb/src/remote.rs index ce00a370..2ef92b55 100644 --- a/rust/lancedb/src/remote.rs +++ b/rust/lancedb/src/remote.rs @@ -23,3 +23,4 @@ pub mod table; pub mod util; const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream"; +const JSON_CONTENT_TYPE: &str = "application/json"; diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 88abfc38..7adea431 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -1,17 +1,25 @@ use std::sync::{Arc, Mutex}; -use crate::table::dataset::DatasetReadGuard; +use crate::index::Index; +use crate::query::Select; use crate::table::AddDataMode; +use crate::utils::{supported_btree_data_type, supported_vector_data_type}; use crate::Error; use arrow_array::RecordBatchReader; -use arrow_schema::SchemaRef; +use arrow_ipc::reader::StreamReader; +use arrow_schema::{DataType, SchemaRef}; use async_trait::async_trait; -use datafusion_physical_plan::ExecutionPlan; +use bytes::Buf; +use datafusion_common::DataFusionError; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; +use datafusion_physical_plan::{ExecutionPlan, SendableRecordBatchStream}; +use futures::TryStreamExt; use http::header::CONTENT_TYPE; use http::StatusCode; use lance::arrow::json::JsonSchema; -use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner}; +use lance::dataset::scanner::DatasetRecordBatchStream; use lance::dataset::{ColumnAlteration, NewColumnTransform}; +use lance_datafusion::exec::OneShotExec; use serde::{Deserialize, Serialize}; use crate::{ @@ -26,7 +34,7 @@ use crate::{ }; use super::client::{HttpSend, RestfulLanceDbClient, Sender}; -use super::ARROW_STREAM_CONTENT_TYPE; +use super::{ARROW_STREAM_CONTENT_TYPE, JSON_CONTENT_TYPE}; #[derive(Debug)] pub struct RemoteTable { @@ -85,6 +93,93 @@ impl RemoteTable { self.client.check_response(response).await } + + async fn read_arrow_stream( + &self, + body: reqwest::Response, + ) -> Result { + // Assert that the content type is correct + let content_type = body + .headers() + .get(CONTENT_TYPE) + .ok_or_else(|| Error::Http { + message: "Missing content type".into(), + })? + .to_str() + .map_err(|e| Error::Http { + message: format!("Failed to parse content type: {}", e), + })?; + if content_type != ARROW_STREAM_CONTENT_TYPE { + return Err(Error::Http { + message: format!( + "Expected content type {}, got {}", + ARROW_STREAM_CONTENT_TYPE, content_type + ), + }); + } + + // There isn't a way to actually stream this data yet. I have an upstream issue: + // https://github.com/apache/arrow-rs/issues/6420 + let body = body.bytes().await?; + let reader = StreamReader::try_new(body.reader(), None)?; + let schema = reader.schema(); + let stream = futures::stream::iter(reader).map_err(DataFusionError::from); + Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) + } + + fn apply_query_params(body: &mut serde_json::Value, params: &Query) -> Result<()> { + if params.offset.is_some() { + return Err(Error::NotSupported { + message: "Offset is not yet supported in LanceDB Cloud".into(), + }); + } + + if let Some(limit) = params.limit { + body["k"] = serde_json::Value::Number(serde_json::Number::from(limit)); + } + + if let Some(filter) = ¶ms.filter { + body["filter"] = serde_json::Value::String(filter.clone()); + } + + match ¶ms.select { + Select::All => {} + Select::Columns(columns) => { + body["columns"] = serde_json::Value::Array( + columns + .iter() + .map(|s| serde_json::Value::String(s.clone())) + .collect(), + ); + } + Select::Dynamic(pairs) => { + body["columns"] = serde_json::Value::Array( + pairs + .iter() + .map(|(name, expr)| serde_json::json!([name, expr])) + .collect(), + ); + } + } + + if params.fast_search { + body["fast_search"] = serde_json::Value::Bool(true); + } + + if let Some(full_text_search) = ¶ms.full_text_search { + if full_text_search.wand_factor.is_some() { + return Err(Error::NotSupported { + message: "Wand factor is not yet supported in LanceDB Cloud".into(), + }); + } + body["full_text_query"] = serde_json::json!({ + "columns": full_text_search.columns, + "query": full_text_search.query, + }) + } + + Ok(()) + } } #[derive(Deserialize)] @@ -196,38 +291,78 @@ impl TableInternal for RemoteTable { Ok(()) } - async fn build_plan( - &self, - _ds_ref: &DatasetReadGuard, - _query: &VectorQuery, - _options: Option, - ) -> Result { - Err(Error::NotSupported { - message: "build_plan is not supported on LanceDB cloud.".into(), - }) - } + async fn create_plan( &self, - _query: &VectorQuery, + query: &VectorQuery, _options: QueryExecutionOptions, ) -> Result> { - Err(Error::NotSupported { - message: "create_plan is not supported on LanceDB cloud.".into(), - }) - } - async fn explain_plan(&self, _query: &VectorQuery, _verbose: bool) -> Result { - Err(Error::NotSupported { - message: "explain_plan is not supported on LanceDB cloud.".into(), - }) + let request = self.client.post(&format!("/table/{}/query/", self.name)); + + let mut body = serde_json::Value::Object(Default::default()); + Self::apply_query_params(&mut body, &query.base)?; + + body["prefilter"] = query.prefilter.into(); + body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default()); + body["nprobes"] = query.nprobes.into(); + body["refine_factor"] = query.refine_factor.into(); + + if let Some(vector) = query.query_vector.as_ref() { + let vector: Vec = match vector.data_type() { + DataType::Float32 => vector + .as_any() + .downcast_ref::() + .unwrap() + .values() + .iter() + .cloned() + .collect(), + _ => { + return Err(Error::InvalidInput { + message: "VectorQuery vector must be of type Float32".into(), + }) + } + }; + body["vector"] = serde_json::json!(vector); + } + + if let Some(vector_column) = query.column.as_ref() { + body["vector_column"] = serde_json::Value::String(vector_column.clone()); + } + + if !query.use_index { + body["bypass_vector_index"] = serde_json::Value::Bool(true); + } + + let request = request.json(&body); + + let response = self.client.send(request).await?; + + let stream = self.read_arrow_stream(response).await?; + + Ok(Arc::new(OneShotExec::new(stream))) } + async fn plain_query( &self, - _query: &Query, + query: &Query, _options: QueryExecutionOptions, ) -> Result { - Err(Error::NotSupported { - message: "plain_query is not yet supported on LanceDB cloud.".into(), - }) + let request = self + .client + .post(&format!("/table/{}/query/", self.name)) + .header(CONTENT_TYPE, JSON_CONTENT_TYPE); + + let mut body = serde_json::Value::Object(Default::default()); + Self::apply_query_params(&mut body, query)?; + + let request = request.json(&body); + + let response = self.client.send(request).await?; + + let stream = self.read_arrow_stream(response).await?; + + Ok(DatasetRecordBatchStream::new(stream)) } async fn update(&self, update: UpdateBuilder) -> Result { let request = self.client.post(&format!("/table/{}/update/", self.name)); @@ -266,11 +401,79 @@ impl TableInternal for RemoteTable { self.check_table_response(response).await?; Ok(()) } - async fn create_index(&self, _index: IndexBuilder) -> Result<()> { - Err(Error::NotSupported { - message: "create_index is not yet supported on LanceDB cloud.".into(), - }) + + async fn create_index(&self, mut index: IndexBuilder) -> Result<()> { + let request = self + .client + .post(&format!("/table/{}/create_index/", self.name)); + + let column = match index.columns.len() { + 0 => { + return Err(Error::InvalidInput { + message: "No columns specified".into(), + }) + } + 1 => index.columns.pop().unwrap(), + _ => { + return Err(Error::NotSupported { + message: "Indices over multiple columns not yet supported".into(), + }) + } + }; + let mut body = serde_json::json!({ + "column": column + }); + + let (index_type, distance_type) = match index.index { + // TODO: Should we pass the actual index parameters? SaaS does not + // yet support them. + Index::IvfPq(index) => ("IVF_PQ", Some(index.distance_type)), + Index::IvfHnswSq(index) => ("IVF_HNSW_SQ", Some(index.distance_type)), + Index::BTree(_) => ("BTREE", None), + Index::Bitmap(_) => ("BITMAP", None), + Index::LabelList(_) => ("LABEL_LIST", None), + Index::FTS(_) => ("FTS", None), + Index::Auto => { + let schema = self.schema().await?; + let field = schema + .field_with_name(&column) + .map_err(|_| Error::InvalidInput { + message: format!("Column {} not found in schema", column), + })?; + if supported_vector_data_type(field.data_type()) { + ("IVF_PQ", None) + } else if supported_btree_data_type(field.data_type()) { + ("BTREE", None) + } else { + return Err(Error::NotSupported { + message: format!( + "there are no indices supported for the field `{}` with the data type {}", + field.name(), + field.data_type() + ), + }); + } + } + _ => { + return Err(Error::NotSupported { + message: "Index type not supported".into(), + }) + } + }; + body["index_type"] = serde_json::Value::String(index_type.into()); + if let Some(distance_type) = distance_type { + body["distance_type"] = serde_json::Value::String(distance_type.to_string()); + } + + let request = request.json(&body); + + let response = self.client.send(request).await?; + + self.check_table_response(response).await?; + + Ok(()) } + async fn merge_insert( &self, params: MergeInsertBuilder, @@ -375,9 +578,14 @@ mod tests { use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator}; use arrow_schema::{DataType, Field, Schema}; use futures::{future::BoxFuture, StreamExt, TryFutureExt}; + use lance_index::scalar::FullTextSearchQuery; use reqwest::Body; - use crate::{Error, Table}; + use crate::{ + index::{vector::IvfPqIndexBuilder, Index}, + query::{ExecutableQuery, QueryBase}, + DistanceType, Error, Table, + }; #[tokio::test] async fn test_not_found() { @@ -468,6 +676,10 @@ mod tests { let table = Table::new_with_handler("my_table", |request| { assert_eq!(request.method(), "POST"); assert_eq!(request.url().path(), "/table/my_table/count_rows/"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); assert_eq!(request.body().unwrap().as_bytes().unwrap(), br#"{}"#); http::Response::builder().status(200).body("42").unwrap() @@ -479,6 +691,10 @@ mod tests { let table = Table::new_with_handler("my_table", |request| { assert_eq!(request.method(), "POST"); assert_eq!(request.url().path(), "/table/my_table/count_rows/"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); assert_eq!( request.body().unwrap().as_bytes().unwrap(), br#"{"filter":"a > 10"}"# @@ -613,6 +829,10 @@ mod tests { let table = Table::new_with_handler("my_table", |request| { assert_eq!(request.method(), "POST"); assert_eq!(request.url().path(), "/table/my_table/update/"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); if let Some(body) = request.body().unwrap().as_bytes() { let body = std::str::from_utf8(body).unwrap(); @@ -720,6 +940,10 @@ mod tests { let table = Table::new_with_handler("my_table", |request| { assert_eq!(request.method(), "POST"); assert_eq!(request.url().path(), "/table/my_table/delete/"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); let body = request.body().unwrap().as_bytes().unwrap(); let body: serde_json::Value = serde_json::from_slice(body).unwrap(); @@ -731,4 +955,201 @@ mod tests { table.delete("id in (1, 2, 3)").await.unwrap(); } + + #[tokio::test] + async fn test_query_vector_default_values() { + let expected_data = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let expected_data_ref = expected_data.clone(); + + let table = Table::new_with_handler("my_table", move |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/table/my_table/query/"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); + + let body = request.body().unwrap().as_bytes().unwrap(); + let body: serde_json::Value = serde_json::from_slice(body).unwrap(); + let mut expected_body = serde_json::json!({ + "prefilter": true, + "distance_type": "l2", + "nprobes": 20, + "refine_factor": null, + }); + // Pass vector separately to make sure it matches f32 precision. + expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into(); + assert_eq!(body, expected_body); + + let response_body = write_ipc_stream(&expected_data_ref); + http::Response::builder() + .status(200) + .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE) + .body(response_body) + .unwrap() + }); + + let data = table + .query() + .nearest_to(vec![0.1, 0.2, 0.3]) + .unwrap() + .execute() + .await; + let data = data.unwrap().collect::>().await; + assert_eq!(data.len(), 1); + assert_eq!(data[0].as_ref().unwrap(), &expected_data); + } + + #[tokio::test] + async fn test_query_vector_all_params() { + let table = Table::new_with_handler("my_table", |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/table/my_table/query/"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); + + let body = request.body().unwrap().as_bytes().unwrap(); + let body: serde_json::Value = serde_json::from_slice(body).unwrap(); + let mut expected_body = serde_json::json!({ + "vector_column": "my_vector", + "prefilter": false, + "k": 42, + "distance_type": "cosine", + "bypass_vector_index": true, + "columns": ["a", "b"], + "nprobes": 12, + "refine_factor": 2, + }); + // Pass vector separately to make sure it matches f32 precision. + expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into(); + assert_eq!(body, expected_body); + + let data = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let response_body = write_ipc_stream(&data); + http::Response::builder() + .status(200) + .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE) + .body(response_body) + .unwrap() + }); + + let _ = table + .query() + .limit(42) + .select(Select::columns(&["a", "b"])) + .nearest_to(vec![0.1, 0.2, 0.3]) + .unwrap() + .column("my_vector") + .postfilter() + .distance_type(crate::DistanceType::Cosine) + .nprobes(12) + .refine_factor(2) + .bypass_vector_index() + .execute() + .await + .unwrap(); + } + + #[tokio::test] + async fn test_query_fts() { + let table = Table::new_with_handler("my_table", |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/table/my_table/query/"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); + + let body = request.body().unwrap().as_bytes().unwrap(); + let body: serde_json::Value = serde_json::from_slice(body).unwrap(); + let expected_body = serde_json::json!({ + "full_text_query": { + "columns": ["a", "b"], + "query": "hello world", + }, + "k": 10, + }); + assert_eq!(body, expected_body); + + let data = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + let response_body = write_ipc_stream(&data); + http::Response::builder() + .status(200) + .header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE) + .body(response_body) + .unwrap() + }); + + let _ = table + .query() + .full_text_search( + FullTextSearchQuery::new("hello world".into()) + .columns(Some(vec!["a".into(), "b".into()])), + ) + .limit(10) + .execute() + .await + .unwrap(); + } + + #[tokio::test] + async fn test_create_index() { + let cases = [ + ("IVF_PQ", Some("l2"), Index::IvfPq(Default::default())), + ( + "IVF_PQ", + Some("cosine"), + Index::IvfPq(IvfPqIndexBuilder::default().distance_type(DistanceType::Cosine)), + ), + ( + "IVF_HNSW_SQ", + Some("l2"), + Index::IvfHnswSq(Default::default()), + ), + // HNSW_PQ isn't yet supported on SaaS + ("BTREE", None, Index::BTree(Default::default())), + ("BITMAP", None, Index::Bitmap(Default::default())), + ("LABEL_LIST", None, Index::LabelList(Default::default())), + ("FTS", None, Index::FTS(Default::default())), + ]; + + for (index_type, distance_type, index) in cases { + let table = Table::new_with_handler("my_table", move |request| { + assert_eq!(request.method(), "POST"); + assert_eq!(request.url().path(), "/table/my_table/create_index/"); + assert_eq!( + request.headers().get("Content-Type").unwrap(), + JSON_CONTENT_TYPE + ); + let body = request.body().unwrap().as_bytes().unwrap(); + let body: serde_json::Value = serde_json::from_slice(body).unwrap(); + let mut expected_body = serde_json::json!({ + "column": "a", + "index_type": index_type, + }); + if let Some(distance_type) = distance_type { + expected_body["distance_type"] = distance_type.into(); + } + assert_eq!(body, expected_body); + + http::Response::builder().status(200).body("{}").unwrap() + }); + + table.create_index(&["a"], index).execute().await.unwrap(); + } + } } diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index daff561d..e4ec93d4 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -21,8 +21,9 @@ use std::sync::Arc; use arrow::array::AsArray; use arrow::datatypes::Float32Type; use arrow_array::{RecordBatchIterator, RecordBatchReader}; -use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use arrow_schema::{Field, Schema, SchemaRef}; use async_trait::async_trait; +use datafusion_physical_plan::display::DisplayableExecutionPlan; use datafusion_physical_plan::ExecutionPlan; use lance::dataset::builder::DatasetBuilder; use lance::dataset::cleanup::RemovalStats; @@ -66,9 +67,13 @@ use crate::index::{ use crate::query::{ IntoQueryVector, Query, QueryExecutionOptions, Select, VectorQuery, DEFAULT_TOP_K, }; -use crate::utils::{default_vector_column, PatchReadParam, PatchWriteParam}; +use crate::utils::{ + default_vector_column, supported_bitmap_data_type, supported_btree_data_type, + supported_fts_data_type, supported_label_list_data_type, supported_vector_data_type, + PatchReadParam, PatchWriteParam, +}; -use self::dataset::{DatasetConsistencyWrapper, DatasetReadGuard}; +use self::dataset::DatasetConsistencyWrapper; use self::merge::MergeInsertBuilder; pub(crate) mod dataset; @@ -375,12 +380,6 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn async fn schema(&self) -> Result; /// Count the number of rows in this table. async fn count_rows(&self, filter: Option) -> Result; - async fn build_plan( - &self, - ds_ref: &DatasetReadGuard, - query: &VectorQuery, - options: Option, - ) -> Result; async fn create_plan( &self, query: &VectorQuery, @@ -391,7 +390,12 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn query: &Query, options: QueryExecutionOptions, ) -> Result; - async fn explain_plan(&self, query: &VectorQuery, verbose: bool) -> Result; + async fn explain_plan(&self, query: &VectorQuery, verbose: bool) -> Result { + let plan = self.create_plan(query, Default::default()).await?; + let display = DisplayableExecutionPlan::new(plan.as_ref()); + + Ok(format!("{}", display.indent(verbose))) + } async fn add( &self, add: AddDataBuilder, @@ -1088,46 +1092,6 @@ impl NativeTable { Ok(name.to_string()) } - fn supported_btree_data_type(dtype: &DataType) -> bool { - dtype.is_integer() - || dtype.is_floating() - || matches!( - dtype, - DataType::Boolean - | DataType::Utf8 - | DataType::Time32(_) - | DataType::Time64(_) - | DataType::Date32 - | DataType::Date64 - | DataType::Timestamp(_, _) - ) - } - - fn supported_bitmap_data_type(dtype: &DataType) -> bool { - dtype.is_integer() || matches!(dtype, DataType::Utf8) - } - - fn supported_label_list_data_type(dtype: &DataType) -> bool { - match dtype { - DataType::List(field) => Self::supported_bitmap_data_type(field.data_type()), - DataType::FixedSizeList(field, _) => { - Self::supported_bitmap_data_type(field.data_type()) - } - _ => false, - } - } - - fn supported_fts_data_type(dtype: &DataType) -> bool { - matches!(dtype, DataType::Utf8 | DataType::LargeUtf8) - } - - fn supported_vector_data_type(dtype: &DataType) -> bool { - match dtype { - DataType::FixedSizeList(inner, _) => DataType::is_floating(inner.data_type()), - _ => false, - } - } - /// Creates a new Table /// /// # Arguments @@ -1386,7 +1350,7 @@ impl NativeTable { field: &Field, replace: bool, ) -> Result<()> { - if !Self::supported_vector_data_type(field.data_type()) { + if !supported_vector_data_type(field.data_type()) { return Err(Error::InvalidInput { message: format!( "An IVF PQ index cannot be created on the column `{}` which has data type {}", @@ -1439,7 +1403,7 @@ impl NativeTable { field: &Field, replace: bool, ) -> Result<()> { - if !Self::supported_vector_data_type(field.data_type()) { + if !supported_vector_data_type(field.data_type()) { return Err(Error::InvalidInput { message: format!( "An IVF HNSW PQ index cannot be created on the column `{}` which has data type {}", @@ -1510,7 +1474,7 @@ impl NativeTable { field: &Field, replace: bool, ) -> Result<()> { - if !Self::supported_vector_data_type(field.data_type()) { + if !supported_vector_data_type(field.data_type()) { return Err(Error::InvalidInput { message: format!( "An IVF HNSW SQ index cannot be created on the column `{}` which has data type {}", @@ -1563,10 +1527,10 @@ impl NativeTable { } async fn create_auto_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> { - if Self::supported_vector_data_type(field.data_type()) { + if supported_vector_data_type(field.data_type()) { self.create_ivf_pq_index(IvfPqIndexBuilder::default(), field, opts.replace) .await - } else if Self::supported_btree_data_type(field.data_type()) { + } else if supported_btree_data_type(field.data_type()) { self.create_btree_index(field, opts).await } else { Err(Error::InvalidInput { @@ -1580,7 +1544,7 @@ impl NativeTable { } async fn create_btree_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> { - if !Self::supported_btree_data_type(field.data_type()) { + if !supported_btree_data_type(field.data_type()) { return Err(Error::Schema { message: format!( "A BTree index cannot be created on the field `{}` which has data type {}", @@ -1607,7 +1571,7 @@ impl NativeTable { } async fn create_bitmap_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> { - if !Self::supported_bitmap_data_type(field.data_type()) { + if !supported_bitmap_data_type(field.data_type()) { return Err(Error::Schema { message: format!( "A Bitmap index cannot be created on the field `{}` which has data type {}", @@ -1634,7 +1598,7 @@ impl NativeTable { } async fn create_label_list_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> { - if !Self::supported_label_list_data_type(field.data_type()) { + if !supported_label_list_data_type(field.data_type()) { return Err(Error::Schema { message: format!( "A LabelList index cannot be created on the field `{}` which has data type {}", @@ -1666,7 +1630,7 @@ impl NativeTable { fts_opts: FtsIndexBuilder, replace: bool, ) -> Result<()> { - if !Self::supported_fts_data_type(field.data_type()) { + if !supported_fts_data_type(field.data_type()) { return Err(Error::Schema { message: format!( "A FTS index cannot be created on the field `{}` which has data type {}", @@ -1887,12 +1851,13 @@ impl TableInternal for NativeTable { Ok(res.rows_updated) } - async fn build_plan( + async fn create_plan( &self, - ds_ref: &DatasetReadGuard, query: &VectorQuery, - options: Option, - ) -> Result { + options: QueryExecutionOptions, + ) -> Result> { + let ds_ref = self.dataset.get().await?; + let mut scanner: Scanner = ds_ref.scan(); if let Some(query_vector) = query.query_vector.as_ref() { @@ -1966,25 +1931,12 @@ impl TableInternal for NativeTable { scanner.with_row_id(); } - if let Some(opts) = options { - scanner.batch_size(opts.max_batch_length as usize); - } + scanner.batch_size(options.max_batch_length as usize); + if query.base.fast_search { scanner.fast_search(); } - Ok(scanner) - } - - async fn create_plan( - &self, - query: &VectorQuery, - options: QueryExecutionOptions, - ) -> Result> { - let ds_ref = self.dataset.get().await?; - - let mut scanner = self.build_plan(&ds_ref, query, Some(options)).await?; - match &query.base.select { Select::Columns(select) => { scanner.project(select.as_slice())?; @@ -2023,16 +1975,6 @@ impl TableInternal for NativeTable { .await } - async fn explain_plan(&self, query: &VectorQuery, verbose: bool) -> Result { - let ds_ref = self.dataset.get().await?; - - let scanner = self.build_plan(&ds_ref, query, None).await?; - - let plan = scanner.explain_plan(verbose).await?; - - Ok(plan) - } - async fn merge_insert( &self, params: MergeInsertBuilder, diff --git a/rust/lancedb/src/utils.rs b/rust/lancedb/src/utils.rs index 09f3f276..f165a367 100644 --- a/rust/lancedb/src/utils.rs +++ b/rust/lancedb/src/utils.rs @@ -14,7 +14,7 @@ use std::sync::Arc; -use arrow_schema::Schema; +use arrow_schema::{DataType, Schema}; use lance::dataset::{ReadParams, WriteParams}; use lance::io::{ObjectStoreParams, WrappingObjectStore}; use lazy_static::lazy_static; @@ -137,6 +137,44 @@ pub(crate) fn default_vector_column(schema: &Schema, dim: Option) -> Result } } +pub fn supported_btree_data_type(dtype: &DataType) -> bool { + dtype.is_integer() + || dtype.is_floating() + || matches!( + dtype, + DataType::Boolean + | DataType::Utf8 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Date32 + | DataType::Date64 + | DataType::Timestamp(_, _) + ) +} + +pub fn supported_bitmap_data_type(dtype: &DataType) -> bool { + dtype.is_integer() || matches!(dtype, DataType::Utf8) +} + +pub fn supported_label_list_data_type(dtype: &DataType) -> bool { + match dtype { + DataType::List(field) => supported_bitmap_data_type(field.data_type()), + DataType::FixedSizeList(field, _) => supported_bitmap_data_type(field.data_type()), + _ => false, + } +} + +pub fn supported_fts_data_type(dtype: &DataType) -> bool { + matches!(dtype, DataType::Utf8 | DataType::LargeUtf8) +} + +pub fn supported_vector_data_type(dtype: &DataType) -> bool { + match dtype { + DataType::FixedSizeList(inner, _) => DataType::is_floating(inner.data_type()), + _ => false, + } +} + #[cfg(test)] mod tests { use super::*;