diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index 1a9af9d8..a3523426 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -18,7 +18,7 @@ from lancedb._lancedb import ( UpdateResult, ) from lancedb.embeddings.base import EmbeddingFunctionConfig -from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfFlat, IvfPq, LabelList +from lancedb.index import FTS, BTree, Bitmap, HnswSq, IvfFlat, IvfPq, LabelList from lancedb.remote.db import LOOP import pyarrow as pa @@ -186,6 +186,8 @@ class RemoteTable(Table): accelerator: Optional[str] = None, index_type="vector", wait_timeout: Optional[timedelta] = None, + *, + num_bits: int = 8, ): """Create an index on the table. Currently, the only parameters that matter are @@ -220,11 +222,6 @@ class RemoteTable(Table): >>> table.create_index("l2", "vector") # doctest: +SKIP """ - if num_partitions is not None: - logging.warning( - "num_partitions is not supported on LanceDB cloud." - "This parameter will be tuned automatically." - ) if num_sub_vectors is not None: logging.warning( "num_sub_vectors is not supported on LanceDB cloud." @@ -244,13 +241,21 @@ class RemoteTable(Table): index_type = index_type.upper() if index_type == "VECTOR" or index_type == "IVF_PQ": - config = IvfPq(distance_type=metric) + config = IvfPq( + distance_type=metric, + num_partitions=num_partitions, + num_sub_vectors=num_sub_vectors, + num_bits=num_bits, + ) elif index_type == "IVF_HNSW_PQ": - config = HnswPq(distance_type=metric) + raise ValueError( + "IVF_HNSW_PQ is not supported on LanceDB cloud." + "Please use IVF_HNSW_SQ instead." + ) elif index_type == "IVF_HNSW_SQ": - config = HnswSq(distance_type=metric) + config = HnswSq(distance_type=metric, num_partitions=num_partitions) elif index_type == "IVF_FLAT": - config = IvfFlat(distance_type=metric) + config = IvfFlat(distance_type=metric, num_partitions=num_partitions) else: raise ValueError( f"Unknown vector index type: {index_type}. Valid options are" diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 3a4d4818..e8123b33 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -57,6 +57,8 @@ use crate::{ }; const REQUEST_TIMEOUT_HEADER: HeaderName = HeaderName::from_static("x-request-timeout-ms"); +const METRIC_TYPE_KEY: &str = "metric_type"; +const INDEX_TYPE_KEY: &str = "index_type"; pub struct RemoteTags<'a, S: HttpSend = Sender> { inner: &'a RemoteTable, @@ -997,23 +999,53 @@ impl BaseTable for RemoteTable { "column": column }); - let (index_type, distance_type) = match index.index { + match index.index { // TODO: Should we pass the actual index parameters? SaaS does not // yet support them. - Index::IvfFlat(index) => ("IVF_FLAT", Some(index.distance_type)), - Index::IvfPq(index) => ("IVF_PQ", Some(index.distance_type)), - Index::IvfHnswSq(index) => ("IVF_HNSW_SQ", Some(index.distance_type)), - Index::BTree(_) => ("BTREE", None), - Index::Bitmap(_) => ("BITMAP", None), - Index::LabelList(_) => ("LABEL_LIST", None), + Index::IvfFlat(index) => { + body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_FLAT".to_string()); + body[METRIC_TYPE_KEY] = + serde_json::Value::String(index.distance_type.to_string().to_lowercase()); + if let Some(num_partitions) = index.num_partitions { + body["num_partitions"] = serde_json::Value::Number(num_partitions.into()); + } + } + Index::IvfPq(index) => { + body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_PQ".to_string()); + body[METRIC_TYPE_KEY] = + serde_json::Value::String(index.distance_type.to_string().to_lowercase()); + if let Some(num_partitions) = index.num_partitions { + body["num_partitions"] = serde_json::Value::Number(num_partitions.into()); + } + if let Some(num_bits) = index.num_bits { + body["num_bits"] = serde_json::Value::Number(num_bits.into()); + } + } + Index::IvfHnswSq(index) => { + body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_HNSW_SQ".to_string()); + body[METRIC_TYPE_KEY] = + serde_json::Value::String(index.distance_type.to_string().to_lowercase()); + if let Some(num_partitions) = index.num_partitions { + body["num_partitions"] = serde_json::Value::Number(num_partitions.into()); + } + } + Index::BTree(_) => { + body[INDEX_TYPE_KEY] = serde_json::Value::String("BTREE".to_string()); + } + Index::Bitmap(_) => { + body[INDEX_TYPE_KEY] = serde_json::Value::String("BITMAP".to_string()); + } + Index::LabelList(_) => { + body[INDEX_TYPE_KEY] = serde_json::Value::String("LABEL_LIST".to_string()); + } Index::FTS(fts) => { + body[INDEX_TYPE_KEY] = serde_json::Value::String("FTS".to_string()); let params = serde_json::to_value(&fts).map_err(|e| Error::InvalidInput { message: format!("failed to serialize FTS index params {:?}", e), })?; for (key, value) in params.as_object().unwrap() { body[key] = value.clone(); } - ("FTS", None) } Index::Auto => { let schema = self.schema().await?; @@ -1023,9 +1055,11 @@ impl BaseTable for RemoteTable { message: format!("Column {} not found in schema", column), })?; if supported_vector_data_type(field.data_type()) { - ("IVF_PQ", Some(DistanceType::L2)) + body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_PQ".to_string()); + body[METRIC_TYPE_KEY] = + serde_json::Value::String(DistanceType::L2.to_string().to_lowercase()); } else if supported_btree_data_type(field.data_type()) { - ("BTREE", None) + body[INDEX_TYPE_KEY] = serde_json::Value::String("BTREE".to_string()); } else { return Err(Error::NotSupported { message: format!( @@ -1042,12 +1076,6 @@ impl BaseTable for RemoteTable { }) } }; - body["index_type"] = serde_json::Value::String(index_type.into()); - if let Some(distance_type) = distance_type { - // Phalanx expects this to be lowercase right now. - body["metric_type"] = - serde_json::Value::String(distance_type.to_string().to_lowercase()); - } let request = request.json(&body); @@ -1429,11 +1457,12 @@ mod tests { use chrono::{DateTime, Utc}; use futures::{future::BoxFuture, StreamExt, TryFutureExt}; use lance_index::scalar::inverted::query::MatchQuery; - use lance_index::scalar::FullTextSearchQuery; + use lance_index::scalar::{FullTextSearchQuery, InvertedIndexParams}; use reqwest::Body; use rstest::rstest; + use serde_json::json; - use crate::index::vector::IvfFlatIndexBuilder; + use crate::index::vector::{IvfFlatIndexBuilder, IvfHnswSqIndexBuilder}; use crate::remote::db::DEFAULT_SERVER_VERSION; use crate::remote::JSON_CONTENT_TYPE; use crate::{ @@ -2433,29 +2462,79 @@ mod tests { let cases = [ ( "IVF_FLAT", - Some("hamming"), + json!({ + "metric_type": "hamming", + }), Index::IvfFlat(IvfFlatIndexBuilder::default().distance_type(DistanceType::Hamming)), ), - ("IVF_PQ", Some("l2"), Index::IvfPq(Default::default())), + ( + "IVF_FLAT", + json!({ + "metric_type": "hamming", + "num_partitions": 128, + }), + Index::IvfFlat( + IvfFlatIndexBuilder::default() + .distance_type(DistanceType::Hamming) + .num_partitions(128), + ), + ), ( "IVF_PQ", - Some("cosine"), - Index::IvfPq(IvfPqIndexBuilder::default().distance_type(DistanceType::Cosine)), + json!({ + "metric_type": "l2", + }), + Index::IvfPq(Default::default()), + ), + ( + "IVF_PQ", + json!({ + "metric_type": "cosine", + "num_partitions": 128, + "num_bits": 4, + }), + Index::IvfPq( + IvfPqIndexBuilder::default() + .distance_type(DistanceType::Cosine) + .num_partitions(128) + .num_bits(4), + ), ), ( "IVF_HNSW_SQ", - Some("l2"), + json!({ + "metric_type": "l2", + }), Index::IvfHnswSq(Default::default()), ), + ( + "IVF_HNSW_SQ", + json!({ + "metric_type": "l2", + "num_partitions": 128, + }), + Index::IvfHnswSq( + IvfHnswSqIndexBuilder::default() + .distance_type(DistanceType::L2) + .num_partitions(128), + ), + ), // HNSW_PQ isn't yet supported on SaaS - ("BTREE", None, Index::BTree(Default::default())), - ("BITMAP", None, Index::Bitmap(Default::default())), - ("LABEL_LIST", None, Index::LabelList(Default::default())), - ("FTS", None, Index::FTS(Default::default())), + ("BTREE", json!({}), Index::BTree(Default::default())), + ("BITMAP", json!({}), Index::Bitmap(Default::default())), + ( + "LABEL_LIST", + json!({}), + Index::LabelList(Default::default()), + ), + ( + "FTS", + serde_json::to_value(InvertedIndexParams::default()).unwrap(), + Index::FTS(Default::default()), + ), ]; - for (index_type, distance_type, index) in cases { - let params = index.clone(); + for (index_type, expected_body, index) in cases { let table = Table::new_with_handler("my_table", move |request| { assert_eq!(request.method(), "POST"); assert_eq!(request.url().path(), "/v1/table/my_table/create_index/"); @@ -2465,19 +2544,9 @@ mod tests { ); let body = request.body().unwrap().as_bytes().unwrap(); let body: serde_json::Value = serde_json::from_slice(body).unwrap(); - let mut expected_body = serde_json::json!({ - "column": "a", - "index_type": index_type, - }); - if let Some(distance_type) = distance_type { - expected_body["metric_type"] = distance_type.to_lowercase().into(); - } - if let Index::FTS(fts) = ¶ms { - let params = serde_json::to_value(fts).unwrap(); - for (key, value) in params.as_object().unwrap() { - expected_body[key] = value.clone(); - } - } + let mut expected_body = expected_body.clone(); + expected_body["column"] = "a".into(); + expected_body[INDEX_TYPE_KEY] = index_type.into(); assert_eq!(body, expected_body);