feat(python): add search() method to async API (#2049)

Reviving #1966.

Closes #1938

The `search()` method can apply embeddings for the user. This simplifies
hybrid search, so instead of writing:

```python
vector_query = embeddings.compute_query_embeddings("flower moon")[0]
await (
    async_tbl.query()
    .nearest_to(vector_query)
    .nearest_to_text("flower moon")
    .to_pandas()
)
```

You can write:

```python
await (await async_tbl.search("flower moon", query_type="hybrid")).to_pandas()
```

Unfortunately, we had to do a double-await here because `search()` needs
to be async. This is because it often needs to do IO to retrieve and run
an embedding function.
This commit is contained in:
Will Jones
2025-02-24 14:19:25 -08:00
committed by GitHub
parent f391ed828a
commit ecdee4d2b1
10 changed files with 461 additions and 57 deletions

View File

@@ -75,6 +75,6 @@ async def test_binary_vector_async():
query = np.random.randint(0, 2, size=256)
packed_query = np.packbits(query)
await tbl.query().nearest_to(packed_query).distance_type("hamming").to_arrow()
await (await tbl.search(packed_query)).distance_type("hamming").to_arrow()
# --8<-- [end:async_binary_vector]
await db.drop_table("my_binary_vectors")

View File

@@ -53,13 +53,13 @@ async def test_binary_vector_async():
query = np.random.random(256)
# Search for the vectors within the range of [0.1, 0.5)
await tbl.query().nearest_to(query).distance_range(0.1, 0.5).to_arrow()
await (await tbl.search(query)).distance_range(0.1, 0.5).to_arrow()
# Search for the vectors with the distance less than 0.5
await tbl.query().nearest_to(query).distance_range(upper_bound=0.5).to_arrow()
await (await tbl.search(query)).distance_range(upper_bound=0.5).to_arrow()
# Search for the vectors with the distance greater or equal to 0.1
await tbl.query().nearest_to(query).distance_range(lower_bound=0.1).to_arrow()
await (await tbl.search(query)).distance_range(lower_bound=0.1).to_arrow()
# --8<-- [end:async_distance_range]
await db.drop_table("my_table")

View File

@@ -28,3 +28,24 @@ def test_embeddings_openai():
actual = table.search(query).limit(1).to_pydantic(Words)[0]
print(actual.text)
# --8<-- [end:openai_embeddings]
@pytest.mark.slow
@pytest.mark.asyncio
async def test_embeddings_openai_async():
uri = "memory://"
# --8<-- [start:async_openai_embeddings]
db = await lancedb.connect_async(uri)
func = get_registry().get("openai").create(name="text-embedding-ada-002")
class Words(LanceModel):
text: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
table = await db.create_table("words", schema=Words, mode="overwrite")
await table.add([{"text": "hello world"}, {"text": "goodbye world"}])
query = "greetings"
actual = await (await table.search(query)).limit(1).to_pydantic(Words)[0]
print(actual.text)
# --8<-- [end:async_openai_embeddings]

View File

@@ -72,8 +72,7 @@ async def test_ann_index_async():
# --8<-- [end:create_ann_index_async]
# --8<-- [start:vector_search_async]
await (
async_tbl.query()
.nearest_to(np.random.random((32)))
(await async_tbl.search(np.random.random((32))))
.limit(2)
.nprobes(20)
.refine_factor(10)
@@ -82,18 +81,14 @@ async def test_ann_index_async():
# --8<-- [end:vector_search_async]
# --8<-- [start:vector_search_async_with_filter]
await (
async_tbl.query()
.nearest_to(np.random.random((32)))
(await async_tbl.search(np.random.random((32))))
.where("item != 'item 1141'")
.to_pandas()
)
# --8<-- [end:vector_search_async_with_filter]
# --8<-- [start:vector_search_async_with_select]
await (
async_tbl.query()
.nearest_to(np.random.random((32)))
.select(["vector"])
.to_pandas()
(await async_tbl.search(np.random.random((32)))).select(["vector"]).to_pandas()
)
# --8<-- [end:vector_search_async_with_select]
@@ -164,7 +159,7 @@ async def test_scalar_index_async():
{"book_id": 3, "vector": [5.0, 6]},
]
async_tbl = await async_db.create_table("book_with_embeddings_async", data)
(await async_tbl.query().where("book_id != 3").nearest_to([1, 2]).to_pandas())
(await (await async_tbl.search([1, 2])).where("book_id != 3").to_pandas())
# --8<-- [end:vector_search_with_scalar_index_async]
# --8<-- [start:update_scalar_index_async]
await async_tbl.add([{"vector": [7, 8], "book_id": 4}])

View File

@@ -126,19 +126,17 @@ async def test_pandas_and_pyarrow_async():
query_vector = [100, 100]
# Pandas DataFrame
df = await async_tbl.query().nearest_to(query_vector).limit(1).to_pandas()
df = await (await async_tbl.search(query_vector)).limit(1).to_pandas()
print(df)
# --8<-- [end:vector_search_async]
# --8<-- [start:vector_search_with_filter_async]
# Apply the filter via LanceDB
results = (
await async_tbl.query().nearest_to([100, 100]).where("price < 15").to_pandas()
)
results = await (await async_tbl.search([100, 100])).where("price < 15").to_pandas()
assert len(results) == 1
assert results["item"].iloc[0] == "foo"
# Apply the filter via Pandas
df = results = await async_tbl.query().nearest_to([100, 100]).to_pandas()
df = results = await (await async_tbl.search([100, 100])).to_pandas()
results = df[df.price < 15]
assert len(results) == 1
assert results["item"].iloc[0] == "foo"
@@ -188,3 +186,26 @@ def test_polars():
# --8<-- [start:print_table_lazyform]
print(ldf.first().collect())
# --8<-- [end:print_table_lazyform]
@pytest.mark.asyncio
async def test_polars_async():
uri = "data/sample-lancedb"
db = await lancedb.connect_async(uri)
# --8<-- [start:create_table_polars_async]
data = pl.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["foo", "bar"],
"price": [10.0, 20.0],
}
)
table = await db.create_table("pl_table_async", data=data)
# --8<-- [end:create_table_polars_async]
# --8<-- [start:vector_search_polars_async]
query = [3.0, 4.0]
result = await (await table.search(query)).limit(1).to_polars()
print(result)
print(type(result))
# --8<-- [end:vector_search_polars_async]

View File

@@ -117,12 +117,11 @@ async def test_vector_search_async():
for i, row in enumerate(np.random.random((10_000, 1536)).astype("float32"))
]
async_tbl = await async_db.create_table("vector_search_async", data=data)
(await async_tbl.query().nearest_to(np.random.random((1536))).limit(10).to_list())
(await (await async_tbl.search(np.random.random((1536)))).limit(10).to_list())
# --8<-- [end:exhaustive_search_async]
# --8<-- [start:exhaustive_search_async_cosine]
(
await async_tbl.query()
.nearest_to(np.random.random((1536)))
await (await async_tbl.search(np.random.random((1536))))
.distance_type("cosine")
.limit(10)
.to_list()
@@ -145,13 +144,13 @@ async def test_vector_search_async():
async_tbl = await async_db.create_table("documents_async", data=data)
# --8<-- [end:create_table_async_with_nested_schema]
# --8<-- [start:search_result_async_as_pyarrow]
await async_tbl.query().nearest_to(np.random.randn(1536)).to_arrow()
await (await async_tbl.search(np.random.randn(1536))).to_arrow()
# --8<-- [end:search_result_async_as_pyarrow]
# --8<-- [start:search_result_async_as_pandas]
await async_tbl.query().nearest_to(np.random.randn(1536)).to_pandas()
await (await async_tbl.search(np.random.randn(1536))).to_pandas()
# --8<-- [end:search_result_async_as_pandas]
# --8<-- [start:search_result_async_as_list]
await async_tbl.query().nearest_to(np.random.randn(1536)).to_list()
await (await async_tbl.search(np.random.randn(1536))).to_list()
# --8<-- [end:search_result_async_as_list]
@@ -219,9 +218,7 @@ async def test_fts_native_async():
# async API uses our native FTS algorithm
await async_tbl.create_index("text", config=FTS())
await (
async_tbl.query().nearest_to_text("puppy").select(["text"]).limit(10).to_list()
)
await (await async_tbl.search("puppy")).select(["text"]).limit(10).to_list()
# [{'text': 'Frodo was a happy puppy', '_score': 0.6931471824645996}]
# ...
# --8<-- [end:basic_fts_async]
@@ -235,18 +232,11 @@ async def test_fts_native_async():
)
# --8<-- [end:fts_config_folding_async]
# --8<-- [start:fts_prefiltering_async]
await (
async_tbl.query()
.nearest_to_text("puppy")
.limit(10)
.where("text='foo'")
.to_list()
)
await (await async_tbl.search("puppy")).limit(10).where("text='foo'").to_list()
# --8<-- [end:fts_prefiltering_async]
# --8<-- [start:fts_postfiltering_async]
await (
async_tbl.query()
.nearest_to_text("puppy")
(await async_tbl.search("puppy"))
.limit(10)
.where("text='foo'")
.postfilter()
@@ -347,14 +337,8 @@ async def test_hybrid_search_async():
# Create a fts index before the hybrid search
await async_tbl.create_index("text", config=FTS())
text_query = "flower moon"
vector_query = embeddings.compute_query_embeddings(text_query)[0]
# hybrid search with default re-ranker
await (
async_tbl.query()
.nearest_to(vector_query)
.nearest_to_text(text_query)
.to_pandas()
)
await (await async_tbl.search("flower moon", query_type="hybrid")).to_pandas()
# --8<-- [end:basic_hybrid_search_async]
# --8<-- [start:hybrid_search_pass_vector_text_async]
vector_query = [0.1, 0.2, 0.3, 0.4, 0.5]

View File

@@ -1,25 +1,35 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from typing import List, Union
import unittest.mock as mock
from datetime import timedelta
from pathlib import Path
import lancedb
from lancedb.index import IvfPq, FTS
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
from lancedb.db import AsyncConnection
from lancedb.embeddings.base import TextEmbeddingFunction
from lancedb.embeddings.registry import get_registry, register
from lancedb.index import FTS, IvfPq
import lancedb.pydantic
import numpy as np
import pandas.testing as tm
import pyarrow as pa
import pyarrow.compute as pc
import pytest
import pytest_asyncio
from lancedb.pydantic import LanceModel, Vector
from lancedb.query import (
AsyncFTSQuery,
AsyncHybridQuery,
AsyncQueryBase,
AsyncVectorQuery,
LanceVectorQueryBuilder,
Query,
)
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
from lancedb.table import AsyncTable, LanceTable
from utils import exception_output
@pytest.fixture(scope="module")
@@ -716,3 +726,101 @@ async def test_query_with_f16(tmp_path: Path):
tbl = await db.create_table("test", df)
results = await tbl.vector_search([np.float16(1), np.float16(2)]).to_pandas()
assert len(results) == 2
@pytest.mark.asyncio
async def test_query_search_auto(mem_db_async: AsyncConnection):
nrows = 1000
data = pa.table(
{
"text": [str(i) for i in range(nrows)],
}
)
@register("test2")
class TestEmbedding(TextEmbeddingFunction):
def ndims(self):
return 4
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
embeddings = []
for text in texts:
vec = np.array([float(text) / 1000] * self.ndims())
embeddings.append(vec)
return embeddings
registry = get_registry()
func = registry.get("test2").create()
class TestModel(LanceModel):
text: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
tbl = await mem_db_async.create_table("test", data, schema=TestModel)
funcs = await tbl.embedding_functions()
assert len(funcs) == 1
# No FTS or vector index
# Search for vector -> vector query
q = [0.1] * 4
query = await tbl.search(q)
assert isinstance(query, AsyncVectorQuery)
# Search for string -> vector query
query = await tbl.search("0.1")
assert isinstance(query, AsyncVectorQuery)
await tbl.create_index("text", config=FTS())
query = await tbl.search("0.1")
assert isinstance(query, AsyncHybridQuery)
data_with_vecs = await tbl.to_arrow()
data_with_vecs = data_with_vecs.replace_schema_metadata(None)
tbl2 = await mem_db_async.create_table("test2", data_with_vecs)
with pytest.raises(
Exception,
match=(
"Cannot perform full text search unless an INVERTED index has "
"been created"
),
):
query = await (await tbl2.search("0.1")).to_arrow()
@pytest.mark.asyncio
async def test_query_search_specified(mem_db_async: AsyncConnection):
nrows, ndims = 1000, 16
data = pa.table(
{
"text": [str(i) for i in range(nrows)],
"vector": pa.FixedSizeListArray.from_arrays(
pc.random(nrows * ndims).cast(pa.float32()), ndims
),
}
)
table = await mem_db_async.create_table("test", data)
await table.create_index("text", config=FTS())
# Validate that specifying fts, vector or hybrid gets the right query.
q = [0.1] * ndims
query = await table.search(q, query_type="vector")
assert isinstance(query, AsyncVectorQuery)
query = await table.search("0.1", query_type="fts")
assert isinstance(query, AsyncFTSQuery)
with pytest.raises(ValueError, match="Unknown query type: 'foo'"):
await table.search("0.1", query_type="foo")
with pytest.raises(
ValueError, match="Column 'vector' has no registered embedding function"
) as e:
await table.search("0.1", query_type="vector")
assert "No embedding functions are registered for any columns" in exception_output(
e
)

View File

@@ -338,6 +338,7 @@ def test_query_sync_empty_query():
"filter": "true",
"vector": [],
"columns": ["id"],
"prefilter": False,
"version": None,
}
@@ -412,6 +413,7 @@ def test_query_sync_fts():
"columns": [],
},
"k": 10,
"prefilter": True,
"vector": [],
"version": None,
}
@@ -429,6 +431,7 @@ def test_query_sync_fts():
},
"k": 42,
"vector": [],
"prefilter": True,
"with_row_id": True,
"version": None,
}
@@ -455,6 +458,7 @@ def test_query_sync_hybrid():
},
"k": 42,
"vector": [],
"prefilter": True,
"with_row_id": True,
"version": None,
}