mirror of
https://github.com/lancedb/lancedb.git
synced 2026-06-07 14:20:39 +00:00
feat(python): add search() method to async API (#2049)
Reviving #1966. Closes #1938 The `search()` method can apply embeddings for the user. This simplifies hybrid search, so instead of writing: ```python vector_query = embeddings.compute_query_embeddings("flower moon")[0] await ( async_tbl.query() .nearest_to(vector_query) .nearest_to_text("flower moon") .to_pandas() ) ``` You can write: ```python await (await async_tbl.search("flower moon", query_type="hybrid")).to_pandas() ``` Unfortunately, we had to do a double-await here because `search()` needs to be async. This is because it often needs to do IO to retrieve and run an embedding function.
This commit is contained in:
@@ -75,6 +75,6 @@ async def test_binary_vector_async():
|
||||
|
||||
query = np.random.randint(0, 2, size=256)
|
||||
packed_query = np.packbits(query)
|
||||
await tbl.query().nearest_to(packed_query).distance_type("hamming").to_arrow()
|
||||
await (await tbl.search(packed_query)).distance_type("hamming").to_arrow()
|
||||
# --8<-- [end:async_binary_vector]
|
||||
await db.drop_table("my_binary_vectors")
|
||||
|
||||
@@ -53,13 +53,13 @@ async def test_binary_vector_async():
|
||||
query = np.random.random(256)
|
||||
|
||||
# Search for the vectors within the range of [0.1, 0.5)
|
||||
await tbl.query().nearest_to(query).distance_range(0.1, 0.5).to_arrow()
|
||||
await (await tbl.search(query)).distance_range(0.1, 0.5).to_arrow()
|
||||
|
||||
# Search for the vectors with the distance less than 0.5
|
||||
await tbl.query().nearest_to(query).distance_range(upper_bound=0.5).to_arrow()
|
||||
await (await tbl.search(query)).distance_range(upper_bound=0.5).to_arrow()
|
||||
|
||||
# Search for the vectors with the distance greater or equal to 0.1
|
||||
await tbl.query().nearest_to(query).distance_range(lower_bound=0.1).to_arrow()
|
||||
await (await tbl.search(query)).distance_range(lower_bound=0.1).to_arrow()
|
||||
|
||||
# --8<-- [end:async_distance_range]
|
||||
await db.drop_table("my_table")
|
||||
|
||||
@@ -28,3 +28,24 @@ def test_embeddings_openai():
|
||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||
print(actual.text)
|
||||
# --8<-- [end:openai_embeddings]
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_embeddings_openai_async():
|
||||
uri = "memory://"
|
||||
# --8<-- [start:async_openai_embeddings]
|
||||
db = await lancedb.connect_async(uri)
|
||||
func = get_registry().get("openai").create(name="text-embedding-ada-002")
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = await db.create_table("words", schema=Words, mode="overwrite")
|
||||
await table.add([{"text": "hello world"}, {"text": "goodbye world"}])
|
||||
|
||||
query = "greetings"
|
||||
actual = await (await table.search(query)).limit(1).to_pydantic(Words)[0]
|
||||
print(actual.text)
|
||||
# --8<-- [end:async_openai_embeddings]
|
||||
|
||||
@@ -72,8 +72,7 @@ async def test_ann_index_async():
|
||||
# --8<-- [end:create_ann_index_async]
|
||||
# --8<-- [start:vector_search_async]
|
||||
await (
|
||||
async_tbl.query()
|
||||
.nearest_to(np.random.random((32)))
|
||||
(await async_tbl.search(np.random.random((32))))
|
||||
.limit(2)
|
||||
.nprobes(20)
|
||||
.refine_factor(10)
|
||||
@@ -82,18 +81,14 @@ async def test_ann_index_async():
|
||||
# --8<-- [end:vector_search_async]
|
||||
# --8<-- [start:vector_search_async_with_filter]
|
||||
await (
|
||||
async_tbl.query()
|
||||
.nearest_to(np.random.random((32)))
|
||||
(await async_tbl.search(np.random.random((32))))
|
||||
.where("item != 'item 1141'")
|
||||
.to_pandas()
|
||||
)
|
||||
# --8<-- [end:vector_search_async_with_filter]
|
||||
# --8<-- [start:vector_search_async_with_select]
|
||||
await (
|
||||
async_tbl.query()
|
||||
.nearest_to(np.random.random((32)))
|
||||
.select(["vector"])
|
||||
.to_pandas()
|
||||
(await async_tbl.search(np.random.random((32)))).select(["vector"]).to_pandas()
|
||||
)
|
||||
# --8<-- [end:vector_search_async_with_select]
|
||||
|
||||
@@ -164,7 +159,7 @@ async def test_scalar_index_async():
|
||||
{"book_id": 3, "vector": [5.0, 6]},
|
||||
]
|
||||
async_tbl = await async_db.create_table("book_with_embeddings_async", data)
|
||||
(await async_tbl.query().where("book_id != 3").nearest_to([1, 2]).to_pandas())
|
||||
(await (await async_tbl.search([1, 2])).where("book_id != 3").to_pandas())
|
||||
# --8<-- [end:vector_search_with_scalar_index_async]
|
||||
# --8<-- [start:update_scalar_index_async]
|
||||
await async_tbl.add([{"vector": [7, 8], "book_id": 4}])
|
||||
|
||||
@@ -126,19 +126,17 @@ async def test_pandas_and_pyarrow_async():
|
||||
|
||||
query_vector = [100, 100]
|
||||
# Pandas DataFrame
|
||||
df = await async_tbl.query().nearest_to(query_vector).limit(1).to_pandas()
|
||||
df = await (await async_tbl.search(query_vector)).limit(1).to_pandas()
|
||||
print(df)
|
||||
# --8<-- [end:vector_search_async]
|
||||
# --8<-- [start:vector_search_with_filter_async]
|
||||
# Apply the filter via LanceDB
|
||||
results = (
|
||||
await async_tbl.query().nearest_to([100, 100]).where("price < 15").to_pandas()
|
||||
)
|
||||
results = await (await async_tbl.search([100, 100])).where("price < 15").to_pandas()
|
||||
assert len(results) == 1
|
||||
assert results["item"].iloc[0] == "foo"
|
||||
|
||||
# Apply the filter via Pandas
|
||||
df = results = await async_tbl.query().nearest_to([100, 100]).to_pandas()
|
||||
df = results = await (await async_tbl.search([100, 100])).to_pandas()
|
||||
results = df[df.price < 15]
|
||||
assert len(results) == 1
|
||||
assert results["item"].iloc[0] == "foo"
|
||||
@@ -188,3 +186,26 @@ def test_polars():
|
||||
# --8<-- [start:print_table_lazyform]
|
||||
print(ldf.first().collect())
|
||||
# --8<-- [end:print_table_lazyform]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_polars_async():
|
||||
uri = "data/sample-lancedb"
|
||||
db = await lancedb.connect_async(uri)
|
||||
|
||||
# --8<-- [start:create_table_polars_async]
|
||||
data = pl.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
table = await db.create_table("pl_table_async", data=data)
|
||||
# --8<-- [end:create_table_polars_async]
|
||||
# --8<-- [start:vector_search_polars_async]
|
||||
query = [3.0, 4.0]
|
||||
result = await (await table.search(query)).limit(1).to_polars()
|
||||
print(result)
|
||||
print(type(result))
|
||||
# --8<-- [end:vector_search_polars_async]
|
||||
|
||||
@@ -117,12 +117,11 @@ async def test_vector_search_async():
|
||||
for i, row in enumerate(np.random.random((10_000, 1536)).astype("float32"))
|
||||
]
|
||||
async_tbl = await async_db.create_table("vector_search_async", data=data)
|
||||
(await async_tbl.query().nearest_to(np.random.random((1536))).limit(10).to_list())
|
||||
(await (await async_tbl.search(np.random.random((1536)))).limit(10).to_list())
|
||||
# --8<-- [end:exhaustive_search_async]
|
||||
# --8<-- [start:exhaustive_search_async_cosine]
|
||||
(
|
||||
await async_tbl.query()
|
||||
.nearest_to(np.random.random((1536)))
|
||||
await (await async_tbl.search(np.random.random((1536))))
|
||||
.distance_type("cosine")
|
||||
.limit(10)
|
||||
.to_list()
|
||||
@@ -145,13 +144,13 @@ async def test_vector_search_async():
|
||||
async_tbl = await async_db.create_table("documents_async", data=data)
|
||||
# --8<-- [end:create_table_async_with_nested_schema]
|
||||
# --8<-- [start:search_result_async_as_pyarrow]
|
||||
await async_tbl.query().nearest_to(np.random.randn(1536)).to_arrow()
|
||||
await (await async_tbl.search(np.random.randn(1536))).to_arrow()
|
||||
# --8<-- [end:search_result_async_as_pyarrow]
|
||||
# --8<-- [start:search_result_async_as_pandas]
|
||||
await async_tbl.query().nearest_to(np.random.randn(1536)).to_pandas()
|
||||
await (await async_tbl.search(np.random.randn(1536))).to_pandas()
|
||||
# --8<-- [end:search_result_async_as_pandas]
|
||||
# --8<-- [start:search_result_async_as_list]
|
||||
await async_tbl.query().nearest_to(np.random.randn(1536)).to_list()
|
||||
await (await async_tbl.search(np.random.randn(1536))).to_list()
|
||||
# --8<-- [end:search_result_async_as_list]
|
||||
|
||||
|
||||
@@ -219,9 +218,7 @@ async def test_fts_native_async():
|
||||
|
||||
# async API uses our native FTS algorithm
|
||||
await async_tbl.create_index("text", config=FTS())
|
||||
await (
|
||||
async_tbl.query().nearest_to_text("puppy").select(["text"]).limit(10).to_list()
|
||||
)
|
||||
await (await async_tbl.search("puppy")).select(["text"]).limit(10).to_list()
|
||||
# [{'text': 'Frodo was a happy puppy', '_score': 0.6931471824645996}]
|
||||
# ...
|
||||
# --8<-- [end:basic_fts_async]
|
||||
@@ -235,18 +232,11 @@ async def test_fts_native_async():
|
||||
)
|
||||
# --8<-- [end:fts_config_folding_async]
|
||||
# --8<-- [start:fts_prefiltering_async]
|
||||
await (
|
||||
async_tbl.query()
|
||||
.nearest_to_text("puppy")
|
||||
.limit(10)
|
||||
.where("text='foo'")
|
||||
.to_list()
|
||||
)
|
||||
await (await async_tbl.search("puppy")).limit(10).where("text='foo'").to_list()
|
||||
# --8<-- [end:fts_prefiltering_async]
|
||||
# --8<-- [start:fts_postfiltering_async]
|
||||
await (
|
||||
async_tbl.query()
|
||||
.nearest_to_text("puppy")
|
||||
(await async_tbl.search("puppy"))
|
||||
.limit(10)
|
||||
.where("text='foo'")
|
||||
.postfilter()
|
||||
@@ -347,14 +337,8 @@ async def test_hybrid_search_async():
|
||||
# Create a fts index before the hybrid search
|
||||
await async_tbl.create_index("text", config=FTS())
|
||||
text_query = "flower moon"
|
||||
vector_query = embeddings.compute_query_embeddings(text_query)[0]
|
||||
# hybrid search with default re-ranker
|
||||
await (
|
||||
async_tbl.query()
|
||||
.nearest_to(vector_query)
|
||||
.nearest_to_text(text_query)
|
||||
.to_pandas()
|
||||
)
|
||||
await (await async_tbl.search("flower moon", query_type="hybrid")).to_pandas()
|
||||
# --8<-- [end:basic_hybrid_search_async]
|
||||
# --8<-- [start:hybrid_search_pass_vector_text_async]
|
||||
vector_query = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
|
||||
@@ -1,25 +1,35 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from typing import List, Union
|
||||
import unittest.mock as mock
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import lancedb
|
||||
from lancedb.index import IvfPq, FTS
|
||||
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
|
||||
from lancedb.db import AsyncConnection
|
||||
from lancedb.embeddings.base import TextEmbeddingFunction
|
||||
from lancedb.embeddings.registry import get_registry, register
|
||||
from lancedb.index import FTS, IvfPq
|
||||
import lancedb.pydantic
|
||||
import numpy as np
|
||||
import pandas.testing as tm
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.query import (
|
||||
AsyncFTSQuery,
|
||||
AsyncHybridQuery,
|
||||
AsyncQueryBase,
|
||||
AsyncVectorQuery,
|
||||
LanceVectorQueryBuilder,
|
||||
Query,
|
||||
)
|
||||
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
|
||||
from lancedb.table import AsyncTable, LanceTable
|
||||
from utils import exception_output
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -716,3 +726,101 @@ async def test_query_with_f16(tmp_path: Path):
|
||||
tbl = await db.create_table("test", df)
|
||||
results = await tbl.vector_search([np.float16(1), np.float16(2)]).to_pandas()
|
||||
assert len(results) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_search_auto(mem_db_async: AsyncConnection):
|
||||
nrows = 1000
|
||||
data = pa.table(
|
||||
{
|
||||
"text": [str(i) for i in range(nrows)],
|
||||
}
|
||||
)
|
||||
|
||||
@register("test2")
|
||||
class TestEmbedding(TextEmbeddingFunction):
|
||||
def ndims(self):
|
||||
return 4
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
vec = np.array([float(text) / 1000] * self.ndims())
|
||||
embeddings.append(vec)
|
||||
return embeddings
|
||||
|
||||
registry = get_registry()
|
||||
func = registry.get("test2").create()
|
||||
|
||||
class TestModel(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
tbl = await mem_db_async.create_table("test", data, schema=TestModel)
|
||||
|
||||
funcs = await tbl.embedding_functions()
|
||||
assert len(funcs) == 1
|
||||
|
||||
# No FTS or vector index
|
||||
# Search for vector -> vector query
|
||||
q = [0.1] * 4
|
||||
query = await tbl.search(q)
|
||||
assert isinstance(query, AsyncVectorQuery)
|
||||
|
||||
# Search for string -> vector query
|
||||
query = await tbl.search("0.1")
|
||||
assert isinstance(query, AsyncVectorQuery)
|
||||
|
||||
await tbl.create_index("text", config=FTS())
|
||||
|
||||
query = await tbl.search("0.1")
|
||||
assert isinstance(query, AsyncHybridQuery)
|
||||
|
||||
data_with_vecs = await tbl.to_arrow()
|
||||
data_with_vecs = data_with_vecs.replace_schema_metadata(None)
|
||||
tbl2 = await mem_db_async.create_table("test2", data_with_vecs)
|
||||
with pytest.raises(
|
||||
Exception,
|
||||
match=(
|
||||
"Cannot perform full text search unless an INVERTED index has "
|
||||
"been created"
|
||||
),
|
||||
):
|
||||
query = await (await tbl2.search("0.1")).to_arrow()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_search_specified(mem_db_async: AsyncConnection):
|
||||
nrows, ndims = 1000, 16
|
||||
data = pa.table(
|
||||
{
|
||||
"text": [str(i) for i in range(nrows)],
|
||||
"vector": pa.FixedSizeListArray.from_arrays(
|
||||
pc.random(nrows * ndims).cast(pa.float32()), ndims
|
||||
),
|
||||
}
|
||||
)
|
||||
table = await mem_db_async.create_table("test", data)
|
||||
await table.create_index("text", config=FTS())
|
||||
|
||||
# Validate that specifying fts, vector or hybrid gets the right query.
|
||||
q = [0.1] * ndims
|
||||
query = await table.search(q, query_type="vector")
|
||||
assert isinstance(query, AsyncVectorQuery)
|
||||
|
||||
query = await table.search("0.1", query_type="fts")
|
||||
assert isinstance(query, AsyncFTSQuery)
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown query type: 'foo'"):
|
||||
await table.search("0.1", query_type="foo")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Column 'vector' has no registered embedding function"
|
||||
) as e:
|
||||
await table.search("0.1", query_type="vector")
|
||||
|
||||
assert "No embedding functions are registered for any columns" in exception_output(
|
||||
e
|
||||
)
|
||||
|
||||
@@ -338,6 +338,7 @@ def test_query_sync_empty_query():
|
||||
"filter": "true",
|
||||
"vector": [],
|
||||
"columns": ["id"],
|
||||
"prefilter": False,
|
||||
"version": None,
|
||||
}
|
||||
|
||||
@@ -412,6 +413,7 @@ def test_query_sync_fts():
|
||||
"columns": [],
|
||||
},
|
||||
"k": 10,
|
||||
"prefilter": True,
|
||||
"vector": [],
|
||||
"version": None,
|
||||
}
|
||||
@@ -429,6 +431,7 @@ def test_query_sync_fts():
|
||||
},
|
||||
"k": 42,
|
||||
"vector": [],
|
||||
"prefilter": True,
|
||||
"with_row_id": True,
|
||||
"version": None,
|
||||
}
|
||||
@@ -455,6 +458,7 @@ def test_query_sync_hybrid():
|
||||
},
|
||||
"k": 42,
|
||||
"vector": [],
|
||||
"prefilter": True,
|
||||
"with_row_id": True,
|
||||
"version": None,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user