diff --git a/docs/src/ann_indexes.md b/docs/src/ann_indexes.md index c66c39b9..205bfa50 100644 --- a/docs/src/ann_indexes.md +++ b/docs/src/ann_indexes.md @@ -34,6 +34,7 @@ tbl.create_index(num_partitions=256, num_sub_vectors=96) Since `create_index` has a training step, it can take a few minutes to finish for large tables. You can control the index creation by providing the following parameters: +- **metric** (default: "L2"): The distance metric to use. By default we use euclidean distance. We also support cosine distance. - **num_partitions** (default: 256): The number of partitions of the index. The number of partitions should be configured so each partition has 3-5K vectors. For example, a table with ~1M vectors should use 256 partitions. You can specify arbitrary number of partitions but powers of 2 is most conventional. A higher number leads to faster queries, but it makes index generation slower. @@ -56,6 +57,7 @@ There are a couple of parameters that can be used to fine-tune the search: e.g., for 1M vectors divided into 256 partitions, if you're looking for top 20, then refine_factor=200 reranks the whole partition.
Note: refine_factor is only applicable if an ANN index is present. If specified on a table without an ANN index, it is ignored. + ```python tbl.search(np.random.random((768))) \ .limit(2) \ diff --git a/notebooks/youtube_transcript_search.ipynb b/notebooks/youtube_transcript_search.ipynb index b3bfd83d..987373e8 100644 --- a/notebooks/youtube_transcript_search.ipynb +++ b/notebooks/youtube_transcript_search.ipynb @@ -1,7 +1,6 @@ { "cells": [ { - "attachments": {}, "cell_type": "markdown", "id": "42bf01fb", "metadata": {}, @@ -22,10 +21,10 @@ "output_type": "stream", "text": [ "\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.0.1\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1.1\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", "\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.0.1\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1.1\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" ] } @@ -88,7 +87,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "5ac2b6a3", "metadata": {}, @@ -231,7 +229,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "2106b5bb", "metadata": {}, @@ -251,7 +248,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "39f3161f3ef54a129cd65fb296332b54", + "model_id": "c6f1c76d9567421d88911923388d2530", "version_major": 2, "version_minor": 0 }, @@ -574,7 +571,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "id": "559a095b", "metadata": {}, @@ -631,7 +627,7 @@ " " + "" ] }, "execution_count": 15, @@ -651,7 +647,7 @@ "from IPython.display import YouTubeVideo\n", "\n", "top_match = context.iloc[0]\n", - "YouTubeVideo(top_match[\"url\"].split(\"/\")[-1], start=top_match[\"start\"])" + "YouTubeVideo(top_match[\"url\"].split(\"/\")[-1], start=int(top_match[\"start\"]))" ] }, { diff --git a/python/lancedb/query.py b/python/lancedb/query.py index fcaf8937..a0411c06 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -24,6 +24,7 @@ class LanceQueryBuilder: """ def __init__(self, table: "lancedb.table.LanceTable", query: np.ndarray): + self._metric = "L2" self._nprobes = 20 self._refine_factor = None self._table = table @@ -77,6 +78,21 @@ class LanceQueryBuilder: self._where = where return self + def metric(self, metric: str) -> LanceQueryBuilder: + """Set the distance metric to use. + + Parameters + ---------- + metric: str + The distance metric to use. By default "l2" is used. + + Returns + ------- + The LanceQueryBuilder object. + """ + self._metric = metric + return self + def nprobes(self, nprobes: int) -> LanceQueryBuilder: """Set the number of probes to use. @@ -123,6 +139,7 @@ class LanceQueryBuilder: "column": VECTOR_COLUMN_NAME, "q": self._query, "k": self._limit, + "metric": self._metric, "nprobes": self._nprobes, "refine_factor": self._refine_factor, }, diff --git a/python/lancedb/table.py b/python/lancedb/table.py index a2eded9c..bb9fb133 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -106,11 +106,14 @@ class LanceTable: def _dataset_uri(self) -> str: return os.path.join(self._conn.uri, f"{self.name}.lance") - def create_index(self, num_partitions=256, num_sub_vectors=96): + def create_index(self, metric="L2", num_partitions=256, num_sub_vectors=96): """Create an index on the table. Parameters ---------- + metric: str, default "L2" + The distance metric to use when creating the index. Valid values are "L2" or "cosine". + L2 is euclidean distance. num_partitions: int The number of IVF partitions to use when creating the index. Default is 256. @@ -121,6 +124,7 @@ class LanceTable: self._dataset.create_index( column=VECTOR_COLUMN_NAME, index_type="IVF_PQ", + metric=metric, num_partitions=num_partitions, num_sub_vectors=num_sub_vectors, ) diff --git a/python/pyproject.toml b/python/pyproject.toml index b30d6494..2884c8ee 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "lancedb" version = "0.1" -dependencies = ["pylance>=0.4.3", "ratelimiter", "retry", "tqdm"] +dependencies = ["pylance>=0.4.4", "ratelimiter", "retry", "tqdm"] description = "lancedb" authors = [ { name = "Lance Devs", email = "dev@eto.ai" }, diff --git a/python/tests/test_query.py b/python/tests/test_query.py index c08cdd8f..9ad7c928 100644 --- a/python/tests/test_query.py +++ b/python/tests/test_query.py @@ -14,7 +14,9 @@ import lance from lancedb.query import LanceQueryBuilder +import numpy as np import pandas as pd +import pandas.testing as tm import pyarrow as pa import pytest @@ -60,3 +62,21 @@ def test_query_builder_with_filter(table): df = LanceQueryBuilder(table, [0, 0]).where("id = 2").to_df() assert df["id"].values[0] == 2 assert all(df["vector"].values[0] == [3, 4]) + + +def test_query_builder_with_metric(table): + query = [4, 8] + df_default = LanceQueryBuilder(table, query).to_df() + df_l2 = LanceQueryBuilder(table, query).metric("l2").to_df() + tm.assert_frame_equal(df_default, df_l2) + + df_cosine = LanceQueryBuilder(table, query).metric("cosine").limit(1).to_df() + assert df_cosine.score[0] == pytest.approx( + cosine_distance(query, df_cosine.vector[0]), + abs=1e-6, + ) + assert 0 <= df_cosine.score[0] <= 1 + + +def cosine_distance(vec1, vec2): + return 1 - np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))