feat(python): add search() method to async API (#2049)

Reviving #1966.

Closes #1938

The `search()` method can apply embeddings for the user. This simplifies
hybrid search, so instead of writing:

```python
vector_query = embeddings.compute_query_embeddings("flower moon")[0]
await (
    async_tbl.query()
    .nearest_to(vector_query)
    .nearest_to_text("flower moon")
    .to_pandas()
)
```

You can write:

```python
await (await async_tbl.search("flower moon", query_type="hybrid")).to_pandas()
```

Unfortunately, we had to do a double-await here because `search()` needs
to be async. This is because it often needs to do IO to retrieve and run
an embedding function.
This commit is contained in:
Will Jones
2025-02-24 14:19:25 -08:00
committed by GitHub
parent f391ed828a
commit ecdee4d2b1
10 changed files with 461 additions and 57 deletions

View File

@@ -1,25 +1,35 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from typing import List, Union
import unittest.mock as mock
from datetime import timedelta
from pathlib import Path
import lancedb
from lancedb.index import IvfPq, FTS
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
from lancedb.db import AsyncConnection
from lancedb.embeddings.base import TextEmbeddingFunction
from lancedb.embeddings.registry import get_registry, register
from lancedb.index import FTS, IvfPq
import lancedb.pydantic
import numpy as np
import pandas.testing as tm
import pyarrow as pa
import pyarrow.compute as pc
import pytest
import pytest_asyncio
from lancedb.pydantic import LanceModel, Vector
from lancedb.query import (
AsyncFTSQuery,
AsyncHybridQuery,
AsyncQueryBase,
AsyncVectorQuery,
LanceVectorQueryBuilder,
Query,
)
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
from lancedb.table import AsyncTable, LanceTable
from utils import exception_output
@pytest.fixture(scope="module")
@@ -716,3 +726,101 @@ async def test_query_with_f16(tmp_path: Path):
tbl = await db.create_table("test", df)
results = await tbl.vector_search([np.float16(1), np.float16(2)]).to_pandas()
assert len(results) == 2
@pytest.mark.asyncio
async def test_query_search_auto(mem_db_async: AsyncConnection):
nrows = 1000
data = pa.table(
{
"text": [str(i) for i in range(nrows)],
}
)
@register("test2")
class TestEmbedding(TextEmbeddingFunction):
def ndims(self):
return 4
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
embeddings = []
for text in texts:
vec = np.array([float(text) / 1000] * self.ndims())
embeddings.append(vec)
return embeddings
registry = get_registry()
func = registry.get("test2").create()
class TestModel(LanceModel):
text: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
tbl = await mem_db_async.create_table("test", data, schema=TestModel)
funcs = await tbl.embedding_functions()
assert len(funcs) == 1
# No FTS or vector index
# Search for vector -> vector query
q = [0.1] * 4
query = await tbl.search(q)
assert isinstance(query, AsyncVectorQuery)
# Search for string -> vector query
query = await tbl.search("0.1")
assert isinstance(query, AsyncVectorQuery)
await tbl.create_index("text", config=FTS())
query = await tbl.search("0.1")
assert isinstance(query, AsyncHybridQuery)
data_with_vecs = await tbl.to_arrow()
data_with_vecs = data_with_vecs.replace_schema_metadata(None)
tbl2 = await mem_db_async.create_table("test2", data_with_vecs)
with pytest.raises(
Exception,
match=(
"Cannot perform full text search unless an INVERTED index has "
"been created"
),
):
query = await (await tbl2.search("0.1")).to_arrow()
@pytest.mark.asyncio
async def test_query_search_specified(mem_db_async: AsyncConnection):
nrows, ndims = 1000, 16
data = pa.table(
{
"text": [str(i) for i in range(nrows)],
"vector": pa.FixedSizeListArray.from_arrays(
pc.random(nrows * ndims).cast(pa.float32()), ndims
),
}
)
table = await mem_db_async.create_table("test", data)
await table.create_index("text", config=FTS())
# Validate that specifying fts, vector or hybrid gets the right query.
q = [0.1] * ndims
query = await table.search(q, query_type="vector")
assert isinstance(query, AsyncVectorQuery)
query = await table.search("0.1", query_type="fts")
assert isinstance(query, AsyncFTSQuery)
with pytest.raises(ValueError, match="Unknown query type: 'foo'"):
await table.search("0.1", query_type="foo")
with pytest.raises(
ValueError, match="Column 'vector' has no registered embedding function"
) as e:
await table.search("0.1", query_type="vector")
assert "No embedding functions are registered for any columns" in exception_output(
e
)