diff --git a/docs/src/ann_indexes.md b/docs/src/ann_indexes.md index 96b93e2b..be98cc45 100644 --- a/docs/src/ann_indexes.md +++ b/docs/src/ann_indexes.md @@ -28,11 +28,12 @@ tbl.create_index(num_partitions=256, num_sub_vectors=96) Since `create_index` has a training step, it can take a few minutes to finish for large tables. You can control the index creation by providing the following parameters: -- **num_partitions** (default: 256): The number of partitions of the index. The number of partitions should be configured so each partition has 3-5K vectors. For example, a table -with ~1M vectors should use 256 partitions. You can specify arbitrary number of partitions but powers of 2 is most conventional. -A higher number leads to faster queries, but it makes index generation slower. +- **metric** (default: "L2"): The distance metric to use. By default we use euclidean distance. We also support cosine distance. +- **num_partitions** (default: 256): The number of partitions of the index. The number of partitions should be configured so each partition has 3-5K vectors. For example, a table +with ~1M vectors should use 256 partitions. You can specify arbitrary number of partitions but powers of 2 is most conventional. +A higher number leads to faster queries, but it makes index generation slower. - **num_sub_vectors** (default: 96): The number of subvectors (M) that will be created during Product Quantization (PQ). A larger number makes -search more accurate, but also makes the index larger and slower to build. +search more accurate, but also makes the index larger and slower to build. ## Querying an ANN Index @@ -41,8 +42,9 @@ Querying vector indexes is done via the [search](https://lancedb.github.io/lance There are a couple of parameters that can be used to fine-tune the search: - **limit** (default: 10): The amount of results that will be returned +- **metric** (default: "L2"): The distance metric to use. By default we use euclidean distance. We also support cosine distance. - **nprobes** (default: 20): The number of probes used. A higher number makes search more accurate but also slower. -- **refine_factor** (default: None): Refine the results by reading extra elements and re-ranking them in memory. A higher number makes +- **refine_factor** (default: None): Refine the results by reading extra elements and re-ranking them in memory. A higher number makes search more accurate but also slower. ```python diff --git a/python/lancedb/query.py b/python/lancedb/query.py index 21333bec..1adb8ccb 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -24,6 +24,7 @@ class LanceQueryBuilder: """ def __init__(self, table: "lancedb.table.LanceTable", query: np.ndarray): + self._metric = "L2" self._nprobes = 20 self._refine_factor = None self._table = table @@ -77,6 +78,21 @@ class LanceQueryBuilder: self._where = where return self + def metric(self, metric: str) -> LanceQueryBuilder: + """Set the distance metric to use. + + Parameters + ---------- + metric: str + The distance metric to use. By default "l2" is used. + + Returns + ------- + The LanceQueryBuilder object. + """ + self._metric = metric + return self + def nprobes(self, nprobes: int) -> LanceQueryBuilder: """Set the number of probes to use. @@ -118,6 +134,7 @@ class LanceQueryBuilder: "column": VECTOR_COLUMN_NAME, "q": self._query, "k": self._limit, + "metric": self._metric, "nprobes": self._nprobes, "refine_factor": self._refine_factor, }, diff --git a/python/lancedb/table.py b/python/lancedb/table.py index f798fb37..f633cce5 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -106,11 +106,14 @@ class LanceTable: def _dataset_uri(self) -> str: return os.path.join(self._conn.uri, f"{self.name}.lance") - def create_index(self, num_partitions=256, num_sub_vectors=96): + def create_index(self, metric="L2", num_partitions=256, num_sub_vectors=96): """Create an index on the table. Parameters ---------- + metric: str, default "L2" + The distance metric to use when creating the index. Valid values are "L2" or "cosine". + L2 is euclidean distance. num_partitions: int The number of IVF partitions to use when creating the index. Default is 256. @@ -121,6 +124,7 @@ class LanceTable: self._dataset.create_index( column=VECTOR_COLUMN_NAME, index_type="IVF_PQ", + metric=metric, num_partitions=num_partitions, num_sub_vectors=num_sub_vectors, ) diff --git a/python/pyproject.toml b/python/pyproject.toml index b30d6494..2884c8ee 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "lancedb" version = "0.1" -dependencies = ["pylance>=0.4.3", "ratelimiter", "retry", "tqdm"] +dependencies = ["pylance>=0.4.4", "ratelimiter", "retry", "tqdm"] description = "lancedb" authors = [ { name = "Lance Devs", email = "dev@eto.ai" }, diff --git a/python/tests/test_query.py b/python/tests/test_query.py index c08cdd8f..9ad7c928 100644 --- a/python/tests/test_query.py +++ b/python/tests/test_query.py @@ -14,7 +14,9 @@ import lance from lancedb.query import LanceQueryBuilder +import numpy as np import pandas as pd +import pandas.testing as tm import pyarrow as pa import pytest @@ -60,3 +62,21 @@ def test_query_builder_with_filter(table): df = LanceQueryBuilder(table, [0, 0]).where("id = 2").to_df() assert df["id"].values[0] == 2 assert all(df["vector"].values[0] == [3, 4]) + + +def test_query_builder_with_metric(table): + query = [4, 8] + df_default = LanceQueryBuilder(table, query).to_df() + df_l2 = LanceQueryBuilder(table, query).metric("l2").to_df() + tm.assert_frame_equal(df_default, df_l2) + + df_cosine = LanceQueryBuilder(table, query).metric("cosine").limit(1).to_df() + assert df_cosine.score[0] == pytest.approx( + cosine_distance(query, df_cosine.vector[0]), + abs=1e-6, + ) + assert 0 <= df_cosine.score[0] <= 1 + + +def cosine_distance(vec1, vec2): + return 1 - np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))