feat(python): add search() method to async API (#2049)

Reviving #1966.

Closes #1938

The `search()` method can apply embeddings for the user. This simplifies
hybrid search, so instead of writing:

```python
vector_query = embeddings.compute_query_embeddings("flower moon")[0]
await (
    async_tbl.query()
    .nearest_to(vector_query)
    .nearest_to_text("flower moon")
    .to_pandas()
)
```

You can write:

```python
await (await async_tbl.search("flower moon", query_type="hybrid")).to_pandas()
```

Unfortunately, we had to do a double-await here because `search()` needs
to be async. This is because it often needs to do IO to retrieve and run
an embedding function.
This commit is contained in:
Will Jones
2025-02-24 14:19:25 -08:00
committed by GitHub
parent f391ed828a
commit ecdee4d2b1
10 changed files with 461 additions and 57 deletions

View File

@@ -9,23 +9,50 @@ LanceDB supports [Polars](https://github.com/pola-rs/polars), a blazingly fast D
First, we connect to a LanceDB database.
=== "Sync API"
```py
--8<-- "python/python/tests/docs/test_python.py:import-lancedb"
--8<-- "python/python/tests/docs/test_python.py:connect_to_lancedb"
```
=== "Async API"
```py
--8<-- "python/python/tests/docs/test_python.py:import-lancedb"
--8<-- "python/python/tests/docs/test_python.py:connect_to_lancedb_async"
```
```py
--8<-- "python/python/tests/docs/test_python.py:import-lancedb"
--8<-- "python/python/tests/docs/test_python.py:connect_to_lancedb"
```
We can load a Polars `DataFrame` to LanceDB directly.
```py
--8<-- "python/python/tests/docs/test_python.py:import-polars"
--8<-- "python/python/tests/docs/test_python.py:create_table_polars"
```
=== "Sync API"
```py
--8<-- "python/python/tests/docs/test_python.py:import-polars"
--8<-- "python/python/tests/docs/test_python.py:create_table_polars"
```
=== "Async API"
```py
--8<-- "python/python/tests/docs/test_python.py:import-polars"
--8<-- "python/python/tests/docs/test_python.py:create_table_polars_async"
```
We can now perform similarity search via the LanceDB Python API.
```py
--8<-- "python/python/tests/docs/test_python.py:vector_search_polars"
```
=== "Sync API"
```py
--8<-- "python/python/tests/docs/test_python.py:vector_search_polars"
```
=== "Async API"
```py
--8<-- "python/python/tests/docs/test_python.py:vector_search_polars_async"
```
In addition to the selected columns, LanceDB also returns a vector
and also the `_distance` column which is the distance between the query
@@ -112,4 +139,3 @@ The reason it's beneficial to not convert the LanceDB Table
to a DataFrame is because the table can potentially be way larger
than memory, and Polars LazyFrames allow us to work with such
larger-than-memory datasets by not loading it into memory all at once.

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
import inspect
import warnings
from abc import ABC, abstractmethod
@@ -30,6 +31,7 @@ from .dependencies import _check_for_pandas
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.fs as pa_fs
import numpy as np
from lance import LanceDataset
from lance.dependencies import _check_for_hugging_face
@@ -39,6 +41,8 @@ from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
from .merge import LanceMergeInsertBuilder
from .pydantic import LanceModel, model_to_dict
from .query import (
AsyncFTSQuery,
AsyncHybridQuery,
AsyncQuery,
AsyncVectorQuery,
LanceEmptyQueryBuilder,
@@ -2702,6 +2706,19 @@ class AsyncTable:
"""
return await self._inner.schema()
async def embedding_functions(self) -> Dict[str, EmbeddingFunctionConfig]:
"""
Get the embedding functions for the table
Returns
-------
funcs: Dict[str, EmbeddingFunctionConfig]
A mapping of the vector column to the embedding function
or empty dict if not configured.
"""
schema = await self.schema()
return EmbeddingFunctionRegistry.get_instance().parse_functions(schema.metadata)
async def count_rows(self, filter: Optional[str] = None) -> int:
"""
Count the number of rows in the table.
@@ -2931,6 +2948,234 @@ class AsyncTable:
return LanceMergeInsertBuilder(self, on)
@overload
async def search(
self,
query: Optional[Union[str]] = None,
vector_column_name: Optional[str] = None,
query_type: Literal["auto"] = ...,
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> Union[AsyncHybridQuery | AsyncFTSQuery | AsyncVectorQuery]: ...
@overload
async def search(
self,
query: Optional[Union[str]] = None,
vector_column_name: Optional[str] = None,
query_type: Literal["hybrid"] = ...,
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> AsyncHybridQuery: ...
@overload
async def search(
self,
query: Optional[Union[VEC, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: Literal["auto"] = ...,
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> AsyncVectorQuery: ...
@overload
async def search(
self,
query: Optional[str] = None,
vector_column_name: Optional[str] = None,
query_type: Literal["fts"] = ...,
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> AsyncFTSQuery: ...
@overload
async def search(
self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: Literal["vector"] = ...,
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> AsyncVectorQuery: ...
async def search(
self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: QueryType = "auto",
ordering_field_name: Optional[str] = None,
fts_columns: Optional[Union[str, List[str]]] = None,
) -> AsyncQuery:
"""Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search]
and [full-text search][experimental-full-text-search].
All query options are defined in [AsyncQuery][lancedb.query.AsyncQuery].
Parameters
----------
query: list/np.ndarray/str/PIL.Image.Image, default None
The targetted vector to search for.
- *default None*.
Acceptable types are: list, np.ndarray, PIL.Image.Image
- If None then the select/where/limit clauses are applied to filter
the table
vector_column_name: str, optional
The name of the vector column to search.
The vector column needs to be a pyarrow fixed size list type
- If not specified then the vector column is inferred from
the table schema
- If the table has multiple vector columns then the *vector_column_name*
needs to be specified. Otherwise, an error is raised.
query_type: str
*default "auto"*.
Acceptable types are: "vector", "fts", "hybrid", or "auto"
- If "auto" then the query type is inferred from the query;
- If `query` is a list/np.ndarray then the query type is
"vector";
- If `query` is a PIL.Image.Image then either do vector search,
or raise an error if no corresponding embedding function is found.
- If `query` is a string, then the query type is "vector" if the
table has embedding functions else the query type is "fts"
Returns
-------
LanceQueryBuilder
A query builder object representing the query.
"""
def is_embedding(query):
return isinstance(query, (list, np.ndarray, pa.Array, pa.ChunkedArray))
async def get_embedding_func(
vector_column_name: Optional[str],
query_type: QueryType,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]],
) -> Tuple[str, EmbeddingFunctionConfig]:
schema = await self.schema()
vector_column_name = infer_vector_column_name(
schema=schema,
query_type=query_type,
query=query,
vector_column_name=vector_column_name,
)
funcs = EmbeddingFunctionRegistry.get_instance().parse_functions(
schema.metadata
)
func = funcs.get(vector_column_name)
if func is None:
error = ValueError(
f"Column '{vector_column_name}' has no registered "
"embedding function."
)
if len(funcs) > 0:
add_note(
error,
"Embedding functions are registered for columns: "
f"{list(funcs.keys())}",
)
else:
add_note(
error, "No embedding functions are registered for any columns."
)
raise error
return vector_column_name, func
async def make_embedding(embedding, query):
if embedding is not None:
loop = asyncio.get_running_loop()
# This function is likely to block, since it either calls an expensive
# function or makes an HTTP request to an embeddings REST API.
return (
await loop.run_in_executor(
None,
embedding.function.compute_query_embeddings_with_retry,
query,
)
)[0]
else:
return None
if query_type == "auto":
# Infer the query type.
if is_embedding(query):
vector_query = query
query_type = "vector"
elif isinstance(query, str):
try:
(
indices,
(vector_column_name, embedding_conf),
) = await asyncio.gather(
self.list_indices(),
get_embedding_func(vector_column_name, "auto", query),
)
except ValueError as e:
if "Column" in str(
e
) and "has no registered embedding function" in str(e):
# If the column has no registered embedding function,
# then it's an FTS query.
query_type = "fts"
else:
raise e
else:
if embedding_conf is not None:
vector_query = await make_embedding(embedding_conf, query)
if any(
i.columns[0] == embedding_conf.source_column
and i.index_type == "FTS"
for i in indices
):
query_type = "hybrid"
else:
query_type = "vector"
else:
query_type = "fts"
else:
# it's an image or something else embeddable.
query_type = "vector"
elif query_type == "vector":
if is_embedding(query):
vector_query = query
else:
vector_column_name, embedding_conf = await get_embedding_func(
vector_column_name, query_type, query
)
vector_query = await make_embedding(embedding_conf, query)
elif query_type == "hybrid":
if is_embedding(query):
raise ValueError("Hybrid search requires a text query")
else:
vector_column_name, embedding_conf = await get_embedding_func(
vector_column_name, query_type, query
)
vector_query = await make_embedding(embedding_conf, query)
if query_type == "vector":
builder = self.query().nearest_to(vector_query)
if vector_column_name:
builder = builder.column(vector_column_name)
return builder
elif query_type == "fts":
return self.query().nearest_to_text(query, columns=fts_columns or [])
elif query_type == "hybrid":
builder = self.query().nearest_to(vector_query)
if vector_column_name:
builder = builder.column(vector_column_name)
return builder.nearest_to_text(query, columns=fts_columns or [])
else:
raise ValueError(f"Unknown query type: '{query_type}'")
def vector_search(
self,
query_vector: Union[VEC, Tuple],

View File

@@ -75,6 +75,6 @@ async def test_binary_vector_async():
query = np.random.randint(0, 2, size=256)
packed_query = np.packbits(query)
await tbl.query().nearest_to(packed_query).distance_type("hamming").to_arrow()
await (await tbl.search(packed_query)).distance_type("hamming").to_arrow()
# --8<-- [end:async_binary_vector]
await db.drop_table("my_binary_vectors")

View File

@@ -53,13 +53,13 @@ async def test_binary_vector_async():
query = np.random.random(256)
# Search for the vectors within the range of [0.1, 0.5)
await tbl.query().nearest_to(query).distance_range(0.1, 0.5).to_arrow()
await (await tbl.search(query)).distance_range(0.1, 0.5).to_arrow()
# Search for the vectors with the distance less than 0.5
await tbl.query().nearest_to(query).distance_range(upper_bound=0.5).to_arrow()
await (await tbl.search(query)).distance_range(upper_bound=0.5).to_arrow()
# Search for the vectors with the distance greater or equal to 0.1
await tbl.query().nearest_to(query).distance_range(lower_bound=0.1).to_arrow()
await (await tbl.search(query)).distance_range(lower_bound=0.1).to_arrow()
# --8<-- [end:async_distance_range]
await db.drop_table("my_table")

View File

@@ -28,3 +28,24 @@ def test_embeddings_openai():
actual = table.search(query).limit(1).to_pydantic(Words)[0]
print(actual.text)
# --8<-- [end:openai_embeddings]
@pytest.mark.slow
@pytest.mark.asyncio
async def test_embeddings_openai_async():
uri = "memory://"
# --8<-- [start:async_openai_embeddings]
db = await lancedb.connect_async(uri)
func = get_registry().get("openai").create(name="text-embedding-ada-002")
class Words(LanceModel):
text: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
table = await db.create_table("words", schema=Words, mode="overwrite")
await table.add([{"text": "hello world"}, {"text": "goodbye world"}])
query = "greetings"
actual = await (await table.search(query)).limit(1).to_pydantic(Words)[0]
print(actual.text)
# --8<-- [end:async_openai_embeddings]

View File

@@ -72,8 +72,7 @@ async def test_ann_index_async():
# --8<-- [end:create_ann_index_async]
# --8<-- [start:vector_search_async]
await (
async_tbl.query()
.nearest_to(np.random.random((32)))
(await async_tbl.search(np.random.random((32))))
.limit(2)
.nprobes(20)
.refine_factor(10)
@@ -82,18 +81,14 @@ async def test_ann_index_async():
# --8<-- [end:vector_search_async]
# --8<-- [start:vector_search_async_with_filter]
await (
async_tbl.query()
.nearest_to(np.random.random((32)))
(await async_tbl.search(np.random.random((32))))
.where("item != 'item 1141'")
.to_pandas()
)
# --8<-- [end:vector_search_async_with_filter]
# --8<-- [start:vector_search_async_with_select]
await (
async_tbl.query()
.nearest_to(np.random.random((32)))
.select(["vector"])
.to_pandas()
(await async_tbl.search(np.random.random((32)))).select(["vector"]).to_pandas()
)
# --8<-- [end:vector_search_async_with_select]
@@ -164,7 +159,7 @@ async def test_scalar_index_async():
{"book_id": 3, "vector": [5.0, 6]},
]
async_tbl = await async_db.create_table("book_with_embeddings_async", data)
(await async_tbl.query().where("book_id != 3").nearest_to([1, 2]).to_pandas())
(await (await async_tbl.search([1, 2])).where("book_id != 3").to_pandas())
# --8<-- [end:vector_search_with_scalar_index_async]
# --8<-- [start:update_scalar_index_async]
await async_tbl.add([{"vector": [7, 8], "book_id": 4}])

View File

@@ -126,19 +126,17 @@ async def test_pandas_and_pyarrow_async():
query_vector = [100, 100]
# Pandas DataFrame
df = await async_tbl.query().nearest_to(query_vector).limit(1).to_pandas()
df = await (await async_tbl.search(query_vector)).limit(1).to_pandas()
print(df)
# --8<-- [end:vector_search_async]
# --8<-- [start:vector_search_with_filter_async]
# Apply the filter via LanceDB
results = (
await async_tbl.query().nearest_to([100, 100]).where("price < 15").to_pandas()
)
results = await (await async_tbl.search([100, 100])).where("price < 15").to_pandas()
assert len(results) == 1
assert results["item"].iloc[0] == "foo"
# Apply the filter via Pandas
df = results = await async_tbl.query().nearest_to([100, 100]).to_pandas()
df = results = await (await async_tbl.search([100, 100])).to_pandas()
results = df[df.price < 15]
assert len(results) == 1
assert results["item"].iloc[0] == "foo"
@@ -188,3 +186,26 @@ def test_polars():
# --8<-- [start:print_table_lazyform]
print(ldf.first().collect())
# --8<-- [end:print_table_lazyform]
@pytest.mark.asyncio
async def test_polars_async():
uri = "data/sample-lancedb"
db = await lancedb.connect_async(uri)
# --8<-- [start:create_table_polars_async]
data = pl.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["foo", "bar"],
"price": [10.0, 20.0],
}
)
table = await db.create_table("pl_table_async", data=data)
# --8<-- [end:create_table_polars_async]
# --8<-- [start:vector_search_polars_async]
query = [3.0, 4.0]
result = await (await table.search(query)).limit(1).to_polars()
print(result)
print(type(result))
# --8<-- [end:vector_search_polars_async]

View File

@@ -117,12 +117,11 @@ async def test_vector_search_async():
for i, row in enumerate(np.random.random((10_000, 1536)).astype("float32"))
]
async_tbl = await async_db.create_table("vector_search_async", data=data)
(await async_tbl.query().nearest_to(np.random.random((1536))).limit(10).to_list())
(await (await async_tbl.search(np.random.random((1536)))).limit(10).to_list())
# --8<-- [end:exhaustive_search_async]
# --8<-- [start:exhaustive_search_async_cosine]
(
await async_tbl.query()
.nearest_to(np.random.random((1536)))
await (await async_tbl.search(np.random.random((1536))))
.distance_type("cosine")
.limit(10)
.to_list()
@@ -145,13 +144,13 @@ async def test_vector_search_async():
async_tbl = await async_db.create_table("documents_async", data=data)
# --8<-- [end:create_table_async_with_nested_schema]
# --8<-- [start:search_result_async_as_pyarrow]
await async_tbl.query().nearest_to(np.random.randn(1536)).to_arrow()
await (await async_tbl.search(np.random.randn(1536))).to_arrow()
# --8<-- [end:search_result_async_as_pyarrow]
# --8<-- [start:search_result_async_as_pandas]
await async_tbl.query().nearest_to(np.random.randn(1536)).to_pandas()
await (await async_tbl.search(np.random.randn(1536))).to_pandas()
# --8<-- [end:search_result_async_as_pandas]
# --8<-- [start:search_result_async_as_list]
await async_tbl.query().nearest_to(np.random.randn(1536)).to_list()
await (await async_tbl.search(np.random.randn(1536))).to_list()
# --8<-- [end:search_result_async_as_list]
@@ -219,9 +218,7 @@ async def test_fts_native_async():
# async API uses our native FTS algorithm
await async_tbl.create_index("text", config=FTS())
await (
async_tbl.query().nearest_to_text("puppy").select(["text"]).limit(10).to_list()
)
await (await async_tbl.search("puppy")).select(["text"]).limit(10).to_list()
# [{'text': 'Frodo was a happy puppy', '_score': 0.6931471824645996}]
# ...
# --8<-- [end:basic_fts_async]
@@ -235,18 +232,11 @@ async def test_fts_native_async():
)
# --8<-- [end:fts_config_folding_async]
# --8<-- [start:fts_prefiltering_async]
await (
async_tbl.query()
.nearest_to_text("puppy")
.limit(10)
.where("text='foo'")
.to_list()
)
await (await async_tbl.search("puppy")).limit(10).where("text='foo'").to_list()
# --8<-- [end:fts_prefiltering_async]
# --8<-- [start:fts_postfiltering_async]
await (
async_tbl.query()
.nearest_to_text("puppy")
(await async_tbl.search("puppy"))
.limit(10)
.where("text='foo'")
.postfilter()
@@ -347,14 +337,8 @@ async def test_hybrid_search_async():
# Create a fts index before the hybrid search
await async_tbl.create_index("text", config=FTS())
text_query = "flower moon"
vector_query = embeddings.compute_query_embeddings(text_query)[0]
# hybrid search with default re-ranker
await (
async_tbl.query()
.nearest_to(vector_query)
.nearest_to_text(text_query)
.to_pandas()
)
await (await async_tbl.search("flower moon", query_type="hybrid")).to_pandas()
# --8<-- [end:basic_hybrid_search_async]
# --8<-- [start:hybrid_search_pass_vector_text_async]
vector_query = [0.1, 0.2, 0.3, 0.4, 0.5]

View File

@@ -1,25 +1,35 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from typing import List, Union
import unittest.mock as mock
from datetime import timedelta
from pathlib import Path
import lancedb
from lancedb.index import IvfPq, FTS
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
from lancedb.db import AsyncConnection
from lancedb.embeddings.base import TextEmbeddingFunction
from lancedb.embeddings.registry import get_registry, register
from lancedb.index import FTS, IvfPq
import lancedb.pydantic
import numpy as np
import pandas.testing as tm
import pyarrow as pa
import pyarrow.compute as pc
import pytest
import pytest_asyncio
from lancedb.pydantic import LanceModel, Vector
from lancedb.query import (
AsyncFTSQuery,
AsyncHybridQuery,
AsyncQueryBase,
AsyncVectorQuery,
LanceVectorQueryBuilder,
Query,
)
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
from lancedb.table import AsyncTable, LanceTable
from utils import exception_output
@pytest.fixture(scope="module")
@@ -716,3 +726,101 @@ async def test_query_with_f16(tmp_path: Path):
tbl = await db.create_table("test", df)
results = await tbl.vector_search([np.float16(1), np.float16(2)]).to_pandas()
assert len(results) == 2
@pytest.mark.asyncio
async def test_query_search_auto(mem_db_async: AsyncConnection):
nrows = 1000
data = pa.table(
{
"text": [str(i) for i in range(nrows)],
}
)
@register("test2")
class TestEmbedding(TextEmbeddingFunction):
def ndims(self):
return 4
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
embeddings = []
for text in texts:
vec = np.array([float(text) / 1000] * self.ndims())
embeddings.append(vec)
return embeddings
registry = get_registry()
func = registry.get("test2").create()
class TestModel(LanceModel):
text: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
tbl = await mem_db_async.create_table("test", data, schema=TestModel)
funcs = await tbl.embedding_functions()
assert len(funcs) == 1
# No FTS or vector index
# Search for vector -> vector query
q = [0.1] * 4
query = await tbl.search(q)
assert isinstance(query, AsyncVectorQuery)
# Search for string -> vector query
query = await tbl.search("0.1")
assert isinstance(query, AsyncVectorQuery)
await tbl.create_index("text", config=FTS())
query = await tbl.search("0.1")
assert isinstance(query, AsyncHybridQuery)
data_with_vecs = await tbl.to_arrow()
data_with_vecs = data_with_vecs.replace_schema_metadata(None)
tbl2 = await mem_db_async.create_table("test2", data_with_vecs)
with pytest.raises(
Exception,
match=(
"Cannot perform full text search unless an INVERTED index has "
"been created"
),
):
query = await (await tbl2.search("0.1")).to_arrow()
@pytest.mark.asyncio
async def test_query_search_specified(mem_db_async: AsyncConnection):
nrows, ndims = 1000, 16
data = pa.table(
{
"text": [str(i) for i in range(nrows)],
"vector": pa.FixedSizeListArray.from_arrays(
pc.random(nrows * ndims).cast(pa.float32()), ndims
),
}
)
table = await mem_db_async.create_table("test", data)
await table.create_index("text", config=FTS())
# Validate that specifying fts, vector or hybrid gets the right query.
q = [0.1] * ndims
query = await table.search(q, query_type="vector")
assert isinstance(query, AsyncVectorQuery)
query = await table.search("0.1", query_type="fts")
assert isinstance(query, AsyncFTSQuery)
with pytest.raises(ValueError, match="Unknown query type: 'foo'"):
await table.search("0.1", query_type="foo")
with pytest.raises(
ValueError, match="Column 'vector' has no registered embedding function"
) as e:
await table.search("0.1", query_type="vector")
assert "No embedding functions are registered for any columns" in exception_output(
e
)

View File

@@ -338,6 +338,7 @@ def test_query_sync_empty_query():
"filter": "true",
"vector": [],
"columns": ["id"],
"prefilter": False,
"version": None,
}
@@ -412,6 +413,7 @@ def test_query_sync_fts():
"columns": [],
},
"k": 10,
"prefilter": True,
"vector": [],
"version": None,
}
@@ -429,6 +431,7 @@ def test_query_sync_fts():
},
"k": 42,
"vector": [],
"prefilter": True,
"with_row_id": True,
"version": None,
}
@@ -455,6 +458,7 @@ def test_query_sync_hybrid():
},
"k": 42,
"vector": [],
"prefilter": True,
"with_row_id": True,
"version": None,
}