From 1b990983b32012b7374c820b4541391de00a25e5 Mon Sep 17 00:00:00 2001 From: QianZhu Date: Mon, 12 Feb 2024 16:35:44 -0800 Subject: [PATCH] Qian/make vector col optional (#950) remote SDK tests were completed through lancedb_integtest --- python/lancedb/embeddings/instructor.py | 6 +-- python/lancedb/query.py | 6 +-- python/lancedb/remote/table.py | 19 ++++++--- python/lancedb/table.py | 32 +++++++++++---- python/lancedb/util.py | 39 ++++++++++++++++++ python/tests/test_embeddings_slow.py | 20 ++++++++-- python/tests/test_remote_db.py | 1 + python/tests/test_table.py | 53 +++++++++++++++++++++++++ 8 files changed, 154 insertions(+), 22 deletions(-) diff --git a/python/lancedb/embeddings/instructor.py b/python/lancedb/embeddings/instructor.py index c2058b27..8d3311ec 100644 --- a/python/lancedb/embeddings/instructor.py +++ b/python/lancedb/embeddings/instructor.py @@ -102,9 +102,9 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction): # convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly source_instruction: str = "represent the document for retrieval" - query_instruction: str = ( - "represent the document for retrieving the most similar documents" - ) + query_instruction: ( + str + ) = "represent the document for retrieving the most similar documents" @weak_lru(maxsize=1) def ndims(self): diff --git a/python/lancedb/query.py b/python/lancedb/query.py index 8ce7b5f2..44809da9 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -24,7 +24,7 @@ import pyarrow as pa import pydantic from . import __version__ -from .common import VEC, VECTOR_COLUMN_NAME +from .common import VEC from .rerankers.base import Reranker from .rerankers.linear_combination import LinearCombinationReranker from .util import safe_import_pandas @@ -75,7 +75,7 @@ class Query(pydantic.BaseModel): tuning advice. """ - vector_column: str = VECTOR_COLUMN_NAME + vector_column: Optional[str] = None # vector to search for vector: Union[List[float], List[List[float]]] @@ -403,7 +403,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): self, table: "Table", query: Union[np.ndarray, list, "PIL.Image.Image"], - vector_column: str = VECTOR_COLUMN_NAME, + vector_column: str, ): super().__init__(table) self._query = query diff --git a/python/lancedb/remote/table.py b/python/lancedb/remote/table.py index 925690c3..a38e7861 100644 --- a/python/lancedb/remote/table.py +++ b/python/lancedb/remote/table.py @@ -24,7 +24,7 @@ from lancedb.merge import LanceMergeInsertBuilder from ..query import LanceVectorQueryBuilder from ..table import Query, Table, _sanitize_data -from ..util import value_to_sql +from ..util import inf_vector_column_query, value_to_sql from .arrow import to_ipc_binary from .client import ARROW_STREAM_CONTENT_TYPE from .db import RemoteDBConnection @@ -198,7 +198,9 @@ class RemoteTable(Table): ) def search( - self, query: Union[VEC, str], vector_column_name: str = VECTOR_COLUMN_NAME + self, + query: Union[VEC, str], + vector_column_name: Optional[str] = None, ) -> LanceVectorQueryBuilder: """Create a search query to find the nearest neighbors of the given query vector. We currently support [vector search][search] @@ -217,7 +219,7 @@ class RemoteTable(Table): ... ] >>> table = db.create_table("my_table", data) # doctest: +SKIP >>> query = [0.4, 1.4, 2.4] - >>> (table.search(query, vector_column_name="vector") # doctest: +SKIP + >>> (table.search(query) # doctest: +SKIP ... .where("original_width > 1000", prefilter=True) # doctest: +SKIP ... .select(["caption", "original_width"]) # doctest: +SKIP ... .limit(2) # doctest: +SKIP @@ -236,9 +238,14 @@ class RemoteTable(Table): - If None then the select/where/limit clauses are applied to filter the table - vector_column_name: str + vector_column_name: str, optional The name of the vector column to search. - *default "vector"* + + - If not specified then the vector column is inferred from + the table schema + + - If the table has multiple vector columns then the *vector_column_name* + needs to be specified. Otherwise, an error is raised. Returns ------- @@ -253,6 +260,8 @@ class RemoteTable(Table): - and also the "_distance" column which is the distance between the query vector and the returned vector. """ + if vector_column_name is None: + vector_column_name = inf_vector_column_query(self.schema) return LanceVectorQueryBuilder(self, query, vector_column_name) def _execute_query(self, query: Query) -> pa.Table: diff --git a/python/lancedb/table.py b/python/lancedb/table.py index fb790268..04bad713 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -36,6 +36,7 @@ from .pydantic import LanceModel, model_to_dict from .query import LanceQueryBuilder, Query from .util import ( fs_from_uri, + inf_vector_column_query, join_uri, safe_import_pandas, safe_import_polars, @@ -413,7 +414,7 @@ class Table(ABC): def search( self, query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, - vector_column_name: str = VECTOR_COLUMN_NAME, + vector_column_name: Optional[str] = None, query_type: str = "auto", ) -> LanceQueryBuilder: """Create a search query to find the nearest neighbors @@ -433,7 +434,7 @@ class Table(ABC): ... ] >>> table = db.create_table("my_table", data) >>> query = [0.4, 1.4, 2.4] - >>> (table.search(query, vector_column_name="vector") + >>> (table.search(query) ... .where("original_width > 1000", prefilter=True) ... .select(["caption", "original_width"]) ... .limit(2) @@ -452,11 +453,16 @@ class Table(ABC): - If None then the select/where/limit clauses are applied to filter the table - vector_column_name: str + vector_column_name: str, optional The name of the vector column to search. The vector column needs to be a pyarrow fixed size list type - *default "vector"* + + - If not specified then the vector column is inferred from + the table schema + + - If the table has multiple vector columns then the *vector_column_name* + needs to be specified. Otherwise, an error is raised. query_type: str *default "auto"*. Acceptable types are: "vector", "fts", "hybrid", or "auto" @@ -1193,7 +1199,7 @@ class LanceTable(Table): def search( self, query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, - vector_column_name: str = VECTOR_COLUMN_NAME, + vector_column_name: Optional[str] = None, query_type: str = "auto", ) -> LanceQueryBuilder: """Create a search query to find the nearest neighbors @@ -1211,7 +1217,7 @@ class LanceTable(Table): ... ] >>> table = db.create_table("my_table", data) >>> query = [0.4, 1.4, 2.4] - >>> (table.search(query, vector_column_name="vector") + >>> (table.search(query) ... .where("original_width > 1000", prefilter=True) ... .select(["caption", "original_width"]) ... .limit(2) @@ -1230,8 +1236,17 @@ class LanceTable(Table): - If None then the select/[where][sql]/limit clauses are applied to filter the table - vector_column_name: str, default "vector" + vector_column_name: str, optional The name of the vector column to search. + + The vector column needs to be a pyarrow fixed size list type + *default "vector"* + + - If not specified then the vector column is inferred from + the table schema + + - If the table has multiple vector columns then the *vector_column_name* + needs to be specified. Otherwise, an error is raised. query_type: str, default "auto" "vector", "fts", or "auto" If "auto" then the query type is inferred from the query; @@ -1249,6 +1264,8 @@ class LanceTable(Table): and also the "_distance" column which is the distance between the query vector and the returned vector. """ + if vector_column_name is None and query is not None: + vector_column_name = inf_vector_column_query(self.schema) register_event("search_table") return LanceQueryBuilder.create( self, query, query_type, vector_column_name=vector_column_name @@ -1435,6 +1452,7 @@ class LanceTable(Table): def _execute_query(self, query: Query) -> pa.Table: ds = self.to_lance() + return ds.to_table( columns=query.columns, filter=query.filter, diff --git a/python/lancedb/util.py b/python/lancedb/util.py index 7eb80ea8..915b660a 100644 --- a/python/lancedb/util.py +++ b/python/lancedb/util.py @@ -20,6 +20,7 @@ from typing import Tuple, Union from urllib.parse import urlparse import numpy as np +import pyarrow as pa import pyarrow.fs as pa_fs @@ -152,6 +153,44 @@ def safe_import_polars(): return None +def inf_vector_column_query(schema: pa.Schema) -> str: + """ + Get the vector column name + + Parameters + ---------- + schema : pa.Schema + The schema of the vector column. + + Returns + ------- + str: the vector column name. + """ + vector_col_name = "" + vector_col_count = 0 + for field_name in schema.names: + field = schema.field(field_name) + if pa.types.is_fixed_size_list(field.type) and pa.types.is_floating( + field.type.value_type + ): + vector_col_count += 1 + if vector_col_count > 1: + raise ValueError( + "Schema has more than one vector column. " + "Please specify the vector column name " + "for vector search" + ) + break + elif vector_col_count == 1: + vector_col_name = field_name + if vector_col_count == 0: + raise ValueError( + "There is no vector column in the data. " + "Please specify the vector column name for vector search" + ) + return vector_col_name + + @singledispatch def value_to_sql(value): raise NotImplementedError("SQL conversion is not implemented for this type") diff --git a/python/tests/test_embeddings_slow.py b/python/tests/test_embeddings_slow.py index cfdb1247..dff931c1 100644 --- a/python/tests/test_embeddings_slow.py +++ b/python/tests/test_embeddings_slow.py @@ -69,10 +69,14 @@ def test_basic_text_embeddings(alias, tmp_path): ) query = "greetings" - actual = table.search(query).limit(1).to_pydantic(Words)[0] + actual = ( + table.search(query, vector_column_name="vector").limit(1).to_pydantic(Words)[0] + ) vec = func.compute_query_embeddings(query)[0] - expected = table.search(vec).limit(1).to_pydantic(Words)[0] + expected = ( + table.search(vec, vector_column_name="vector").limit(1).to_pydantic(Words)[0] + ) assert actual.text == expected.text assert actual.text == "hello world" assert not np.allclose(actual.vector, actual.vector2) @@ -116,7 +120,11 @@ def test_openclip(tmp_path): ) # text search - actual = table.search("man's best friend").limit(1).to_pydantic(Images)[0] + actual = ( + table.search("man's best friend", vector_column_name="vector") + .limit(1) + .to_pydantic(Images)[0] + ) assert actual.label == "dog" frombytes = ( table.search("man's best friend", vector_column_name="vec_from_bytes") @@ -130,7 +138,11 @@ def test_openclip(tmp_path): query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg" image_bytes = requests.get(query_image_uri).content query_image = Image.open(io.BytesIO(image_bytes)) - actual = table.search(query_image).limit(1).to_pydantic(Images)[0] + actual = ( + table.search(query_image, vector_column_name="vector") + .limit(1) + .to_pydantic(Images)[0] + ) assert actual.label == "dog" other = ( table.search(query_image, vector_column_name="vec_from_bytes") diff --git a/python/tests/test_remote_db.py b/python/tests/test_remote_db.py index bca4451f..f4aff298 100644 --- a/python/tests/test_remote_db.py +++ b/python/tests/test_remote_db.py @@ -38,4 +38,5 @@ def test_remote_db(): setattr(conn, "_client", FakeLanceDBClient()) table = conn["test"] + table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))]) table.search([1.0, 2.0]).to_pandas() diff --git a/python/tests/test_table.py b/python/tests/test_table.py index 02e8b1f0..3c29ed35 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -710,6 +710,59 @@ def test_empty_query(db): assert len(df) == 100 +def test_search_with_schema_inf_single_vector(db): + class MyTable(LanceModel): + text: str + vector_col: Vector(10) + + table = LanceTable.create( + db, + "my_table", + schema=MyTable, + ) + + v1 = np.random.randn(10) + v2 = np.random.randn(10) + data = [ + {"vector_col": v1, "text": "foo"}, + {"vector_col": v2, "text": "bar"}, + ] + df = pd.DataFrame(data) + table.add(df) + + q = np.random.randn(10) + result1 = table.search(q, vector_column_name="vector_col").limit(1).to_pandas() + result2 = table.search(q).limit(1).to_pandas() + + assert result1["text"].iloc[0] == result2["text"].iloc[0] + + +def test_search_with_schema_inf_multiple_vector(db): + class MyTable(LanceModel): + text: str + vector1: Vector(10) + vector2: Vector(10) + + table = LanceTable.create( + db, + "my_table", + schema=MyTable, + ) + + v1 = np.random.randn(10) + v2 = np.random.randn(10) + data = [ + {"vector1": v1, "vector2": v2, "text": "foo"}, + {"vector1": v2, "vector2": v1, "text": "bar"}, + ] + df = pd.DataFrame(data) + table.add(df) + + q = np.random.randn(10) + with pytest.raises(ValueError): + table.search(q).limit(1).to_pandas() + + def test_compact_cleanup(db): table = LanceTable.create( db,