mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
feat(python): add search() method to async API (#2049)
Reviving #1966. Closes #1938 The `search()` method can apply embeddings for the user. This simplifies hybrid search, so instead of writing: ```python vector_query = embeddings.compute_query_embeddings("flower moon")[0] await ( async_tbl.query() .nearest_to(vector_query) .nearest_to_text("flower moon") .to_pandas() ) ``` You can write: ```python await (await async_tbl.search("flower moon", query_type="hybrid")).to_pandas() ``` Unfortunately, we had to do a double-await here because `search()` needs to be async. This is because it often needs to do IO to retrieve and run an embedding function.
This commit is contained in:
@@ -9,23 +9,50 @@ LanceDB supports [Polars](https://github.com/pola-rs/polars), a blazingly fast D
|
||||
|
||||
First, we connect to a LanceDB database.
|
||||
|
||||
=== "Sync API"
|
||||
|
||||
```py
|
||||
--8<-- "python/python/tests/docs/test_python.py:import-lancedb"
|
||||
--8<-- "python/python/tests/docs/test_python.py:connect_to_lancedb"
|
||||
```
|
||||
|
||||
=== "Async API"
|
||||
|
||||
```py
|
||||
--8<-- "python/python/tests/docs/test_python.py:import-lancedb"
|
||||
--8<-- "python/python/tests/docs/test_python.py:connect_to_lancedb_async"
|
||||
```
|
||||
|
||||
```py
|
||||
--8<-- "python/python/tests/docs/test_python.py:import-lancedb"
|
||||
--8<-- "python/python/tests/docs/test_python.py:connect_to_lancedb"
|
||||
```
|
||||
|
||||
We can load a Polars `DataFrame` to LanceDB directly.
|
||||
|
||||
```py
|
||||
--8<-- "python/python/tests/docs/test_python.py:import-polars"
|
||||
--8<-- "python/python/tests/docs/test_python.py:create_table_polars"
|
||||
```
|
||||
=== "Sync API"
|
||||
|
||||
```py
|
||||
--8<-- "python/python/tests/docs/test_python.py:import-polars"
|
||||
--8<-- "python/python/tests/docs/test_python.py:create_table_polars"
|
||||
```
|
||||
|
||||
=== "Async API"
|
||||
|
||||
```py
|
||||
--8<-- "python/python/tests/docs/test_python.py:import-polars"
|
||||
--8<-- "python/python/tests/docs/test_python.py:create_table_polars_async"
|
||||
```
|
||||
|
||||
We can now perform similarity search via the LanceDB Python API.
|
||||
|
||||
```py
|
||||
--8<-- "python/python/tests/docs/test_python.py:vector_search_polars"
|
||||
```
|
||||
=== "Sync API"
|
||||
|
||||
```py
|
||||
--8<-- "python/python/tests/docs/test_python.py:vector_search_polars"
|
||||
```
|
||||
|
||||
=== "Async API"
|
||||
|
||||
```py
|
||||
--8<-- "python/python/tests/docs/test_python.py:vector_search_polars_async"
|
||||
```
|
||||
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the `_distance` column which is the distance between the query
|
||||
@@ -112,4 +139,3 @@ The reason it's beneficial to not convert the LanceDB Table
|
||||
to a DataFrame is because the table can potentially be way larger
|
||||
than memory, and Polars LazyFrames allow us to work with such
|
||||
larger-than-memory datasets by not loading it into memory all at once.
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -30,6 +31,7 @@ from .dependencies import _check_for_pandas
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.fs as pa_fs
|
||||
import numpy as np
|
||||
from lance import LanceDataset
|
||||
from lance.dependencies import _check_for_hugging_face
|
||||
|
||||
@@ -39,6 +41,8 @@ from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS
|
||||
from .merge import LanceMergeInsertBuilder
|
||||
from .pydantic import LanceModel, model_to_dict
|
||||
from .query import (
|
||||
AsyncFTSQuery,
|
||||
AsyncHybridQuery,
|
||||
AsyncQuery,
|
||||
AsyncVectorQuery,
|
||||
LanceEmptyQueryBuilder,
|
||||
@@ -2702,6 +2706,19 @@ class AsyncTable:
|
||||
"""
|
||||
return await self._inner.schema()
|
||||
|
||||
async def embedding_functions(self) -> Dict[str, EmbeddingFunctionConfig]:
|
||||
"""
|
||||
Get the embedding functions for the table
|
||||
|
||||
Returns
|
||||
-------
|
||||
funcs: Dict[str, EmbeddingFunctionConfig]
|
||||
A mapping of the vector column to the embedding function
|
||||
or empty dict if not configured.
|
||||
"""
|
||||
schema = await self.schema()
|
||||
return EmbeddingFunctionRegistry.get_instance().parse_functions(schema.metadata)
|
||||
|
||||
async def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
"""
|
||||
Count the number of rows in the table.
|
||||
@@ -2931,6 +2948,234 @@ class AsyncTable:
|
||||
|
||||
return LanceMergeInsertBuilder(self, on)
|
||||
|
||||
@overload
|
||||
async def search(
|
||||
self,
|
||||
query: Optional[Union[str]] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: Literal["auto"] = ...,
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> Union[AsyncHybridQuery | AsyncFTSQuery | AsyncVectorQuery]: ...
|
||||
|
||||
@overload
|
||||
async def search(
|
||||
self,
|
||||
query: Optional[Union[str]] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: Literal["hybrid"] = ...,
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> AsyncHybridQuery: ...
|
||||
|
||||
@overload
|
||||
async def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, "PIL.Image.Image", Tuple]] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: Literal["auto"] = ...,
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> AsyncVectorQuery: ...
|
||||
|
||||
@overload
|
||||
async def search(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: Literal["fts"] = ...,
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> AsyncFTSQuery: ...
|
||||
|
||||
@overload
|
||||
async def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: Literal["vector"] = ...,
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> AsyncVectorQuery: ...
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: QueryType = "auto",
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Optional[Union[str, List[str]]] = None,
|
||||
) -> AsyncQuery:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector. We currently support [vector search][search]
|
||||
and [full-text search][experimental-full-text-search].
|
||||
|
||||
All query options are defined in [AsyncQuery][lancedb.query.AsyncQuery].
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: list/np.ndarray/str/PIL.Image.Image, default None
|
||||
The targetted vector to search for.
|
||||
|
||||
- *default None*.
|
||||
Acceptable types are: list, np.ndarray, PIL.Image.Image
|
||||
|
||||
- If None then the select/where/limit clauses are applied to filter
|
||||
the table
|
||||
vector_column_name: str, optional
|
||||
The name of the vector column to search.
|
||||
|
||||
The vector column needs to be a pyarrow fixed size list type
|
||||
|
||||
- If not specified then the vector column is inferred from
|
||||
the table schema
|
||||
|
||||
- If the table has multiple vector columns then the *vector_column_name*
|
||||
needs to be specified. Otherwise, an error is raised.
|
||||
query_type: str
|
||||
*default "auto"*.
|
||||
Acceptable types are: "vector", "fts", "hybrid", or "auto"
|
||||
|
||||
- If "auto" then the query type is inferred from the query;
|
||||
|
||||
- If `query` is a list/np.ndarray then the query type is
|
||||
"vector";
|
||||
|
||||
- If `query` is a PIL.Image.Image then either do vector search,
|
||||
or raise an error if no corresponding embedding function is found.
|
||||
|
||||
- If `query` is a string, then the query type is "vector" if the
|
||||
table has embedding functions else the query type is "fts"
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
A query builder object representing the query.
|
||||
"""
|
||||
|
||||
def is_embedding(query):
|
||||
return isinstance(query, (list, np.ndarray, pa.Array, pa.ChunkedArray))
|
||||
|
||||
async def get_embedding_func(
|
||||
vector_column_name: Optional[str],
|
||||
query_type: QueryType,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]],
|
||||
) -> Tuple[str, EmbeddingFunctionConfig]:
|
||||
schema = await self.schema()
|
||||
vector_column_name = infer_vector_column_name(
|
||||
schema=schema,
|
||||
query_type=query_type,
|
||||
query=query,
|
||||
vector_column_name=vector_column_name,
|
||||
)
|
||||
funcs = EmbeddingFunctionRegistry.get_instance().parse_functions(
|
||||
schema.metadata
|
||||
)
|
||||
func = funcs.get(vector_column_name)
|
||||
if func is None:
|
||||
error = ValueError(
|
||||
f"Column '{vector_column_name}' has no registered "
|
||||
"embedding function."
|
||||
)
|
||||
if len(funcs) > 0:
|
||||
add_note(
|
||||
error,
|
||||
"Embedding functions are registered for columns: "
|
||||
f"{list(funcs.keys())}",
|
||||
)
|
||||
else:
|
||||
add_note(
|
||||
error, "No embedding functions are registered for any columns."
|
||||
)
|
||||
raise error
|
||||
return vector_column_name, func
|
||||
|
||||
async def make_embedding(embedding, query):
|
||||
if embedding is not None:
|
||||
loop = asyncio.get_running_loop()
|
||||
# This function is likely to block, since it either calls an expensive
|
||||
# function or makes an HTTP request to an embeddings REST API.
|
||||
return (
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
embedding.function.compute_query_embeddings_with_retry,
|
||||
query,
|
||||
)
|
||||
)[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
if query_type == "auto":
|
||||
# Infer the query type.
|
||||
if is_embedding(query):
|
||||
vector_query = query
|
||||
query_type = "vector"
|
||||
elif isinstance(query, str):
|
||||
try:
|
||||
(
|
||||
indices,
|
||||
(vector_column_name, embedding_conf),
|
||||
) = await asyncio.gather(
|
||||
self.list_indices(),
|
||||
get_embedding_func(vector_column_name, "auto", query),
|
||||
)
|
||||
except ValueError as e:
|
||||
if "Column" in str(
|
||||
e
|
||||
) and "has no registered embedding function" in str(e):
|
||||
# If the column has no registered embedding function,
|
||||
# then it's an FTS query.
|
||||
query_type = "fts"
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
if embedding_conf is not None:
|
||||
vector_query = await make_embedding(embedding_conf, query)
|
||||
if any(
|
||||
i.columns[0] == embedding_conf.source_column
|
||||
and i.index_type == "FTS"
|
||||
for i in indices
|
||||
):
|
||||
query_type = "hybrid"
|
||||
else:
|
||||
query_type = "vector"
|
||||
else:
|
||||
query_type = "fts"
|
||||
else:
|
||||
# it's an image or something else embeddable.
|
||||
query_type = "vector"
|
||||
elif query_type == "vector":
|
||||
if is_embedding(query):
|
||||
vector_query = query
|
||||
else:
|
||||
vector_column_name, embedding_conf = await get_embedding_func(
|
||||
vector_column_name, query_type, query
|
||||
)
|
||||
vector_query = await make_embedding(embedding_conf, query)
|
||||
elif query_type == "hybrid":
|
||||
if is_embedding(query):
|
||||
raise ValueError("Hybrid search requires a text query")
|
||||
else:
|
||||
vector_column_name, embedding_conf = await get_embedding_func(
|
||||
vector_column_name, query_type, query
|
||||
)
|
||||
vector_query = await make_embedding(embedding_conf, query)
|
||||
|
||||
if query_type == "vector":
|
||||
builder = self.query().nearest_to(vector_query)
|
||||
if vector_column_name:
|
||||
builder = builder.column(vector_column_name)
|
||||
return builder
|
||||
elif query_type == "fts":
|
||||
return self.query().nearest_to_text(query, columns=fts_columns or [])
|
||||
elif query_type == "hybrid":
|
||||
builder = self.query().nearest_to(vector_query)
|
||||
if vector_column_name:
|
||||
builder = builder.column(vector_column_name)
|
||||
return builder.nearest_to_text(query, columns=fts_columns or [])
|
||||
else:
|
||||
raise ValueError(f"Unknown query type: '{query_type}'")
|
||||
|
||||
def vector_search(
|
||||
self,
|
||||
query_vector: Union[VEC, Tuple],
|
||||
|
||||
@@ -75,6 +75,6 @@ async def test_binary_vector_async():
|
||||
|
||||
query = np.random.randint(0, 2, size=256)
|
||||
packed_query = np.packbits(query)
|
||||
await tbl.query().nearest_to(packed_query).distance_type("hamming").to_arrow()
|
||||
await (await tbl.search(packed_query)).distance_type("hamming").to_arrow()
|
||||
# --8<-- [end:async_binary_vector]
|
||||
await db.drop_table("my_binary_vectors")
|
||||
|
||||
@@ -53,13 +53,13 @@ async def test_binary_vector_async():
|
||||
query = np.random.random(256)
|
||||
|
||||
# Search for the vectors within the range of [0.1, 0.5)
|
||||
await tbl.query().nearest_to(query).distance_range(0.1, 0.5).to_arrow()
|
||||
await (await tbl.search(query)).distance_range(0.1, 0.5).to_arrow()
|
||||
|
||||
# Search for the vectors with the distance less than 0.5
|
||||
await tbl.query().nearest_to(query).distance_range(upper_bound=0.5).to_arrow()
|
||||
await (await tbl.search(query)).distance_range(upper_bound=0.5).to_arrow()
|
||||
|
||||
# Search for the vectors with the distance greater or equal to 0.1
|
||||
await tbl.query().nearest_to(query).distance_range(lower_bound=0.1).to_arrow()
|
||||
await (await tbl.search(query)).distance_range(lower_bound=0.1).to_arrow()
|
||||
|
||||
# --8<-- [end:async_distance_range]
|
||||
await db.drop_table("my_table")
|
||||
|
||||
@@ -28,3 +28,24 @@ def test_embeddings_openai():
|
||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||
print(actual.text)
|
||||
# --8<-- [end:openai_embeddings]
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.asyncio
|
||||
async def test_embeddings_openai_async():
|
||||
uri = "memory://"
|
||||
# --8<-- [start:async_openai_embeddings]
|
||||
db = await lancedb.connect_async(uri)
|
||||
func = get_registry().get("openai").create(name="text-embedding-ada-002")
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = await db.create_table("words", schema=Words, mode="overwrite")
|
||||
await table.add([{"text": "hello world"}, {"text": "goodbye world"}])
|
||||
|
||||
query = "greetings"
|
||||
actual = await (await table.search(query)).limit(1).to_pydantic(Words)[0]
|
||||
print(actual.text)
|
||||
# --8<-- [end:async_openai_embeddings]
|
||||
|
||||
@@ -72,8 +72,7 @@ async def test_ann_index_async():
|
||||
# --8<-- [end:create_ann_index_async]
|
||||
# --8<-- [start:vector_search_async]
|
||||
await (
|
||||
async_tbl.query()
|
||||
.nearest_to(np.random.random((32)))
|
||||
(await async_tbl.search(np.random.random((32))))
|
||||
.limit(2)
|
||||
.nprobes(20)
|
||||
.refine_factor(10)
|
||||
@@ -82,18 +81,14 @@ async def test_ann_index_async():
|
||||
# --8<-- [end:vector_search_async]
|
||||
# --8<-- [start:vector_search_async_with_filter]
|
||||
await (
|
||||
async_tbl.query()
|
||||
.nearest_to(np.random.random((32)))
|
||||
(await async_tbl.search(np.random.random((32))))
|
||||
.where("item != 'item 1141'")
|
||||
.to_pandas()
|
||||
)
|
||||
# --8<-- [end:vector_search_async_with_filter]
|
||||
# --8<-- [start:vector_search_async_with_select]
|
||||
await (
|
||||
async_tbl.query()
|
||||
.nearest_to(np.random.random((32)))
|
||||
.select(["vector"])
|
||||
.to_pandas()
|
||||
(await async_tbl.search(np.random.random((32)))).select(["vector"]).to_pandas()
|
||||
)
|
||||
# --8<-- [end:vector_search_async_with_select]
|
||||
|
||||
@@ -164,7 +159,7 @@ async def test_scalar_index_async():
|
||||
{"book_id": 3, "vector": [5.0, 6]},
|
||||
]
|
||||
async_tbl = await async_db.create_table("book_with_embeddings_async", data)
|
||||
(await async_tbl.query().where("book_id != 3").nearest_to([1, 2]).to_pandas())
|
||||
(await (await async_tbl.search([1, 2])).where("book_id != 3").to_pandas())
|
||||
# --8<-- [end:vector_search_with_scalar_index_async]
|
||||
# --8<-- [start:update_scalar_index_async]
|
||||
await async_tbl.add([{"vector": [7, 8], "book_id": 4}])
|
||||
|
||||
@@ -126,19 +126,17 @@ async def test_pandas_and_pyarrow_async():
|
||||
|
||||
query_vector = [100, 100]
|
||||
# Pandas DataFrame
|
||||
df = await async_tbl.query().nearest_to(query_vector).limit(1).to_pandas()
|
||||
df = await (await async_tbl.search(query_vector)).limit(1).to_pandas()
|
||||
print(df)
|
||||
# --8<-- [end:vector_search_async]
|
||||
# --8<-- [start:vector_search_with_filter_async]
|
||||
# Apply the filter via LanceDB
|
||||
results = (
|
||||
await async_tbl.query().nearest_to([100, 100]).where("price < 15").to_pandas()
|
||||
)
|
||||
results = await (await async_tbl.search([100, 100])).where("price < 15").to_pandas()
|
||||
assert len(results) == 1
|
||||
assert results["item"].iloc[0] == "foo"
|
||||
|
||||
# Apply the filter via Pandas
|
||||
df = results = await async_tbl.query().nearest_to([100, 100]).to_pandas()
|
||||
df = results = await (await async_tbl.search([100, 100])).to_pandas()
|
||||
results = df[df.price < 15]
|
||||
assert len(results) == 1
|
||||
assert results["item"].iloc[0] == "foo"
|
||||
@@ -188,3 +186,26 @@ def test_polars():
|
||||
# --8<-- [start:print_table_lazyform]
|
||||
print(ldf.first().collect())
|
||||
# --8<-- [end:print_table_lazyform]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_polars_async():
|
||||
uri = "data/sample-lancedb"
|
||||
db = await lancedb.connect_async(uri)
|
||||
|
||||
# --8<-- [start:create_table_polars_async]
|
||||
data = pl.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
table = await db.create_table("pl_table_async", data=data)
|
||||
# --8<-- [end:create_table_polars_async]
|
||||
# --8<-- [start:vector_search_polars_async]
|
||||
query = [3.0, 4.0]
|
||||
result = await (await table.search(query)).limit(1).to_polars()
|
||||
print(result)
|
||||
print(type(result))
|
||||
# --8<-- [end:vector_search_polars_async]
|
||||
|
||||
@@ -117,12 +117,11 @@ async def test_vector_search_async():
|
||||
for i, row in enumerate(np.random.random((10_000, 1536)).astype("float32"))
|
||||
]
|
||||
async_tbl = await async_db.create_table("vector_search_async", data=data)
|
||||
(await async_tbl.query().nearest_to(np.random.random((1536))).limit(10).to_list())
|
||||
(await (await async_tbl.search(np.random.random((1536)))).limit(10).to_list())
|
||||
# --8<-- [end:exhaustive_search_async]
|
||||
# --8<-- [start:exhaustive_search_async_cosine]
|
||||
(
|
||||
await async_tbl.query()
|
||||
.nearest_to(np.random.random((1536)))
|
||||
await (await async_tbl.search(np.random.random((1536))))
|
||||
.distance_type("cosine")
|
||||
.limit(10)
|
||||
.to_list()
|
||||
@@ -145,13 +144,13 @@ async def test_vector_search_async():
|
||||
async_tbl = await async_db.create_table("documents_async", data=data)
|
||||
# --8<-- [end:create_table_async_with_nested_schema]
|
||||
# --8<-- [start:search_result_async_as_pyarrow]
|
||||
await async_tbl.query().nearest_to(np.random.randn(1536)).to_arrow()
|
||||
await (await async_tbl.search(np.random.randn(1536))).to_arrow()
|
||||
# --8<-- [end:search_result_async_as_pyarrow]
|
||||
# --8<-- [start:search_result_async_as_pandas]
|
||||
await async_tbl.query().nearest_to(np.random.randn(1536)).to_pandas()
|
||||
await (await async_tbl.search(np.random.randn(1536))).to_pandas()
|
||||
# --8<-- [end:search_result_async_as_pandas]
|
||||
# --8<-- [start:search_result_async_as_list]
|
||||
await async_tbl.query().nearest_to(np.random.randn(1536)).to_list()
|
||||
await (await async_tbl.search(np.random.randn(1536))).to_list()
|
||||
# --8<-- [end:search_result_async_as_list]
|
||||
|
||||
|
||||
@@ -219,9 +218,7 @@ async def test_fts_native_async():
|
||||
|
||||
# async API uses our native FTS algorithm
|
||||
await async_tbl.create_index("text", config=FTS())
|
||||
await (
|
||||
async_tbl.query().nearest_to_text("puppy").select(["text"]).limit(10).to_list()
|
||||
)
|
||||
await (await async_tbl.search("puppy")).select(["text"]).limit(10).to_list()
|
||||
# [{'text': 'Frodo was a happy puppy', '_score': 0.6931471824645996}]
|
||||
# ...
|
||||
# --8<-- [end:basic_fts_async]
|
||||
@@ -235,18 +232,11 @@ async def test_fts_native_async():
|
||||
)
|
||||
# --8<-- [end:fts_config_folding_async]
|
||||
# --8<-- [start:fts_prefiltering_async]
|
||||
await (
|
||||
async_tbl.query()
|
||||
.nearest_to_text("puppy")
|
||||
.limit(10)
|
||||
.where("text='foo'")
|
||||
.to_list()
|
||||
)
|
||||
await (await async_tbl.search("puppy")).limit(10).where("text='foo'").to_list()
|
||||
# --8<-- [end:fts_prefiltering_async]
|
||||
# --8<-- [start:fts_postfiltering_async]
|
||||
await (
|
||||
async_tbl.query()
|
||||
.nearest_to_text("puppy")
|
||||
(await async_tbl.search("puppy"))
|
||||
.limit(10)
|
||||
.where("text='foo'")
|
||||
.postfilter()
|
||||
@@ -347,14 +337,8 @@ async def test_hybrid_search_async():
|
||||
# Create a fts index before the hybrid search
|
||||
await async_tbl.create_index("text", config=FTS())
|
||||
text_query = "flower moon"
|
||||
vector_query = embeddings.compute_query_embeddings(text_query)[0]
|
||||
# hybrid search with default re-ranker
|
||||
await (
|
||||
async_tbl.query()
|
||||
.nearest_to(vector_query)
|
||||
.nearest_to_text(text_query)
|
||||
.to_pandas()
|
||||
)
|
||||
await (await async_tbl.search("flower moon", query_type="hybrid")).to_pandas()
|
||||
# --8<-- [end:basic_hybrid_search_async]
|
||||
# --8<-- [start:hybrid_search_pass_vector_text_async]
|
||||
vector_query = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
|
||||
@@ -1,25 +1,35 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from typing import List, Union
|
||||
import unittest.mock as mock
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import lancedb
|
||||
from lancedb.index import IvfPq, FTS
|
||||
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
|
||||
from lancedb.db import AsyncConnection
|
||||
from lancedb.embeddings.base import TextEmbeddingFunction
|
||||
from lancedb.embeddings.registry import get_registry, register
|
||||
from lancedb.index import FTS, IvfPq
|
||||
import lancedb.pydantic
|
||||
import numpy as np
|
||||
import pandas.testing as tm
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.query import (
|
||||
AsyncFTSQuery,
|
||||
AsyncHybridQuery,
|
||||
AsyncQueryBase,
|
||||
AsyncVectorQuery,
|
||||
LanceVectorQueryBuilder,
|
||||
Query,
|
||||
)
|
||||
from lancedb.rerankers.cross_encoder import CrossEncoderReranker
|
||||
from lancedb.table import AsyncTable, LanceTable
|
||||
from utils import exception_output
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -716,3 +726,101 @@ async def test_query_with_f16(tmp_path: Path):
|
||||
tbl = await db.create_table("test", df)
|
||||
results = await tbl.vector_search([np.float16(1), np.float16(2)]).to_pandas()
|
||||
assert len(results) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_search_auto(mem_db_async: AsyncConnection):
|
||||
nrows = 1000
|
||||
data = pa.table(
|
||||
{
|
||||
"text": [str(i) for i in range(nrows)],
|
||||
}
|
||||
)
|
||||
|
||||
@register("test2")
|
||||
class TestEmbedding(TextEmbeddingFunction):
|
||||
def ndims(self):
|
||||
return 4
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
vec = np.array([float(text) / 1000] * self.ndims())
|
||||
embeddings.append(vec)
|
||||
return embeddings
|
||||
|
||||
registry = get_registry()
|
||||
func = registry.get("test2").create()
|
||||
|
||||
class TestModel(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
tbl = await mem_db_async.create_table("test", data, schema=TestModel)
|
||||
|
||||
funcs = await tbl.embedding_functions()
|
||||
assert len(funcs) == 1
|
||||
|
||||
# No FTS or vector index
|
||||
# Search for vector -> vector query
|
||||
q = [0.1] * 4
|
||||
query = await tbl.search(q)
|
||||
assert isinstance(query, AsyncVectorQuery)
|
||||
|
||||
# Search for string -> vector query
|
||||
query = await tbl.search("0.1")
|
||||
assert isinstance(query, AsyncVectorQuery)
|
||||
|
||||
await tbl.create_index("text", config=FTS())
|
||||
|
||||
query = await tbl.search("0.1")
|
||||
assert isinstance(query, AsyncHybridQuery)
|
||||
|
||||
data_with_vecs = await tbl.to_arrow()
|
||||
data_with_vecs = data_with_vecs.replace_schema_metadata(None)
|
||||
tbl2 = await mem_db_async.create_table("test2", data_with_vecs)
|
||||
with pytest.raises(
|
||||
Exception,
|
||||
match=(
|
||||
"Cannot perform full text search unless an INVERTED index has "
|
||||
"been created"
|
||||
),
|
||||
):
|
||||
query = await (await tbl2.search("0.1")).to_arrow()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_search_specified(mem_db_async: AsyncConnection):
|
||||
nrows, ndims = 1000, 16
|
||||
data = pa.table(
|
||||
{
|
||||
"text": [str(i) for i in range(nrows)],
|
||||
"vector": pa.FixedSizeListArray.from_arrays(
|
||||
pc.random(nrows * ndims).cast(pa.float32()), ndims
|
||||
),
|
||||
}
|
||||
)
|
||||
table = await mem_db_async.create_table("test", data)
|
||||
await table.create_index("text", config=FTS())
|
||||
|
||||
# Validate that specifying fts, vector or hybrid gets the right query.
|
||||
q = [0.1] * ndims
|
||||
query = await table.search(q, query_type="vector")
|
||||
assert isinstance(query, AsyncVectorQuery)
|
||||
|
||||
query = await table.search("0.1", query_type="fts")
|
||||
assert isinstance(query, AsyncFTSQuery)
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown query type: 'foo'"):
|
||||
await table.search("0.1", query_type="foo")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Column 'vector' has no registered embedding function"
|
||||
) as e:
|
||||
await table.search("0.1", query_type="vector")
|
||||
|
||||
assert "No embedding functions are registered for any columns" in exception_output(
|
||||
e
|
||||
)
|
||||
|
||||
@@ -338,6 +338,7 @@ def test_query_sync_empty_query():
|
||||
"filter": "true",
|
||||
"vector": [],
|
||||
"columns": ["id"],
|
||||
"prefilter": False,
|
||||
"version": None,
|
||||
}
|
||||
|
||||
@@ -412,6 +413,7 @@ def test_query_sync_fts():
|
||||
"columns": [],
|
||||
},
|
||||
"k": 10,
|
||||
"prefilter": True,
|
||||
"vector": [],
|
||||
"version": None,
|
||||
}
|
||||
@@ -429,6 +431,7 @@ def test_query_sync_fts():
|
||||
},
|
||||
"k": 42,
|
||||
"vector": [],
|
||||
"prefilter": True,
|
||||
"with_row_id": True,
|
||||
"version": None,
|
||||
}
|
||||
@@ -455,6 +458,7 @@ def test_query_sync_hybrid():
|
||||
},
|
||||
"k": 42,
|
||||
"vector": [],
|
||||
"prefilter": True,
|
||||
"with_row_id": True,
|
||||
"version": None,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user