feat(python): support embedding functions in remote table (#1405)

This commit is contained in:
Ayush Chaurasia
2024-08-07 20:22:43 +05:30
committed by GitHub
parent a62f661d90
commit e01045692c

View File

@@ -22,8 +22,9 @@ from lance import json_to_schema
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
from lancedb.merge import LanceMergeInsertBuilder from lancedb.merge import LanceMergeInsertBuilder
from lancedb.embeddings import EmbeddingFunctionRegistry
from ..query import LanceVectorQueryBuilder from ..query import LanceVectorQueryBuilder, LanceQueryBuilder
from ..table import Query, Table, _sanitize_data from ..table import Query, Table, _sanitize_data
from ..util import inf_vector_column_query, value_to_sql from ..util import inf_vector_column_query, value_to_sql
from .arrow import to_ipc_binary from .arrow import to_ipc_binary
@@ -58,6 +59,21 @@ class RemoteTable(Table):
resp = self._conn._client.post(f"/v1/table/{self._name}/describe/") resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
return resp["version"] return resp["version"]
@cached_property
def embedding_functions(self) -> dict:
"""
Get the embedding functions for the table
Returns
-------
funcs: dict
A mapping of the vector column to the embedding function
or empty dict if not configured.
"""
return EmbeddingFunctionRegistry.get_instance().parse_functions(
self.schema.metadata
)
def to_arrow(self) -> pa.Table: def to_arrow(self) -> pa.Table:
"""to_arrow() is not yet supported on LanceDB cloud.""" """to_arrow() is not yet supported on LanceDB cloud."""
raise NotImplementedError("to_arrow() is not yet supported on LanceDB cloud.") raise NotImplementedError("to_arrow() is not yet supported on LanceDB cloud.")
@@ -213,7 +229,7 @@ class RemoteTable(Table):
data, _ = _sanitize_data( data, _ = _sanitize_data(
data, data,
self.schema, self.schema,
metadata=None, metadata=self.schema.metadata,
on_bad_vectors=on_bad_vectors, on_bad_vectors=on_bad_vectors,
fill_value=fill_value, fill_value=fill_value,
) )
@@ -293,6 +309,7 @@ class RemoteTable(Table):
""" """
if vector_column_name is None: if vector_column_name is None:
vector_column_name = inf_vector_column_query(self.schema) vector_column_name = inf_vector_column_query(self.schema)
query = LanceQueryBuilder._query_to_vector(self, query, vector_column_name)
return LanceVectorQueryBuilder(self, query, vector_column_name) return LanceVectorQueryBuilder(self, query, vector_column_name)
def _execute_query( def _execute_query(