feat: support IVF_FLAT, binary vectors and hamming distance (#1955)

binary vectors and hamming distance can work on only IVF_FLAT, so
introduce them all in this PR.

---------

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
BubbleCal
2024-12-25 02:36:20 +08:00
committed by GitHub
parent ac0068b80e
commit e70fd4fecc
14 changed files with 390 additions and 35 deletions

View File

@@ -129,8 +129,12 @@ lists the indices that LanceDb supports.
::: lancedb.index.LabelList ::: lancedb.index.LabelList
::: lancedb.index.FTS
::: lancedb.index.IvfPq ::: lancedb.index.IvfPq
::: lancedb.index.IvfFlat
## Querying (Asynchronous) ## Querying (Asynchronous)
Queries allow you to return data from your database. Basic queries can be Queries allow you to return data from your database. Basic queries can be

View File

@@ -13,11 +13,15 @@ A vector search finds the approximate or exact nearest neighbors to a given quer
Distance metrics are a measure of the similarity between a pair of vectors. Distance metrics are a measure of the similarity between a pair of vectors.
Currently, LanceDB supports the following metrics: Currently, LanceDB supports the following metrics:
| Metric | Description | | Metric | Description |
| -------- | --------------------------------------------------------------------------- | | --------- | --------------------------------------------------------------------------- |
| `l2` | [Euclidean / L2 distance](https://en.wikipedia.org/wiki/Euclidean_distance) | | `l2` | [Euclidean / L2 distance](https://en.wikipedia.org/wiki/Euclidean_distance) |
| `cosine` | [Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity) | | `cosine` | [Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity) |
| `dot` | [Dot Production](https://en.wikipedia.org/wiki/Dot_product) | | `dot` | [Dot Production](https://en.wikipedia.org/wiki/Dot_product) |
| `hamming` | [Hamming Distance](https://en.wikipedia.org/wiki/Hamming_distance) |
!!! note
The `hamming` metric is only available for binary vectors.
## Exhaustive search (kNN) ## Exhaustive search (kNN)
@@ -107,6 +111,31 @@ an ANN search means that using an index often involves a trade-off between recal
See the [IVF_PQ index](./concepts/index_ivfpq.md) for a deeper description of how `IVF_PQ` See the [IVF_PQ index](./concepts/index_ivfpq.md) for a deeper description of how `IVF_PQ`
indexes work in LanceDB. indexes work in LanceDB.
## Binary vector
LanceDB supports binary vectors as a data type, and has the ability to search binary vectors with hamming distance. The binary vectors are stored as uint8 arrays (every 8 bits are stored as a byte):
!!! note
The dim of the binary vector must be a multiple of 8. A vector of dim 128 will be stored as a uint8 array of size 16.
=== "Python"
=== "sync API"
```python
--8<-- "python/python/tests/docs/test_binary_vector.py:imports"
--8<-- "python/python/tests/docs/test_binary_vector.py:sync_binary_vector"
```
=== "async API"
```python
--8<-- "python/python/tests/docs/test_binary_vector.py:imports"
--8<-- "python/python/tests/docs/test_binary_vector.py:async_binary_vector"
```
## Output search results ## Output search results
LanceDB returns vector search results via different formats commonly used in python. LanceDB returns vector search results via different formats commonly used in python.

View File

@@ -16,6 +16,7 @@ excluded_globs = [
"../src/concepts/*.md", "../src/concepts/*.md",
"../src/ann_indexes.md", "../src/ann_indexes.md",
"../src/basic.md", "../src/basic.md",
"../src/search.md",
"../src/hybrid_search/hybrid_search.md", "../src/hybrid_search/hybrid_search.md",
"../src/reranking/*.md", "../src/reranking/*.md",
"../src/guides/tuning_retrievers/*.md", "../src/guides/tuning_retrievers/*.md",

View File

@@ -5,8 +5,9 @@ pub fn parse_distance_type(distance_type: impl AsRef<str>) -> napi::Result<Dista
"l2" => Ok(DistanceType::L2), "l2" => Ok(DistanceType::L2),
"cosine" => Ok(DistanceType::Cosine), "cosine" => Ok(DistanceType::Cosine),
"dot" => Ok(DistanceType::Dot), "dot" => Ok(DistanceType::Dot),
"hamming" => Ok(DistanceType::Hamming),
_ => Err(napi::Error::from_reason(format!( _ => Err(napi::Error::from_reason(format!(
"Invalid distance type '{}'. Must be one of l2, cosine, or dot", "Invalid distance type '{}'. Must be one of l2, cosine, dot, or hamming",
distance_type.as_ref() distance_type.as_ref()
))), ))),
} }

View File

@@ -355,6 +355,97 @@ class HnswSq:
ef_construction: int = 300 ef_construction: int = 300
@dataclass
class IvfFlat:
"""Describes an IVF Flat Index
This index stores raw vectors.
These vectors are grouped into partitions of similar vectors.
Each partition keeps track of a centroid which is
the average value of all vectors in the group.
Attributes
----------
distance_type: str, default "L2"
The distance metric used to train the index
This is used when training the index to calculate the IVF partitions
(vectors are grouped in partitions with similar vectors according to this
distance type) and to calculate a subvector's code during quantization.
The distance type used to train an index MUST match the distance type used
to search the index. Failure to do so will yield inaccurate results.
The following distance types are available:
"l2" - Euclidean distance. This is a very common distance metric that
accounts for both magnitude and direction when determining the distance
between vectors. L2 distance has a range of [0, ∞).
"cosine" - Cosine distance. Cosine distance is a distance metric
calculated from the cosine similarity between two vectors. Cosine
similarity is a measure of similarity between two non-zero vectors of an
inner product space. It is defined to equal the cosine of the angle
between them. Unlike L2, the cosine distance is not affected by the
magnitude of the vectors. Cosine distance has a range of [0, 2].
Note: the cosine distance is undefined when one (or both) of the vectors
are all zeros (there is no direction). These vectors are invalid and may
never be returned from a vector search.
"dot" - Dot product. Dot distance is the dot product of two vectors. Dot
distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
L2 norm is 1), then dot distance is equivalent to the cosine distance.
"hamming" - Hamming distance. Hamming distance is a distance metric
calculated as the number of positions at which the corresponding bits are
different. Hamming distance has a range of [0, vector dimension].
num_partitions: int, default sqrt(num_rows)
The number of IVF partitions to create.
This value should generally scale with the number of rows in the dataset.
By default the number of partitions is the square root of the number of
rows.
If this value is too large then the first part of the search (picking the
right partition) will be slow. If this value is too small then the second
part of the search (searching within a partition) will be slow.
max_iterations: int, default 50
Max iteration to train kmeans.
When training an IVF PQ index we use kmeans to calculate the partitions.
This parameter controls how many iterations of kmeans to run.
Increasing this might improve the quality of the index but in most cases
these extra iterations have diminishing returns.
The default value is 50.
sample_rate: int, default 256
The rate used to calculate the number of training vectors for kmeans.
When an IVF PQ index is trained, we need to calculate partitions. These
are groups of vectors that are similar to each other. To do this we use an
algorithm called kmeans.
Running kmeans on a large dataset can be slow. To speed this up we run
kmeans on a random sample of the data. This parameter controls the size of
the sample. The total number of vectors used to train the index is
`sample_rate * num_partitions`.
Increasing this value might improve the quality of the index but in most
cases the default should be sufficient.
The default value is 256.
"""
distance_type: Literal["l2", "cosine", "dot", "hamming"] = "l2"
num_partitions: Optional[int] = None
max_iterations: int = 50
sample_rate: int = 256
@dataclass @dataclass
class IvfPq: class IvfPq:
"""Describes an IVF PQ Index """Describes an IVF PQ Index
@@ -477,4 +568,4 @@ class IvfPq:
sample_rate: int = 256 sample_rate: int = 256
__all__ = ["BTree", "IvfPq", "HnswPq", "HnswSq", "IndexConfig"] __all__ = ["BTree", "IvfFlat", "IvfPq", "HnswPq", "HnswSq", "IndexConfig"]

View File

@@ -34,7 +34,7 @@ from lance.dependencies import _check_for_hugging_face
from .common import DATA, VEC, VECTOR_COLUMN_NAME from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from .index import BTree, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
from .merge import LanceMergeInsertBuilder from .merge import LanceMergeInsertBuilder
from .pydantic import LanceModel, model_to_dict from .pydantic import LanceModel, model_to_dict
from .query import ( from .query import (
@@ -433,7 +433,9 @@ class Table(ABC):
accelerator: Optional[str] = None, accelerator: Optional[str] = None,
index_cache_size: Optional[int] = None, index_cache_size: Optional[int] = None,
*, *,
index_type: Literal["IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"] = "IVF_PQ", index_type: Literal[
"IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
] = "IVF_PQ",
num_bits: int = 8, num_bits: int = 8,
max_iterations: int = 50, max_iterations: int = 50,
sample_rate: int = 256, sample_rate: int = 256,
@@ -446,8 +448,9 @@ class Table(ABC):
---------- ----------
metric: str, default "L2" metric: str, default "L2"
The distance metric to use when creating the index. The distance metric to use when creating the index.
Valid values are "L2", "cosine", or "dot". Valid values are "L2", "cosine", "dot", or "hamming".
L2 is euclidean distance. L2 is euclidean distance.
Hamming is available only for binary vectors.
num_partitions: int, default 256 num_partitions: int, default 256
The number of IVF partitions to use when creating the index. The number of IVF partitions to use when creating the index.
Default is 256. Default is 256.
@@ -1408,7 +1411,9 @@ class LanceTable(Table):
accelerator: Optional[str] = None, accelerator: Optional[str] = None,
index_cache_size: Optional[int] = None, index_cache_size: Optional[int] = None,
num_bits: int = 8, num_bits: int = 8,
index_type: Literal["IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"] = "IVF_PQ", index_type: Literal[
"IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
] = "IVF_PQ",
max_iterations: int = 50, max_iterations: int = 50,
sample_rate: int = 256, sample_rate: int = 256,
m: int = 20, m: int = 20,
@@ -1432,6 +1437,13 @@ class LanceTable(Table):
) )
self.checkout_latest() self.checkout_latest()
return return
elif index_type == "IVF_FLAT":
config = IvfFlat(
distance_type=metric,
num_partitions=num_partitions,
max_iterations=max_iterations,
sample_rate=sample_rate,
)
elif index_type == "IVF_PQ": elif index_type == "IVF_PQ":
config = IvfPq( config = IvfPq(
distance_type=metric, distance_type=metric,
@@ -2619,7 +2631,7 @@ class AsyncTable:
*, *,
replace: Optional[bool] = None, replace: Optional[bool] = None,
config: Optional[ config: Optional[
Union[IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS] Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS]
] = None, ] = None,
): ):
"""Create an index to speed up queries """Create an index to speed up queries
@@ -2648,7 +2660,7 @@ class AsyncTable:
""" """
if config is not None: if config is not None:
if not isinstance( if not isinstance(
config, (IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS) config, (IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS)
): ):
raise TypeError( raise TypeError(
"config must be an instance of IvfPq, HnswPq, HnswSq, BTree," "config must be an instance of IvfPq, HnswPq, HnswSq, BTree,"

View File

@@ -0,0 +1,44 @@
import shutil
# --8<-- [start:imports]
import lancedb
import numpy as np
import pytest
# --8<-- [end:imports]
shutil.rmtree("data/binary_lancedb", ignore_errors=True)
def test_binary_vector():
# --8<-- [start:sync_binary_vector]
db = lancedb.connect("data/binary_lancedb")
data = [
{
"id": i,
"vector": np.random.randint(0, 256, size=16),
}
for i in range(1024)
]
tbl = db.create_table("my_binary_vectors", data=data)
query = np.random.randint(0, 256, size=16)
tbl.search(query).to_arrow()
# --8<-- [end:sync_binary_vector]
db.drop_table("my_binary_vectors")
@pytest.mark.asyncio
async def test_binary_vector_async():
# --8<-- [start:async_binary_vector]
db = await lancedb.connect_async("data/binary_lancedb")
data = [
{
"id": i,
"vector": np.random.randint(0, 256, size=16),
}
for i in range(1024)
]
tbl = await db.create_table("my_binary_vectors", data=data)
query = np.random.randint(0, 256, size=16)
await tbl.query().nearest_to(query).to_arrow()
# --8<-- [end:async_binary_vector]
await db.drop_table("my_binary_vectors")

View File

@@ -8,7 +8,7 @@ import pyarrow as pa
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from lancedb import AsyncConnection, AsyncTable, connect_async from lancedb import AsyncConnection, AsyncTable, connect_async
from lancedb.index import BTree, IvfPq, Bitmap, LabelList, HnswPq, HnswSq from lancedb.index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq
@pytest_asyncio.fixture @pytest_asyncio.fixture
@@ -42,6 +42,27 @@ async def some_table(db_async):
) )
@pytest_asyncio.fixture
async def binary_table(db_async):
data = [
{
"id": i,
"vector": [i] * 128,
}
for i in range(NROWS)
]
return await db_async.create_table(
"binary_table",
data,
schema=pa.schema(
[
pa.field("id", pa.int64()),
pa.field("vector", pa.list_(pa.uint8(), 128)),
]
),
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_scalar_index(some_table: AsyncTable): async def test_create_scalar_index(some_table: AsyncTable):
# Can create # Can create
@@ -143,3 +164,27 @@ async def test_create_hnswsq_index(some_table: AsyncTable):
await some_table.create_index("vector", config=HnswSq(num_partitions=10)) await some_table.create_index("vector", config=HnswSq(num_partitions=10))
indices = await some_table.list_indices() indices = await some_table.list_indices()
assert len(indices) == 1 assert len(indices) == 1
@pytest.mark.asyncio
async def test_create_index_with_binary_vectors(binary_table: AsyncTable):
await binary_table.create_index(
"vector", config=IvfFlat(distance_type="hamming", num_partitions=10)
)
indices = await binary_table.list_indices()
assert len(indices) == 1
assert indices[0].index_type == "IvfFlat"
assert indices[0].columns == ["vector"]
assert indices[0].name == "vector_idx"
stats = await binary_table.index_stats("vector_idx")
assert stats.index_type == "IVF_FLAT"
assert stats.distance_type == "hamming"
assert stats.num_indexed_rows == await binary_table.count_rows()
assert stats.num_unindexed_rows == 0
assert stats.num_indices == 1
# the dataset contains vectors with all values from 0 to 255
for v in range(256):
res = await binary_table.query().nearest_to([v] * 128).to_arrow()
assert res["id"][0].as_py() == v

View File

@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use lancedb::index::vector::IvfFlatIndexBuilder;
use lancedb::index::{ use lancedb::index::{
scalar::{BTreeIndexBuilder, FtsIndexBuilder, TokenizerConfig}, scalar::{BTreeIndexBuilder, FtsIndexBuilder, TokenizerConfig},
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder}, vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
@@ -59,6 +60,18 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
opts.tokenizer_configs = inner_opts; opts.tokenizer_configs = inner_opts;
Ok(LanceDbIndex::FTS(opts)) Ok(LanceDbIndex::FTS(opts))
}, },
"IvfFlat" => {
let params = source.extract::<IvfFlatParams>()?;
let distance_type = parse_distance_type(params.distance_type)?;
let mut ivf_flat_builder = IvfFlatIndexBuilder::default()
.distance_type(distance_type)
.max_iterations(params.max_iterations)
.sample_rate(params.sample_rate);
if let Some(num_partitions) = params.num_partitions {
ivf_flat_builder = ivf_flat_builder.num_partitions(num_partitions);
}
Ok(LanceDbIndex::IvfFlat(ivf_flat_builder))
},
"IvfPq" => { "IvfPq" => {
let params = source.extract::<IvfPqParams>()?; let params = source.extract::<IvfPqParams>()?;
let distance_type = parse_distance_type(params.distance_type)?; let distance_type = parse_distance_type(params.distance_type)?;
@@ -129,6 +142,14 @@ struct FtsParams {
ascii_folding: bool, ascii_folding: bool,
} }
#[derive(FromPyObject)]
struct IvfFlatParams {
distance_type: String,
num_partitions: Option<u32>,
max_iterations: u32,
sample_rate: u32,
}
#[derive(FromPyObject)] #[derive(FromPyObject)]
struct IvfPqParams { struct IvfPqParams {
distance_type: String, distance_type: String,

View File

@@ -43,8 +43,9 @@ pub fn parse_distance_type(distance_type: impl AsRef<str>) -> PyResult<DistanceT
"l2" => Ok(DistanceType::L2), "l2" => Ok(DistanceType::L2),
"cosine" => Ok(DistanceType::Cosine), "cosine" => Ok(DistanceType::Cosine),
"dot" => Ok(DistanceType::Dot), "dot" => Ok(DistanceType::Dot),
"hamming" => Ok(DistanceType::Hamming),
_ => Err(PyValueError::new_err(format!( _ => Err(PyValueError::new_err(format!(
"Invalid distance type '{}'. Must be one of l2, cosine, or dot", "Invalid distance type '{}'. Must be one of l2, cosine, dot, or hamming",
distance_type.as_ref() distance_type.as_ref()
))), ))),
} }

View File

@@ -17,6 +17,7 @@ use std::sync::Arc;
use scalar::FtsIndexBuilder; use scalar::FtsIndexBuilder;
use serde::Deserialize; use serde::Deserialize;
use serde_with::skip_serializing_none; use serde_with::skip_serializing_none;
use vector::IvfFlatIndexBuilder;
use crate::{table::TableInternal, DistanceType, Error, Result}; use crate::{table::TableInternal, DistanceType, Error, Result};
@@ -56,6 +57,9 @@ pub enum Index {
/// Full text search index using bm25. /// Full text search index using bm25.
FTS(FtsIndexBuilder), FTS(FtsIndexBuilder),
/// IVF index
IvfFlat(IvfFlatIndexBuilder),
/// IVF index with Product Quantization /// IVF index with Product Quantization
IvfPq(IvfPqIndexBuilder), IvfPq(IvfPqIndexBuilder),
@@ -106,6 +110,8 @@ impl IndexBuilder {
#[derive(Debug, Clone, PartialEq, Deserialize)] #[derive(Debug, Clone, PartialEq, Deserialize)]
pub enum IndexType { pub enum IndexType {
// Vector // Vector
#[serde(alias = "IVF_FLAT")]
IvfFlat,
#[serde(alias = "IVF_PQ")] #[serde(alias = "IVF_PQ")]
IvfPq, IvfPq,
#[serde(alias = "IVF_HNSW_PQ")] #[serde(alias = "IVF_HNSW_PQ")]
@@ -127,6 +133,7 @@ pub enum IndexType {
impl std::fmt::Display for IndexType { impl std::fmt::Display for IndexType {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self { match self {
Self::IvfFlat => write!(f, "IVF_FLAT"),
Self::IvfPq => write!(f, "IVF_PQ"), Self::IvfPq => write!(f, "IVF_PQ"),
Self::IvfHnswPq => write!(f, "IVF_HNSW_PQ"), Self::IvfHnswPq => write!(f, "IVF_HNSW_PQ"),
Self::IvfHnswSq => write!(f, "IVF_HNSW_SQ"), Self::IvfHnswSq => write!(f, "IVF_HNSW_SQ"),
@@ -147,6 +154,7 @@ impl std::str::FromStr for IndexType {
"BITMAP" => Ok(Self::Bitmap), "BITMAP" => Ok(Self::Bitmap),
"LABEL_LIST" | "LABELLIST" => Ok(Self::LabelList), "LABEL_LIST" | "LABELLIST" => Ok(Self::LabelList),
"FTS" | "INVERTED" => Ok(Self::FTS), "FTS" | "INVERTED" => Ok(Self::FTS),
"IVF_FLAT" => Ok(Self::IvfFlat),
"IVF_PQ" => Ok(Self::IvfPq), "IVF_PQ" => Ok(Self::IvfPq),
"IVF_HNSW_PQ" => Ok(Self::IvfHnswPq), "IVF_HNSW_PQ" => Ok(Self::IvfHnswPq),
"IVF_HNSW_SQ" => Ok(Self::IvfHnswSq), "IVF_HNSW_SQ" => Ok(Self::IvfHnswSq),

View File

@@ -162,6 +162,43 @@ macro_rules! impl_hnsw_params_setter {
}; };
} }
/// Builder for an IVF Flat index.
///
/// This index stores raw vectors. These vectors are grouped into partitions of similar vectors.
/// Each partition keeps track of a centroid which is the average value of all vectors in the group.
///
/// During a query the centroids are compared with the query vector to find the closest partitions.
/// The raw vectors in these partitions are then searched to find the closest vectors.
///
/// The partitioning process is called IVF and the `num_partitions` parameter controls how many groups to create.
///
/// Note that training an IVF Flat index on a large dataset is a slow operation and currently is also a memory intensive operation.
#[derive(Debug, Clone)]
pub struct IvfFlatIndexBuilder {
pub(crate) distance_type: DistanceType,
// IVF
pub(crate) num_partitions: Option<u32>,
pub(crate) sample_rate: u32,
pub(crate) max_iterations: u32,
}
impl Default for IvfFlatIndexBuilder {
fn default() -> Self {
Self {
distance_type: DistanceType::L2,
num_partitions: None,
sample_rate: 256,
max_iterations: 50,
}
}
}
impl IvfFlatIndexBuilder {
impl_distance_type_setter!();
impl_ivf_params_setter!();
}
/// Builder for an IVF PQ index. /// Builder for an IVF PQ index.
/// ///
/// This index stores a compressed (quantized) copy of every vector. These vectors /// This index stores a compressed (quantized) copy of every vector. These vectors

View File

@@ -18,9 +18,9 @@ use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use arrow::array::AsArray; use arrow::array::AsArray;
use arrow::datatypes::Float32Type; use arrow::datatypes::{Float32Type, UInt8Type};
use arrow_array::{RecordBatchIterator, RecordBatchReader}; use arrow_array::{RecordBatchIterator, RecordBatchReader};
use arrow_schema::{Field, Schema, SchemaRef}; use arrow_schema::{DataType, Field, Schema, SchemaRef};
use async_trait::async_trait; use async_trait::async_trait;
use datafusion_physical_plan::display::DisplayableExecutionPlan; use datafusion_physical_plan::display::DisplayableExecutionPlan;
use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::projection::ProjectionExec;
@@ -58,8 +58,8 @@ use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, M
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::index::scalar::FtsIndexBuilder; use crate::index::scalar::FtsIndexBuilder;
use crate::index::vector::{ use crate::index::vector::{
suggested_num_partitions_for_hnsw, IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, suggested_num_partitions_for_hnsw, IvfFlatIndexBuilder, IvfHnswPqIndexBuilder,
IvfPqIndexBuilder, VectorIndex, IvfHnswSqIndexBuilder, IvfPqIndexBuilder, VectorIndex,
}; };
use crate::index::IndexStatistics; use crate::index::IndexStatistics;
use crate::index::{ use crate::index::{
@@ -1306,6 +1306,44 @@ impl NativeTable {
.collect()) .collect())
} }
async fn create_ivf_flat_index(
&self,
index: IvfFlatIndexBuilder,
field: &Field,
replace: bool,
) -> Result<()> {
if !supported_vector_data_type(field.data_type()) {
return Err(Error::InvalidInput {
message: format!(
"An IVF Flat index cannot be created on the column `{}` which has data type {}",
field.name(),
field.data_type()
),
});
}
let num_partitions = if let Some(n) = index.num_partitions {
n
} else {
suggested_num_partitions(self.count_rows(None).await?)
};
let mut dataset = self.dataset.get_mut().await?;
let lance_idx_params = lance::index::vector::VectorIndexParams::ivf_flat(
num_partitions as usize,
index.distance_type.into(),
);
dataset
.create_index(
&[field.name()],
IndexType::Vector,
None,
&lance_idx_params,
replace,
)
.await?;
Ok(())
}
async fn create_ivf_pq_index( async fn create_ivf_pq_index(
&self, &self,
index: IvfPqIndexBuilder, index: IvfPqIndexBuilder,
@@ -1778,6 +1816,10 @@ impl TableInternal for NativeTable {
Index::Bitmap(_) => self.create_bitmap_index(field, opts).await, Index::Bitmap(_) => self.create_bitmap_index(field, opts).await,
Index::LabelList(_) => self.create_label_list_index(field, opts).await, Index::LabelList(_) => self.create_label_list_index(field, opts).await,
Index::FTS(fts_opts) => self.create_fts_index(field, fts_opts, opts.replace).await, Index::FTS(fts_opts) => self.create_fts_index(field, fts_opts, opts.replace).await,
Index::IvfFlat(ivf_flat) => {
self.create_ivf_flat_index(ivf_flat, field, opts.replace)
.await
}
Index::IvfPq(ivf_pq) => self.create_ivf_pq_index(ivf_pq, field, opts.replace).await, Index::IvfPq(ivf_pq) => self.create_ivf_pq_index(ivf_pq, field, opts.replace).await,
Index::IvfHnswPq(ivf_hnsw_pq) => { Index::IvfHnswPq(ivf_hnsw_pq) => {
self.create_ivf_hnsw_pq_index(ivf_hnsw_pq, field, opts.replace) self.create_ivf_hnsw_pq_index(ivf_hnsw_pq, field, opts.replace)
@@ -1848,14 +1890,21 @@ impl TableInternal for NativeTable {
message: format!("Column {} not found in dataset schema", column), message: format!("Column {} not found in dataset schema", column),
})?; })?;
if let arrow_schema::DataType::FixedSizeList(f, dim) = field.data_type() { let mut is_binary = false;
if !f.data_type().is_floating() { if let arrow_schema::DataType::FixedSizeList(element, dim) = field.data_type() {
return Err(Error::InvalidInput { match element.data_type() {
message: format!( e_type if e_type.is_floating() => {}
"The data type of the vector column '{}' is not a floating point type", e_type if *e_type == DataType::UInt8 => {
column is_binary = true;
), }
}); _ => {
return Err(Error::InvalidInput {
message: format!(
"The data type of the vector column '{}' is not a floating point type",
column
),
});
}
} }
if dim != query_vector.len() as i32 { if dim != query_vector.len() as i32 {
return Err(Error::InvalidInput { return Err(Error::InvalidInput {
@@ -1870,12 +1919,22 @@ impl TableInternal for NativeTable {
} }
} }
let query_vector = query_vector.as_primitive::<Float32Type>(); if is_binary {
scanner.nearest( let query_vector = arrow::compute::cast(&query_vector, &DataType::UInt8)?;
&column, let query_vector = query_vector.as_primitive::<UInt8Type>();
query_vector, scanner.nearest(
query.base.limit.unwrap_or(DEFAULT_TOP_K), &column,
)?; query_vector,
query.base.limit.unwrap_or(DEFAULT_TOP_K),
)?;
} else {
let query_vector = query_vector.as_primitive::<Float32Type>();
scanner.nearest(
&column,
query_vector,
query.base.limit.unwrap_or(DEFAULT_TOP_K),
)?;
}
} }
scanner.limit( scanner.limit(
query.base.limit.map(|limit| limit as i64), query.base.limit.map(|limit| limit as i64),

View File

@@ -110,7 +110,7 @@ pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result
.iter() .iter()
.filter_map(|field| match field.data_type() { .filter_map(|field| match field.data_type() {
arrow_schema::DataType::FixedSizeList(f, d) arrow_schema::DataType::FixedSizeList(f, d)
if f.data_type().is_floating() if (f.data_type().is_floating() || f.data_type() == &DataType::UInt8)
&& dim.map(|expect| *d == expect).unwrap_or(true) => && dim.map(|expect| *d == expect).unwrap_or(true) =>
{ {
Some(field.name()) Some(field.name())
@@ -171,7 +171,9 @@ pub fn supported_fts_data_type(dtype: &DataType) -> bool {
pub fn supported_vector_data_type(dtype: &DataType) -> bool { pub fn supported_vector_data_type(dtype: &DataType) -> bool {
match dtype { match dtype {
DataType::FixedSizeList(inner, _) => DataType::is_floating(inner.data_type()), DataType::FixedSizeList(inner, _) => {
DataType::is_floating(inner.data_type()) || *inner.data_type() == DataType::UInt8
}
_ => false, _ => false,
} }
} }