mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-06 11:52:57 +00:00
Qian/make vector col optional (#950)
remote SDK tests were completed through lancedb_integtest
This commit is contained in:
@@ -102,9 +102,9 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
|||||||
# convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly
|
# convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly
|
||||||
|
|
||||||
source_instruction: str = "represent the document for retrieval"
|
source_instruction: str = "represent the document for retrieval"
|
||||||
query_instruction: str = (
|
query_instruction: (
|
||||||
"represent the document for retrieving the most similar documents"
|
str
|
||||||
)
|
) = "represent the document for retrieving the most similar documents"
|
||||||
|
|
||||||
@weak_lru(maxsize=1)
|
@weak_lru(maxsize=1)
|
||||||
def ndims(self):
|
def ndims(self):
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ import pyarrow as pa
|
|||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
from .common import VEC, VECTOR_COLUMN_NAME
|
from .common import VEC
|
||||||
from .rerankers.base import Reranker
|
from .rerankers.base import Reranker
|
||||||
from .rerankers.linear_combination import LinearCombinationReranker
|
from .rerankers.linear_combination import LinearCombinationReranker
|
||||||
from .util import safe_import_pandas
|
from .util import safe_import_pandas
|
||||||
@@ -75,7 +75,7 @@ class Query(pydantic.BaseModel):
|
|||||||
tuning advice.
|
tuning advice.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
vector_column: str = VECTOR_COLUMN_NAME
|
vector_column: Optional[str] = None
|
||||||
|
|
||||||
# vector to search for
|
# vector to search for
|
||||||
vector: Union[List[float], List[List[float]]]
|
vector: Union[List[float], List[List[float]]]
|
||||||
@@ -403,7 +403,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
self,
|
self,
|
||||||
table: "Table",
|
table: "Table",
|
||||||
query: Union[np.ndarray, list, "PIL.Image.Image"],
|
query: Union[np.ndarray, list, "PIL.Image.Image"],
|
||||||
vector_column: str = VECTOR_COLUMN_NAME,
|
vector_column: str,
|
||||||
):
|
):
|
||||||
super().__init__(table)
|
super().__init__(table)
|
||||||
self._query = query
|
self._query = query
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from lancedb.merge import LanceMergeInsertBuilder
|
|||||||
|
|
||||||
from ..query import LanceVectorQueryBuilder
|
from ..query import LanceVectorQueryBuilder
|
||||||
from ..table import Query, Table, _sanitize_data
|
from ..table import Query, Table, _sanitize_data
|
||||||
from ..util import value_to_sql
|
from ..util import inf_vector_column_query, value_to_sql
|
||||||
from .arrow import to_ipc_binary
|
from .arrow import to_ipc_binary
|
||||||
from .client import ARROW_STREAM_CONTENT_TYPE
|
from .client import ARROW_STREAM_CONTENT_TYPE
|
||||||
from .db import RemoteDBConnection
|
from .db import RemoteDBConnection
|
||||||
@@ -198,7 +198,9 @@ class RemoteTable(Table):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self, query: Union[VEC, str], vector_column_name: str = VECTOR_COLUMN_NAME
|
self,
|
||||||
|
query: Union[VEC, str],
|
||||||
|
vector_column_name: Optional[str] = None,
|
||||||
) -> LanceVectorQueryBuilder:
|
) -> LanceVectorQueryBuilder:
|
||||||
"""Create a search query to find the nearest neighbors
|
"""Create a search query to find the nearest neighbors
|
||||||
of the given query vector. We currently support [vector search][search]
|
of the given query vector. We currently support [vector search][search]
|
||||||
@@ -217,7 +219,7 @@ class RemoteTable(Table):
|
|||||||
... ]
|
... ]
|
||||||
>>> table = db.create_table("my_table", data) # doctest: +SKIP
|
>>> table = db.create_table("my_table", data) # doctest: +SKIP
|
||||||
>>> query = [0.4, 1.4, 2.4]
|
>>> query = [0.4, 1.4, 2.4]
|
||||||
>>> (table.search(query, vector_column_name="vector") # doctest: +SKIP
|
>>> (table.search(query) # doctest: +SKIP
|
||||||
... .where("original_width > 1000", prefilter=True) # doctest: +SKIP
|
... .where("original_width > 1000", prefilter=True) # doctest: +SKIP
|
||||||
... .select(["caption", "original_width"]) # doctest: +SKIP
|
... .select(["caption", "original_width"]) # doctest: +SKIP
|
||||||
... .limit(2) # doctest: +SKIP
|
... .limit(2) # doctest: +SKIP
|
||||||
@@ -236,9 +238,14 @@ class RemoteTable(Table):
|
|||||||
|
|
||||||
- If None then the select/where/limit clauses are applied to filter
|
- If None then the select/where/limit clauses are applied to filter
|
||||||
the table
|
the table
|
||||||
vector_column_name: str
|
vector_column_name: str, optional
|
||||||
The name of the vector column to search.
|
The name of the vector column to search.
|
||||||
*default "vector"*
|
|
||||||
|
- If not specified then the vector column is inferred from
|
||||||
|
the table schema
|
||||||
|
|
||||||
|
- If the table has multiple vector columns then the *vector_column_name*
|
||||||
|
needs to be specified. Otherwise, an error is raised.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@@ -253,6 +260,8 @@ class RemoteTable(Table):
|
|||||||
- and also the "_distance" column which is the distance between the query
|
- and also the "_distance" column which is the distance between the query
|
||||||
vector and the returned vector.
|
vector and the returned vector.
|
||||||
"""
|
"""
|
||||||
|
if vector_column_name is None:
|
||||||
|
vector_column_name = inf_vector_column_query(self.schema)
|
||||||
return LanceVectorQueryBuilder(self, query, vector_column_name)
|
return LanceVectorQueryBuilder(self, query, vector_column_name)
|
||||||
|
|
||||||
def _execute_query(self, query: Query) -> pa.Table:
|
def _execute_query(self, query: Query) -> pa.Table:
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ from .pydantic import LanceModel, model_to_dict
|
|||||||
from .query import LanceQueryBuilder, Query
|
from .query import LanceQueryBuilder, Query
|
||||||
from .util import (
|
from .util import (
|
||||||
fs_from_uri,
|
fs_from_uri,
|
||||||
|
inf_vector_column_query,
|
||||||
join_uri,
|
join_uri,
|
||||||
safe_import_pandas,
|
safe_import_pandas,
|
||||||
safe_import_polars,
|
safe_import_polars,
|
||||||
@@ -413,7 +414,7 @@ class Table(ABC):
|
|||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
vector_column_name: Optional[str] = None,
|
||||||
query_type: str = "auto",
|
query_type: str = "auto",
|
||||||
) -> LanceQueryBuilder:
|
) -> LanceQueryBuilder:
|
||||||
"""Create a search query to find the nearest neighbors
|
"""Create a search query to find the nearest neighbors
|
||||||
@@ -433,7 +434,7 @@ class Table(ABC):
|
|||||||
... ]
|
... ]
|
||||||
>>> table = db.create_table("my_table", data)
|
>>> table = db.create_table("my_table", data)
|
||||||
>>> query = [0.4, 1.4, 2.4]
|
>>> query = [0.4, 1.4, 2.4]
|
||||||
>>> (table.search(query, vector_column_name="vector")
|
>>> (table.search(query)
|
||||||
... .where("original_width > 1000", prefilter=True)
|
... .where("original_width > 1000", prefilter=True)
|
||||||
... .select(["caption", "original_width"])
|
... .select(["caption", "original_width"])
|
||||||
... .limit(2)
|
... .limit(2)
|
||||||
@@ -452,11 +453,16 @@ class Table(ABC):
|
|||||||
|
|
||||||
- If None then the select/where/limit clauses are applied to filter
|
- If None then the select/where/limit clauses are applied to filter
|
||||||
the table
|
the table
|
||||||
vector_column_name: str
|
vector_column_name: str, optional
|
||||||
The name of the vector column to search.
|
The name of the vector column to search.
|
||||||
|
|
||||||
The vector column needs to be a pyarrow fixed size list type
|
The vector column needs to be a pyarrow fixed size list type
|
||||||
*default "vector"*
|
|
||||||
|
- If not specified then the vector column is inferred from
|
||||||
|
the table schema
|
||||||
|
|
||||||
|
- If the table has multiple vector columns then the *vector_column_name*
|
||||||
|
needs to be specified. Otherwise, an error is raised.
|
||||||
query_type: str
|
query_type: str
|
||||||
*default "auto"*.
|
*default "auto"*.
|
||||||
Acceptable types are: "vector", "fts", "hybrid", or "auto"
|
Acceptable types are: "vector", "fts", "hybrid", or "auto"
|
||||||
@@ -1193,7 +1199,7 @@ class LanceTable(Table):
|
|||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
vector_column_name: Optional[str] = None,
|
||||||
query_type: str = "auto",
|
query_type: str = "auto",
|
||||||
) -> LanceQueryBuilder:
|
) -> LanceQueryBuilder:
|
||||||
"""Create a search query to find the nearest neighbors
|
"""Create a search query to find the nearest neighbors
|
||||||
@@ -1211,7 +1217,7 @@ class LanceTable(Table):
|
|||||||
... ]
|
... ]
|
||||||
>>> table = db.create_table("my_table", data)
|
>>> table = db.create_table("my_table", data)
|
||||||
>>> query = [0.4, 1.4, 2.4]
|
>>> query = [0.4, 1.4, 2.4]
|
||||||
>>> (table.search(query, vector_column_name="vector")
|
>>> (table.search(query)
|
||||||
... .where("original_width > 1000", prefilter=True)
|
... .where("original_width > 1000", prefilter=True)
|
||||||
... .select(["caption", "original_width"])
|
... .select(["caption", "original_width"])
|
||||||
... .limit(2)
|
... .limit(2)
|
||||||
@@ -1230,8 +1236,17 @@ class LanceTable(Table):
|
|||||||
|
|
||||||
- If None then the select/[where][sql]/limit clauses are applied
|
- If None then the select/[where][sql]/limit clauses are applied
|
||||||
to filter the table
|
to filter the table
|
||||||
vector_column_name: str, default "vector"
|
vector_column_name: str, optional
|
||||||
The name of the vector column to search.
|
The name of the vector column to search.
|
||||||
|
|
||||||
|
The vector column needs to be a pyarrow fixed size list type
|
||||||
|
*default "vector"*
|
||||||
|
|
||||||
|
- If not specified then the vector column is inferred from
|
||||||
|
the table schema
|
||||||
|
|
||||||
|
- If the table has multiple vector columns then the *vector_column_name*
|
||||||
|
needs to be specified. Otherwise, an error is raised.
|
||||||
query_type: str, default "auto"
|
query_type: str, default "auto"
|
||||||
"vector", "fts", or "auto"
|
"vector", "fts", or "auto"
|
||||||
If "auto" then the query type is inferred from the query;
|
If "auto" then the query type is inferred from the query;
|
||||||
@@ -1249,6 +1264,8 @@ class LanceTable(Table):
|
|||||||
and also the "_distance" column which is the distance between the query
|
and also the "_distance" column which is the distance between the query
|
||||||
vector and the returned vector.
|
vector and the returned vector.
|
||||||
"""
|
"""
|
||||||
|
if vector_column_name is None and query is not None:
|
||||||
|
vector_column_name = inf_vector_column_query(self.schema)
|
||||||
register_event("search_table")
|
register_event("search_table")
|
||||||
return LanceQueryBuilder.create(
|
return LanceQueryBuilder.create(
|
||||||
self, query, query_type, vector_column_name=vector_column_name
|
self, query, query_type, vector_column_name=vector_column_name
|
||||||
@@ -1435,6 +1452,7 @@ class LanceTable(Table):
|
|||||||
|
|
||||||
def _execute_query(self, query: Query) -> pa.Table:
|
def _execute_query(self, query: Query) -> pa.Table:
|
||||||
ds = self.to_lance()
|
ds = self.to_lance()
|
||||||
|
|
||||||
return ds.to_table(
|
return ds.to_table(
|
||||||
columns=query.columns,
|
columns=query.columns,
|
||||||
filter=query.filter,
|
filter=query.filter,
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from typing import Tuple, Union
|
|||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pyarrow as pa
|
||||||
import pyarrow.fs as pa_fs
|
import pyarrow.fs as pa_fs
|
||||||
|
|
||||||
|
|
||||||
@@ -152,6 +153,44 @@ def safe_import_polars():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def inf_vector_column_query(schema: pa.Schema) -> str:
|
||||||
|
"""
|
||||||
|
Get the vector column name
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
schema : pa.Schema
|
||||||
|
The schema of the vector column.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str: the vector column name.
|
||||||
|
"""
|
||||||
|
vector_col_name = ""
|
||||||
|
vector_col_count = 0
|
||||||
|
for field_name in schema.names:
|
||||||
|
field = schema.field(field_name)
|
||||||
|
if pa.types.is_fixed_size_list(field.type) and pa.types.is_floating(
|
||||||
|
field.type.value_type
|
||||||
|
):
|
||||||
|
vector_col_count += 1
|
||||||
|
if vector_col_count > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Schema has more than one vector column. "
|
||||||
|
"Please specify the vector column name "
|
||||||
|
"for vector search"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
elif vector_col_count == 1:
|
||||||
|
vector_col_name = field_name
|
||||||
|
if vector_col_count == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"There is no vector column in the data. "
|
||||||
|
"Please specify the vector column name for vector search"
|
||||||
|
)
|
||||||
|
return vector_col_name
|
||||||
|
|
||||||
|
|
||||||
@singledispatch
|
@singledispatch
|
||||||
def value_to_sql(value):
|
def value_to_sql(value):
|
||||||
raise NotImplementedError("SQL conversion is not implemented for this type")
|
raise NotImplementedError("SQL conversion is not implemented for this type")
|
||||||
|
|||||||
@@ -69,10 +69,14 @@ def test_basic_text_embeddings(alias, tmp_path):
|
|||||||
)
|
)
|
||||||
|
|
||||||
query = "greetings"
|
query = "greetings"
|
||||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
actual = (
|
||||||
|
table.search(query, vector_column_name="vector").limit(1).to_pydantic(Words)[0]
|
||||||
|
)
|
||||||
|
|
||||||
vec = func.compute_query_embeddings(query)[0]
|
vec = func.compute_query_embeddings(query)[0]
|
||||||
expected = table.search(vec).limit(1).to_pydantic(Words)[0]
|
expected = (
|
||||||
|
table.search(vec, vector_column_name="vector").limit(1).to_pydantic(Words)[0]
|
||||||
|
)
|
||||||
assert actual.text == expected.text
|
assert actual.text == expected.text
|
||||||
assert actual.text == "hello world"
|
assert actual.text == "hello world"
|
||||||
assert not np.allclose(actual.vector, actual.vector2)
|
assert not np.allclose(actual.vector, actual.vector2)
|
||||||
@@ -116,7 +120,11 @@ def test_openclip(tmp_path):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# text search
|
# text search
|
||||||
actual = table.search("man's best friend").limit(1).to_pydantic(Images)[0]
|
actual = (
|
||||||
|
table.search("man's best friend", vector_column_name="vector")
|
||||||
|
.limit(1)
|
||||||
|
.to_pydantic(Images)[0]
|
||||||
|
)
|
||||||
assert actual.label == "dog"
|
assert actual.label == "dog"
|
||||||
frombytes = (
|
frombytes = (
|
||||||
table.search("man's best friend", vector_column_name="vec_from_bytes")
|
table.search("man's best friend", vector_column_name="vec_from_bytes")
|
||||||
@@ -130,7 +138,11 @@ def test_openclip(tmp_path):
|
|||||||
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
|
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
|
||||||
image_bytes = requests.get(query_image_uri).content
|
image_bytes = requests.get(query_image_uri).content
|
||||||
query_image = Image.open(io.BytesIO(image_bytes))
|
query_image = Image.open(io.BytesIO(image_bytes))
|
||||||
actual = table.search(query_image).limit(1).to_pydantic(Images)[0]
|
actual = (
|
||||||
|
table.search(query_image, vector_column_name="vector")
|
||||||
|
.limit(1)
|
||||||
|
.to_pydantic(Images)[0]
|
||||||
|
)
|
||||||
assert actual.label == "dog"
|
assert actual.label == "dog"
|
||||||
other = (
|
other = (
|
||||||
table.search(query_image, vector_column_name="vec_from_bytes")
|
table.search(query_image, vector_column_name="vec_from_bytes")
|
||||||
|
|||||||
@@ -38,4 +38,5 @@ def test_remote_db():
|
|||||||
setattr(conn, "_client", FakeLanceDBClient())
|
setattr(conn, "_client", FakeLanceDBClient())
|
||||||
|
|
||||||
table = conn["test"]
|
table = conn["test"]
|
||||||
|
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
|
||||||
table.search([1.0, 2.0]).to_pandas()
|
table.search([1.0, 2.0]).to_pandas()
|
||||||
|
|||||||
@@ -710,6 +710,59 @@ def test_empty_query(db):
|
|||||||
assert len(df) == 100
|
assert len(df) == 100
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_with_schema_inf_single_vector(db):
|
||||||
|
class MyTable(LanceModel):
|
||||||
|
text: str
|
||||||
|
vector_col: Vector(10)
|
||||||
|
|
||||||
|
table = LanceTable.create(
|
||||||
|
db,
|
||||||
|
"my_table",
|
||||||
|
schema=MyTable,
|
||||||
|
)
|
||||||
|
|
||||||
|
v1 = np.random.randn(10)
|
||||||
|
v2 = np.random.randn(10)
|
||||||
|
data = [
|
||||||
|
{"vector_col": v1, "text": "foo"},
|
||||||
|
{"vector_col": v2, "text": "bar"},
|
||||||
|
]
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
table.add(df)
|
||||||
|
|
||||||
|
q = np.random.randn(10)
|
||||||
|
result1 = table.search(q, vector_column_name="vector_col").limit(1).to_pandas()
|
||||||
|
result2 = table.search(q).limit(1).to_pandas()
|
||||||
|
|
||||||
|
assert result1["text"].iloc[0] == result2["text"].iloc[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_with_schema_inf_multiple_vector(db):
|
||||||
|
class MyTable(LanceModel):
|
||||||
|
text: str
|
||||||
|
vector1: Vector(10)
|
||||||
|
vector2: Vector(10)
|
||||||
|
|
||||||
|
table = LanceTable.create(
|
||||||
|
db,
|
||||||
|
"my_table",
|
||||||
|
schema=MyTable,
|
||||||
|
)
|
||||||
|
|
||||||
|
v1 = np.random.randn(10)
|
||||||
|
v2 = np.random.randn(10)
|
||||||
|
data = [
|
||||||
|
{"vector1": v1, "vector2": v2, "text": "foo"},
|
||||||
|
{"vector1": v2, "vector2": v1, "text": "bar"},
|
||||||
|
]
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
table.add(df)
|
||||||
|
|
||||||
|
q = np.random.randn(10)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
table.search(q).limit(1).to_pandas()
|
||||||
|
|
||||||
|
|
||||||
def test_compact_cleanup(db):
|
def test_compact_cleanup(db):
|
||||||
table = LanceTable.create(
|
table = LanceTable.create(
|
||||||
db,
|
db,
|
||||||
|
|||||||
Reference in New Issue
Block a user