Qian/make vector col optional (#950)

remote SDK tests were completed through lancedb_integtest
This commit is contained in:
QianZhu
2024-02-12 16:35:44 -08:00
committed by Weston Pace
parent 88205aba64
commit 7afcfca10d
8 changed files with 154 additions and 22 deletions

View File

@@ -102,9 +102,9 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
# convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly
source_instruction: str = "represent the document for retrieval"
query_instruction: str = (
"represent the document for retrieving the most similar documents"
)
query_instruction: (
str
) = "represent the document for retrieving the most similar documents"
@weak_lru(maxsize=1)
def ndims(self):

View File

@@ -24,7 +24,7 @@ import pyarrow as pa
import pydantic
from . import __version__
from .common import VEC, VECTOR_COLUMN_NAME
from .common import VEC
from .rerankers.base import Reranker
from .rerankers.linear_combination import LinearCombinationReranker
from .util import safe_import_pandas
@@ -75,7 +75,7 @@ class Query(pydantic.BaseModel):
tuning advice.
"""
vector_column: str = VECTOR_COLUMN_NAME
vector_column: Optional[str] = None
# vector to search for
vector: Union[List[float], List[List[float]]]
@@ -403,7 +403,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self,
table: "Table",
query: Union[np.ndarray, list, "PIL.Image.Image"],
vector_column: str = VECTOR_COLUMN_NAME,
vector_column: str,
):
super().__init__(table)
self._query = query

View File

@@ -24,7 +24,7 @@ from lancedb.merge import LanceMergeInsertBuilder
from ..query import LanceVectorQueryBuilder
from ..table import Query, Table, _sanitize_data
from ..util import value_to_sql
from ..util import inf_vector_column_query, value_to_sql
from .arrow import to_ipc_binary
from .client import ARROW_STREAM_CONTENT_TYPE
from .db import RemoteDBConnection
@@ -198,7 +198,9 @@ class RemoteTable(Table):
)
def search(
self, query: Union[VEC, str], vector_column_name: str = VECTOR_COLUMN_NAME
self,
query: Union[VEC, str],
vector_column_name: Optional[str] = None,
) -> LanceVectorQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search]
@@ -217,7 +219,7 @@ class RemoteTable(Table):
... ]
>>> table = db.create_table("my_table", data) # doctest: +SKIP
>>> query = [0.4, 1.4, 2.4]
>>> (table.search(query, vector_column_name="vector") # doctest: +SKIP
>>> (table.search(query) # doctest: +SKIP
... .where("original_width > 1000", prefilter=True) # doctest: +SKIP
... .select(["caption", "original_width"]) # doctest: +SKIP
... .limit(2) # doctest: +SKIP
@@ -236,9 +238,14 @@ class RemoteTable(Table):
- If None then the select/where/limit clauses are applied to filter
the table
vector_column_name: str
vector_column_name: str, optional
The name of the vector column to search.
*default "vector"*
- If not specified then the vector column is inferred from
the table schema
- If the table has multiple vector columns then the *vector_column_name*
needs to be specified. Otherwise, an error is raised.
Returns
-------
@@ -253,6 +260,8 @@ class RemoteTable(Table):
- and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
if vector_column_name is None:
vector_column_name = inf_vector_column_query(self.schema)
return LanceVectorQueryBuilder(self, query, vector_column_name)
def _execute_query(self, query: Query) -> pa.Table:

View File

@@ -36,6 +36,7 @@ from .pydantic import LanceModel, model_to_dict
from .query import LanceQueryBuilder, Query
from .util import (
fs_from_uri,
inf_vector_column_query,
join_uri,
safe_import_pandas,
safe_import_polars,
@@ -412,7 +413,7 @@ class Table(ABC):
def search(
self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: str = VECTOR_COLUMN_NAME,
vector_column_name: Optional[str] = None,
query_type: str = "auto",
) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
@@ -432,7 +433,7 @@ class Table(ABC):
... ]
>>> table = db.create_table("my_table", data)
>>> query = [0.4, 1.4, 2.4]
>>> (table.search(query, vector_column_name="vector")
>>> (table.search(query)
... .where("original_width > 1000", prefilter=True)
... .select(["caption", "original_width"])
... .limit(2)
@@ -451,11 +452,16 @@ class Table(ABC):
- If None then the select/where/limit clauses are applied to filter
the table
vector_column_name: str
vector_column_name: str, optional
The name of the vector column to search.
The vector column needs to be a pyarrow fixed size list type
*default "vector"*
- If not specified then the vector column is inferred from
the table schema
- If the table has multiple vector columns then the *vector_column_name*
needs to be specified. Otherwise, an error is raised.
query_type: str
*default "auto"*.
Acceptable types are: "vector", "fts", "hybrid", or "auto"
@@ -1188,7 +1194,7 @@ class LanceTable(Table):
def search(
self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: str = VECTOR_COLUMN_NAME,
vector_column_name: Optional[str] = None,
query_type: str = "auto",
) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
@@ -1206,7 +1212,7 @@ class LanceTable(Table):
... ]
>>> table = db.create_table("my_table", data)
>>> query = [0.4, 1.4, 2.4]
>>> (table.search(query, vector_column_name="vector")
>>> (table.search(query)
... .where("original_width > 1000", prefilter=True)
... .select(["caption", "original_width"])
... .limit(2)
@@ -1225,8 +1231,17 @@ class LanceTable(Table):
- If None then the select/[where][sql]/limit clauses are applied
to filter the table
vector_column_name: str, default "vector"
vector_column_name: str, optional
The name of the vector column to search.
The vector column needs to be a pyarrow fixed size list type
*default "vector"*
- If not specified then the vector column is inferred from
the table schema
- If the table has multiple vector columns then the *vector_column_name*
needs to be specified. Otherwise, an error is raised.
query_type: str, default "auto"
"vector", "fts", or "auto"
If "auto" then the query type is inferred from the query;
@@ -1244,6 +1259,8 @@ class LanceTable(Table):
and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
if vector_column_name is None and query is not None:
vector_column_name = inf_vector_column_query(self.schema)
return LanceQueryBuilder.create(
self, query, query_type, vector_column_name=vector_column_name
)
@@ -1427,6 +1444,7 @@ class LanceTable(Table):
def _execute_query(self, query: Query) -> pa.Table:
ds = self.to_lance()
return ds.to_table(
columns=query.columns,
filter=query.filter,

View File

@@ -20,6 +20,7 @@ from typing import Tuple, Union
from urllib.parse import urlparse
import numpy as np
import pyarrow as pa
import pyarrow.fs as pa_fs
@@ -152,6 +153,44 @@ def safe_import_polars():
return None
def inf_vector_column_query(schema: pa.Schema) -> str:
"""
Get the vector column name
Parameters
----------
schema : pa.Schema
The schema of the vector column.
Returns
-------
str: the vector column name.
"""
vector_col_name = ""
vector_col_count = 0
for field_name in schema.names:
field = schema.field(field_name)
if pa.types.is_fixed_size_list(field.type) and pa.types.is_floating(
field.type.value_type
):
vector_col_count += 1
if vector_col_count > 1:
raise ValueError(
"Schema has more than one vector column. "
"Please specify the vector column name "
"for vector search"
)
break
elif vector_col_count == 1:
vector_col_name = field_name
if vector_col_count == 0:
raise ValueError(
"There is no vector column in the data. "
"Please specify the vector column name for vector search"
)
return vector_col_name
@singledispatch
def value_to_sql(value):
raise NotImplementedError("SQL conversion is not implemented for this type")