diff --git a/docs/src/python/polars_arrow.md b/docs/src/python/polars_arrow.md index 9d6d8c33..51a281a3 100644 --- a/docs/src/python/polars_arrow.md +++ b/docs/src/python/polars_arrow.md @@ -9,23 +9,50 @@ LanceDB supports [Polars](https://github.com/pola-rs/polars), a blazingly fast D First, we connect to a LanceDB database. +=== "Sync API" + + ```py + --8<-- "python/python/tests/docs/test_python.py:import-lancedb" + --8<-- "python/python/tests/docs/test_python.py:connect_to_lancedb" + ``` + +=== "Async API" + + ```py + --8<-- "python/python/tests/docs/test_python.py:import-lancedb" + --8<-- "python/python/tests/docs/test_python.py:connect_to_lancedb_async" + ``` -```py ---8<-- "python/python/tests/docs/test_python.py:import-lancedb" ---8<-- "python/python/tests/docs/test_python.py:connect_to_lancedb" -``` We can load a Polars `DataFrame` to LanceDB directly. -```py ---8<-- "python/python/tests/docs/test_python.py:import-polars" ---8<-- "python/python/tests/docs/test_python.py:create_table_polars" -``` +=== "Sync API" + + ```py + --8<-- "python/python/tests/docs/test_python.py:import-polars" + --8<-- "python/python/tests/docs/test_python.py:create_table_polars" + ``` + +=== "Async API" + + ```py + --8<-- "python/python/tests/docs/test_python.py:import-polars" + --8<-- "python/python/tests/docs/test_python.py:create_table_polars_async" + ``` + We can now perform similarity search via the LanceDB Python API. -```py ---8<-- "python/python/tests/docs/test_python.py:vector_search_polars" -``` +=== "Sync API" + + ```py + --8<-- "python/python/tests/docs/test_python.py:vector_search_polars" + ``` + +=== "Async API" + + ```py + --8<-- "python/python/tests/docs/test_python.py:vector_search_polars_async" + ``` In addition to the selected columns, LanceDB also returns a vector and also the `_distance` column which is the distance between the query @@ -112,4 +139,3 @@ The reason it's beneficial to not convert the LanceDB Table to a DataFrame is because the table can potentially be way larger than memory, and Polars LazyFrames allow us to work with such larger-than-memory datasets by not loading it into memory all at once. - diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 48513e7f..a752deaa 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -3,6 +3,7 @@ from __future__ import annotations +import asyncio import inspect import warnings from abc import ABC, abstractmethod @@ -30,6 +31,7 @@ from .dependencies import _check_for_pandas import pyarrow as pa import pyarrow.compute as pc import pyarrow.fs as pa_fs +import numpy as np from lance import LanceDataset from lance.dependencies import _check_for_hugging_face @@ -39,6 +41,8 @@ from .index import BTree, IvfFlat, IvfPq, Bitmap, LabelList, HnswPq, HnswSq, FTS from .merge import LanceMergeInsertBuilder from .pydantic import LanceModel, model_to_dict from .query import ( + AsyncFTSQuery, + AsyncHybridQuery, AsyncQuery, AsyncVectorQuery, LanceEmptyQueryBuilder, @@ -2702,6 +2706,19 @@ class AsyncTable: """ return await self._inner.schema() + async def embedding_functions(self) -> Dict[str, EmbeddingFunctionConfig]: + """ + Get the embedding functions for the table + + Returns + ------- + funcs: Dict[str, EmbeddingFunctionConfig] + A mapping of the vector column to the embedding function + or empty dict if not configured. + """ + schema = await self.schema() + return EmbeddingFunctionRegistry.get_instance().parse_functions(schema.metadata) + async def count_rows(self, filter: Optional[str] = None) -> int: """ Count the number of rows in the table. @@ -2931,6 +2948,234 @@ class AsyncTable: return LanceMergeInsertBuilder(self, on) + @overload + async def search( + self, + query: Optional[Union[str]] = None, + vector_column_name: Optional[str] = None, + query_type: Literal["auto"] = ..., + ordering_field_name: Optional[str] = None, + fts_columns: Optional[Union[str, List[str]]] = None, + ) -> Union[AsyncHybridQuery | AsyncFTSQuery | AsyncVectorQuery]: ... + + @overload + async def search( + self, + query: Optional[Union[str]] = None, + vector_column_name: Optional[str] = None, + query_type: Literal["hybrid"] = ..., + ordering_field_name: Optional[str] = None, + fts_columns: Optional[Union[str, List[str]]] = None, + ) -> AsyncHybridQuery: ... + + @overload + async def search( + self, + query: Optional[Union[VEC, "PIL.Image.Image", Tuple]] = None, + vector_column_name: Optional[str] = None, + query_type: Literal["auto"] = ..., + ordering_field_name: Optional[str] = None, + fts_columns: Optional[Union[str, List[str]]] = None, + ) -> AsyncVectorQuery: ... + + @overload + async def search( + self, + query: Optional[str] = None, + vector_column_name: Optional[str] = None, + query_type: Literal["fts"] = ..., + ordering_field_name: Optional[str] = None, + fts_columns: Optional[Union[str, List[str]]] = None, + ) -> AsyncFTSQuery: ... + + @overload + async def search( + self, + query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, + vector_column_name: Optional[str] = None, + query_type: Literal["vector"] = ..., + ordering_field_name: Optional[str] = None, + fts_columns: Optional[Union[str, List[str]]] = None, + ) -> AsyncVectorQuery: ... + + async def search( + self, + query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, + vector_column_name: Optional[str] = None, + query_type: QueryType = "auto", + ordering_field_name: Optional[str] = None, + fts_columns: Optional[Union[str, List[str]]] = None, + ) -> AsyncQuery: + """Create a search query to find the nearest neighbors + of the given query vector. We currently support [vector search][search] + and [full-text search][experimental-full-text-search]. + + All query options are defined in [AsyncQuery][lancedb.query.AsyncQuery]. + + Parameters + ---------- + query: list/np.ndarray/str/PIL.Image.Image, default None + The targetted vector to search for. + + - *default None*. + Acceptable types are: list, np.ndarray, PIL.Image.Image + + - If None then the select/where/limit clauses are applied to filter + the table + vector_column_name: str, optional + The name of the vector column to search. + + The vector column needs to be a pyarrow fixed size list type + + - If not specified then the vector column is inferred from + the table schema + + - If the table has multiple vector columns then the *vector_column_name* + needs to be specified. Otherwise, an error is raised. + query_type: str + *default "auto"*. + Acceptable types are: "vector", "fts", "hybrid", or "auto" + + - If "auto" then the query type is inferred from the query; + + - If `query` is a list/np.ndarray then the query type is + "vector"; + + - If `query` is a PIL.Image.Image then either do vector search, + or raise an error if no corresponding embedding function is found. + + - If `query` is a string, then the query type is "vector" if the + table has embedding functions else the query type is "fts" + + Returns + ------- + LanceQueryBuilder + A query builder object representing the query. + """ + + def is_embedding(query): + return isinstance(query, (list, np.ndarray, pa.Array, pa.ChunkedArray)) + + async def get_embedding_func( + vector_column_name: Optional[str], + query_type: QueryType, + query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]], + ) -> Tuple[str, EmbeddingFunctionConfig]: + schema = await self.schema() + vector_column_name = infer_vector_column_name( + schema=schema, + query_type=query_type, + query=query, + vector_column_name=vector_column_name, + ) + funcs = EmbeddingFunctionRegistry.get_instance().parse_functions( + schema.metadata + ) + func = funcs.get(vector_column_name) + if func is None: + error = ValueError( + f"Column '{vector_column_name}' has no registered " + "embedding function." + ) + if len(funcs) > 0: + add_note( + error, + "Embedding functions are registered for columns: " + f"{list(funcs.keys())}", + ) + else: + add_note( + error, "No embedding functions are registered for any columns." + ) + raise error + return vector_column_name, func + + async def make_embedding(embedding, query): + if embedding is not None: + loop = asyncio.get_running_loop() + # This function is likely to block, since it either calls an expensive + # function or makes an HTTP request to an embeddings REST API. + return ( + await loop.run_in_executor( + None, + embedding.function.compute_query_embeddings_with_retry, + query, + ) + )[0] + else: + return None + + if query_type == "auto": + # Infer the query type. + if is_embedding(query): + vector_query = query + query_type = "vector" + elif isinstance(query, str): + try: + ( + indices, + (vector_column_name, embedding_conf), + ) = await asyncio.gather( + self.list_indices(), + get_embedding_func(vector_column_name, "auto", query), + ) + except ValueError as e: + if "Column" in str( + e + ) and "has no registered embedding function" in str(e): + # If the column has no registered embedding function, + # then it's an FTS query. + query_type = "fts" + else: + raise e + else: + if embedding_conf is not None: + vector_query = await make_embedding(embedding_conf, query) + if any( + i.columns[0] == embedding_conf.source_column + and i.index_type == "FTS" + for i in indices + ): + query_type = "hybrid" + else: + query_type = "vector" + else: + query_type = "fts" + else: + # it's an image or something else embeddable. + query_type = "vector" + elif query_type == "vector": + if is_embedding(query): + vector_query = query + else: + vector_column_name, embedding_conf = await get_embedding_func( + vector_column_name, query_type, query + ) + vector_query = await make_embedding(embedding_conf, query) + elif query_type == "hybrid": + if is_embedding(query): + raise ValueError("Hybrid search requires a text query") + else: + vector_column_name, embedding_conf = await get_embedding_func( + vector_column_name, query_type, query + ) + vector_query = await make_embedding(embedding_conf, query) + + if query_type == "vector": + builder = self.query().nearest_to(vector_query) + if vector_column_name: + builder = builder.column(vector_column_name) + return builder + elif query_type == "fts": + return self.query().nearest_to_text(query, columns=fts_columns or []) + elif query_type == "hybrid": + builder = self.query().nearest_to(vector_query) + if vector_column_name: + builder = builder.column(vector_column_name) + return builder.nearest_to_text(query, columns=fts_columns or []) + else: + raise ValueError(f"Unknown query type: '{query_type}'") + def vector_search( self, query_vector: Union[VEC, Tuple], diff --git a/python/python/tests/docs/test_binary_vector.py b/python/python/tests/docs/test_binary_vector.py index ed03cadd..5691bbf6 100644 --- a/python/python/tests/docs/test_binary_vector.py +++ b/python/python/tests/docs/test_binary_vector.py @@ -75,6 +75,6 @@ async def test_binary_vector_async(): query = np.random.randint(0, 2, size=256) packed_query = np.packbits(query) - await tbl.query().nearest_to(packed_query).distance_type("hamming").to_arrow() + await (await tbl.search(packed_query)).distance_type("hamming").to_arrow() # --8<-- [end:async_binary_vector] await db.drop_table("my_binary_vectors") diff --git a/python/python/tests/docs/test_distance_range.py b/python/python/tests/docs/test_distance_range.py index a405c682..26f2ca4c 100644 --- a/python/python/tests/docs/test_distance_range.py +++ b/python/python/tests/docs/test_distance_range.py @@ -53,13 +53,13 @@ async def test_binary_vector_async(): query = np.random.random(256) # Search for the vectors within the range of [0.1, 0.5) - await tbl.query().nearest_to(query).distance_range(0.1, 0.5).to_arrow() + await (await tbl.search(query)).distance_range(0.1, 0.5).to_arrow() # Search for the vectors with the distance less than 0.5 - await tbl.query().nearest_to(query).distance_range(upper_bound=0.5).to_arrow() + await (await tbl.search(query)).distance_range(upper_bound=0.5).to_arrow() # Search for the vectors with the distance greater or equal to 0.1 - await tbl.query().nearest_to(query).distance_range(lower_bound=0.1).to_arrow() + await (await tbl.search(query)).distance_range(lower_bound=0.1).to_arrow() # --8<-- [end:async_distance_range] await db.drop_table("my_table") diff --git a/python/python/tests/docs/test_embeddings_optional.py b/python/python/tests/docs/test_embeddings_optional.py index 89eababb..5197a88a 100644 --- a/python/python/tests/docs/test_embeddings_optional.py +++ b/python/python/tests/docs/test_embeddings_optional.py @@ -28,3 +28,24 @@ def test_embeddings_openai(): actual = table.search(query).limit(1).to_pydantic(Words)[0] print(actual.text) # --8<-- [end:openai_embeddings] + + +@pytest.mark.slow +@pytest.mark.asyncio +async def test_embeddings_openai_async(): + uri = "memory://" + # --8<-- [start:async_openai_embeddings] + db = await lancedb.connect_async(uri) + func = get_registry().get("openai").create(name="text-embedding-ada-002") + + class Words(LanceModel): + text: str = func.SourceField() + vector: Vector(func.ndims()) = func.VectorField() + + table = await db.create_table("words", schema=Words, mode="overwrite") + await table.add([{"text": "hello world"}, {"text": "goodbye world"}]) + + query = "greetings" + actual = await (await table.search(query)).limit(1).to_pydantic(Words)[0] + print(actual.text) + # --8<-- [end:async_openai_embeddings] diff --git a/python/python/tests/docs/test_guide_index.py b/python/python/tests/docs/test_guide_index.py index c9dd5a86..77729493 100644 --- a/python/python/tests/docs/test_guide_index.py +++ b/python/python/tests/docs/test_guide_index.py @@ -72,8 +72,7 @@ async def test_ann_index_async(): # --8<-- [end:create_ann_index_async] # --8<-- [start:vector_search_async] await ( - async_tbl.query() - .nearest_to(np.random.random((32))) + (await async_tbl.search(np.random.random((32)))) .limit(2) .nprobes(20) .refine_factor(10) @@ -82,18 +81,14 @@ async def test_ann_index_async(): # --8<-- [end:vector_search_async] # --8<-- [start:vector_search_async_with_filter] await ( - async_tbl.query() - .nearest_to(np.random.random((32))) + (await async_tbl.search(np.random.random((32)))) .where("item != 'item 1141'") .to_pandas() ) # --8<-- [end:vector_search_async_with_filter] # --8<-- [start:vector_search_async_with_select] await ( - async_tbl.query() - .nearest_to(np.random.random((32))) - .select(["vector"]) - .to_pandas() + (await async_tbl.search(np.random.random((32)))).select(["vector"]).to_pandas() ) # --8<-- [end:vector_search_async_with_select] @@ -164,7 +159,7 @@ async def test_scalar_index_async(): {"book_id": 3, "vector": [5.0, 6]}, ] async_tbl = await async_db.create_table("book_with_embeddings_async", data) - (await async_tbl.query().where("book_id != 3").nearest_to([1, 2]).to_pandas()) + (await (await async_tbl.search([1, 2])).where("book_id != 3").to_pandas()) # --8<-- [end:vector_search_with_scalar_index_async] # --8<-- [start:update_scalar_index_async] await async_tbl.add([{"vector": [7, 8], "book_id": 4}]) diff --git a/python/python/tests/docs/test_python.py b/python/python/tests/docs/test_python.py index a9dfbf95..d8f6a732 100644 --- a/python/python/tests/docs/test_python.py +++ b/python/python/tests/docs/test_python.py @@ -126,19 +126,17 @@ async def test_pandas_and_pyarrow_async(): query_vector = [100, 100] # Pandas DataFrame - df = await async_tbl.query().nearest_to(query_vector).limit(1).to_pandas() + df = await (await async_tbl.search(query_vector)).limit(1).to_pandas() print(df) # --8<-- [end:vector_search_async] # --8<-- [start:vector_search_with_filter_async] # Apply the filter via LanceDB - results = ( - await async_tbl.query().nearest_to([100, 100]).where("price < 15").to_pandas() - ) + results = await (await async_tbl.search([100, 100])).where("price < 15").to_pandas() assert len(results) == 1 assert results["item"].iloc[0] == "foo" # Apply the filter via Pandas - df = results = await async_tbl.query().nearest_to([100, 100]).to_pandas() + df = results = await (await async_tbl.search([100, 100])).to_pandas() results = df[df.price < 15] assert len(results) == 1 assert results["item"].iloc[0] == "foo" @@ -188,3 +186,26 @@ def test_polars(): # --8<-- [start:print_table_lazyform] print(ldf.first().collect()) # --8<-- [end:print_table_lazyform] + + +@pytest.mark.asyncio +async def test_polars_async(): + uri = "data/sample-lancedb" + db = await lancedb.connect_async(uri) + + # --8<-- [start:create_table_polars_async] + data = pl.DataFrame( + { + "vector": [[3.1, 4.1], [5.9, 26.5]], + "item": ["foo", "bar"], + "price": [10.0, 20.0], + } + ) + table = await db.create_table("pl_table_async", data=data) + # --8<-- [end:create_table_polars_async] + # --8<-- [start:vector_search_polars_async] + query = [3.0, 4.0] + result = await (await table.search(query)).limit(1).to_polars() + print(result) + print(type(result)) + # --8<-- [end:vector_search_polars_async] diff --git a/python/python/tests/docs/test_search.py b/python/python/tests/docs/test_search.py index a3f4c60a..fe276fe4 100644 --- a/python/python/tests/docs/test_search.py +++ b/python/python/tests/docs/test_search.py @@ -117,12 +117,11 @@ async def test_vector_search_async(): for i, row in enumerate(np.random.random((10_000, 1536)).astype("float32")) ] async_tbl = await async_db.create_table("vector_search_async", data=data) - (await async_tbl.query().nearest_to(np.random.random((1536))).limit(10).to_list()) + (await (await async_tbl.search(np.random.random((1536)))).limit(10).to_list()) # --8<-- [end:exhaustive_search_async] # --8<-- [start:exhaustive_search_async_cosine] ( - await async_tbl.query() - .nearest_to(np.random.random((1536))) + await (await async_tbl.search(np.random.random((1536)))) .distance_type("cosine") .limit(10) .to_list() @@ -145,13 +144,13 @@ async def test_vector_search_async(): async_tbl = await async_db.create_table("documents_async", data=data) # --8<-- [end:create_table_async_with_nested_schema] # --8<-- [start:search_result_async_as_pyarrow] - await async_tbl.query().nearest_to(np.random.randn(1536)).to_arrow() + await (await async_tbl.search(np.random.randn(1536))).to_arrow() # --8<-- [end:search_result_async_as_pyarrow] # --8<-- [start:search_result_async_as_pandas] - await async_tbl.query().nearest_to(np.random.randn(1536)).to_pandas() + await (await async_tbl.search(np.random.randn(1536))).to_pandas() # --8<-- [end:search_result_async_as_pandas] # --8<-- [start:search_result_async_as_list] - await async_tbl.query().nearest_to(np.random.randn(1536)).to_list() + await (await async_tbl.search(np.random.randn(1536))).to_list() # --8<-- [end:search_result_async_as_list] @@ -219,9 +218,7 @@ async def test_fts_native_async(): # async API uses our native FTS algorithm await async_tbl.create_index("text", config=FTS()) - await ( - async_tbl.query().nearest_to_text("puppy").select(["text"]).limit(10).to_list() - ) + await (await async_tbl.search("puppy")).select(["text"]).limit(10).to_list() # [{'text': 'Frodo was a happy puppy', '_score': 0.6931471824645996}] # ... # --8<-- [end:basic_fts_async] @@ -235,18 +232,11 @@ async def test_fts_native_async(): ) # --8<-- [end:fts_config_folding_async] # --8<-- [start:fts_prefiltering_async] - await ( - async_tbl.query() - .nearest_to_text("puppy") - .limit(10) - .where("text='foo'") - .to_list() - ) + await (await async_tbl.search("puppy")).limit(10).where("text='foo'").to_list() # --8<-- [end:fts_prefiltering_async] # --8<-- [start:fts_postfiltering_async] await ( - async_tbl.query() - .nearest_to_text("puppy") + (await async_tbl.search("puppy")) .limit(10) .where("text='foo'") .postfilter() @@ -347,14 +337,8 @@ async def test_hybrid_search_async(): # Create a fts index before the hybrid search await async_tbl.create_index("text", config=FTS()) text_query = "flower moon" - vector_query = embeddings.compute_query_embeddings(text_query)[0] # hybrid search with default re-ranker - await ( - async_tbl.query() - .nearest_to(vector_query) - .nearest_to_text(text_query) - .to_pandas() - ) + await (await async_tbl.search("flower moon", query_type="hybrid")).to_pandas() # --8<-- [end:basic_hybrid_search_async] # --8<-- [start:hybrid_search_pass_vector_text_async] vector_query = [0.1, 0.2, 0.3, 0.4, 0.5] diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index 1cdd23d1..47ad6f05 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -1,25 +1,35 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The LanceDB Authors +from typing import List, Union import unittest.mock as mock from datetime import timedelta from pathlib import Path import lancedb -from lancedb.index import IvfPq, FTS -from lancedb.rerankers.cross_encoder import CrossEncoderReranker +from lancedb.db import AsyncConnection +from lancedb.embeddings.base import TextEmbeddingFunction +from lancedb.embeddings.registry import get_registry, register +from lancedb.index import FTS, IvfPq +import lancedb.pydantic import numpy as np import pandas.testing as tm import pyarrow as pa +import pyarrow.compute as pc import pytest import pytest_asyncio from lancedb.pydantic import LanceModel, Vector from lancedb.query import ( + AsyncFTSQuery, + AsyncHybridQuery, AsyncQueryBase, + AsyncVectorQuery, LanceVectorQueryBuilder, Query, ) +from lancedb.rerankers.cross_encoder import CrossEncoderReranker from lancedb.table import AsyncTable, LanceTable +from utils import exception_output @pytest.fixture(scope="module") @@ -716,3 +726,101 @@ async def test_query_with_f16(tmp_path: Path): tbl = await db.create_table("test", df) results = await tbl.vector_search([np.float16(1), np.float16(2)]).to_pandas() assert len(results) == 2 + + +@pytest.mark.asyncio +async def test_query_search_auto(mem_db_async: AsyncConnection): + nrows = 1000 + data = pa.table( + { + "text": [str(i) for i in range(nrows)], + } + ) + + @register("test2") + class TestEmbedding(TextEmbeddingFunction): + def ndims(self): + return 4 + + def generate_embeddings( + self, texts: Union[List[str], np.ndarray] + ) -> List[np.array]: + embeddings = [] + for text in texts: + vec = np.array([float(text) / 1000] * self.ndims()) + embeddings.append(vec) + return embeddings + + registry = get_registry() + func = registry.get("test2").create() + + class TestModel(LanceModel): + text: str = func.SourceField() + vector: Vector(func.ndims()) = func.VectorField() + + tbl = await mem_db_async.create_table("test", data, schema=TestModel) + + funcs = await tbl.embedding_functions() + assert len(funcs) == 1 + + # No FTS or vector index + # Search for vector -> vector query + q = [0.1] * 4 + query = await tbl.search(q) + assert isinstance(query, AsyncVectorQuery) + + # Search for string -> vector query + query = await tbl.search("0.1") + assert isinstance(query, AsyncVectorQuery) + + await tbl.create_index("text", config=FTS()) + + query = await tbl.search("0.1") + assert isinstance(query, AsyncHybridQuery) + + data_with_vecs = await tbl.to_arrow() + data_with_vecs = data_with_vecs.replace_schema_metadata(None) + tbl2 = await mem_db_async.create_table("test2", data_with_vecs) + with pytest.raises( + Exception, + match=( + "Cannot perform full text search unless an INVERTED index has " + "been created" + ), + ): + query = await (await tbl2.search("0.1")).to_arrow() + + +@pytest.mark.asyncio +async def test_query_search_specified(mem_db_async: AsyncConnection): + nrows, ndims = 1000, 16 + data = pa.table( + { + "text": [str(i) for i in range(nrows)], + "vector": pa.FixedSizeListArray.from_arrays( + pc.random(nrows * ndims).cast(pa.float32()), ndims + ), + } + ) + table = await mem_db_async.create_table("test", data) + await table.create_index("text", config=FTS()) + + # Validate that specifying fts, vector or hybrid gets the right query. + q = [0.1] * ndims + query = await table.search(q, query_type="vector") + assert isinstance(query, AsyncVectorQuery) + + query = await table.search("0.1", query_type="fts") + assert isinstance(query, AsyncFTSQuery) + + with pytest.raises(ValueError, match="Unknown query type: 'foo'"): + await table.search("0.1", query_type="foo") + + with pytest.raises( + ValueError, match="Column 'vector' has no registered embedding function" + ) as e: + await table.search("0.1", query_type="vector") + + assert "No embedding functions are registered for any columns" in exception_output( + e + ) diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index 2d962cf0..34590259 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -338,6 +338,7 @@ def test_query_sync_empty_query(): "filter": "true", "vector": [], "columns": ["id"], + "prefilter": False, "version": None, } @@ -412,6 +413,7 @@ def test_query_sync_fts(): "columns": [], }, "k": 10, + "prefilter": True, "vector": [], "version": None, } @@ -429,6 +431,7 @@ def test_query_sync_fts(): }, "k": 42, "vector": [], + "prefilter": True, "with_row_id": True, "version": None, } @@ -455,6 +458,7 @@ def test_query_sync_hybrid(): }, "k": 42, "vector": [], + "prefilter": True, "with_row_id": True, "version": None, }