diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index e95d3d951..84a067771 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -218,8 +218,6 @@ class RemoteTable(Table): train: bool = True, ): """Create an index on the table. - Currently, the only parameters that matter are - the metric and the vector column name. Parameters ---------- @@ -250,11 +248,6 @@ class RemoteTable(Table): >>> table.create_index("l2", "vector") # doctest: +SKIP """ - if num_sub_vectors is not None: - logging.warning( - "num_sub_vectors is not supported on LanceDB cloud." - "This parameter will be tuned automatically." - ) if accelerator is not None: logging.warning( "GPU accelerator is not yet supported on LanceDB cloud." diff --git a/rust/lancedb/src/index/scalar.rs b/rust/lancedb/src/index/scalar.rs index 980b57d25..370dd5abb 100644 --- a/rust/lancedb/src/index/scalar.rs +++ b/rust/lancedb/src/index/scalar.rs @@ -27,7 +27,7 @@ /// /// The btree index does not currently have any parameters though parameters such as the /// block size may be added in the future. -#[derive(Default, Debug, Clone)] +#[derive(Default, Debug, Clone, serde::Serialize)] pub struct BTreeIndexBuilder {} impl BTreeIndexBuilder {} @@ -39,7 +39,7 @@ impl BTreeIndexBuilder {} /// This index works best for low-cardinality (i.e., less than 1000 unique values) columns, /// where the number of unique values is small. /// The bitmap stores a list of row ids where the value is present. -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, serde::Serialize)] pub struct BitmapIndexBuilder {} /// Builder for LabelList index. @@ -48,7 +48,7 @@ pub struct BitmapIndexBuilder {} /// support queries with `array_contains_all` and `array_contains_any` /// using an underlying bitmap index. /// -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, serde::Serialize)] pub struct LabelListIndexBuilder {} pub use lance_index::scalar::inverted::query::*; diff --git a/rust/lancedb/src/index/vector.rs b/rust/lancedb/src/index/vector.rs index 65990cf03..a5507f41c 100644 --- a/rust/lancedb/src/index/vector.rs +++ b/rust/lancedb/src/index/vector.rs @@ -7,6 +7,7 @@ //! Vector indices are only supported on fixed-size-list (tensor) columns of floating point //! values use lance::table::format::{IndexMetadata, Manifest}; +use serde::Serialize; use crate::DistanceType; @@ -181,14 +182,17 @@ macro_rules! impl_hnsw_params_setter { /// The partitioning process is called IVF and the `num_partitions` parameter controls how many groups to create. /// /// Note that training an IVF Flat index on a large dataset is a slow operation and currently is also a memory intensive operation. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize)] pub struct IvfFlatIndexBuilder { + #[serde(rename = "metric_type")] pub(crate) distance_type: DistanceType, // IVF + #[serde(skip_serializing_if = "Option::is_none")] pub(crate) num_partitions: Option, pub(crate) sample_rate: u32, pub(crate) max_iterations: u32, + #[serde(skip_serializing_if = "Option::is_none")] pub(crate) target_partition_size: Option, } @@ -213,14 +217,17 @@ impl IvfFlatIndexBuilder { /// /// This index compresses vectors using scalar quantization and groups them into IVF partitions. /// It offers a balance between search performance and storage footprint. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize)] pub struct IvfSqIndexBuilder { + #[serde(rename = "metric_type")] pub(crate) distance_type: DistanceType, // IVF + #[serde(skip_serializing_if = "Option::is_none")] pub(crate) num_partitions: Option, pub(crate) sample_rate: u32, pub(crate) max_iterations: u32, + #[serde(skip_serializing_if = "Option::is_none")] pub(crate) target_partition_size: Option, } @@ -261,18 +268,23 @@ impl IvfSqIndexBuilder { /// /// Note that training an IVF PQ index on a large dataset is a slow operation and /// currently is also a memory intensive operation. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize)] pub struct IvfPqIndexBuilder { + #[serde(rename = "metric_type")] pub(crate) distance_type: DistanceType, // IVF + #[serde(skip_serializing_if = "Option::is_none")] pub(crate) num_partitions: Option, pub(crate) sample_rate: u32, pub(crate) max_iterations: u32, + #[serde(skip_serializing_if = "Option::is_none")] pub(crate) target_partition_size: Option, // PQ + #[serde(skip_serializing_if = "Option::is_none")] pub(crate) num_sub_vectors: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub(crate) num_bits: Option, } @@ -323,14 +335,18 @@ pub(crate) fn suggested_num_sub_vectors(dim: u32) -> u32 { /// /// Note that training an IVF RQ index on a large dataset is a slow operation and /// currently is also a memory intensive operation. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize)] pub struct IvfRqIndexBuilder { // IVF + #[serde(rename = "metric_type")] pub(crate) distance_type: DistanceType, + #[serde(skip_serializing_if = "Option::is_none")] pub(crate) num_partitions: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub(crate) num_bits: Option, pub(crate) sample_rate: u32, pub(crate) max_iterations: u32, + #[serde(skip_serializing_if = "Option::is_none")] pub(crate) target_partition_size: Option, } @@ -365,13 +381,16 @@ impl IvfRqIndexBuilder { /// quickly find the closest vectors to a query vector. /// /// The PQ (product quantizer) is used to compress the vectors as the same as IVF PQ. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize)] pub struct IvfHnswPqIndexBuilder { // IVF + #[serde(rename = "metric_type")] pub(crate) distance_type: DistanceType, + #[serde(skip_serializing_if = "Option::is_none")] pub(crate) num_partitions: Option, pub(crate) sample_rate: u32, pub(crate) max_iterations: u32, + #[serde(skip_serializing_if = "Option::is_none")] pub(crate) target_partition_size: Option, // HNSW @@ -379,7 +398,9 @@ pub struct IvfHnswPqIndexBuilder { pub(crate) ef_construction: u32, // PQ + #[serde(skip_serializing_if = "Option::is_none")] pub(crate) num_sub_vectors: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub(crate) num_bits: Option, } @@ -415,13 +436,16 @@ impl IvfHnswPqIndexBuilder { /// /// The SQ (scalar quantizer) is used to compress the vectors, /// each vector is mapped to a 8-bit integer vector, 4x compression ratio for float32 vector. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize)] pub struct IvfHnswSqIndexBuilder { // IVF + #[serde(rename = "metric_type")] pub(crate) distance_type: DistanceType, + #[serde(skip_serializing_if = "Option::is_none")] pub(crate) num_partitions: Option, pub(crate) sample_rate: u32, pub(crate) max_iterations: u32, + #[serde(skip_serializing_if = "Option::is_none")] pub(crate) target_partition_size: Option, // HNSW diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index aa3c3fe97..6e290c4df 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -1276,73 +1276,24 @@ impl BaseTable for RemoteTable { ); } - match index.index { - // TODO: Should we pass the actual index parameters? SaaS does not - // yet support them. - Index::IvfFlat(index) => { - body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_FLAT".to_string()); - body[METRIC_TYPE_KEY] = - serde_json::Value::String(index.distance_type.to_string().to_lowercase()); - if let Some(num_partitions) = index.num_partitions { - body["num_partitions"] = serde_json::Value::Number(num_partitions.into()); - } - } - Index::IvfPq(index) => { - body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_PQ".to_string()); - body[METRIC_TYPE_KEY] = - serde_json::Value::String(index.distance_type.to_string().to_lowercase()); - if let Some(num_partitions) = index.num_partitions { - body["num_partitions"] = serde_json::Value::Number(num_partitions.into()); - } - if let Some(num_bits) = index.num_bits { - body["num_bits"] = serde_json::Value::Number(num_bits.into()); - } - } - Index::IvfSq(index) => { - body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_SQ".to_string()); - body[METRIC_TYPE_KEY] = - serde_json::Value::String(index.distance_type.to_string().to_lowercase()); - if let Some(num_partitions) = index.num_partitions { - body["num_partitions"] = serde_json::Value::Number(num_partitions.into()); - } - } - Index::IvfHnswSq(index) => { - body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_HNSW_SQ".to_string()); - body[METRIC_TYPE_KEY] = - serde_json::Value::String(index.distance_type.to_string().to_lowercase()); - if let Some(num_partitions) = index.num_partitions { - body["num_partitions"] = serde_json::Value::Number(num_partitions.into()); - } - } - Index::IvfRq(index) => { - body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_RQ".to_string()); - body[METRIC_TYPE_KEY] = - serde_json::Value::String(index.distance_type.to_string().to_lowercase()); - if let Some(num_partitions) = index.num_partitions { - body["num_partitions"] = serde_json::Value::Number(num_partitions.into()); - } - if let Some(num_bits) = index.num_bits { - body["num_bits"] = serde_json::Value::Number(num_bits.into()); - } - } - Index::BTree(_) => { - body[INDEX_TYPE_KEY] = serde_json::Value::String("BTREE".to_string()); - } - Index::Bitmap(_) => { - body[INDEX_TYPE_KEY] = serde_json::Value::String("BITMAP".to_string()); - } - Index::LabelList(_) => { - body[INDEX_TYPE_KEY] = serde_json::Value::String("LABEL_LIST".to_string()); - } - Index::FTS(fts) => { - body[INDEX_TYPE_KEY] = serde_json::Value::String("FTS".to_string()); - let params = serde_json::to_value(&fts).map_err(|e| Error::InvalidInput { - message: format!("failed to serialize FTS index params {:?}", e), - })?; - for (key, value) in params.as_object().unwrap() { - body[key] = value.clone(); - } - } + fn to_json(params: &impl serde::Serialize) -> crate::Result { + serde_json::to_value(params).map_err(|e| Error::InvalidInput { + message: format!("failed to serialize index params {:?}", e), + }) + } + + // Map each Index variant to its wire type name and serializable params. + // Auto is special-cased since it needs schema inspection. + let (index_type_str, params) = match &index.index { + Index::IvfFlat(p) => ("IVF_FLAT", Some(to_json(p)?)), + Index::IvfPq(p) => ("IVF_PQ", Some(to_json(p)?)), + Index::IvfSq(p) => ("IVF_SQ", Some(to_json(p)?)), + Index::IvfHnswSq(p) => ("IVF_HNSW_SQ", Some(to_json(p)?)), + Index::IvfRq(p) => ("IVF_RQ", Some(to_json(p)?)), + Index::BTree(p) => ("BTREE", Some(to_json(p)?)), + Index::Bitmap(p) => ("BITMAP", Some(to_json(p)?)), + Index::LabelList(p) => ("LABEL_LIST", Some(to_json(p)?)), + Index::FTS(p) => ("FTS", Some(to_json(p)?)), Index::Auto => { let schema = self.schema().await?; let field = schema @@ -1351,11 +1302,11 @@ impl BaseTable for RemoteTable { message: format!("Column {} not found in schema", column), })?; if supported_vector_data_type(field.data_type()) { - body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_PQ".to_string()); body[METRIC_TYPE_KEY] = serde_json::Value::String(DistanceType::L2.to_string().to_lowercase()); + ("IVF_PQ", None) } else if supported_btree_data_type(field.data_type()) { - body[INDEX_TYPE_KEY] = serde_json::Value::String("BTREE".to_string()); + ("BTREE", None) } else { return Err(Error::NotSupported { message: format!( @@ -1373,6 +1324,13 @@ impl BaseTable for RemoteTable { } }; + body[INDEX_TYPE_KEY] = index_type_str.into(); + if let Some(params) = params { + for (key, value) in params.as_object().expect("params should be a JSON object") { + body[key] = value.clone(); + } + } + let request = request.json(&body); let (request_id, response) = self.send(request, true).await?; @@ -1833,7 +1791,9 @@ mod tests { use rstest::rstest; use serde_json::json; - use crate::index::vector::{IvfFlatIndexBuilder, IvfHnswSqIndexBuilder}; + use crate::index::vector::{ + IvfFlatIndexBuilder, IvfHnswSqIndexBuilder, IvfRqIndexBuilder, IvfSqIndexBuilder, + }; use crate::remote::db::DEFAULT_SERVER_VERSION; use crate::remote::JSON_CONTENT_TYPE; use crate::utils::background_cache::clock; @@ -2995,6 +2955,8 @@ mod tests { "IVF_FLAT", json!({ "metric_type": "hamming", + "sample_rate": 256, + "max_iterations": 50, }), Index::IvfFlat(IvfFlatIndexBuilder::default().distance_type(DistanceType::Hamming)), ), @@ -3003,6 +2965,8 @@ mod tests { json!({ "metric_type": "hamming", "num_partitions": 128, + "sample_rate": 256, + "max_iterations": 50, }), Index::IvfFlat( IvfFlatIndexBuilder::default() @@ -3014,6 +2978,8 @@ mod tests { "IVF_PQ", json!({ "metric_type": "l2", + "sample_rate": 256, + "max_iterations": 50, }), Index::IvfPq(Default::default()), ), @@ -3023,6 +2989,8 @@ mod tests { "metric_type": "cosine", "num_partitions": 128, "num_bits": 4, + "sample_rate": 256, + "max_iterations": 50, }), Index::IvfPq( IvfPqIndexBuilder::default() @@ -3031,10 +2999,29 @@ mod tests { .num_bits(4), ), ), + ( + "IVF_PQ", + json!({ + "metric_type": "l2", + "num_sub_vectors": 16, + "sample_rate": 512, + "max_iterations": 100, + }), + Index::IvfPq( + IvfPqIndexBuilder::default() + .num_sub_vectors(16) + .sample_rate(512) + .max_iterations(100), + ), + ), ( "IVF_HNSW_SQ", json!({ "metric_type": "l2", + "sample_rate": 256, + "max_iterations": 50, + "m": 20, + "ef_construction": 300, }), Index::IvfHnswSq(Default::default()), ), @@ -3043,11 +3030,65 @@ mod tests { json!({ "metric_type": "l2", "num_partitions": 128, + "sample_rate": 256, + "max_iterations": 50, + "m": 40, + "ef_construction": 500, }), Index::IvfHnswSq( IvfHnswSqIndexBuilder::default() .distance_type(DistanceType::L2) - .num_partitions(128), + .num_partitions(128) + .num_edges(40) + .ef_construction(500), + ), + ), + ( + "IVF_SQ", + json!({ + "metric_type": "l2", + "sample_rate": 256, + "max_iterations": 50, + }), + Index::IvfSq(Default::default()), + ), + ( + "IVF_SQ", + json!({ + "metric_type": "cosine", + "num_partitions": 64, + "sample_rate": 256, + "max_iterations": 50, + }), + Index::IvfSq( + IvfSqIndexBuilder::default() + .distance_type(DistanceType::Cosine) + .num_partitions(64), + ), + ), + ( + "IVF_RQ", + json!({ + "metric_type": "l2", + "sample_rate": 256, + "max_iterations": 50, + }), + Index::IvfRq(Default::default()), + ), + ( + "IVF_RQ", + json!({ + "metric_type": "cosine", + "num_partitions": 64, + "num_bits": 8, + "sample_rate": 256, + "max_iterations": 50, + }), + Index::IvfRq( + IvfRqIndexBuilder::default() + .distance_type(DistanceType::Cosine) + .num_partitions(64) + .num_bits(8), ), ), // HNSW_PQ isn't yet supported on SaaS