Compare commits

...

2 Commits

Author SHA1 Message Date
qzhu
de14120bbe debug 2024-03-14 16:23:41 -07:00
qzhu
fa342e7df4 init debug 2024-03-14 15:56:50 -07:00
3 changed files with 6 additions and 4 deletions

View File

@@ -271,7 +271,8 @@ class LanceQueryBuilder(ABC):
and also the "_distance" column which is the distance between the query and also the "_distance" column which is the distance between the query
vector and the returned vectors. vector and the returned vectors.
""" """
raise NotImplementedError # raise NotImplementedError
self.to_arrow()
def to_list(self) -> List[dict]: def to_list(self) -> List[dict]:
""" """
@@ -434,12 +435,12 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._vector_column = vector_column self._vector_column = vector_column
self._prefilter = False self._prefilter = False
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder: def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceVectorQueryBuilder:
"""Set the distance metric to use. """Set the distance metric to use.
Parameters Parameters
---------- ----------
metric: "L2" or "cosine" metric: "L2" or "cosine" or "dot"
The distance metric to use. By default "L2" is used. The distance metric to use. By default "L2" is used.
Returns Returns

View File

@@ -296,6 +296,7 @@ class RemoteTable(Table):
return LanceVectorQueryBuilder(self, query, vector_column_name) return LanceVectorQueryBuilder(self, query, vector_column_name)
def _execute_query(self, query: Query) -> pa.Table: def _execute_query(self, query: Query) -> pa.Table:
print("query metric", query.metric)
if ( if (
query.vector is not None query.vector is not None
and len(query.vector) > 0 and len(query.vector) > 0

View File

@@ -1522,7 +1522,7 @@ class LanceTable(Table):
def _execute_query(self, query: Query) -> pa.Table: def _execute_query(self, query: Query) -> pa.Table:
ds = self.to_lance() ds = self.to_lance()
print("metric:", query.metric)
return ds.to_table( return ds.to_table(
columns=query.columns, columns=query.columns,
filter=query.filter, filter=query.filter,