mirror of
https://github.com/lancedb/lancedb.git
synced 2026-07-03 19:10:41 +00:00
feat(python): add search() method to async API (#2049)
Reviving #1966. Closes #1938 The `search()` method can apply embeddings for the user. This simplifies hybrid search, so instead of writing: ```python vector_query = embeddings.compute_query_embeddings("flower moon")[0] await ( async_tbl.query() .nearest_to(vector_query) .nearest_to_text("flower moon") .to_pandas() ) ``` You can write: ```python await (await async_tbl.search("flower moon", query_type="hybrid")).to_pandas() ``` Unfortunately, we had to do a double-await here because `search()` needs to be async. This is because it often needs to do IO to retrieve and run an embedding function.
This commit is contained in:
@@ -1,25 +1,35 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from typing import List, Union
|
||||
import unittest.mock as mock
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import lancedb
|
||||
from lancedb.index import IvfPq, FTS
|
||||
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
|
||||
from lancedb.db import AsyncConnection
|
||||
from lancedb.embeddings.base import TextEmbeddingFunction
|
||||
from lancedb.embeddings.registry import get_registry, register
|
||||
from lancedb.index import FTS, IvfPq
|
||||
import lancedb.pydantic
|
||||
import numpy as np
|
||||
import pandas.testing as tm
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.query import (
|
||||
AsyncFTSQuery,
|
||||
AsyncHybridQuery,
|
||||
AsyncQueryBase,
|
||||
AsyncVectorQuery,
|
||||
LanceVectorQueryBuilder,
|
||||
Query,
|
||||
)
|
||||
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
|
||||
from lancedb.table import AsyncTable, LanceTable
|
||||
from utils import exception_output
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -716,3 +726,101 @@ async def test_query_with_f16(tmp_path: Path):
|
||||
tbl = await db.create_table("test", df)
|
||||
results = await tbl.vector_search([np.float16(1), np.float16(2)]).to_pandas()
|
||||
assert len(results) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_search_auto(mem_db_async: AsyncConnection):
|
||||
nrows = 1000
|
||||
data = pa.table(
|
||||
{
|
||||
"text": [str(i) for i in range(nrows)],
|
||||
}
|
||||
)
|
||||
|
||||
@register("test2")
|
||||
class TestEmbedding(TextEmbeddingFunction):
|
||||
def ndims(self):
|
||||
return 4
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
vec = np.array([float(text) / 1000] * self.ndims())
|
||||
embeddings.append(vec)
|
||||
return embeddings
|
||||
|
||||
registry = get_registry()
|
||||
func = registry.get("test2").create()
|
||||
|
||||
class TestModel(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
tbl = await mem_db_async.create_table("test", data, schema=TestModel)
|
||||
|
||||
funcs = await tbl.embedding_functions()
|
||||
assert len(funcs) == 1
|
||||
|
||||
# No FTS or vector index
|
||||
# Search for vector -> vector query
|
||||
q = [0.1] * 4
|
||||
query = await tbl.search(q)
|
||||
assert isinstance(query, AsyncVectorQuery)
|
||||
|
||||
# Search for string -> vector query
|
||||
query = await tbl.search("0.1")
|
||||
assert isinstance(query, AsyncVectorQuery)
|
||||
|
||||
await tbl.create_index("text", config=FTS())
|
||||
|
||||
query = await tbl.search("0.1")
|
||||
assert isinstance(query, AsyncHybridQuery)
|
||||
|
||||
data_with_vecs = await tbl.to_arrow()
|
||||
data_with_vecs = data_with_vecs.replace_schema_metadata(None)
|
||||
tbl2 = await mem_db_async.create_table("test2", data_with_vecs)
|
||||
with pytest.raises(
|
||||
Exception,
|
||||
match=(
|
||||
"Cannot perform full text search unless an INVERTED index has "
|
||||
"been created"
|
||||
),
|
||||
):
|
||||
query = await (await tbl2.search("0.1")).to_arrow()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_search_specified(mem_db_async: AsyncConnection):
|
||||
nrows, ndims = 1000, 16
|
||||
data = pa.table(
|
||||
{
|
||||
"text": [str(i) for i in range(nrows)],
|
||||
"vector": pa.FixedSizeListArray.from_arrays(
|
||||
pc.random(nrows * ndims).cast(pa.float32()), ndims
|
||||
),
|
||||
}
|
||||
)
|
||||
table = await mem_db_async.create_table("test", data)
|
||||
await table.create_index("text", config=FTS())
|
||||
|
||||
# Validate that specifying fts, vector or hybrid gets the right query.
|
||||
q = [0.1] * ndims
|
||||
query = await table.search(q, query_type="vector")
|
||||
assert isinstance(query, AsyncVectorQuery)
|
||||
|
||||
query = await table.search("0.1", query_type="fts")
|
||||
assert isinstance(query, AsyncFTSQuery)
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown query type: 'foo'"):
|
||||
await table.search("0.1", query_type="foo")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Column 'vector' has no registered embedding function"
|
||||
) as e:
|
||||
await table.search("0.1", query_type="vector")
|
||||
|
||||
assert "No embedding functions are registered for any columns" in exception_output(
|
||||
e
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user