feat(rust): remote client query and create_index endpoints (#1663)

Support for `query` and `create_index`.

Closes [#2519](https://github.com/lancedb/sophon/issues/2519)
This commit is contained in:
Will Jones
2024-09-27 09:00:22 -07:00
committed by GitHub
parent ee6c18f207
commit 1778219ea9
7 changed files with 532 additions and 122 deletions

View File

@@ -38,6 +38,7 @@ arrow-arith = "52.2"
arrow-cast = "52.2"
async-trait = "0"
chrono = "0.4.35"
datafusion-common = "40.0"
datafusion-physical-plan = "40.0"
half = { "version" = "=2.4.1", default-features = false, features = [
"num-traits",

View File

@@ -19,6 +19,7 @@ arrow-ord = { workspace = true }
arrow-cast = { workspace = true }
arrow-ipc.workspace = true
chrono = { workspace = true }
datafusion-common.workspace = true
datafusion-physical-plan.workspace = true
object_store = { workspace = true }
snafu = { workspace = true }

View File

@@ -254,6 +254,12 @@ pub enum DistanceType {
Hamming,
}
impl Default for DistanceType {
fn default() -> Self {
Self::L2
}
}
impl From<DistanceType> for LanceDistanceType {
fn from(value: DistanceType) -> Self {
match value {

View File

@@ -23,3 +23,4 @@ pub mod table;
pub mod util;
const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";
const JSON_CONTENT_TYPE: &str = "application/json";

View File

@@ -1,17 +1,25 @@
use std::sync::{Arc, Mutex};
use crate::table::dataset::DatasetReadGuard;
use crate::index::Index;
use crate::query::Select;
use crate::table::AddDataMode;
use crate::utils::{supported_btree_data_type, supported_vector_data_type};
use crate::Error;
use arrow_array::RecordBatchReader;
use arrow_schema::SchemaRef;
use arrow_ipc::reader::StreamReader;
use arrow_schema::{DataType, SchemaRef};
use async_trait::async_trait;
use datafusion_physical_plan::ExecutionPlan;
use bytes::Buf;
use datafusion_common::DataFusionError;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use datafusion_physical_plan::{ExecutionPlan, SendableRecordBatchStream};
use futures::TryStreamExt;
use http::header::CONTENT_TYPE;
use http::StatusCode;
use lance::arrow::json::JsonSchema;
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
use lance::dataset::scanner::DatasetRecordBatchStream;
use lance::dataset::{ColumnAlteration, NewColumnTransform};
use lance_datafusion::exec::OneShotExec;
use serde::{Deserialize, Serialize};
use crate::{
@@ -26,7 +34,7 @@ use crate::{
};
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
use super::ARROW_STREAM_CONTENT_TYPE;
use super::{ARROW_STREAM_CONTENT_TYPE, JSON_CONTENT_TYPE};
#[derive(Debug)]
pub struct RemoteTable<S: HttpSend = Sender> {
@@ -85,6 +93,93 @@ impl<S: HttpSend> RemoteTable<S> {
self.client.check_response(response).await
}
async fn read_arrow_stream(
&self,
body: reqwest::Response,
) -> Result<SendableRecordBatchStream> {
// Assert that the content type is correct
let content_type = body
.headers()
.get(CONTENT_TYPE)
.ok_or_else(|| Error::Http {
message: "Missing content type".into(),
})?
.to_str()
.map_err(|e| Error::Http {
message: format!("Failed to parse content type: {}", e),
})?;
if content_type != ARROW_STREAM_CONTENT_TYPE {
return Err(Error::Http {
message: format!(
"Expected content type {}, got {}",
ARROW_STREAM_CONTENT_TYPE, content_type
),
});
}
// There isn't a way to actually stream this data yet. I have an upstream issue:
// https://github.com/apache/arrow-rs/issues/6420
let body = body.bytes().await?;
let reader = StreamReader::try_new(body.reader(), None)?;
let schema = reader.schema();
let stream = futures::stream::iter(reader).map_err(DataFusionError::from);
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
}
fn apply_query_params(body: &mut serde_json::Value, params: &Query) -> Result<()> {
if params.offset.is_some() {
return Err(Error::NotSupported {
message: "Offset is not yet supported in LanceDB Cloud".into(),
});
}
if let Some(limit) = params.limit {
body["k"] = serde_json::Value::Number(serde_json::Number::from(limit));
}
if let Some(filter) = &params.filter {
body["filter"] = serde_json::Value::String(filter.clone());
}
match &params.select {
Select::All => {}
Select::Columns(columns) => {
body["columns"] = serde_json::Value::Array(
columns
.iter()
.map(|s| serde_json::Value::String(s.clone()))
.collect(),
);
}
Select::Dynamic(pairs) => {
body["columns"] = serde_json::Value::Array(
pairs
.iter()
.map(|(name, expr)| serde_json::json!([name, expr]))
.collect(),
);
}
}
if params.fast_search {
body["fast_search"] = serde_json::Value::Bool(true);
}
if let Some(full_text_search) = &params.full_text_search {
if full_text_search.wand_factor.is_some() {
return Err(Error::NotSupported {
message: "Wand factor is not yet supported in LanceDB Cloud".into(),
});
}
body["full_text_query"] = serde_json::json!({
"columns": full_text_search.columns,
"query": full_text_search.query,
})
}
Ok(())
}
}
#[derive(Deserialize)]
@@ -196,38 +291,78 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
Ok(())
}
async fn build_plan(
&self,
_ds_ref: &DatasetReadGuard,
_query: &VectorQuery,
_options: Option<QueryExecutionOptions>,
) -> Result<Scanner> {
Err(Error::NotSupported {
message: "build_plan is not supported on LanceDB cloud.".into(),
})
}
async fn create_plan(
&self,
_query: &VectorQuery,
query: &VectorQuery,
_options: QueryExecutionOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
Err(Error::NotSupported {
message: "create_plan is not supported on LanceDB cloud.".into(),
})
}
async fn explain_plan(&self, _query: &VectorQuery, _verbose: bool) -> Result<String> {
Err(Error::NotSupported {
message: "explain_plan is not supported on LanceDB cloud.".into(),
})
let request = self.client.post(&format!("/table/{}/query/", self.name));
let mut body = serde_json::Value::Object(Default::default());
Self::apply_query_params(&mut body, &query.base)?;
body["prefilter"] = query.prefilter.into();
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
body["nprobes"] = query.nprobes.into();
body["refine_factor"] = query.refine_factor.into();
if let Some(vector) = query.query_vector.as_ref() {
let vector: Vec<f32> = match vector.data_type() {
DataType::Float32 => vector
.as_any()
.downcast_ref::<arrow_array::Float32Array>()
.unwrap()
.values()
.iter()
.cloned()
.collect(),
_ => {
return Err(Error::InvalidInput {
message: "VectorQuery vector must be of type Float32".into(),
})
}
};
body["vector"] = serde_json::json!(vector);
}
if let Some(vector_column) = query.column.as_ref() {
body["vector_column"] = serde_json::Value::String(vector_column.clone());
}
if !query.use_index {
body["bypass_vector_index"] = serde_json::Value::Bool(true);
}
let request = request.json(&body);
let response = self.client.send(request).await?;
let stream = self.read_arrow_stream(response).await?;
Ok(Arc::new(OneShotExec::new(stream)))
}
async fn plain_query(
&self,
_query: &Query,
query: &Query,
_options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
Err(Error::NotSupported {
message: "plain_query is not yet supported on LanceDB cloud.".into(),
})
let request = self
.client
.post(&format!("/table/{}/query/", self.name))
.header(CONTENT_TYPE, JSON_CONTENT_TYPE);
let mut body = serde_json::Value::Object(Default::default());
Self::apply_query_params(&mut body, query)?;
let request = request.json(&body);
let response = self.client.send(request).await?;
let stream = self.read_arrow_stream(response).await?;
Ok(DatasetRecordBatchStream::new(stream))
}
async fn update(&self, update: UpdateBuilder) -> Result<u64> {
let request = self.client.post(&format!("/table/{}/update/", self.name));
@@ -266,11 +401,79 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
self.check_table_response(response).await?;
Ok(())
}
async fn create_index(&self, _index: IndexBuilder) -> Result<()> {
Err(Error::NotSupported {
message: "create_index is not yet supported on LanceDB cloud.".into(),
})
async fn create_index(&self, mut index: IndexBuilder) -> Result<()> {
let request = self
.client
.post(&format!("/table/{}/create_index/", self.name));
let column = match index.columns.len() {
0 => {
return Err(Error::InvalidInput {
message: "No columns specified".into(),
})
}
1 => index.columns.pop().unwrap(),
_ => {
return Err(Error::NotSupported {
message: "Indices over multiple columns not yet supported".into(),
})
}
};
let mut body = serde_json::json!({
"column": column
});
let (index_type, distance_type) = match index.index {
// TODO: Should we pass the actual index parameters? SaaS does not
// yet support them.
Index::IvfPq(index) => ("IVF_PQ", Some(index.distance_type)),
Index::IvfHnswSq(index) => ("IVF_HNSW_SQ", Some(index.distance_type)),
Index::BTree(_) => ("BTREE", None),
Index::Bitmap(_) => ("BITMAP", None),
Index::LabelList(_) => ("LABEL_LIST", None),
Index::FTS(_) => ("FTS", None),
Index::Auto => {
let schema = self.schema().await?;
let field = schema
.field_with_name(&column)
.map_err(|_| Error::InvalidInput {
message: format!("Column {} not found in schema", column),
})?;
if supported_vector_data_type(field.data_type()) {
("IVF_PQ", None)
} else if supported_btree_data_type(field.data_type()) {
("BTREE", None)
} else {
return Err(Error::NotSupported {
message: format!(
"there are no indices supported for the field `{}` with the data type {}",
field.name(),
field.data_type()
),
});
}
}
_ => {
return Err(Error::NotSupported {
message: "Index type not supported".into(),
})
}
};
body["index_type"] = serde_json::Value::String(index_type.into());
if let Some(distance_type) = distance_type {
body["distance_type"] = serde_json::Value::String(distance_type.to_string());
}
let request = request.json(&body);
let response = self.client.send(request).await?;
self.check_table_response(response).await?;
Ok(())
}
async fn merge_insert(
&self,
params: MergeInsertBuilder,
@@ -375,9 +578,14 @@ mod tests {
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator};
use arrow_schema::{DataType, Field, Schema};
use futures::{future::BoxFuture, StreamExt, TryFutureExt};
use lance_index::scalar::FullTextSearchQuery;
use reqwest::Body;
use crate::{Error, Table};
use crate::{
index::{vector::IvfPqIndexBuilder, Index},
query::{ExecutableQuery, QueryBase},
DistanceType, Error, Table,
};
#[tokio::test]
async fn test_not_found() {
@@ -468,6 +676,10 @@ mod tests {
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/table/my_table/count_rows/");
assert_eq!(
request.headers().get("Content-Type").unwrap(),
JSON_CONTENT_TYPE
);
assert_eq!(request.body().unwrap().as_bytes().unwrap(), br#"{}"#);
http::Response::builder().status(200).body("42").unwrap()
@@ -479,6 +691,10 @@ mod tests {
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/table/my_table/count_rows/");
assert_eq!(
request.headers().get("Content-Type").unwrap(),
JSON_CONTENT_TYPE
);
assert_eq!(
request.body().unwrap().as_bytes().unwrap(),
br#"{"filter":"a > 10"}"#
@@ -613,6 +829,10 @@ mod tests {
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/table/my_table/update/");
assert_eq!(
request.headers().get("Content-Type").unwrap(),
JSON_CONTENT_TYPE
);
if let Some(body) = request.body().unwrap().as_bytes() {
let body = std::str::from_utf8(body).unwrap();
@@ -720,6 +940,10 @@ mod tests {
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/table/my_table/delete/");
assert_eq!(
request.headers().get("Content-Type").unwrap(),
JSON_CONTENT_TYPE
);
let body = request.body().unwrap().as_bytes().unwrap();
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
@@ -731,4 +955,201 @@ mod tests {
table.delete("id in (1, 2, 3)").await.unwrap();
}
#[tokio::test]
async fn test_query_vector_default_values() {
let expected_data = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let expected_data_ref = expected_data.clone();
let table = Table::new_with_handler("my_table", move |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/table/my_table/query/");
assert_eq!(
request.headers().get("Content-Type").unwrap(),
JSON_CONTENT_TYPE
);
let body = request.body().unwrap().as_bytes().unwrap();
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
let mut expected_body = serde_json::json!({
"prefilter": true,
"distance_type": "l2",
"nprobes": 20,
"refine_factor": null,
});
// Pass vector separately to make sure it matches f32 precision.
expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into();
assert_eq!(body, expected_body);
let response_body = write_ipc_stream(&expected_data_ref);
http::Response::builder()
.status(200)
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
.body(response_body)
.unwrap()
});
let data = table
.query()
.nearest_to(vec![0.1, 0.2, 0.3])
.unwrap()
.execute()
.await;
let data = data.unwrap().collect::<Vec<_>>().await;
assert_eq!(data.len(), 1);
assert_eq!(data[0].as_ref().unwrap(), &expected_data);
}
#[tokio::test]
async fn test_query_vector_all_params() {
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/table/my_table/query/");
assert_eq!(
request.headers().get("Content-Type").unwrap(),
JSON_CONTENT_TYPE
);
let body = request.body().unwrap().as_bytes().unwrap();
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
let mut expected_body = serde_json::json!({
"vector_column": "my_vector",
"prefilter": false,
"k": 42,
"distance_type": "cosine",
"bypass_vector_index": true,
"columns": ["a", "b"],
"nprobes": 12,
"refine_factor": 2,
});
// Pass vector separately to make sure it matches f32 precision.
expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into();
assert_eq!(body, expected_body);
let data = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let response_body = write_ipc_stream(&data);
http::Response::builder()
.status(200)
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
.body(response_body)
.unwrap()
});
let _ = table
.query()
.limit(42)
.select(Select::columns(&["a", "b"]))
.nearest_to(vec![0.1, 0.2, 0.3])
.unwrap()
.column("my_vector")
.postfilter()
.distance_type(crate::DistanceType::Cosine)
.nprobes(12)
.refine_factor(2)
.bypass_vector_index()
.execute()
.await
.unwrap();
}
#[tokio::test]
async fn test_query_fts() {
let table = Table::new_with_handler("my_table", |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/table/my_table/query/");
assert_eq!(
request.headers().get("Content-Type").unwrap(),
JSON_CONTENT_TYPE
);
let body = request.body().unwrap().as_bytes().unwrap();
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
let expected_body = serde_json::json!({
"full_text_query": {
"columns": ["a", "b"],
"query": "hello world",
},
"k": 10,
});
assert_eq!(body, expected_body);
let data = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let response_body = write_ipc_stream(&data);
http::Response::builder()
.status(200)
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
.body(response_body)
.unwrap()
});
let _ = table
.query()
.full_text_search(
FullTextSearchQuery::new("hello world".into())
.columns(Some(vec!["a".into(), "b".into()])),
)
.limit(10)
.execute()
.await
.unwrap();
}
#[tokio::test]
async fn test_create_index() {
let cases = [
("IVF_PQ", Some("l2"), Index::IvfPq(Default::default())),
(
"IVF_PQ",
Some("cosine"),
Index::IvfPq(IvfPqIndexBuilder::default().distance_type(DistanceType::Cosine)),
),
(
"IVF_HNSW_SQ",
Some("l2"),
Index::IvfHnswSq(Default::default()),
),
// HNSW_PQ isn't yet supported on SaaS
("BTREE", None, Index::BTree(Default::default())),
("BITMAP", None, Index::Bitmap(Default::default())),
("LABEL_LIST", None, Index::LabelList(Default::default())),
("FTS", None, Index::FTS(Default::default())),
];
for (index_type, distance_type, index) in cases {
let table = Table::new_with_handler("my_table", move |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/table/my_table/create_index/");
assert_eq!(
request.headers().get("Content-Type").unwrap(),
JSON_CONTENT_TYPE
);
let body = request.body().unwrap().as_bytes().unwrap();
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
let mut expected_body = serde_json::json!({
"column": "a",
"index_type": index_type,
});
if let Some(distance_type) = distance_type {
expected_body["distance_type"] = distance_type.into();
}
assert_eq!(body, expected_body);
http::Response::builder().status(200).body("{}").unwrap()
});
table.create_index(&["a"], index).execute().await.unwrap();
}
}
}

View File

@@ -21,8 +21,9 @@ use std::sync::Arc;
use arrow::array::AsArray;
use arrow::datatypes::Float32Type;
use arrow_array::{RecordBatchIterator, RecordBatchReader};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use arrow_schema::{Field, Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_physical_plan::display::DisplayableExecutionPlan;
use datafusion_physical_plan::ExecutionPlan;
use lance::dataset::builder::DatasetBuilder;
use lance::dataset::cleanup::RemovalStats;
@@ -66,9 +67,13 @@ use crate::index::{
use crate::query::{
IntoQueryVector, Query, QueryExecutionOptions, Select, VectorQuery, DEFAULT_TOP_K,
};
use crate::utils::{default_vector_column, PatchReadParam, PatchWriteParam};
use crate::utils::{
default_vector_column, supported_bitmap_data_type, supported_btree_data_type,
supported_fts_data_type, supported_label_list_data_type, supported_vector_data_type,
PatchReadParam, PatchWriteParam,
};
use self::dataset::{DatasetConsistencyWrapper, DatasetReadGuard};
use self::dataset::DatasetConsistencyWrapper;
use self::merge::MergeInsertBuilder;
pub(crate) mod dataset;
@@ -375,12 +380,6 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
async fn schema(&self) -> Result<SchemaRef>;
/// Count the number of rows in this table.
async fn count_rows(&self, filter: Option<String>) -> Result<usize>;
async fn build_plan(
&self,
ds_ref: &DatasetReadGuard,
query: &VectorQuery,
options: Option<QueryExecutionOptions>,
) -> Result<Scanner>;
async fn create_plan(
&self,
query: &VectorQuery,
@@ -391,7 +390,12 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
query: &Query,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream>;
async fn explain_plan(&self, query: &VectorQuery, verbose: bool) -> Result<String>;
async fn explain_plan(&self, query: &VectorQuery, verbose: bool) -> Result<String> {
let plan = self.create_plan(query, Default::default()).await?;
let display = DisplayableExecutionPlan::new(plan.as_ref());
Ok(format!("{}", display.indent(verbose)))
}
async fn add(
&self,
add: AddDataBuilder<NoData>,
@@ -1088,46 +1092,6 @@ impl NativeTable {
Ok(name.to_string())
}
fn supported_btree_data_type(dtype: &DataType) -> bool {
dtype.is_integer()
|| dtype.is_floating()
|| matches!(
dtype,
DataType::Boolean
| DataType::Utf8
| DataType::Time32(_)
| DataType::Time64(_)
| DataType::Date32
| DataType::Date64
| DataType::Timestamp(_, _)
)
}
fn supported_bitmap_data_type(dtype: &DataType) -> bool {
dtype.is_integer() || matches!(dtype, DataType::Utf8)
}
fn supported_label_list_data_type(dtype: &DataType) -> bool {
match dtype {
DataType::List(field) => Self::supported_bitmap_data_type(field.data_type()),
DataType::FixedSizeList(field, _) => {
Self::supported_bitmap_data_type(field.data_type())
}
_ => false,
}
}
fn supported_fts_data_type(dtype: &DataType) -> bool {
matches!(dtype, DataType::Utf8 | DataType::LargeUtf8)
}
fn supported_vector_data_type(dtype: &DataType) -> bool {
match dtype {
DataType::FixedSizeList(inner, _) => DataType::is_floating(inner.data_type()),
_ => false,
}
}
/// Creates a new Table
///
/// # Arguments
@@ -1386,7 +1350,7 @@ impl NativeTable {
field: &Field,
replace: bool,
) -> Result<()> {
if !Self::supported_vector_data_type(field.data_type()) {
if !supported_vector_data_type(field.data_type()) {
return Err(Error::InvalidInput {
message: format!(
"An IVF PQ index cannot be created on the column `{}` which has data type {}",
@@ -1439,7 +1403,7 @@ impl NativeTable {
field: &Field,
replace: bool,
) -> Result<()> {
if !Self::supported_vector_data_type(field.data_type()) {
if !supported_vector_data_type(field.data_type()) {
return Err(Error::InvalidInput {
message: format!(
"An IVF HNSW PQ index cannot be created on the column `{}` which has data type {}",
@@ -1510,7 +1474,7 @@ impl NativeTable {
field: &Field,
replace: bool,
) -> Result<()> {
if !Self::supported_vector_data_type(field.data_type()) {
if !supported_vector_data_type(field.data_type()) {
return Err(Error::InvalidInput {
message: format!(
"An IVF HNSW SQ index cannot be created on the column `{}` which has data type {}",
@@ -1563,10 +1527,10 @@ impl NativeTable {
}
async fn create_auto_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> {
if Self::supported_vector_data_type(field.data_type()) {
if supported_vector_data_type(field.data_type()) {
self.create_ivf_pq_index(IvfPqIndexBuilder::default(), field, opts.replace)
.await
} else if Self::supported_btree_data_type(field.data_type()) {
} else if supported_btree_data_type(field.data_type()) {
self.create_btree_index(field, opts).await
} else {
Err(Error::InvalidInput {
@@ -1580,7 +1544,7 @@ impl NativeTable {
}
async fn create_btree_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> {
if !Self::supported_btree_data_type(field.data_type()) {
if !supported_btree_data_type(field.data_type()) {
return Err(Error::Schema {
message: format!(
"A BTree index cannot be created on the field `{}` which has data type {}",
@@ -1607,7 +1571,7 @@ impl NativeTable {
}
async fn create_bitmap_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> {
if !Self::supported_bitmap_data_type(field.data_type()) {
if !supported_bitmap_data_type(field.data_type()) {
return Err(Error::Schema {
message: format!(
"A Bitmap index cannot be created on the field `{}` which has data type {}",
@@ -1634,7 +1598,7 @@ impl NativeTable {
}
async fn create_label_list_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> {
if !Self::supported_label_list_data_type(field.data_type()) {
if !supported_label_list_data_type(field.data_type()) {
return Err(Error::Schema {
message: format!(
"A LabelList index cannot be created on the field `{}` which has data type {}",
@@ -1666,7 +1630,7 @@ impl NativeTable {
fts_opts: FtsIndexBuilder,
replace: bool,
) -> Result<()> {
if !Self::supported_fts_data_type(field.data_type()) {
if !supported_fts_data_type(field.data_type()) {
return Err(Error::Schema {
message: format!(
"A FTS index cannot be created on the field `{}` which has data type {}",
@@ -1887,12 +1851,13 @@ impl TableInternal for NativeTable {
Ok(res.rows_updated)
}
async fn build_plan(
async fn create_plan(
&self,
ds_ref: &DatasetReadGuard,
query: &VectorQuery,
options: Option<QueryExecutionOptions>,
) -> Result<Scanner> {
options: QueryExecutionOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let ds_ref = self.dataset.get().await?;
let mut scanner: Scanner = ds_ref.scan();
if let Some(query_vector) = query.query_vector.as_ref() {
@@ -1966,25 +1931,12 @@ impl TableInternal for NativeTable {
scanner.with_row_id();
}
if let Some(opts) = options {
scanner.batch_size(opts.max_batch_length as usize);
}
scanner.batch_size(options.max_batch_length as usize);
if query.base.fast_search {
scanner.fast_search();
}
Ok(scanner)
}
async fn create_plan(
&self,
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let ds_ref = self.dataset.get().await?;
let mut scanner = self.build_plan(&ds_ref, query, Some(options)).await?;
match &query.base.select {
Select::Columns(select) => {
scanner.project(select.as_slice())?;
@@ -2023,16 +1975,6 @@ impl TableInternal for NativeTable {
.await
}
async fn explain_plan(&self, query: &VectorQuery, verbose: bool) -> Result<String> {
let ds_ref = self.dataset.get().await?;
let scanner = self.build_plan(&ds_ref, query, None).await?;
let plan = scanner.explain_plan(verbose).await?;
Ok(plan)
}
async fn merge_insert(
&self,
params: MergeInsertBuilder,

View File

@@ -14,7 +14,7 @@
use std::sync::Arc;
use arrow_schema::Schema;
use arrow_schema::{DataType, Schema};
use lance::dataset::{ReadParams, WriteParams};
use lance::io::{ObjectStoreParams, WrappingObjectStore};
use lazy_static::lazy_static;
@@ -137,6 +137,44 @@ pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result
}
}
pub fn supported_btree_data_type(dtype: &DataType) -> bool {
dtype.is_integer()
|| dtype.is_floating()
|| matches!(
dtype,
DataType::Boolean
| DataType::Utf8
| DataType::Time32(_)
| DataType::Time64(_)
| DataType::Date32
| DataType::Date64
| DataType::Timestamp(_, _)
)
}
pub fn supported_bitmap_data_type(dtype: &DataType) -> bool {
dtype.is_integer() || matches!(dtype, DataType::Utf8)
}
pub fn supported_label_list_data_type(dtype: &DataType) -> bool {
match dtype {
DataType::List(field) => supported_bitmap_data_type(field.data_type()),
DataType::FixedSizeList(field, _) => supported_bitmap_data_type(field.data_type()),
_ => false,
}
}
pub fn supported_fts_data_type(dtype: &DataType) -> bool {
matches!(dtype, DataType::Utf8 | DataType::LargeUtf8)
}
pub fn supported_vector_data_type(dtype: &DataType) -> bool {
match dtype {
DataType::FixedSizeList(inner, _) => DataType::is_floating(inner.data_type()),
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;