mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-06 20:02:58 +00:00
feat: support IVF_FLAT, binary vectors and hamming distance (#1955)
binary vectors and hamming distance can work on only IVF_FLAT, so introduce them all in this PR. --------- Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
@@ -129,8 +129,12 @@ lists the indices that LanceDb supports.
|
|||||||
|
|
||||||
::: lancedb.index.LabelList
|
::: lancedb.index.LabelList
|
||||||
|
|
||||||
|
::: lancedb.index.FTS
|
||||||
|
|
||||||
::: lancedb.index.IvfPq
|
::: lancedb.index.IvfPq
|
||||||
|
|
||||||
|
::: lancedb.index.IvfFlat
|
||||||
|
|
||||||
## Querying (Asynchronous)
|
## Querying (Asynchronous)
|
||||||
|
|
||||||
Queries allow you to return data from your database. Basic queries can be
|
Queries allow you to return data from your database. Basic queries can be
|
||||||
|
|||||||
@@ -13,11 +13,15 @@ A vector search finds the approximate or exact nearest neighbors to a given quer
|
|||||||
Distance metrics are a measure of the similarity between a pair of vectors.
|
Distance metrics are a measure of the similarity between a pair of vectors.
|
||||||
Currently, LanceDB supports the following metrics:
|
Currently, LanceDB supports the following metrics:
|
||||||
|
|
||||||
| Metric | Description |
|
| Metric | Description |
|
||||||
| -------- | --------------------------------------------------------------------------- |
|
| --------- | --------------------------------------------------------------------------- |
|
||||||
| `l2` | [Euclidean / L2 distance](https://en.wikipedia.org/wiki/Euclidean_distance) |
|
| `l2` | [Euclidean / L2 distance](https://en.wikipedia.org/wiki/Euclidean_distance) |
|
||||||
| `cosine` | [Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity) |
|
| `cosine` | [Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity) |
|
||||||
| `dot` | [Dot Production](https://en.wikipedia.org/wiki/Dot_product) |
|
| `dot` | [Dot Production](https://en.wikipedia.org/wiki/Dot_product) |
|
||||||
|
| `hamming` | [Hamming Distance](https://en.wikipedia.org/wiki/Hamming_distance) |
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
The `hamming` metric is only available for binary vectors.
|
||||||
|
|
||||||
## Exhaustive search (kNN)
|
## Exhaustive search (kNN)
|
||||||
|
|
||||||
@@ -107,6 +111,31 @@ an ANN search means that using an index often involves a trade-off between recal
|
|||||||
See the [IVF_PQ index](./concepts/index_ivfpq.md) for a deeper description of how `IVF_PQ`
|
See the [IVF_PQ index](./concepts/index_ivfpq.md) for a deeper description of how `IVF_PQ`
|
||||||
indexes work in LanceDB.
|
indexes work in LanceDB.
|
||||||
|
|
||||||
|
## Binary vector
|
||||||
|
|
||||||
|
LanceDB supports binary vectors as a data type, and has the ability to search binary vectors with hamming distance. The binary vectors are stored as uint8 arrays (every 8 bits are stored as a byte):
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
The dim of the binary vector must be a multiple of 8. A vector of dim 128 will be stored as a uint8 array of size 16.
|
||||||
|
|
||||||
|
=== "Python"
|
||||||
|
|
||||||
|
=== "sync API"
|
||||||
|
|
||||||
|
```python
|
||||||
|
--8<-- "python/python/tests/docs/test_binary_vector.py:imports"
|
||||||
|
|
||||||
|
--8<-- "python/python/tests/docs/test_binary_vector.py:sync_binary_vector"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "async API"
|
||||||
|
|
||||||
|
```python
|
||||||
|
--8<-- "python/python/tests/docs/test_binary_vector.py:imports"
|
||||||
|
|
||||||
|
--8<-- "python/python/tests/docs/test_binary_vector.py:async_binary_vector"
|
||||||
|
```
|
||||||
|
|
||||||
## Output search results
|
## Output search results
|
||||||
|
|
||||||
LanceDB returns vector search results via different formats commonly used in python.
|
LanceDB returns vector search results via different formats commonly used in python.
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ excluded_globs = [
|
|||||||
"../src/concepts/*.md",
|
"../src/concepts/*.md",
|
||||||
"../src/ann_indexes.md",
|
"../src/ann_indexes.md",
|
||||||
"../src/basic.md",
|
"../src/basic.md",
|
||||||
|
"../src/search.md",
|
||||||
"../src/hybrid_search/hybrid_search.md",
|
"../src/hybrid_search/hybrid_search.md",
|
||||||
"../src/reranking/*.md",
|
"../src/reranking/*.md",
|
||||||
"../src/guides/tuning_retrievers/*.md",
|
"../src/guides/tuning_retrievers/*.md",
|
||||||
|
|||||||
@@ -5,8 +5,9 @@ pub fn parse_distance_type(distance_type: impl AsRef<str>) -> napi::Result<Dista
|
|||||||
"l2" => Ok(DistanceType::L2),
|
"l2" => Ok(DistanceType::L2),
|
||||||
"cosine" => Ok(DistanceType::Cosine),
|
"cosine" => Ok(DistanceType::Cosine),
|
||||||
"dot" => Ok(DistanceType::Dot),
|
"dot" => Ok(DistanceType::Dot),
|
||||||
|
"hamming" => Ok(DistanceType::Hamming),
|
||||||
_ => Err(napi::Error::from_reason(format!(
|
_ => Err(napi::Error::from_reason(format!(
|
||||||
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
|
"Invalid distance type '{}'. Must be one of l2, cosine, dot, or hamming",
|
||||||
distance_type.as_ref()
|
distance_type.as_ref()
|
||||||
))),
|
))),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -355,6 +355,97 @@ class HnswSq:
|
|||||||
ef_construction: int = 300
|
ef_construction: int = 300
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IvfFlat:
|
||||||
|
"""Describes an IVF Flat Index
|
||||||
|
|
||||||
|
This index stores raw vectors.
|
||||||
|
These vectors are grouped into partitions of similar vectors.
|
||||||
|
Each partition keeps track of a centroid which is
|
||||||
|
the average value of all vectors in the group.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
distance_type: str, default "L2"
|
||||||
|
The distance metric used to train the index
|
||||||
|
|
||||||
|
This is used when training the index to calculate the IVF partitions
|
||||||
|
(vectors are grouped in partitions with similar vectors according to this
|
||||||
|
distance type) and to calculate a subvector's code during quantization.
|
||||||
|
|
||||||
|
The distance type used to train an index MUST match the distance type used
|
||||||
|
to search the index. Failure to do so will yield inaccurate results.
|
||||||
|
|
||||||
|
The following distance types are available:
|
||||||
|
|
||||||
|
"l2" - Euclidean distance. This is a very common distance metric that
|
||||||
|
accounts for both magnitude and direction when determining the distance
|
||||||
|
between vectors. L2 distance has a range of [0, ∞).
|
||||||
|
|
||||||
|
"cosine" - Cosine distance. Cosine distance is a distance metric
|
||||||
|
calculated from the cosine similarity between two vectors. Cosine
|
||||||
|
similarity is a measure of similarity between two non-zero vectors of an
|
||||||
|
inner product space. It is defined to equal the cosine of the angle
|
||||||
|
between them. Unlike L2, the cosine distance is not affected by the
|
||||||
|
magnitude of the vectors. Cosine distance has a range of [0, 2].
|
||||||
|
|
||||||
|
Note: the cosine distance is undefined when one (or both) of the vectors
|
||||||
|
are all zeros (there is no direction). These vectors are invalid and may
|
||||||
|
never be returned from a vector search.
|
||||||
|
|
||||||
|
"dot" - Dot product. Dot distance is the dot product of two vectors. Dot
|
||||||
|
distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
|
||||||
|
L2 norm is 1), then dot distance is equivalent to the cosine distance.
|
||||||
|
|
||||||
|
"hamming" - Hamming distance. Hamming distance is a distance metric
|
||||||
|
calculated as the number of positions at which the corresponding bits are
|
||||||
|
different. Hamming distance has a range of [0, vector dimension].
|
||||||
|
|
||||||
|
num_partitions: int, default sqrt(num_rows)
|
||||||
|
The number of IVF partitions to create.
|
||||||
|
|
||||||
|
This value should generally scale with the number of rows in the dataset.
|
||||||
|
By default the number of partitions is the square root of the number of
|
||||||
|
rows.
|
||||||
|
|
||||||
|
If this value is too large then the first part of the search (picking the
|
||||||
|
right partition) will be slow. If this value is too small then the second
|
||||||
|
part of the search (searching within a partition) will be slow.
|
||||||
|
|
||||||
|
max_iterations: int, default 50
|
||||||
|
Max iteration to train kmeans.
|
||||||
|
|
||||||
|
When training an IVF PQ index we use kmeans to calculate the partitions.
|
||||||
|
This parameter controls how many iterations of kmeans to run.
|
||||||
|
|
||||||
|
Increasing this might improve the quality of the index but in most cases
|
||||||
|
these extra iterations have diminishing returns.
|
||||||
|
|
||||||
|
The default value is 50.
|
||||||
|
sample_rate: int, default 256
|
||||||
|
The rate used to calculate the number of training vectors for kmeans.
|
||||||
|
|
||||||
|
When an IVF PQ index is trained, we need to calculate partitions. These
|
||||||
|
are groups of vectors that are similar to each other. To do this we use an
|
||||||
|
algorithm called kmeans.
|
||||||
|
|
||||||
|
Running kmeans on a large dataset can be slow. To speed this up we run
|
||||||
|
kmeans on a random sample of the data. This parameter controls the size of
|
||||||
|
the sample. The total number of vectors used to train the index is
|
||||||
|
`sample_rate * num_partitions`.
|
||||||
|
|
||||||
|
Increasing this value might improve the quality of the index but in most
|
||||||
|
cases the default should be sufficient.
|
||||||
|
|
||||||
|
The default value is 256.
|
||||||
|
"""
|
||||||
|
|
||||||
|
distance_type: Literal["l2", "cosine", "dot", "hamming"] = "l2"
|
||||||
|
num_partitions: Optional[int] = None
|
||||||
|
max_iterations: int = 50
|
||||||
|
sample_rate: int = 256
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class IvfPq:
|
class IvfPq:
|
||||||
"""Describes an IVF PQ Index
|
"""Describes an IVF PQ Index
|
||||||
@@ -477,4 +568,4 @@ class IvfPq:
|
|||||||
sample_rate: int = 256
|
sample_rate: int = 256
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["BTree", "IvfPq", "HnswPq", "HnswSq", "IndexConfig"]
|
__all__ = ["BTree", "IvfFlat", "IvfPq", "HnswPq", "HnswSq", "IndexConfig"]
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ from lance.dependencies import _check_for_hugging_face
|
|||||||
|
|
||||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||||
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||||
from .index import BTree, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
|
from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
|
||||||
from .merge import LanceMergeInsertBuilder
|
from .merge import LanceMergeInsertBuilder
|
||||||
from .pydantic import LanceModel, model_to_dict
|
from .pydantic import LanceModel, model_to_dict
|
||||||
from .query import (
|
from .query import (
|
||||||
@@ -433,7 +433,9 @@ class Table(ABC):
|
|||||||
accelerator: Optional[str] = None,
|
accelerator: Optional[str] = None,
|
||||||
index_cache_size: Optional[int] = None,
|
index_cache_size: Optional[int] = None,
|
||||||
*,
|
*,
|
||||||
index_type: Literal["IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"] = "IVF_PQ",
|
index_type: Literal[
|
||||||
|
"IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
|
||||||
|
] = "IVF_PQ",
|
||||||
num_bits: int = 8,
|
num_bits: int = 8,
|
||||||
max_iterations: int = 50,
|
max_iterations: int = 50,
|
||||||
sample_rate: int = 256,
|
sample_rate: int = 256,
|
||||||
@@ -446,8 +448,9 @@ class Table(ABC):
|
|||||||
----------
|
----------
|
||||||
metric: str, default "L2"
|
metric: str, default "L2"
|
||||||
The distance metric to use when creating the index.
|
The distance metric to use when creating the index.
|
||||||
Valid values are "L2", "cosine", or "dot".
|
Valid values are "L2", "cosine", "dot", or "hamming".
|
||||||
L2 is euclidean distance.
|
L2 is euclidean distance.
|
||||||
|
Hamming is available only for binary vectors.
|
||||||
num_partitions: int, default 256
|
num_partitions: int, default 256
|
||||||
The number of IVF partitions to use when creating the index.
|
The number of IVF partitions to use when creating the index.
|
||||||
Default is 256.
|
Default is 256.
|
||||||
@@ -1408,7 +1411,9 @@ class LanceTable(Table):
|
|||||||
accelerator: Optional[str] = None,
|
accelerator: Optional[str] = None,
|
||||||
index_cache_size: Optional[int] = None,
|
index_cache_size: Optional[int] = None,
|
||||||
num_bits: int = 8,
|
num_bits: int = 8,
|
||||||
index_type: Literal["IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"] = "IVF_PQ",
|
index_type: Literal[
|
||||||
|
"IVF_FLAT", "IVF_PQ", "IVF_HNSW_SQ", "IVF_HNSW_PQ"
|
||||||
|
] = "IVF_PQ",
|
||||||
max_iterations: int = 50,
|
max_iterations: int = 50,
|
||||||
sample_rate: int = 256,
|
sample_rate: int = 256,
|
||||||
m: int = 20,
|
m: int = 20,
|
||||||
@@ -1432,6 +1437,13 @@ class LanceTable(Table):
|
|||||||
)
|
)
|
||||||
self.checkout_latest()
|
self.checkout_latest()
|
||||||
return
|
return
|
||||||
|
elif index_type == "IVF_FLAT":
|
||||||
|
config = IvfFlat(
|
||||||
|
distance_type=metric,
|
||||||
|
num_partitions=num_partitions,
|
||||||
|
max_iterations=max_iterations,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
)
|
||||||
elif index_type == "IVF_PQ":
|
elif index_type == "IVF_PQ":
|
||||||
config = IvfPq(
|
config = IvfPq(
|
||||||
distance_type=metric,
|
distance_type=metric,
|
||||||
@@ -2619,7 +2631,7 @@ class AsyncTable:
|
|||||||
*,
|
*,
|
||||||
replace: Optional[bool] = None,
|
replace: Optional[bool] = None,
|
||||||
config: Optional[
|
config: Optional[
|
||||||
Union[IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS]
|
Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS]
|
||||||
] = None,
|
] = None,
|
||||||
):
|
):
|
||||||
"""Create an index to speed up queries
|
"""Create an index to speed up queries
|
||||||
@@ -2648,7 +2660,7 @@ class AsyncTable:
|
|||||||
"""
|
"""
|
||||||
if config is not None:
|
if config is not None:
|
||||||
if not isinstance(
|
if not isinstance(
|
||||||
config, (IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS)
|
config, (IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS)
|
||||||
):
|
):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"config must be an instance of IvfPq, HnswPq, HnswSq, BTree,"
|
"config must be an instance of IvfPq, HnswPq, HnswSq, BTree,"
|
||||||
|
|||||||
44
python/python/tests/docs/test_binary_vector.py
Normal file
44
python/python/tests/docs/test_binary_vector.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import shutil
|
||||||
|
|
||||||
|
# --8<-- [start:imports]
|
||||||
|
import lancedb
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
# --8<-- [end:imports]
|
||||||
|
|
||||||
|
shutil.rmtree("data/binary_lancedb", ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_binary_vector():
|
||||||
|
# --8<-- [start:sync_binary_vector]
|
||||||
|
db = lancedb.connect("data/binary_lancedb")
|
||||||
|
data = [
|
||||||
|
{
|
||||||
|
"id": i,
|
||||||
|
"vector": np.random.randint(0, 256, size=16),
|
||||||
|
}
|
||||||
|
for i in range(1024)
|
||||||
|
]
|
||||||
|
tbl = db.create_table("my_binary_vectors", data=data)
|
||||||
|
query = np.random.randint(0, 256, size=16)
|
||||||
|
tbl.search(query).to_arrow()
|
||||||
|
# --8<-- [end:sync_binary_vector]
|
||||||
|
db.drop_table("my_binary_vectors")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_binary_vector_async():
|
||||||
|
# --8<-- [start:async_binary_vector]
|
||||||
|
db = await lancedb.connect_async("data/binary_lancedb")
|
||||||
|
data = [
|
||||||
|
{
|
||||||
|
"id": i,
|
||||||
|
"vector": np.random.randint(0, 256, size=16),
|
||||||
|
}
|
||||||
|
for i in range(1024)
|
||||||
|
]
|
||||||
|
tbl = await db.create_table("my_binary_vectors", data=data)
|
||||||
|
query = np.random.randint(0, 256, size=16)
|
||||||
|
await tbl.query().nearest_to(query).to_arrow()
|
||||||
|
# --8<-- [end:async_binary_vector]
|
||||||
|
await db.drop_table("my_binary_vectors")
|
||||||
@@ -8,7 +8,7 @@ import pyarrow as pa
|
|||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from lancedb import AsyncConnection, AsyncTable, connect_async
|
from lancedb import AsyncConnection, AsyncTable, connect_async
|
||||||
from lancedb.index import BTree, IvfPq, Bitmap, LabelList, HnswPq, HnswSq
|
from lancedb.index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
@@ -42,6 +42,27 @@ async def some_table(db_async):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def binary_table(db_async):
|
||||||
|
data = [
|
||||||
|
{
|
||||||
|
"id": i,
|
||||||
|
"vector": [i] * 128,
|
||||||
|
}
|
||||||
|
for i in range(NROWS)
|
||||||
|
]
|
||||||
|
return await db_async.create_table(
|
||||||
|
"binary_table",
|
||||||
|
data,
|
||||||
|
schema=pa.schema(
|
||||||
|
[
|
||||||
|
pa.field("id", pa.int64()),
|
||||||
|
pa.field("vector", pa.list_(pa.uint8(), 128)),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_scalar_index(some_table: AsyncTable):
|
async def test_create_scalar_index(some_table: AsyncTable):
|
||||||
# Can create
|
# Can create
|
||||||
@@ -143,3 +164,27 @@ async def test_create_hnswsq_index(some_table: AsyncTable):
|
|||||||
await some_table.create_index("vector", config=HnswSq(num_partitions=10))
|
await some_table.create_index("vector", config=HnswSq(num_partitions=10))
|
||||||
indices = await some_table.list_indices()
|
indices = await some_table.list_indices()
|
||||||
assert len(indices) == 1
|
assert len(indices) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_index_with_binary_vectors(binary_table: AsyncTable):
|
||||||
|
await binary_table.create_index(
|
||||||
|
"vector", config=IvfFlat(distance_type="hamming", num_partitions=10)
|
||||||
|
)
|
||||||
|
indices = await binary_table.list_indices()
|
||||||
|
assert len(indices) == 1
|
||||||
|
assert indices[0].index_type == "IvfFlat"
|
||||||
|
assert indices[0].columns == ["vector"]
|
||||||
|
assert indices[0].name == "vector_idx"
|
||||||
|
|
||||||
|
stats = await binary_table.index_stats("vector_idx")
|
||||||
|
assert stats.index_type == "IVF_FLAT"
|
||||||
|
assert stats.distance_type == "hamming"
|
||||||
|
assert stats.num_indexed_rows == await binary_table.count_rows()
|
||||||
|
assert stats.num_unindexed_rows == 0
|
||||||
|
assert stats.num_indices == 1
|
||||||
|
|
||||||
|
# the dataset contains vectors with all values from 0 to 255
|
||||||
|
for v in range(256):
|
||||||
|
res = await binary_table.query().nearest_to([v] * 128).to_arrow()
|
||||||
|
assert res["id"][0].as_py() == v
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use lancedb::index::vector::IvfFlatIndexBuilder;
|
||||||
use lancedb::index::{
|
use lancedb::index::{
|
||||||
scalar::{BTreeIndexBuilder, FtsIndexBuilder, TokenizerConfig},
|
scalar::{BTreeIndexBuilder, FtsIndexBuilder, TokenizerConfig},
|
||||||
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
|
vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder, IvfPqIndexBuilder},
|
||||||
@@ -59,6 +60,18 @@ pub fn extract_index_params(source: &Option<Bound<'_, PyAny>>) -> PyResult<Lance
|
|||||||
opts.tokenizer_configs = inner_opts;
|
opts.tokenizer_configs = inner_opts;
|
||||||
Ok(LanceDbIndex::FTS(opts))
|
Ok(LanceDbIndex::FTS(opts))
|
||||||
},
|
},
|
||||||
|
"IvfFlat" => {
|
||||||
|
let params = source.extract::<IvfFlatParams>()?;
|
||||||
|
let distance_type = parse_distance_type(params.distance_type)?;
|
||||||
|
let mut ivf_flat_builder = IvfFlatIndexBuilder::default()
|
||||||
|
.distance_type(distance_type)
|
||||||
|
.max_iterations(params.max_iterations)
|
||||||
|
.sample_rate(params.sample_rate);
|
||||||
|
if let Some(num_partitions) = params.num_partitions {
|
||||||
|
ivf_flat_builder = ivf_flat_builder.num_partitions(num_partitions);
|
||||||
|
}
|
||||||
|
Ok(LanceDbIndex::IvfFlat(ivf_flat_builder))
|
||||||
|
},
|
||||||
"IvfPq" => {
|
"IvfPq" => {
|
||||||
let params = source.extract::<IvfPqParams>()?;
|
let params = source.extract::<IvfPqParams>()?;
|
||||||
let distance_type = parse_distance_type(params.distance_type)?;
|
let distance_type = parse_distance_type(params.distance_type)?;
|
||||||
@@ -129,6 +142,14 @@ struct FtsParams {
|
|||||||
ascii_folding: bool,
|
ascii_folding: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(FromPyObject)]
|
||||||
|
struct IvfFlatParams {
|
||||||
|
distance_type: String,
|
||||||
|
num_partitions: Option<u32>,
|
||||||
|
max_iterations: u32,
|
||||||
|
sample_rate: u32,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(FromPyObject)]
|
#[derive(FromPyObject)]
|
||||||
struct IvfPqParams {
|
struct IvfPqParams {
|
||||||
distance_type: String,
|
distance_type: String,
|
||||||
|
|||||||
@@ -43,8 +43,9 @@ pub fn parse_distance_type(distance_type: impl AsRef<str>) -> PyResult<DistanceT
|
|||||||
"l2" => Ok(DistanceType::L2),
|
"l2" => Ok(DistanceType::L2),
|
||||||
"cosine" => Ok(DistanceType::Cosine),
|
"cosine" => Ok(DistanceType::Cosine),
|
||||||
"dot" => Ok(DistanceType::Dot),
|
"dot" => Ok(DistanceType::Dot),
|
||||||
|
"hamming" => Ok(DistanceType::Hamming),
|
||||||
_ => Err(PyValueError::new_err(format!(
|
_ => Err(PyValueError::new_err(format!(
|
||||||
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
|
"Invalid distance type '{}'. Must be one of l2, cosine, dot, or hamming",
|
||||||
distance_type.as_ref()
|
distance_type.as_ref()
|
||||||
))),
|
))),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ use std::sync::Arc;
|
|||||||
use scalar::FtsIndexBuilder;
|
use scalar::FtsIndexBuilder;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_with::skip_serializing_none;
|
use serde_with::skip_serializing_none;
|
||||||
|
use vector::IvfFlatIndexBuilder;
|
||||||
|
|
||||||
use crate::{table::TableInternal, DistanceType, Error, Result};
|
use crate::{table::TableInternal, DistanceType, Error, Result};
|
||||||
|
|
||||||
@@ -56,6 +57,9 @@ pub enum Index {
|
|||||||
/// Full text search index using bm25.
|
/// Full text search index using bm25.
|
||||||
FTS(FtsIndexBuilder),
|
FTS(FtsIndexBuilder),
|
||||||
|
|
||||||
|
/// IVF index
|
||||||
|
IvfFlat(IvfFlatIndexBuilder),
|
||||||
|
|
||||||
/// IVF index with Product Quantization
|
/// IVF index with Product Quantization
|
||||||
IvfPq(IvfPqIndexBuilder),
|
IvfPq(IvfPqIndexBuilder),
|
||||||
|
|
||||||
@@ -106,6 +110,8 @@ impl IndexBuilder {
|
|||||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||||
pub enum IndexType {
|
pub enum IndexType {
|
||||||
// Vector
|
// Vector
|
||||||
|
#[serde(alias = "IVF_FLAT")]
|
||||||
|
IvfFlat,
|
||||||
#[serde(alias = "IVF_PQ")]
|
#[serde(alias = "IVF_PQ")]
|
||||||
IvfPq,
|
IvfPq,
|
||||||
#[serde(alias = "IVF_HNSW_PQ")]
|
#[serde(alias = "IVF_HNSW_PQ")]
|
||||||
@@ -127,6 +133,7 @@ pub enum IndexType {
|
|||||||
impl std::fmt::Display for IndexType {
|
impl std::fmt::Display for IndexType {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
|
Self::IvfFlat => write!(f, "IVF_FLAT"),
|
||||||
Self::IvfPq => write!(f, "IVF_PQ"),
|
Self::IvfPq => write!(f, "IVF_PQ"),
|
||||||
Self::IvfHnswPq => write!(f, "IVF_HNSW_PQ"),
|
Self::IvfHnswPq => write!(f, "IVF_HNSW_PQ"),
|
||||||
Self::IvfHnswSq => write!(f, "IVF_HNSW_SQ"),
|
Self::IvfHnswSq => write!(f, "IVF_HNSW_SQ"),
|
||||||
@@ -147,6 +154,7 @@ impl std::str::FromStr for IndexType {
|
|||||||
"BITMAP" => Ok(Self::Bitmap),
|
"BITMAP" => Ok(Self::Bitmap),
|
||||||
"LABEL_LIST" | "LABELLIST" => Ok(Self::LabelList),
|
"LABEL_LIST" | "LABELLIST" => Ok(Self::LabelList),
|
||||||
"FTS" | "INVERTED" => Ok(Self::FTS),
|
"FTS" | "INVERTED" => Ok(Self::FTS),
|
||||||
|
"IVF_FLAT" => Ok(Self::IvfFlat),
|
||||||
"IVF_PQ" => Ok(Self::IvfPq),
|
"IVF_PQ" => Ok(Self::IvfPq),
|
||||||
"IVF_HNSW_PQ" => Ok(Self::IvfHnswPq),
|
"IVF_HNSW_PQ" => Ok(Self::IvfHnswPq),
|
||||||
"IVF_HNSW_SQ" => Ok(Self::IvfHnswSq),
|
"IVF_HNSW_SQ" => Ok(Self::IvfHnswSq),
|
||||||
|
|||||||
@@ -162,6 +162,43 @@ macro_rules! impl_hnsw_params_setter {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Builder for an IVF Flat index.
|
||||||
|
///
|
||||||
|
/// This index stores raw vectors. These vectors are grouped into partitions of similar vectors.
|
||||||
|
/// Each partition keeps track of a centroid which is the average value of all vectors in the group.
|
||||||
|
///
|
||||||
|
/// During a query the centroids are compared with the query vector to find the closest partitions.
|
||||||
|
/// The raw vectors in these partitions are then searched to find the closest vectors.
|
||||||
|
///
|
||||||
|
/// The partitioning process is called IVF and the `num_partitions` parameter controls how many groups to create.
|
||||||
|
///
|
||||||
|
/// Note that training an IVF Flat index on a large dataset is a slow operation and currently is also a memory intensive operation.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct IvfFlatIndexBuilder {
|
||||||
|
pub(crate) distance_type: DistanceType,
|
||||||
|
|
||||||
|
// IVF
|
||||||
|
pub(crate) num_partitions: Option<u32>,
|
||||||
|
pub(crate) sample_rate: u32,
|
||||||
|
pub(crate) max_iterations: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for IvfFlatIndexBuilder {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
distance_type: DistanceType::L2,
|
||||||
|
num_partitions: None,
|
||||||
|
sample_rate: 256,
|
||||||
|
max_iterations: 50,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IvfFlatIndexBuilder {
|
||||||
|
impl_distance_type_setter!();
|
||||||
|
impl_ivf_params_setter!();
|
||||||
|
}
|
||||||
|
|
||||||
/// Builder for an IVF PQ index.
|
/// Builder for an IVF PQ index.
|
||||||
///
|
///
|
||||||
/// This index stores a compressed (quantized) copy of every vector. These vectors
|
/// This index stores a compressed (quantized) copy of every vector. These vectors
|
||||||
|
|||||||
@@ -18,9 +18,9 @@ use std::path::Path;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use arrow::array::AsArray;
|
use arrow::array::AsArray;
|
||||||
use arrow::datatypes::Float32Type;
|
use arrow::datatypes::{Float32Type, UInt8Type};
|
||||||
use arrow_array::{RecordBatchIterator, RecordBatchReader};
|
use arrow_array::{RecordBatchIterator, RecordBatchReader};
|
||||||
use arrow_schema::{Field, Schema, SchemaRef};
|
use arrow_schema::{DataType, Field, Schema, SchemaRef};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use datafusion_physical_plan::display::DisplayableExecutionPlan;
|
use datafusion_physical_plan::display::DisplayableExecutionPlan;
|
||||||
use datafusion_physical_plan::projection::ProjectionExec;
|
use datafusion_physical_plan::projection::ProjectionExec;
|
||||||
@@ -58,8 +58,8 @@ use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, M
|
|||||||
use crate::error::{Error, Result};
|
use crate::error::{Error, Result};
|
||||||
use crate::index::scalar::FtsIndexBuilder;
|
use crate::index::scalar::FtsIndexBuilder;
|
||||||
use crate::index::vector::{
|
use crate::index::vector::{
|
||||||
suggested_num_partitions_for_hnsw, IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder,
|
suggested_num_partitions_for_hnsw, IvfFlatIndexBuilder, IvfHnswPqIndexBuilder,
|
||||||
IvfPqIndexBuilder, VectorIndex,
|
IvfHnswSqIndexBuilder, IvfPqIndexBuilder, VectorIndex,
|
||||||
};
|
};
|
||||||
use crate::index::IndexStatistics;
|
use crate::index::IndexStatistics;
|
||||||
use crate::index::{
|
use crate::index::{
|
||||||
@@ -1306,6 +1306,44 @@ impl NativeTable {
|
|||||||
.collect())
|
.collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn create_ivf_flat_index(
|
||||||
|
&self,
|
||||||
|
index: IvfFlatIndexBuilder,
|
||||||
|
field: &Field,
|
||||||
|
replace: bool,
|
||||||
|
) -> Result<()> {
|
||||||
|
if !supported_vector_data_type(field.data_type()) {
|
||||||
|
return Err(Error::InvalidInput {
|
||||||
|
message: format!(
|
||||||
|
"An IVF Flat index cannot be created on the column `{}` which has data type {}",
|
||||||
|
field.name(),
|
||||||
|
field.data_type()
|
||||||
|
),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let num_partitions = if let Some(n) = index.num_partitions {
|
||||||
|
n
|
||||||
|
} else {
|
||||||
|
suggested_num_partitions(self.count_rows(None).await?)
|
||||||
|
};
|
||||||
|
let mut dataset = self.dataset.get_mut().await?;
|
||||||
|
let lance_idx_params = lance::index::vector::VectorIndexParams::ivf_flat(
|
||||||
|
num_partitions as usize,
|
||||||
|
index.distance_type.into(),
|
||||||
|
);
|
||||||
|
dataset
|
||||||
|
.create_index(
|
||||||
|
&[field.name()],
|
||||||
|
IndexType::Vector,
|
||||||
|
None,
|
||||||
|
&lance_idx_params,
|
||||||
|
replace,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
async fn create_ivf_pq_index(
|
async fn create_ivf_pq_index(
|
||||||
&self,
|
&self,
|
||||||
index: IvfPqIndexBuilder,
|
index: IvfPqIndexBuilder,
|
||||||
@@ -1778,6 +1816,10 @@ impl TableInternal for NativeTable {
|
|||||||
Index::Bitmap(_) => self.create_bitmap_index(field, opts).await,
|
Index::Bitmap(_) => self.create_bitmap_index(field, opts).await,
|
||||||
Index::LabelList(_) => self.create_label_list_index(field, opts).await,
|
Index::LabelList(_) => self.create_label_list_index(field, opts).await,
|
||||||
Index::FTS(fts_opts) => self.create_fts_index(field, fts_opts, opts.replace).await,
|
Index::FTS(fts_opts) => self.create_fts_index(field, fts_opts, opts.replace).await,
|
||||||
|
Index::IvfFlat(ivf_flat) => {
|
||||||
|
self.create_ivf_flat_index(ivf_flat, field, opts.replace)
|
||||||
|
.await
|
||||||
|
}
|
||||||
Index::IvfPq(ivf_pq) => self.create_ivf_pq_index(ivf_pq, field, opts.replace).await,
|
Index::IvfPq(ivf_pq) => self.create_ivf_pq_index(ivf_pq, field, opts.replace).await,
|
||||||
Index::IvfHnswPq(ivf_hnsw_pq) => {
|
Index::IvfHnswPq(ivf_hnsw_pq) => {
|
||||||
self.create_ivf_hnsw_pq_index(ivf_hnsw_pq, field, opts.replace)
|
self.create_ivf_hnsw_pq_index(ivf_hnsw_pq, field, opts.replace)
|
||||||
@@ -1848,14 +1890,21 @@ impl TableInternal for NativeTable {
|
|||||||
message: format!("Column {} not found in dataset schema", column),
|
message: format!("Column {} not found in dataset schema", column),
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
if let arrow_schema::DataType::FixedSizeList(f, dim) = field.data_type() {
|
let mut is_binary = false;
|
||||||
if !f.data_type().is_floating() {
|
if let arrow_schema::DataType::FixedSizeList(element, dim) = field.data_type() {
|
||||||
return Err(Error::InvalidInput {
|
match element.data_type() {
|
||||||
message: format!(
|
e_type if e_type.is_floating() => {}
|
||||||
"The data type of the vector column '{}' is not a floating point type",
|
e_type if *e_type == DataType::UInt8 => {
|
||||||
column
|
is_binary = true;
|
||||||
),
|
}
|
||||||
});
|
_ => {
|
||||||
|
return Err(Error::InvalidInput {
|
||||||
|
message: format!(
|
||||||
|
"The data type of the vector column '{}' is not a floating point type",
|
||||||
|
column
|
||||||
|
),
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if dim != query_vector.len() as i32 {
|
if dim != query_vector.len() as i32 {
|
||||||
return Err(Error::InvalidInput {
|
return Err(Error::InvalidInput {
|
||||||
@@ -1870,12 +1919,22 @@ impl TableInternal for NativeTable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let query_vector = query_vector.as_primitive::<Float32Type>();
|
if is_binary {
|
||||||
scanner.nearest(
|
let query_vector = arrow::compute::cast(&query_vector, &DataType::UInt8)?;
|
||||||
&column,
|
let query_vector = query_vector.as_primitive::<UInt8Type>();
|
||||||
query_vector,
|
scanner.nearest(
|
||||||
query.base.limit.unwrap_or(DEFAULT_TOP_K),
|
&column,
|
||||||
)?;
|
query_vector,
|
||||||
|
query.base.limit.unwrap_or(DEFAULT_TOP_K),
|
||||||
|
)?;
|
||||||
|
} else {
|
||||||
|
let query_vector = query_vector.as_primitive::<Float32Type>();
|
||||||
|
scanner.nearest(
|
||||||
|
&column,
|
||||||
|
query_vector,
|
||||||
|
query.base.limit.unwrap_or(DEFAULT_TOP_K),
|
||||||
|
)?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
scanner.limit(
|
scanner.limit(
|
||||||
query.base.limit.map(|limit| limit as i64),
|
query.base.limit.map(|limit| limit as i64),
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result
|
|||||||
.iter()
|
.iter()
|
||||||
.filter_map(|field| match field.data_type() {
|
.filter_map(|field| match field.data_type() {
|
||||||
arrow_schema::DataType::FixedSizeList(f, d)
|
arrow_schema::DataType::FixedSizeList(f, d)
|
||||||
if f.data_type().is_floating()
|
if (f.data_type().is_floating() || f.data_type() == &DataType::UInt8)
|
||||||
&& dim.map(|expect| *d == expect).unwrap_or(true) =>
|
&& dim.map(|expect| *d == expect).unwrap_or(true) =>
|
||||||
{
|
{
|
||||||
Some(field.name())
|
Some(field.name())
|
||||||
@@ -171,7 +171,9 @@ pub fn supported_fts_data_type(dtype: &DataType) -> bool {
|
|||||||
|
|
||||||
pub fn supported_vector_data_type(dtype: &DataType) -> bool {
|
pub fn supported_vector_data_type(dtype: &DataType) -> bool {
|
||||||
match dtype {
|
match dtype {
|
||||||
DataType::FixedSizeList(inner, _) => DataType::is_floating(inner.data_type()),
|
DataType::FixedSizeList(inner, _) => {
|
||||||
|
DataType::is_floating(inner.data_type()) || *inner.data_type() == DataType::UInt8
|
||||||
|
}
|
||||||
_ => false,
|
_ => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user