Compare commits

...

2 Commits

Author SHA1 Message Date
qzhu
de14120bbe debug 2024-03-14 16:23:41 -07:00
qzhu
fa342e7df4 init debug 2024-03-14 15:56:50 -07:00
3 changed files with 6 additions and 4 deletions

View File

@@ -271,7 +271,8 @@ class LanceQueryBuilder(ABC):
and also the "_distance" column which is the distance between the query
vector and the returned vectors.
"""
raise NotImplementedError
# raise NotImplementedError
self.to_arrow()
def to_list(self) -> List[dict]:
"""
@@ -434,12 +435,12 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._vector_column = vector_column
self._prefilter = False
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceVectorQueryBuilder:
"""Set the distance metric to use.
Parameters
----------
metric: "L2" or "cosine"
metric: "L2" or "cosine" or "dot"
The distance metric to use. By default "L2" is used.
Returns

View File

@@ -296,6 +296,7 @@ class RemoteTable(Table):
return LanceVectorQueryBuilder(self, query, vector_column_name)
def _execute_query(self, query: Query) -> pa.Table:
print("query metric", query.metric)
if (
query.vector is not None
and len(query.vector) > 0

View File

@@ -1522,7 +1522,7 @@ class LanceTable(Table):
def _execute_query(self, query: Query) -> pa.Table:
ds = self.to_lance()
print("metric:", query.metric)
return ds.to_table(
columns=query.columns,
filter=query.filter,