mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-16 19:40:40 +00:00
feat(rust): remote client query and create_index endpoints (#1663)
Support for `query` and `create_index`. Closes [#2519](https://github.com/lancedb/sophon/issues/2519)
This commit is contained in:
@@ -38,6 +38,7 @@ arrow-arith = "52.2"
|
||||
arrow-cast = "52.2"
|
||||
async-trait = "0"
|
||||
chrono = "0.4.35"
|
||||
datafusion-common = "40.0"
|
||||
datafusion-physical-plan = "40.0"
|
||||
half = { "version" = "=2.4.1", default-features = false, features = [
|
||||
"num-traits",
|
||||
|
||||
@@ -19,6 +19,7 @@ arrow-ord = { workspace = true }
|
||||
arrow-cast = { workspace = true }
|
||||
arrow-ipc.workspace = true
|
||||
chrono = { workspace = true }
|
||||
datafusion-common.workspace = true
|
||||
datafusion-physical-plan.workspace = true
|
||||
object_store = { workspace = true }
|
||||
snafu = { workspace = true }
|
||||
|
||||
@@ -254,6 +254,12 @@ pub enum DistanceType {
|
||||
Hamming,
|
||||
}
|
||||
|
||||
impl Default for DistanceType {
|
||||
fn default() -> Self {
|
||||
Self::L2
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DistanceType> for LanceDistanceType {
|
||||
fn from(value: DistanceType) -> Self {
|
||||
match value {
|
||||
|
||||
@@ -23,3 +23,4 @@ pub mod table;
|
||||
pub mod util;
|
||||
|
||||
const ARROW_STREAM_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";
|
||||
const JSON_CONTENT_TYPE: &str = "application/json";
|
||||
|
||||
@@ -1,17 +1,25 @@
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::table::dataset::DatasetReadGuard;
|
||||
use crate::index::Index;
|
||||
use crate::query::Select;
|
||||
use crate::table::AddDataMode;
|
||||
use crate::utils::{supported_btree_data_type, supported_vector_data_type};
|
||||
use crate::Error;
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_schema::SchemaRef;
|
||||
use arrow_ipc::reader::StreamReader;
|
||||
use arrow_schema::{DataType, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use datafusion_physical_plan::ExecutionPlan;
|
||||
use bytes::Buf;
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
|
||||
use datafusion_physical_plan::{ExecutionPlan, SendableRecordBatchStream};
|
||||
use futures::TryStreamExt;
|
||||
use http::header::CONTENT_TYPE;
|
||||
use http::StatusCode;
|
||||
use lance::arrow::json::JsonSchema;
|
||||
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
|
||||
use lance::dataset::scanner::DatasetRecordBatchStream;
|
||||
use lance::dataset::{ColumnAlteration, NewColumnTransform};
|
||||
use lance_datafusion::exec::OneShotExec;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
@@ -26,7 +34,7 @@ use crate::{
|
||||
};
|
||||
|
||||
use super::client::{HttpSend, RestfulLanceDbClient, Sender};
|
||||
use super::ARROW_STREAM_CONTENT_TYPE;
|
||||
use super::{ARROW_STREAM_CONTENT_TYPE, JSON_CONTENT_TYPE};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RemoteTable<S: HttpSend = Sender> {
|
||||
@@ -85,6 +93,93 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
|
||||
self.client.check_response(response).await
|
||||
}
|
||||
|
||||
async fn read_arrow_stream(
|
||||
&self,
|
||||
body: reqwest::Response,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
// Assert that the content type is correct
|
||||
let content_type = body
|
||||
.headers()
|
||||
.get(CONTENT_TYPE)
|
||||
.ok_or_else(|| Error::Http {
|
||||
message: "Missing content type".into(),
|
||||
})?
|
||||
.to_str()
|
||||
.map_err(|e| Error::Http {
|
||||
message: format!("Failed to parse content type: {}", e),
|
||||
})?;
|
||||
if content_type != ARROW_STREAM_CONTENT_TYPE {
|
||||
return Err(Error::Http {
|
||||
message: format!(
|
||||
"Expected content type {}, got {}",
|
||||
ARROW_STREAM_CONTENT_TYPE, content_type
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
// There isn't a way to actually stream this data yet. I have an upstream issue:
|
||||
// https://github.com/apache/arrow-rs/issues/6420
|
||||
let body = body.bytes().await?;
|
||||
let reader = StreamReader::try_new(body.reader(), None)?;
|
||||
let schema = reader.schema();
|
||||
let stream = futures::stream::iter(reader).map_err(DataFusionError::from);
|
||||
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
|
||||
}
|
||||
|
||||
fn apply_query_params(body: &mut serde_json::Value, params: &Query) -> Result<()> {
|
||||
if params.offset.is_some() {
|
||||
return Err(Error::NotSupported {
|
||||
message: "Offset is not yet supported in LanceDB Cloud".into(),
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(limit) = params.limit {
|
||||
body["k"] = serde_json::Value::Number(serde_json::Number::from(limit));
|
||||
}
|
||||
|
||||
if let Some(filter) = ¶ms.filter {
|
||||
body["filter"] = serde_json::Value::String(filter.clone());
|
||||
}
|
||||
|
||||
match ¶ms.select {
|
||||
Select::All => {}
|
||||
Select::Columns(columns) => {
|
||||
body["columns"] = serde_json::Value::Array(
|
||||
columns
|
||||
.iter()
|
||||
.map(|s| serde_json::Value::String(s.clone()))
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
Select::Dynamic(pairs) => {
|
||||
body["columns"] = serde_json::Value::Array(
|
||||
pairs
|
||||
.iter()
|
||||
.map(|(name, expr)| serde_json::json!([name, expr]))
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if params.fast_search {
|
||||
body["fast_search"] = serde_json::Value::Bool(true);
|
||||
}
|
||||
|
||||
if let Some(full_text_search) = ¶ms.full_text_search {
|
||||
if full_text_search.wand_factor.is_some() {
|
||||
return Err(Error::NotSupported {
|
||||
message: "Wand factor is not yet supported in LanceDB Cloud".into(),
|
||||
});
|
||||
}
|
||||
body["full_text_query"] = serde_json::json!({
|
||||
"columns": full_text_search.columns,
|
||||
"query": full_text_search.query,
|
||||
})
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -196,38 +291,78 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
async fn build_plan(
|
||||
&self,
|
||||
_ds_ref: &DatasetReadGuard,
|
||||
_query: &VectorQuery,
|
||||
_options: Option<QueryExecutionOptions>,
|
||||
) -> Result<Scanner> {
|
||||
Err(Error::NotSupported {
|
||||
message: "build_plan is not supported on LanceDB cloud.".into(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn create_plan(
|
||||
&self,
|
||||
_query: &VectorQuery,
|
||||
query: &VectorQuery,
|
||||
_options: QueryExecutionOptions,
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
Err(Error::NotSupported {
|
||||
message: "create_plan is not supported on LanceDB cloud.".into(),
|
||||
})
|
||||
}
|
||||
async fn explain_plan(&self, _query: &VectorQuery, _verbose: bool) -> Result<String> {
|
||||
Err(Error::NotSupported {
|
||||
message: "explain_plan is not supported on LanceDB cloud.".into(),
|
||||
})
|
||||
let request = self.client.post(&format!("/table/{}/query/", self.name));
|
||||
|
||||
let mut body = serde_json::Value::Object(Default::default());
|
||||
Self::apply_query_params(&mut body, &query.base)?;
|
||||
|
||||
body["prefilter"] = query.prefilter.into();
|
||||
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
|
||||
body["nprobes"] = query.nprobes.into();
|
||||
body["refine_factor"] = query.refine_factor.into();
|
||||
|
||||
if let Some(vector) = query.query_vector.as_ref() {
|
||||
let vector: Vec<f32> = match vector.data_type() {
|
||||
DataType::Float32 => vector
|
||||
.as_any()
|
||||
.downcast_ref::<arrow_array::Float32Array>()
|
||||
.unwrap()
|
||||
.values()
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect(),
|
||||
_ => {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "VectorQuery vector must be of type Float32".into(),
|
||||
})
|
||||
}
|
||||
};
|
||||
body["vector"] = serde_json::json!(vector);
|
||||
}
|
||||
|
||||
if let Some(vector_column) = query.column.as_ref() {
|
||||
body["vector_column"] = serde_json::Value::String(vector_column.clone());
|
||||
}
|
||||
|
||||
if !query.use_index {
|
||||
body["bypass_vector_index"] = serde_json::Value::Bool(true);
|
||||
}
|
||||
|
||||
let request = request.json(&body);
|
||||
|
||||
let response = self.client.send(request).await?;
|
||||
|
||||
let stream = self.read_arrow_stream(response).await?;
|
||||
|
||||
Ok(Arc::new(OneShotExec::new(stream)))
|
||||
}
|
||||
|
||||
async fn plain_query(
|
||||
&self,
|
||||
_query: &Query,
|
||||
query: &Query,
|
||||
_options: QueryExecutionOptions,
|
||||
) -> Result<DatasetRecordBatchStream> {
|
||||
Err(Error::NotSupported {
|
||||
message: "plain_query is not yet supported on LanceDB cloud.".into(),
|
||||
})
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/table/{}/query/", self.name))
|
||||
.header(CONTENT_TYPE, JSON_CONTENT_TYPE);
|
||||
|
||||
let mut body = serde_json::Value::Object(Default::default());
|
||||
Self::apply_query_params(&mut body, query)?;
|
||||
|
||||
let request = request.json(&body);
|
||||
|
||||
let response = self.client.send(request).await?;
|
||||
|
||||
let stream = self.read_arrow_stream(response).await?;
|
||||
|
||||
Ok(DatasetRecordBatchStream::new(stream))
|
||||
}
|
||||
async fn update(&self, update: UpdateBuilder) -> Result<u64> {
|
||||
let request = self.client.post(&format!("/table/{}/update/", self.name));
|
||||
@@ -266,11 +401,79 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
|
||||
self.check_table_response(response).await?;
|
||||
Ok(())
|
||||
}
|
||||
async fn create_index(&self, _index: IndexBuilder) -> Result<()> {
|
||||
Err(Error::NotSupported {
|
||||
message: "create_index is not yet supported on LanceDB cloud.".into(),
|
||||
})
|
||||
|
||||
async fn create_index(&self, mut index: IndexBuilder) -> Result<()> {
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/table/{}/create_index/", self.name));
|
||||
|
||||
let column = match index.columns.len() {
|
||||
0 => {
|
||||
return Err(Error::InvalidInput {
|
||||
message: "No columns specified".into(),
|
||||
})
|
||||
}
|
||||
1 => index.columns.pop().unwrap(),
|
||||
_ => {
|
||||
return Err(Error::NotSupported {
|
||||
message: "Indices over multiple columns not yet supported".into(),
|
||||
})
|
||||
}
|
||||
};
|
||||
let mut body = serde_json::json!({
|
||||
"column": column
|
||||
});
|
||||
|
||||
let (index_type, distance_type) = match index.index {
|
||||
// TODO: Should we pass the actual index parameters? SaaS does not
|
||||
// yet support them.
|
||||
Index::IvfPq(index) => ("IVF_PQ", Some(index.distance_type)),
|
||||
Index::IvfHnswSq(index) => ("IVF_HNSW_SQ", Some(index.distance_type)),
|
||||
Index::BTree(_) => ("BTREE", None),
|
||||
Index::Bitmap(_) => ("BITMAP", None),
|
||||
Index::LabelList(_) => ("LABEL_LIST", None),
|
||||
Index::FTS(_) => ("FTS", None),
|
||||
Index::Auto => {
|
||||
let schema = self.schema().await?;
|
||||
let field = schema
|
||||
.field_with_name(&column)
|
||||
.map_err(|_| Error::InvalidInput {
|
||||
message: format!("Column {} not found in schema", column),
|
||||
})?;
|
||||
if supported_vector_data_type(field.data_type()) {
|
||||
("IVF_PQ", None)
|
||||
} else if supported_btree_data_type(field.data_type()) {
|
||||
("BTREE", None)
|
||||
} else {
|
||||
return Err(Error::NotSupported {
|
||||
message: format!(
|
||||
"there are no indices supported for the field `{}` with the data type {}",
|
||||
field.name(),
|
||||
field.data_type()
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(Error::NotSupported {
|
||||
message: "Index type not supported".into(),
|
||||
})
|
||||
}
|
||||
};
|
||||
body["index_type"] = serde_json::Value::String(index_type.into());
|
||||
if let Some(distance_type) = distance_type {
|
||||
body["distance_type"] = serde_json::Value::String(distance_type.to_string());
|
||||
}
|
||||
|
||||
let request = request.json(&body);
|
||||
|
||||
let response = self.client.send(request).await?;
|
||||
|
||||
self.check_table_response(response).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn merge_insert(
|
||||
&self,
|
||||
params: MergeInsertBuilder,
|
||||
@@ -375,9 +578,14 @@ mod tests {
|
||||
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use futures::{future::BoxFuture, StreamExt, TryFutureExt};
|
||||
use lance_index::scalar::FullTextSearchQuery;
|
||||
use reqwest::Body;
|
||||
|
||||
use crate::{Error, Table};
|
||||
use crate::{
|
||||
index::{vector::IvfPqIndexBuilder, Index},
|
||||
query::{ExecutableQuery, QueryBase},
|
||||
DistanceType, Error, Table,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_not_found() {
|
||||
@@ -468,6 +676,10 @@ mod tests {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/count_rows/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
assert_eq!(request.body().unwrap().as_bytes().unwrap(), br#"{}"#);
|
||||
|
||||
http::Response::builder().status(200).body("42").unwrap()
|
||||
@@ -479,6 +691,10 @@ mod tests {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/count_rows/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
assert_eq!(
|
||||
request.body().unwrap().as_bytes().unwrap(),
|
||||
br#"{"filter":"a > 10"}"#
|
||||
@@ -613,6 +829,10 @@ mod tests {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/update/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
|
||||
if let Some(body) = request.body().unwrap().as_bytes() {
|
||||
let body = std::str::from_utf8(body).unwrap();
|
||||
@@ -720,6 +940,10 @@ mod tests {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/delete/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
|
||||
let body = request.body().unwrap().as_bytes().unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
@@ -731,4 +955,201 @@ mod tests {
|
||||
|
||||
table.delete("id in (1, 2, 3)").await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_query_vector_default_values() {
|
||||
let expected_data = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
let expected_data_ref = expected_data.clone();
|
||||
|
||||
let table = Table::new_with_handler("my_table", move |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/query/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
|
||||
let body = request.body().unwrap().as_bytes().unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
let mut expected_body = serde_json::json!({
|
||||
"prefilter": true,
|
||||
"distance_type": "l2",
|
||||
"nprobes": 20,
|
||||
"refine_factor": null,
|
||||
});
|
||||
// Pass vector separately to make sure it matches f32 precision.
|
||||
expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into();
|
||||
assert_eq!(body, expected_body);
|
||||
|
||||
let response_body = write_ipc_stream(&expected_data_ref);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
.body(response_body)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let data = table
|
||||
.query()
|
||||
.nearest_to(vec![0.1, 0.2, 0.3])
|
||||
.unwrap()
|
||||
.execute()
|
||||
.await;
|
||||
let data = data.unwrap().collect::<Vec<_>>().await;
|
||||
assert_eq!(data.len(), 1);
|
||||
assert_eq!(data[0].as_ref().unwrap(), &expected_data);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_query_vector_all_params() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/query/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
|
||||
let body = request.body().unwrap().as_bytes().unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
let mut expected_body = serde_json::json!({
|
||||
"vector_column": "my_vector",
|
||||
"prefilter": false,
|
||||
"k": 42,
|
||||
"distance_type": "cosine",
|
||||
"bypass_vector_index": true,
|
||||
"columns": ["a", "b"],
|
||||
"nprobes": 12,
|
||||
"refine_factor": 2,
|
||||
});
|
||||
// Pass vector separately to make sure it matches f32 precision.
|
||||
expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into();
|
||||
assert_eq!(body, expected_body);
|
||||
|
||||
let data = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
let response_body = write_ipc_stream(&data);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
.body(response_body)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let _ = table
|
||||
.query()
|
||||
.limit(42)
|
||||
.select(Select::columns(&["a", "b"]))
|
||||
.nearest_to(vec![0.1, 0.2, 0.3])
|
||||
.unwrap()
|
||||
.column("my_vector")
|
||||
.postfilter()
|
||||
.distance_type(crate::DistanceType::Cosine)
|
||||
.nprobes(12)
|
||||
.refine_factor(2)
|
||||
.bypass_vector_index()
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_query_fts() {
|
||||
let table = Table::new_with_handler("my_table", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/query/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
|
||||
let body = request.body().unwrap().as_bytes().unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
let expected_body = serde_json::json!({
|
||||
"full_text_query": {
|
||||
"columns": ["a", "b"],
|
||||
"query": "hello world",
|
||||
},
|
||||
"k": 10,
|
||||
});
|
||||
assert_eq!(body, expected_body);
|
||||
|
||||
let data = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
let response_body = write_ipc_stream(&data);
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE)
|
||||
.body(response_body)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let _ = table
|
||||
.query()
|
||||
.full_text_search(
|
||||
FullTextSearchQuery::new("hello world".into())
|
||||
.columns(Some(vec!["a".into(), "b".into()])),
|
||||
)
|
||||
.limit(10)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_index() {
|
||||
let cases = [
|
||||
("IVF_PQ", Some("l2"), Index::IvfPq(Default::default())),
|
||||
(
|
||||
"IVF_PQ",
|
||||
Some("cosine"),
|
||||
Index::IvfPq(IvfPqIndexBuilder::default().distance_type(DistanceType::Cosine)),
|
||||
),
|
||||
(
|
||||
"IVF_HNSW_SQ",
|
||||
Some("l2"),
|
||||
Index::IvfHnswSq(Default::default()),
|
||||
),
|
||||
// HNSW_PQ isn't yet supported on SaaS
|
||||
("BTREE", None, Index::BTree(Default::default())),
|
||||
("BITMAP", None, Index::Bitmap(Default::default())),
|
||||
("LABEL_LIST", None, Index::LabelList(Default::default())),
|
||||
("FTS", None, Index::FTS(Default::default())),
|
||||
];
|
||||
|
||||
for (index_type, distance_type, index) in cases {
|
||||
let table = Table::new_with_handler("my_table", move |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(request.url().path(), "/table/my_table/create_index/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
let body = request.body().unwrap().as_bytes().unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
let mut expected_body = serde_json::json!({
|
||||
"column": "a",
|
||||
"index_type": index_type,
|
||||
});
|
||||
if let Some(distance_type) = distance_type {
|
||||
expected_body["distance_type"] = distance_type.into();
|
||||
}
|
||||
assert_eq!(body, expected_body);
|
||||
|
||||
http::Response::builder().status(200).body("{}").unwrap()
|
||||
});
|
||||
|
||||
table.create_index(&["a"], index).execute().await.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,8 +21,9 @@ use std::sync::Arc;
|
||||
use arrow::array::AsArray;
|
||||
use arrow::datatypes::Float32Type;
|
||||
use arrow_array::{RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_schema::{DataType, Field, Schema, SchemaRef};
|
||||
use arrow_schema::{Field, Schema, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use datafusion_physical_plan::display::DisplayableExecutionPlan;
|
||||
use datafusion_physical_plan::ExecutionPlan;
|
||||
use lance::dataset::builder::DatasetBuilder;
|
||||
use lance::dataset::cleanup::RemovalStats;
|
||||
@@ -66,9 +67,13 @@ use crate::index::{
|
||||
use crate::query::{
|
||||
IntoQueryVector, Query, QueryExecutionOptions, Select, VectorQuery, DEFAULT_TOP_K,
|
||||
};
|
||||
use crate::utils::{default_vector_column, PatchReadParam, PatchWriteParam};
|
||||
use crate::utils::{
|
||||
default_vector_column, supported_bitmap_data_type, supported_btree_data_type,
|
||||
supported_fts_data_type, supported_label_list_data_type, supported_vector_data_type,
|
||||
PatchReadParam, PatchWriteParam,
|
||||
};
|
||||
|
||||
use self::dataset::{DatasetConsistencyWrapper, DatasetReadGuard};
|
||||
use self::dataset::DatasetConsistencyWrapper;
|
||||
use self::merge::MergeInsertBuilder;
|
||||
|
||||
pub(crate) mod dataset;
|
||||
@@ -375,12 +380,6 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
|
||||
async fn schema(&self) -> Result<SchemaRef>;
|
||||
/// Count the number of rows in this table.
|
||||
async fn count_rows(&self, filter: Option<String>) -> Result<usize>;
|
||||
async fn build_plan(
|
||||
&self,
|
||||
ds_ref: &DatasetReadGuard,
|
||||
query: &VectorQuery,
|
||||
options: Option<QueryExecutionOptions>,
|
||||
) -> Result<Scanner>;
|
||||
async fn create_plan(
|
||||
&self,
|
||||
query: &VectorQuery,
|
||||
@@ -391,7 +390,12 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
|
||||
query: &Query,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<DatasetRecordBatchStream>;
|
||||
async fn explain_plan(&self, query: &VectorQuery, verbose: bool) -> Result<String>;
|
||||
async fn explain_plan(&self, query: &VectorQuery, verbose: bool) -> Result<String> {
|
||||
let plan = self.create_plan(query, Default::default()).await?;
|
||||
let display = DisplayableExecutionPlan::new(plan.as_ref());
|
||||
|
||||
Ok(format!("{}", display.indent(verbose)))
|
||||
}
|
||||
async fn add(
|
||||
&self,
|
||||
add: AddDataBuilder<NoData>,
|
||||
@@ -1088,46 +1092,6 @@ impl NativeTable {
|
||||
Ok(name.to_string())
|
||||
}
|
||||
|
||||
fn supported_btree_data_type(dtype: &DataType) -> bool {
|
||||
dtype.is_integer()
|
||||
|| dtype.is_floating()
|
||||
|| matches!(
|
||||
dtype,
|
||||
DataType::Boolean
|
||||
| DataType::Utf8
|
||||
| DataType::Time32(_)
|
||||
| DataType::Time64(_)
|
||||
| DataType::Date32
|
||||
| DataType::Date64
|
||||
| DataType::Timestamp(_, _)
|
||||
)
|
||||
}
|
||||
|
||||
fn supported_bitmap_data_type(dtype: &DataType) -> bool {
|
||||
dtype.is_integer() || matches!(dtype, DataType::Utf8)
|
||||
}
|
||||
|
||||
fn supported_label_list_data_type(dtype: &DataType) -> bool {
|
||||
match dtype {
|
||||
DataType::List(field) => Self::supported_bitmap_data_type(field.data_type()),
|
||||
DataType::FixedSizeList(field, _) => {
|
||||
Self::supported_bitmap_data_type(field.data_type())
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn supported_fts_data_type(dtype: &DataType) -> bool {
|
||||
matches!(dtype, DataType::Utf8 | DataType::LargeUtf8)
|
||||
}
|
||||
|
||||
fn supported_vector_data_type(dtype: &DataType) -> bool {
|
||||
match dtype {
|
||||
DataType::FixedSizeList(inner, _) => DataType::is_floating(inner.data_type()),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new Table
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -1386,7 +1350,7 @@ impl NativeTable {
|
||||
field: &Field,
|
||||
replace: bool,
|
||||
) -> Result<()> {
|
||||
if !Self::supported_vector_data_type(field.data_type()) {
|
||||
if !supported_vector_data_type(field.data_type()) {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"An IVF PQ index cannot be created on the column `{}` which has data type {}",
|
||||
@@ -1439,7 +1403,7 @@ impl NativeTable {
|
||||
field: &Field,
|
||||
replace: bool,
|
||||
) -> Result<()> {
|
||||
if !Self::supported_vector_data_type(field.data_type()) {
|
||||
if !supported_vector_data_type(field.data_type()) {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"An IVF HNSW PQ index cannot be created on the column `{}` which has data type {}",
|
||||
@@ -1510,7 +1474,7 @@ impl NativeTable {
|
||||
field: &Field,
|
||||
replace: bool,
|
||||
) -> Result<()> {
|
||||
if !Self::supported_vector_data_type(field.data_type()) {
|
||||
if !supported_vector_data_type(field.data_type()) {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"An IVF HNSW SQ index cannot be created on the column `{}` which has data type {}",
|
||||
@@ -1563,10 +1527,10 @@ impl NativeTable {
|
||||
}
|
||||
|
||||
async fn create_auto_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> {
|
||||
if Self::supported_vector_data_type(field.data_type()) {
|
||||
if supported_vector_data_type(field.data_type()) {
|
||||
self.create_ivf_pq_index(IvfPqIndexBuilder::default(), field, opts.replace)
|
||||
.await
|
||||
} else if Self::supported_btree_data_type(field.data_type()) {
|
||||
} else if supported_btree_data_type(field.data_type()) {
|
||||
self.create_btree_index(field, opts).await
|
||||
} else {
|
||||
Err(Error::InvalidInput {
|
||||
@@ -1580,7 +1544,7 @@ impl NativeTable {
|
||||
}
|
||||
|
||||
async fn create_btree_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> {
|
||||
if !Self::supported_btree_data_type(field.data_type()) {
|
||||
if !supported_btree_data_type(field.data_type()) {
|
||||
return Err(Error::Schema {
|
||||
message: format!(
|
||||
"A BTree index cannot be created on the field `{}` which has data type {}",
|
||||
@@ -1607,7 +1571,7 @@ impl NativeTable {
|
||||
}
|
||||
|
||||
async fn create_bitmap_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> {
|
||||
if !Self::supported_bitmap_data_type(field.data_type()) {
|
||||
if !supported_bitmap_data_type(field.data_type()) {
|
||||
return Err(Error::Schema {
|
||||
message: format!(
|
||||
"A Bitmap index cannot be created on the field `{}` which has data type {}",
|
||||
@@ -1634,7 +1598,7 @@ impl NativeTable {
|
||||
}
|
||||
|
||||
async fn create_label_list_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> {
|
||||
if !Self::supported_label_list_data_type(field.data_type()) {
|
||||
if !supported_label_list_data_type(field.data_type()) {
|
||||
return Err(Error::Schema {
|
||||
message: format!(
|
||||
"A LabelList index cannot be created on the field `{}` which has data type {}",
|
||||
@@ -1666,7 +1630,7 @@ impl NativeTable {
|
||||
fts_opts: FtsIndexBuilder,
|
||||
replace: bool,
|
||||
) -> Result<()> {
|
||||
if !Self::supported_fts_data_type(field.data_type()) {
|
||||
if !supported_fts_data_type(field.data_type()) {
|
||||
return Err(Error::Schema {
|
||||
message: format!(
|
||||
"A FTS index cannot be created on the field `{}` which has data type {}",
|
||||
@@ -1887,12 +1851,13 @@ impl TableInternal for NativeTable {
|
||||
Ok(res.rows_updated)
|
||||
}
|
||||
|
||||
async fn build_plan(
|
||||
async fn create_plan(
|
||||
&self,
|
||||
ds_ref: &DatasetReadGuard,
|
||||
query: &VectorQuery,
|
||||
options: Option<QueryExecutionOptions>,
|
||||
) -> Result<Scanner> {
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
let ds_ref = self.dataset.get().await?;
|
||||
|
||||
let mut scanner: Scanner = ds_ref.scan();
|
||||
|
||||
if let Some(query_vector) = query.query_vector.as_ref() {
|
||||
@@ -1966,25 +1931,12 @@ impl TableInternal for NativeTable {
|
||||
scanner.with_row_id();
|
||||
}
|
||||
|
||||
if let Some(opts) = options {
|
||||
scanner.batch_size(opts.max_batch_length as usize);
|
||||
}
|
||||
scanner.batch_size(options.max_batch_length as usize);
|
||||
|
||||
if query.base.fast_search {
|
||||
scanner.fast_search();
|
||||
}
|
||||
|
||||
Ok(scanner)
|
||||
}
|
||||
|
||||
async fn create_plan(
|
||||
&self,
|
||||
query: &VectorQuery,
|
||||
options: QueryExecutionOptions,
|
||||
) -> Result<Arc<dyn ExecutionPlan>> {
|
||||
let ds_ref = self.dataset.get().await?;
|
||||
|
||||
let mut scanner = self.build_plan(&ds_ref, query, Some(options)).await?;
|
||||
|
||||
match &query.base.select {
|
||||
Select::Columns(select) => {
|
||||
scanner.project(select.as_slice())?;
|
||||
@@ -2023,16 +1975,6 @@ impl TableInternal for NativeTable {
|
||||
.await
|
||||
}
|
||||
|
||||
async fn explain_plan(&self, query: &VectorQuery, verbose: bool) -> Result<String> {
|
||||
let ds_ref = self.dataset.get().await?;
|
||||
|
||||
let scanner = self.build_plan(&ds_ref, query, None).await?;
|
||||
|
||||
let plan = scanner.explain_plan(verbose).await?;
|
||||
|
||||
Ok(plan)
|
||||
}
|
||||
|
||||
async fn merge_insert(
|
||||
&self,
|
||||
params: MergeInsertBuilder,
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_schema::Schema;
|
||||
use arrow_schema::{DataType, Schema};
|
||||
use lance::dataset::{ReadParams, WriteParams};
|
||||
use lance::io::{ObjectStoreParams, WrappingObjectStore};
|
||||
use lazy_static::lazy_static;
|
||||
@@ -137,6 +137,44 @@ pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result
|
||||
}
|
||||
}
|
||||
|
||||
pub fn supported_btree_data_type(dtype: &DataType) -> bool {
|
||||
dtype.is_integer()
|
||||
|| dtype.is_floating()
|
||||
|| matches!(
|
||||
dtype,
|
||||
DataType::Boolean
|
||||
| DataType::Utf8
|
||||
| DataType::Time32(_)
|
||||
| DataType::Time64(_)
|
||||
| DataType::Date32
|
||||
| DataType::Date64
|
||||
| DataType::Timestamp(_, _)
|
||||
)
|
||||
}
|
||||
|
||||
pub fn supported_bitmap_data_type(dtype: &DataType) -> bool {
|
||||
dtype.is_integer() || matches!(dtype, DataType::Utf8)
|
||||
}
|
||||
|
||||
pub fn supported_label_list_data_type(dtype: &DataType) -> bool {
|
||||
match dtype {
|
||||
DataType::List(field) => supported_bitmap_data_type(field.data_type()),
|
||||
DataType::FixedSizeList(field, _) => supported_bitmap_data_type(field.data_type()),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn supported_fts_data_type(dtype: &DataType) -> bool {
|
||||
matches!(dtype, DataType::Utf8 | DataType::LargeUtf8)
|
||||
}
|
||||
|
||||
pub fn supported_vector_data_type(dtype: &DataType) -> bool {
|
||||
match dtype {
|
||||
DataType::FixedSizeList(inner, _) => DataType::is_floating(inner.data_type()),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
Reference in New Issue
Block a user