From e70fd4feccd7187018d4b45618b775ff180dd41d Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 25 Dec 2024 02:36:20 +0800 Subject: [PATCH] feat: support IVF_FLAT, binary vectors and hamming distance (#1955) binary vectors and hamming distance can work on only IVF_FLAT, so introduce them all in this PR. --------- Signed-off-by: BubbleCal --- docs/src/python/python.md | 4 + docs/src/search.md | 39 +++++++- docs/test/md_testing.py | 1 + nodejs/src/util.rs | 3 +- python/python/lancedb/index.py | 93 +++++++++++++++++- python/python/lancedb/table.py | 24 +++-- .../python/tests/docs/test_binary_vector.py | 44 +++++++++ python/python/tests/test_index.py | 47 ++++++++- python/src/index.rs | 21 ++++ python/src/util.rs | 3 +- rust/lancedb/src/index.rs | 8 ++ rust/lancedb/src/index/vector.rs | 37 ++++++++ rust/lancedb/src/table.rs | 95 +++++++++++++++---- rust/lancedb/src/utils.rs | 6 +- 14 files changed, 390 insertions(+), 35 deletions(-) create mode 100644 python/python/tests/docs/test_binary_vector.py diff --git a/docs/src/python/python.md b/docs/src/python/python.md index c250c5f3..9a9dcf55 100644 --- a/docs/src/python/python.md +++ b/docs/src/python/python.md @@ -129,8 +129,12 @@ lists the indices that LanceDb supports. ::: lancedb.index.LabelList +::: lancedb.index.FTS + ::: lancedb.index.IvfPq +::: lancedb.index.IvfFlat + ## Querying (Asynchronous) Queries allow you to return data from your database. Basic queries can be diff --git a/docs/src/search.md b/docs/src/search.md index f207a942..3420abab 100644 --- a/docs/src/search.md +++ b/docs/src/search.md @@ -13,11 +13,15 @@ A vector search finds the approximate or exact nearest neighbors to a given quer Distance metrics are a measure of the similarity between a pair of vectors. Currently, LanceDB supports the following metrics: -| Metric | Description | -| -------- | --------------------------------------------------------------------------- | -| `l2` | [Euclidean / L2 distance](https://en.wikipedia.org/wiki/Euclidean_distance) | -| `cosine` | [Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity) | -| `dot` | [Dot Production](https://en.wikipedia.org/wiki/Dot_product) | +| Metric | Description | +| --------- | --------------------------------------------------------------------------- | +| `l2` | [Euclidean / L2 distance](https://en.wikipedia.org/wiki/Euclidean_distance) | +| `cosine` | [Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity) | +| `dot` | [Dot Production](https://en.wikipedia.org/wiki/Dot_product) | +| `hamming` | [Hamming Distance](https://en.wikipedia.org/wiki/Hamming_distance) | + +!!! note + The `hamming` metric is only available for binary vectors. ## Exhaustive search (kNN) @@ -107,6 +111,31 @@ an ANN search means that using an index often involves a trade-off between recal See the [IVF_PQ index](./concepts/index_ivfpq.md) for a deeper description of how `IVF_PQ` indexes work in LanceDB. +## Binary vector + +LanceDB supports binary vectors as a data type, and has the ability to search binary vectors with hamming distance. The binary vectors are stored as uint8 arrays (every 8 bits are stored as a byte): + +!!! note + The dim of the binary vector must be a multiple of 8. A vector of dim 128 will be stored as a uint8 array of size 16. + +=== "Python" + + === "sync API" + + ```python + --8<-- "python/python/tests/docs/test_binary_vector.py:imports" + + --8<-- "python/python/tests/docs/test_binary_vector.py:sync_binary_vector" + ``` + + === "async API" + + ```python + --8<-- "python/python/tests/docs/test_binary_vector.py:imports" + + --8<-- "python/python/tests/docs/test_binary_vector.py:async_binary_vector" + ``` + ## Output search results LanceDB returns vector search results via different formats commonly used in python. diff --git a/docs/test/md_testing.py b/docs/test/md_testing.py index 871c2ccd..08008177 100755 --- a/docs/test/md_testing.py +++ b/docs/test/md_testing.py @@ -16,6 +16,7 @@ excluded_globs = [ "../src/concepts/*.md", "../src/ann_indexes.md", "../src/basic.md", + "../src/search.md", "../src/hybrid_search/hybrid_search.md", "../src/reranking/*.md", "../src/guides/tuning_retrievers/*.md", diff --git a/nodejs/src/util.rs b/nodejs/src/util.rs index 7cca8752..9fb3681b 100644 --- a/nodejs/src/util.rs +++ b/nodejs/src/util.rs @@ -5,8 +5,9 @@ pub fn parse_distance_type(distance_type: impl AsRef) -> napi::Result Ok(DistanceType::L2), "cosine" => Ok(DistanceType::Cosine), "dot" => Ok(DistanceType::Dot), + "hamming" => Ok(DistanceType::Hamming), _ => Err(napi::Error::from_reason(format!( - "Invalid distance type '{}'. Must be one of l2, cosine, or dot", + "Invalid distance type '{}'. Must be one of l2, cosine, dot, or hamming", distance_type.as_ref() ))), } diff --git a/python/python/lancedb/index.py b/python/python/lancedb/index.py index c34d6ad8..63834b8d 100644 --- a/python/python/lancedb/index.py +++ b/python/python/lancedb/index.py @@ -355,6 +355,97 @@ class HnswSq: ef_construction: int = 300 +@dataclass +class IvfFlat: + """Describes an IVF Flat Index + + This index stores raw vectors. + These vectors are grouped into partitions of similar vectors. + Each partition keeps track of a centroid which is + the average value of all vectors in the group. + + Attributes + ---------- + distance_type: str, default "L2" + The distance metric used to train the index + + This is used when training the index to calculate the IVF partitions + (vectors are grouped in partitions with similar vectors according to this + distance type) and to calculate a subvector's code during quantization. + + The distance type used to train an index MUST match the distance type used + to search the index. Failure to do so will yield inaccurate results. + + The following distance types are available: + + "l2" - Euclidean distance. This is a very common distance metric that + accounts for both magnitude and direction when determining the distance + between vectors. L2 distance has a range of [0, ∞). + + "cosine" - Cosine distance. Cosine distance is a distance metric + calculated from the cosine similarity between two vectors. Cosine + similarity is a measure of similarity between two non-zero vectors of an + inner product space. It is defined to equal the cosine of the angle + between them. Unlike L2, the cosine distance is not affected by the + magnitude of the vectors. Cosine distance has a range of [0, 2]. + + Note: the cosine distance is undefined when one (or both) of the vectors + are all zeros (there is no direction). These vectors are invalid and may + never be returned from a vector search. + + "dot" - Dot product. Dot distance is the dot product of two vectors. Dot + distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their + L2 norm is 1), then dot distance is equivalent to the cosine distance. + + "hamming" - Hamming distance. Hamming distance is a distance metric + calculated as the number of positions at which the corresponding bits are + different. Hamming distance has a range of [0, vector dimension]. + + num_partitions: int, default sqrt(num_rows) + The number of IVF partitions to create. + + This value should generally scale with the number of rows in the dataset. + By default the number of partitions is the square root of the number of + rows. + + If this value is too large then the first part of the search (picking the + right partition) will be slow. If this value is too small then the second + part of the search (searching within a partition) will be slow. + + max_iterations: int, default 50 + Max iteration to train kmeans. + + When training an IVF PQ index we use kmeans to calculate the partitions. + This parameter controls how many iterations of kmeans to run. + + Increasing this might improve the quality of the index but in most cases + these extra iterations have diminishing returns. + + The default value is 50. + sample_rate: int, default 256 + The rate used to calculate the number of training vectors for kmeans. + + When an IVF PQ index is trained, we need to calculate partitions. These + are groups of vectors that are similar to each other. To do this we use an + algorithm called kmeans. + + Running kmeans on a large dataset can be slow. To speed this up we run + kmeans on a random sample of the data. This parameter controls the size of + the sample. The total number of vectors used to train the index is + `sample_rate * num_partitions`. + + Increasing this value might improve the quality of the index but in most + cases the default should be sufficient. + + The default value is 256. + """ + + distance_type: Literal["l2", "cosine", "dot", "hamming"] = "l2" + num_partitions: Optional[int] = None + max_iterations: int = 50 + sample_rate: int = 256 + + @dataclass class IvfPq: """Describes an IVF PQ Index @@ -477,4 +568,4 @@ class IvfPq: sample_rate: int = 256 -__all__ = ["BTree", "IvfPq", "HnswPq", "HnswSq", "IndexConfig"] +__all__ = ["BTree", "IvfFlat", "IvfPq", "HnswPq", "HnswSq", "IndexConfig"] diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index ebc04320..fe1c7588 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -34,7 +34,7 @@ from lance.dependencies import _check_for_hugging_face from .common import DATA, VEC, VECTOR_COLUMN_NAME from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry -from .index import BTree, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS +from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS from .merge import LanceMergeInsertBuilder from .pydantic import LanceModel, model_to_dict from .query import ( @@ -433,7 +433,9 @@ class Table(ABC): accelerator: Optional[str] = None, index_cache_size: Optional[int] = None, *, - index_type: Literal["IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"] = "IVF_PQ", + index_type: Literal[ + "IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ" + ] = "IVF_PQ", num_bits: int = 8, max_iterations: int = 50, sample_rate: int = 256, @@ -446,8 +448,9 @@ class Table(ABC): ---------- metric: str, default "L2" The distance metric to use when creating the index. - Valid values are "L2", "cosine", or "dot". + Valid values are "L2", "cosine", "dot", or "hamming". L2 is euclidean distance. + Hamming is available only for binary vectors. num_partitions: int, default 256 The number of IVF partitions to use when creating the index. Default is 256. @@ -1408,7 +1411,9 @@ class LanceTable(Table): accelerator: Optional[str] = None, index_cache_size: Optional[int] = None, num_bits: int = 8, - index_type: Literal["IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"] = "IVF_PQ", + index_type: Literal[ + "IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ" + ] = "IVF_PQ", max_iterations: int = 50, sample_rate: int = 256, m: int = 20, @@ -1432,6 +1437,13 @@ class LanceTable(Table): ) self.checkout_latest() return + elif index_type == "IVF_FLAT": + config = IvfFlat( + distance_type=metric, + num_partitions=num_partitions, + max_iterations=max_iterations, + sample_rate=sample_rate, + ) elif index_type == "IVF_PQ": config = IvfPq( distance_type=metric, @@ -2619,7 +2631,7 @@ class AsyncTable: *, replace: Optional[bool] = None, config: Optional[ - Union[IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS] + Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS] ] = None, ): """Create an index to speed up queries @@ -2648,7 +2660,7 @@ class AsyncTable: """ if config is not None: if not isinstance( - config, (IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS) + config, (IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS) ): raise TypeError( "config must be an instance of IvfPq, HnswPq, HnswSq, BTree," diff --git a/python/python/tests/docs/test_binary_vector.py b/python/python/tests/docs/test_binary_vector.py new file mode 100644 index 00000000..69b466e8 --- /dev/null +++ b/python/python/tests/docs/test_binary_vector.py @@ -0,0 +1,44 @@ +import shutil + +# --8<-- [start:imports] +import lancedb +import numpy as np +import pytest +# --8<-- [end:imports] + +shutil.rmtree("data/binary_lancedb", ignore_errors=True) + + +def test_binary_vector(): + # --8<-- [start:sync_binary_vector] + db = lancedb.connect("data/binary_lancedb") + data = [ + { + "id": i, + "vector": np.random.randint(0, 256, size=16), + } + for i in range(1024) + ] + tbl = db.create_table("my_binary_vectors", data=data) + query = np.random.randint(0, 256, size=16) + tbl.search(query).to_arrow() + # --8<-- [end:sync_binary_vector] + db.drop_table("my_binary_vectors") + + +@pytest.mark.asyncio +async def test_binary_vector_async(): + # --8<-- [start:async_binary_vector] + db = await lancedb.connect_async("data/binary_lancedb") + data = [ + { + "id": i, + "vector": np.random.randint(0, 256, size=16), + } + for i in range(1024) + ] + tbl = await db.create_table("my_binary_vectors", data=data) + query = np.random.randint(0, 256, size=16) + await tbl.query().nearest_to(query).to_arrow() + # --8<-- [end:async_binary_vector] + await db.drop_table("my_binary_vectors") diff --git a/python/python/tests/test_index.py b/python/python/tests/test_index.py index 4c0caf7e..6cdee77c 100644 --- a/python/python/tests/test_index.py +++ b/python/python/tests/test_index.py @@ -8,7 +8,7 @@ import pyarrow as pa import pytest import pytest_asyncio from lancedb import AsyncConnection, AsyncTable, connect_async -from lancedb.index import BTree, IvfPq, Bitmap, LabelList, HnswPq, HnswSq +from lancedb.index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq @pytest_asyncio.fixture @@ -42,6 +42,27 @@ async def some_table(db_async): ) +@pytest_asyncio.fixture +async def binary_table(db_async): + data = [ + { + "id": i, + "vector": [i] * 128, + } + for i in range(NROWS) + ] + return await db_async.create_table( + "binary_table", + data, + schema=pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("vector", pa.list_(pa.uint8(), 128)), + ] + ), + ) + + @pytest.mark.asyncio async def test_create_scalar_index(some_table: AsyncTable): # Can create @@ -143,3 +164,27 @@ async def test_create_hnswsq_index(some_table: AsyncTable): await some_table.create_index("vector", config=HnswSq(num_partitions=10)) indices = await some_table.list_indices() assert len(indices) == 1 + + +@pytest.mark.asyncio +async def test_create_index_with_binary_vectors(binary_table: AsyncTable): + await binary_table.create_index( + "vector", config=IvfFlat(distance_type="hamming", num_partitions=10) + ) + indices = await binary_table.list_indices() + assert len(indices) == 1 + assert indices[0].index_type == "IvfFlat" + assert indices[0].columns == ["vector"] + assert indices[0].name == "vector_idx" + + stats = await binary_table.index_stats("vector_idx") + assert stats.index_type == "IVF_FLAT" + assert stats.distance_type == "hamming" + assert stats.num_indexed_rows == await binary_table.count_rows() + assert stats.num_unindexed_rows == 0 + assert stats.num_indices == 1 + + # the dataset contains vectors with all values from 0 to 255 + for v in range(256): + res = await binary_table.query().nearest_to([v] * 128).to_arrow() + assert res["id"][0].as_py() == v diff --git a/python/src/index.rs b/python/src/index.rs index be6c2269..48b5e074 100644 --- a/python/src/index.rs +++ b/python/src/index.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use lancedb::index::vector::IvfFlatIndexBuilder; use lancedb::index::{ scalar::{BTreeIndexBuilder, FtsIndexBuilder, TokenizerConfig}, vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder}, @@ -59,6 +60,18 @@ pub fn extract_index_params(source: &Option>) -> PyResult { + let params = source.extract::()?; + let distance_type = parse_distance_type(params.distance_type)?; + let mut ivf_flat_builder = IvfFlatIndexBuilder::default() + .distance_type(distance_type) + .max_iterations(params.max_iterations) + .sample_rate(params.sample_rate); + if let Some(num_partitions) = params.num_partitions { + ivf_flat_builder = ivf_flat_builder.num_partitions(num_partitions); + } + Ok(LanceDbIndex::IvfFlat(ivf_flat_builder)) + }, "IvfPq" => { let params = source.extract::()?; let distance_type = parse_distance_type(params.distance_type)?; @@ -129,6 +142,14 @@ struct FtsParams { ascii_folding: bool, } +#[derive(FromPyObject)] +struct IvfFlatParams { + distance_type: String, + num_partitions: Option, + max_iterations: u32, + sample_rate: u32, +} + #[derive(FromPyObject)] struct IvfPqParams { distance_type: String, diff --git a/python/src/util.rs b/python/src/util.rs index 60bf8b00..a649f875 100644 --- a/python/src/util.rs +++ b/python/src/util.rs @@ -43,8 +43,9 @@ pub fn parse_distance_type(distance_type: impl AsRef) -> PyResult Ok(DistanceType::L2), "cosine" => Ok(DistanceType::Cosine), "dot" => Ok(DistanceType::Dot), + "hamming" => Ok(DistanceType::Hamming), _ => Err(PyValueError::new_err(format!( - "Invalid distance type '{}'. Must be one of l2, cosine, or dot", + "Invalid distance type '{}'. Must be one of l2, cosine, dot, or hamming", distance_type.as_ref() ))), } diff --git a/rust/lancedb/src/index.rs b/rust/lancedb/src/index.rs index 432e01c2..201ee605 100644 --- a/rust/lancedb/src/index.rs +++ b/rust/lancedb/src/index.rs @@ -17,6 +17,7 @@ use std::sync::Arc; use scalar::FtsIndexBuilder; use serde::Deserialize; use serde_with::skip_serializing_none; +use vector::IvfFlatIndexBuilder; use crate::{table::TableInternal, DistanceType, Error, Result}; @@ -56,6 +57,9 @@ pub enum Index { /// Full text search index using bm25. FTS(FtsIndexBuilder), + /// IVF index + IvfFlat(IvfFlatIndexBuilder), + /// IVF index with Product Quantization IvfPq(IvfPqIndexBuilder), @@ -106,6 +110,8 @@ impl IndexBuilder { #[derive(Debug, Clone, PartialEq, Deserialize)] pub enum IndexType { // Vector + #[serde(alias = "IVF_FLAT")] + IvfFlat, #[serde(alias = "IVF_PQ")] IvfPq, #[serde(alias = "IVF_HNSW_PQ")] @@ -127,6 +133,7 @@ pub enum IndexType { impl std::fmt::Display for IndexType { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { + Self::IvfFlat => write!(f, "IVF_FLAT"), Self::IvfPq => write!(f, "IVF_PQ"), Self::IvfHnswPq => write!(f, "IVF_HNSW_PQ"), Self::IvfHnswSq => write!(f, "IVF_HNSW_SQ"), @@ -147,6 +154,7 @@ impl std::str::FromStr for IndexType { "BITMAP" => Ok(Self::Bitmap), "LABEL_LIST" | "LABELLIST" => Ok(Self::LabelList), "FTS" | "INVERTED" => Ok(Self::FTS), + "IVF_FLAT" => Ok(Self::IvfFlat), "IVF_PQ" => Ok(Self::IvfPq), "IVF_HNSW_PQ" => Ok(Self::IvfHnswPq), "IVF_HNSW_SQ" => Ok(Self::IvfHnswSq), diff --git a/rust/lancedb/src/index/vector.rs b/rust/lancedb/src/index/vector.rs index f338026c..e7f0b6de 100644 --- a/rust/lancedb/src/index/vector.rs +++ b/rust/lancedb/src/index/vector.rs @@ -162,6 +162,43 @@ macro_rules! impl_hnsw_params_setter { }; } +/// Builder for an IVF Flat index. +/// +/// This index stores raw vectors. These vectors are grouped into partitions of similar vectors. +/// Each partition keeps track of a centroid which is the average value of all vectors in the group. +/// +/// During a query the centroids are compared with the query vector to find the closest partitions. +/// The raw vectors in these partitions are then searched to find the closest vectors. +/// +/// The partitioning process is called IVF and the `num_partitions` parameter controls how many groups to create. +/// +/// Note that training an IVF Flat index on a large dataset is a slow operation and currently is also a memory intensive operation. +#[derive(Debug, Clone)] +pub struct IvfFlatIndexBuilder { + pub(crate) distance_type: DistanceType, + + // IVF + pub(crate) num_partitions: Option, + pub(crate) sample_rate: u32, + pub(crate) max_iterations: u32, +} + +impl Default for IvfFlatIndexBuilder { + fn default() -> Self { + Self { + distance_type: DistanceType::L2, + num_partitions: None, + sample_rate: 256, + max_iterations: 50, + } + } +} + +impl IvfFlatIndexBuilder { + impl_distance_type_setter!(); + impl_ivf_params_setter!(); +} + /// Builder for an IVF PQ index. /// /// This index stores a compressed (quantized) copy of every vector. These vectors diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 6ba2f241..4385270e 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -18,9 +18,9 @@ use std::path::Path; use std::sync::Arc; use arrow::array::AsArray; -use arrow::datatypes::Float32Type; +use arrow::datatypes::{Float32Type, UInt8Type}; use arrow_array::{RecordBatchIterator, RecordBatchReader}; -use arrow_schema::{Field, Schema, SchemaRef}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; use datafusion_physical_plan::display::DisplayableExecutionPlan; use datafusion_physical_plan::projection::ProjectionExec; @@ -58,8 +58,8 @@ use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, M use crate::error::{Error, Result}; use crate::index::scalar::FtsIndexBuilder; use crate::index::vector::{ - suggested_num_partitions_for_hnsw, IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, - IvfPqIndexBuilder, VectorIndex, + suggested_num_partitions_for_hnsw, IvfFlatIndexBuilder, IvfHnswPqIndexBuilder, + IvfHnswSqIndexBuilder, IvfPqIndexBuilder, VectorIndex, }; use crate::index::IndexStatistics; use crate::index::{ @@ -1306,6 +1306,44 @@ impl NativeTable { .collect()) } + async fn create_ivf_flat_index( + &self, + index: IvfFlatIndexBuilder, + field: &Field, + replace: bool, + ) -> Result<()> { + if !supported_vector_data_type(field.data_type()) { + return Err(Error::InvalidInput { + message: format!( + "An IVF Flat index cannot be created on the column `{}` which has data type {}", + field.name(), + field.data_type() + ), + }); + } + + let num_partitions = if let Some(n) = index.num_partitions { + n + } else { + suggested_num_partitions(self.count_rows(None).await?) + }; + let mut dataset = self.dataset.get_mut().await?; + let lance_idx_params = lance::index::vector::VectorIndexParams::ivf_flat( + num_partitions as usize, + index.distance_type.into(), + ); + dataset + .create_index( + &[field.name()], + IndexType::Vector, + None, + &lance_idx_params, + replace, + ) + .await?; + Ok(()) + } + async fn create_ivf_pq_index( &self, index: IvfPqIndexBuilder, @@ -1778,6 +1816,10 @@ impl TableInternal for NativeTable { Index::Bitmap(_) => self.create_bitmap_index(field, opts).await, Index::LabelList(_) => self.create_label_list_index(field, opts).await, Index::FTS(fts_opts) => self.create_fts_index(field, fts_opts, opts.replace).await, + Index::IvfFlat(ivf_flat) => { + self.create_ivf_flat_index(ivf_flat, field, opts.replace) + .await + } Index::IvfPq(ivf_pq) => self.create_ivf_pq_index(ivf_pq, field, opts.replace).await, Index::IvfHnswPq(ivf_hnsw_pq) => { self.create_ivf_hnsw_pq_index(ivf_hnsw_pq, field, opts.replace) @@ -1848,14 +1890,21 @@ impl TableInternal for NativeTable { message: format!("Column {} not found in dataset schema", column), })?; - if let arrow_schema::DataType::FixedSizeList(f, dim) = field.data_type() { - if !f.data_type().is_floating() { - return Err(Error::InvalidInput { - message: format!( - "The data type of the vector column '{}' is not a floating point type", - column - ), - }); + let mut is_binary = false; + if let arrow_schema::DataType::FixedSizeList(element, dim) = field.data_type() { + match element.data_type() { + e_type if e_type.is_floating() => {} + e_type if *e_type == DataType::UInt8 => { + is_binary = true; + } + _ => { + return Err(Error::InvalidInput { + message: format!( + "The data type of the vector column '{}' is not a floating point type", + column + ), + }); + } } if dim != query_vector.len() as i32 { return Err(Error::InvalidInput { @@ -1870,12 +1919,22 @@ impl TableInternal for NativeTable { } } - let query_vector = query_vector.as_primitive::(); - scanner.nearest( - &column, - query_vector, - query.base.limit.unwrap_or(DEFAULT_TOP_K), - )?; + if is_binary { + let query_vector = arrow::compute::cast(&query_vector, &DataType::UInt8)?; + let query_vector = query_vector.as_primitive::(); + scanner.nearest( + &column, + query_vector, + query.base.limit.unwrap_or(DEFAULT_TOP_K), + )?; + } else { + let query_vector = query_vector.as_primitive::(); + scanner.nearest( + &column, + query_vector, + query.base.limit.unwrap_or(DEFAULT_TOP_K), + )?; + } } scanner.limit( query.base.limit.map(|limit| limit as i64), diff --git a/rust/lancedb/src/utils.rs b/rust/lancedb/src/utils.rs index 2ba006e2..d1019a9f 100644 --- a/rust/lancedb/src/utils.rs +++ b/rust/lancedb/src/utils.rs @@ -110,7 +110,7 @@ pub(crate) fn default_vector_column(schema: &Schema, dim: Option) -> Result .iter() .filter_map(|field| match field.data_type() { arrow_schema::DataType::FixedSizeList(f, d) - if f.data_type().is_floating() + if (f.data_type().is_floating() || f.data_type() == &DataType::UInt8) && dim.map(|expect| *d == expect).unwrap_or(true) => { Some(field.name()) @@ -171,7 +171,9 @@ pub fn supported_fts_data_type(dtype: &DataType) -> bool { pub fn supported_vector_data_type(dtype: &DataType) -> bool { match dtype { - DataType::FixedSizeList(inner, _) => DataType::is_floating(inner.data_type()), + DataType::FixedSizeList(inner, _) => { + DataType::is_floating(inner.data_type()) || *inner.data_type() == DataType::UInt8 + } _ => false, } }