feat: support to specify num_partitions and num_bits (#2488)

This commit is contained in:
BubbleCal
2025-07-09 11:36:09 +08:00
committed by GitHub
parent b64252d4fd
commit cab36d94b2
2 changed files with 126 additions and 52 deletions

View File

@@ -18,7 +18,7 @@ from lancedb._lancedb import (
UpdateResult,
)
from lancedb.embeddings.base import EmbeddingFunctionConfig
from lancedb.index import FTS, BTree, Bitmap, HnswPq, HnswSq, IvfFlat, IvfPq, LabelList
from lancedb.index import FTS, BTree, Bitmap, HnswSq, IvfFlat, IvfPq, LabelList
from lancedb.remote.db import LOOP
import pyarrow as pa
@@ -186,6 +186,8 @@ class RemoteTable(Table):
accelerator: Optional[str] = None,
index_type="vector",
wait_timeout: Optional[timedelta] = None,
*,
num_bits: int = 8,
):
"""Create an index on the table.
Currently, the only parameters that matter are
@@ -220,11 +222,6 @@ class RemoteTable(Table):
>>> table.create_index("l2", "vector") # doctest: +SKIP
"""
if num_partitions is not None:
logging.warning(
"num_partitions is not supported on LanceDB cloud."
"This parameter will be tuned automatically."
)
if num_sub_vectors is not None:
logging.warning(
"num_sub_vectors is not supported on LanceDB cloud."
@@ -244,13 +241,21 @@ class RemoteTable(Table):
index_type = index_type.upper()
if index_type == "VECTOR" or index_type == "IVF_PQ":
config = IvfPq(distance_type=metric)
config = IvfPq(
distance_type=metric,
num_partitions=num_partitions,
num_sub_vectors=num_sub_vectors,
num_bits=num_bits,
)
elif index_type == "IVF_HNSW_PQ":
config = HnswPq(distance_type=metric)
raise ValueError(
"IVF_HNSW_PQ is not supported on LanceDB cloud."
"Please use IVF_HNSW_SQ instead."
)
elif index_type == "IVF_HNSW_SQ":
config = HnswSq(distance_type=metric)
config = HnswSq(distance_type=metric, num_partitions=num_partitions)
elif index_type == "IVF_FLAT":
config = IvfFlat(distance_type=metric)
config = IvfFlat(distance_type=metric, num_partitions=num_partitions)
else:
raise ValueError(
f"Unknown vector index type: {index_type}. Valid options are"

View File

@@ -57,6 +57,8 @@ use crate::{
};
const REQUEST_TIMEOUT_HEADER: HeaderName = HeaderName::from_static("x-request-timeout-ms");
const METRIC_TYPE_KEY: &str = "metric_type";
const INDEX_TYPE_KEY: &str = "index_type";
pub struct RemoteTags<'a, S: HttpSend = Sender> {
inner: &'a RemoteTable<S>,
@@ -997,23 +999,53 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
"column": column
});
let (index_type, distance_type) = match index.index {
match index.index {
// TODO: Should we pass the actual index parameters? SaaS does not
// yet support them.
Index::IvfFlat(index) => ("IVF_FLAT", Some(index.distance_type)),
Index::IvfPq(index) => ("IVF_PQ", Some(index.distance_type)),
Index::IvfHnswSq(index) => ("IVF_HNSW_SQ", Some(index.distance_type)),
Index::BTree(_) => ("BTREE", None),
Index::Bitmap(_) => ("BITMAP", None),
Index::LabelList(_) => ("LABEL_LIST", None),
Index::IvfFlat(index) => {
body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_FLAT".to_string());
body[METRIC_TYPE_KEY] =
serde_json::Value::String(index.distance_type.to_string().to_lowercase());
if let Some(num_partitions) = index.num_partitions {
body["num_partitions"] = serde_json::Value::Number(num_partitions.into());
}
}
Index::IvfPq(index) => {
body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_PQ".to_string());
body[METRIC_TYPE_KEY] =
serde_json::Value::String(index.distance_type.to_string().to_lowercase());
if let Some(num_partitions) = index.num_partitions {
body["num_partitions"] = serde_json::Value::Number(num_partitions.into());
}
if let Some(num_bits) = index.num_bits {
body["num_bits"] = serde_json::Value::Number(num_bits.into());
}
}
Index::IvfHnswSq(index) => {
body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_HNSW_SQ".to_string());
body[METRIC_TYPE_KEY] =
serde_json::Value::String(index.distance_type.to_string().to_lowercase());
if let Some(num_partitions) = index.num_partitions {
body["num_partitions"] = serde_json::Value::Number(num_partitions.into());
}
}
Index::BTree(_) => {
body[INDEX_TYPE_KEY] = serde_json::Value::String("BTREE".to_string());
}
Index::Bitmap(_) => {
body[INDEX_TYPE_KEY] = serde_json::Value::String("BITMAP".to_string());
}
Index::LabelList(_) => {
body[INDEX_TYPE_KEY] = serde_json::Value::String("LABEL_LIST".to_string());
}
Index::FTS(fts) => {
body[INDEX_TYPE_KEY] = serde_json::Value::String("FTS".to_string());
let params = serde_json::to_value(&fts).map_err(|e| Error::InvalidInput {
message: format!("failed to serialize FTS index params {:?}", e),
})?;
for (key, value) in params.as_object().unwrap() {
body[key] = value.clone();
}
("FTS", None)
}
Index::Auto => {
let schema = self.schema().await?;
@@ -1023,9 +1055,11 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
message: format!("Column {} not found in schema", column),
})?;
if supported_vector_data_type(field.data_type()) {
("IVF_PQ", Some(DistanceType::L2))
body[INDEX_TYPE_KEY] = serde_json::Value::String("IVF_PQ".to_string());
body[METRIC_TYPE_KEY] =
serde_json::Value::String(DistanceType::L2.to_string().to_lowercase());
} else if supported_btree_data_type(field.data_type()) {
("BTREE", None)
body[INDEX_TYPE_KEY] = serde_json::Value::String("BTREE".to_string());
} else {
return Err(Error::NotSupported {
message: format!(
@@ -1042,12 +1076,6 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
})
}
};
body["index_type"] = serde_json::Value::String(index_type.into());
if let Some(distance_type) = distance_type {
// Phalanx expects this to be lowercase right now.
body["metric_type"] =
serde_json::Value::String(distance_type.to_string().to_lowercase());
}
let request = request.json(&body);
@@ -1429,11 +1457,12 @@ mod tests {
use chrono::{DateTime, Utc};
use futures::{future::BoxFuture, StreamExt, TryFutureExt};
use lance_index::scalar::inverted::query::MatchQuery;
use lance_index::scalar::FullTextSearchQuery;
use lance_index::scalar::{FullTextSearchQuery, InvertedIndexParams};
use reqwest::Body;
use rstest::rstest;
use serde_json::json;
use crate::index::vector::IvfFlatIndexBuilder;
use crate::index::vector::{IvfFlatIndexBuilder, IvfHnswSqIndexBuilder};
use crate::remote::db::DEFAULT_SERVER_VERSION;
use crate::remote::JSON_CONTENT_TYPE;
use crate::{
@@ -2433,29 +2462,79 @@ mod tests {
let cases = [
(
"IVF_FLAT",
Some("hamming"),
json!({
"metric_type": "hamming",
}),
Index::IvfFlat(IvfFlatIndexBuilder::default().distance_type(DistanceType::Hamming)),
),
("IVF_PQ", Some("l2"), Index::IvfPq(Default::default())),
(
"IVF_FLAT",
json!({
"metric_type": "hamming",
"num_partitions": 128,
}),
Index::IvfFlat(
IvfFlatIndexBuilder::default()
.distance_type(DistanceType::Hamming)
.num_partitions(128),
),
),
(
"IVF_PQ",
Some("cosine"),
Index::IvfPq(IvfPqIndexBuilder::default().distance_type(DistanceType::Cosine)),
json!({
"metric_type": "l2",
}),
Index::IvfPq(Default::default()),
),
(
"IVF_PQ",
json!({
"metric_type": "cosine",
"num_partitions": 128,
"num_bits": 4,
}),
Index::IvfPq(
IvfPqIndexBuilder::default()
.distance_type(DistanceType::Cosine)
.num_partitions(128)
.num_bits(4),
),
),
(
"IVF_HNSW_SQ",
Some("l2"),
json!({
"metric_type": "l2",
}),
Index::IvfHnswSq(Default::default()),
),
(
"IVF_HNSW_SQ",
json!({
"metric_type": "l2",
"num_partitions": 128,
}),
Index::IvfHnswSq(
IvfHnswSqIndexBuilder::default()
.distance_type(DistanceType::L2)
.num_partitions(128),
),
),
// HNSW_PQ isn't yet supported on SaaS
("BTREE", None, Index::BTree(Default::default())),
("BITMAP", None, Index::Bitmap(Default::default())),
("LABEL_LIST", None, Index::LabelList(Default::default())),
("FTS", None, Index::FTS(Default::default())),
("BTREE", json!({}), Index::BTree(Default::default())),
("BITMAP", json!({}), Index::Bitmap(Default::default())),
(
"LABEL_LIST",
json!({}),
Index::LabelList(Default::default()),
),
(
"FTS",
serde_json::to_value(InvertedIndexParams::default()).unwrap(),
Index::FTS(Default::default()),
),
];
for (index_type, distance_type, index) in cases {
let params = index.clone();
for (index_type, expected_body, index) in cases {
let table = Table::new_with_handler("my_table", move |request| {
assert_eq!(request.method(), "POST");
assert_eq!(request.url().path(), "/v1/table/my_table/create_index/");
@@ -2465,19 +2544,9 @@ mod tests {
);
let body = request.body().unwrap().as_bytes().unwrap();
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
let mut expected_body = serde_json::json!({
"column": "a",
"index_type": index_type,
});
if let Some(distance_type) = distance_type {
expected_body["metric_type"] = distance_type.to_lowercase().into();
}
if let Index::FTS(fts) = &params {
let params = serde_json::to_value(fts).unwrap();
for (key, value) in params.as_object().unwrap() {
expected_body[key] = value.clone();
}
}
let mut expected_body = expected_body.clone();
expected_body["column"] = "a".into();
expected_body[INDEX_TYPE_KEY] = index_type.into();
assert_eq!(body, expected_body);