mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-14 07:42:58 +00:00
multi-modal embedding-function (#484)
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
import os
|
||||
|
||||
import pyarrow as pa
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lancedb.embeddings import EmbeddingFunctionModel, EmbeddingFunctionRegistry
|
||||
from .embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction
|
||||
|
||||
# import lancedb so we don't have to in every example
|
||||
|
||||
@@ -22,17 +22,20 @@ def doctest_setup(monkeypatch, tmpdir):
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
|
||||
|
||||
@registry.register()
|
||||
class MockEmbeddingFunction(EmbeddingFunctionModel):
|
||||
def __call__(self, data):
|
||||
if isinstance(data, str):
|
||||
data = [data]
|
||||
elif isinstance(data, pa.ChunkedArray):
|
||||
data = data.combine_chunks().to_pylist()
|
||||
elif isinstance(data, pa.Array):
|
||||
data = data.to_pylist()
|
||||
@registry.register("test")
|
||||
class MockTextEmbeddingFunction(TextEmbeddingFunction):
|
||||
"""
|
||||
Return the hash of the first 10 characters
|
||||
"""
|
||||
|
||||
return [self.embed(row) for row in data]
|
||||
def generate_embeddings(self, texts):
|
||||
return [self._compute_one_embedding(row) for row in texts]
|
||||
|
||||
def embed(self, row):
|
||||
return [float(hash(c)) for c in row[:10]]
|
||||
def _compute_one_embedding(self, row):
|
||||
emb = np.array([float(hash(c)) for c in row[:10]])
|
||||
emb /= np.linalg.norm(emb)
|
||||
return emb
|
||||
|
||||
@property
|
||||
def ndims(self):
|
||||
return 10
|
||||
|
||||
@@ -22,7 +22,7 @@ import pyarrow as pa
|
||||
from pyarrow import fs
|
||||
|
||||
from .common import DATA, URI
|
||||
from .embeddings import EmbeddingFunctionModel
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
from .pydantic import LanceModel
|
||||
from .table import LanceTable, Table
|
||||
from .util import fs_from_uri, get_uri_location, get_uri_scheme
|
||||
@@ -290,7 +290,7 @@ class LanceDBConnection(DBConnection):
|
||||
mode: str = "create",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionModel]] = None,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
) -> LanceTable:
|
||||
"""Create a table in the database.
|
||||
|
||||
|
||||
@@ -13,10 +13,12 @@
|
||||
|
||||
|
||||
from .functions import (
|
||||
REGISTRY,
|
||||
EmbeddingFunctionModel,
|
||||
EmbeddingFunction,
|
||||
EmbeddingFunctionConfig,
|
||||
EmbeddingFunctionRegistry,
|
||||
SentenceTransformerEmbeddingFunction,
|
||||
TextEmbeddingFunctionModel,
|
||||
OpenAIEmbeddings,
|
||||
OpenClipEmbeddings,
|
||||
SentenceTransformerEmbeddings,
|
||||
TextEmbeddingFunction,
|
||||
)
|
||||
from .utils import with_embeddings
|
||||
|
||||
@@ -10,14 +10,23 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import concurrent.futures
|
||||
import importlib
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import urllib.error
|
||||
import urllib.parse as urlparse
|
||||
import urllib.request
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Union
|
||||
from functools import cached_property
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from cachetools import cached
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class EmbeddingFunctionRegistry:
|
||||
@@ -28,25 +37,33 @@ class EmbeddingFunctionRegistry:
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
return REGISTRY
|
||||
return __REGISTRY__
|
||||
|
||||
def __init__(self):
|
||||
self._functions = {}
|
||||
|
||||
def register(self):
|
||||
def register(self, alias: str = None):
|
||||
"""
|
||||
This creates a decorator that can be used to register
|
||||
an EmbeddingFunctionModel.
|
||||
an EmbeddingFunction.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
alias : Optional[str]
|
||||
a human friendly name for the embedding function. If not
|
||||
provided, the class name will be used.
|
||||
"""
|
||||
|
||||
# This is a decorator for a class that inherits from BaseModel
|
||||
# It adds the class to the registry
|
||||
def decorator(cls):
|
||||
if not issubclass(cls, EmbeddingFunctionModel):
|
||||
raise TypeError("Must be a subclass of EmbeddingFunctionModel")
|
||||
if not issubclass(cls, EmbeddingFunction):
|
||||
raise TypeError("Must be a subclass of EmbeddingFunction")
|
||||
if cls.__name__ in self._functions:
|
||||
raise KeyError(f"{cls.__name__} was already registered")
|
||||
self._functions[cls.__name__] = cls
|
||||
key = alias or cls.__name__
|
||||
self._functions[key] = cls
|
||||
cls.__embedding_function_registry_alias__ = alias
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
@@ -57,13 +74,22 @@ class EmbeddingFunctionRegistry:
|
||||
"""
|
||||
self._functions = {}
|
||||
|
||||
def load(self, name: str):
|
||||
def get(self, name: str):
|
||||
"""
|
||||
Fetch an embedding function class by name
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The name of the embedding function to fetch
|
||||
Either the alias or the class name if no alias was provided
|
||||
during registration
|
||||
"""
|
||||
return self._functions[name]
|
||||
|
||||
def parse_functions(self, metadata: Optional[dict]) -> dict:
|
||||
def parse_functions(
|
||||
self, metadata: Optional[Dict[bytes, bytes]]
|
||||
) -> Dict[str, "EmbeddingFunctionConfig"]:
|
||||
"""
|
||||
Parse the metadata from an arrow table and
|
||||
return a mapping of the vector column to the
|
||||
@@ -71,9 +97,9 @@ class EmbeddingFunctionRegistry:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metadata : Optional[dict]
|
||||
metadata : Optional[Dict[bytes, bytes]]
|
||||
The metadata from an arrow table. Note that
|
||||
the keys and values are bytes.
|
||||
the keys and values are bytes (pyarrow api)
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -86,68 +112,91 @@ class EmbeddingFunctionRegistry:
|
||||
return {}
|
||||
serialized = metadata[b"embedding_functions"]
|
||||
raw_list = json.loads(serialized.decode("utf-8"))
|
||||
functions = {}
|
||||
for obj in raw_list:
|
||||
model = self.load(obj["schema"]["title"])
|
||||
functions[obj["model"]["vector_column"]] = model(**obj["model"])
|
||||
return functions
|
||||
return {
|
||||
obj["vector_column"]: EmbeddingFunctionConfig(
|
||||
vector_column=obj["vector_column"],
|
||||
source_column=obj["source_column"],
|
||||
function=self.get(obj["name"])(**obj["model"]),
|
||||
)
|
||||
for obj in raw_list
|
||||
}
|
||||
|
||||
def function_to_metadata(self, func):
|
||||
def function_to_metadata(self, conf: "EmbeddingFunctionConfig"):
|
||||
"""
|
||||
Convert the given embedding function and source / vector column configs
|
||||
into a config dictionary that can be serialized into arrow metadata
|
||||
"""
|
||||
schema = func.model_json_schema()
|
||||
func = conf.function
|
||||
name = getattr(
|
||||
func, "__embedding_function_registry_alias__", func.__class__.__name__
|
||||
)
|
||||
json_data = func.model_dump()
|
||||
return {
|
||||
"schema": schema,
|
||||
"name": name,
|
||||
"model": json_data,
|
||||
"source_column": conf.source_column,
|
||||
"vector_column": conf.vector_column,
|
||||
}
|
||||
|
||||
def get_table_metadata(self, func_list):
|
||||
"""
|
||||
Convert a list of embedding functions and source / vector column configs
|
||||
Convert a list of embedding functions and source / vector configs
|
||||
into a config dictionary that can be serialized into arrow metadata
|
||||
"""
|
||||
if func_list is None or len(func_list) == 0:
|
||||
return None
|
||||
json_data = [self.function_to_metadata(func) for func in func_list]
|
||||
# Note that metadata dictionary values must be bytes so we need to json dump then utf8 encode
|
||||
# Note that metadata dictionary values must be bytes
|
||||
# so we need to json dump then utf8 encode
|
||||
metadata = json.dumps(json_data, indent=2).encode("utf-8")
|
||||
return {"embedding_functions": metadata}
|
||||
|
||||
|
||||
REGISTRY = EmbeddingFunctionRegistry()
|
||||
|
||||
|
||||
class EmbeddingFunctionModel(BaseModel, ABC):
|
||||
"""
|
||||
A callable ABC for embedding functions
|
||||
"""
|
||||
|
||||
source_column: Optional[str]
|
||||
vector_column: str
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, *args, **kwargs) -> List[np.array]:
|
||||
pass
|
||||
# Global instance
|
||||
__REGISTRY__ = EmbeddingFunctionRegistry()
|
||||
|
||||
|
||||
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
IMAGES = Union[
|
||||
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
|
||||
]
|
||||
|
||||
|
||||
class TextEmbeddingFunctionModel(EmbeddingFunctionModel):
|
||||
class EmbeddingFunction(BaseModel, ABC):
|
||||
"""
|
||||
A callable ABC for embedding functions that take text as input
|
||||
An ABC for embedding functions.
|
||||
|
||||
The API has two methods:
|
||||
1. compute_query_embeddings() which takes a query and returns a list of embeddings
|
||||
2. get_source_embeddings() which returns a list of embeddings for the source column
|
||||
For text data, the two will be the same. For multi-modal data, the source column
|
||||
might be images and the vector column might be text.
|
||||
"""
|
||||
|
||||
def __call__(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
texts = self.sanitize_input(texts)
|
||||
return self.generate_embeddings(texts)
|
||||
@classmethod
|
||||
def create(cls, **kwargs):
|
||||
"""
|
||||
Create an instance of the embedding function
|
||||
"""
|
||||
return cls(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for the source column in the database
|
||||
"""
|
||||
pass
|
||||
|
||||
def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function. This is called
|
||||
before generate_embeddings() and is useful for stripping
|
||||
whitespace, lowercasing, etc.
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
@@ -157,6 +206,71 @@ class TextEmbeddingFunctionModel(EmbeddingFunctionModel):
|
||||
texts = texts.combine_chunks().to_pylist()
|
||||
return texts
|
||||
|
||||
@classmethod
|
||||
def safe_import(cls, module: str, mitigation=None):
|
||||
"""
|
||||
Import the specified module. If the module is not installed,
|
||||
raise an ImportError with a helpful message.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
module : str
|
||||
The name of the module to import
|
||||
mitigation : Optional[str]
|
||||
The package(s) to install to mitigate the error.
|
||||
If not provided then the module name will be used.
|
||||
"""
|
||||
try:
|
||||
return importlib.import_module(module)
|
||||
except ImportError:
|
||||
raise ImportError(f"Please install {mitigation or module}")
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ndims(self):
|
||||
"""
|
||||
Return the dimensions of the vector column
|
||||
"""
|
||||
pass
|
||||
|
||||
def SourceField(self, **kwargs):
|
||||
"""
|
||||
Return a pydantic Field that can automatically indicate
|
||||
the source column for this embedding function
|
||||
"""
|
||||
return Field(json_schema_extra={"source_column_for": self}, **kwargs)
|
||||
|
||||
def VectorField(self, **kwargs):
|
||||
"""
|
||||
Return a pydantic Field that can automatically indicate
|
||||
the target vector column for this embedding function
|
||||
"""
|
||||
return Field(json_schema_extra={"vector_column_for": self}, **kwargs)
|
||||
|
||||
|
||||
class EmbeddingFunctionConfig(BaseModel):
|
||||
"""
|
||||
This is a dataclass that holds the embedding function
|
||||
and source column for a vector column
|
||||
"""
|
||||
|
||||
vector_column: str
|
||||
source_column: str
|
||||
function: EmbeddingFunction
|
||||
|
||||
|
||||
class TextEmbeddingFunction(EmbeddingFunction):
|
||||
"""
|
||||
A callable ABC for embedding functions that take text as input
|
||||
"""
|
||||
|
||||
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
return self.compute_source_embeddings(query, *args, **kwargs)
|
||||
|
||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
texts = self.sanitize_input(texts)
|
||||
return self.generate_embeddings(texts)
|
||||
|
||||
@abstractmethod
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
@@ -167,15 +281,20 @@ class TextEmbeddingFunctionModel(EmbeddingFunctionModel):
|
||||
pass
|
||||
|
||||
|
||||
@REGISTRY.register()
|
||||
class SentenceTransformerEmbeddingFunction(TextEmbeddingFunctionModel):
|
||||
register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name)
|
||||
|
||||
|
||||
@register("sentence-transformers")
|
||||
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the sentence-transformers library
|
||||
|
||||
https://huggingface.co/sentence-transformers
|
||||
"""
|
||||
|
||||
name: str = "all-MiniLM-L6-v2"
|
||||
device: str = "cpu"
|
||||
normalize: bool = False
|
||||
normalize: bool = True
|
||||
|
||||
@property
|
||||
def embedding_model(self):
|
||||
@@ -186,6 +305,10 @@ class SentenceTransformerEmbeddingFunction(TextEmbeddingFunctionModel):
|
||||
"""
|
||||
return self.__class__.get_embedding_model(self.name, self.device)
|
||||
|
||||
@cached_property
|
||||
def ndims(self):
|
||||
return len(self.generate_embeddings(["foo"])[0])
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
@@ -220,9 +343,197 @@ class SentenceTransformerEmbeddingFunction(TextEmbeddingFunctionModel):
|
||||
|
||||
TODO: use lru_cache instead with a reasonable/configurable maxsize
|
||||
"""
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
sentence_transformers = cls.safe_import(
|
||||
"sentence_transformers", "sentence-transformers"
|
||||
)
|
||||
return sentence_transformers.SentenceTransformer(name, device=device)
|
||||
|
||||
return SentenceTransformer(name, device=device)
|
||||
except ImportError:
|
||||
raise ValueError("Please install sentence_transformers")
|
||||
|
||||
@register("openai")
|
||||
class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the OpenAI API
|
||||
|
||||
https://platform.openai.com/docs/guides/embeddings
|
||||
"""
|
||||
|
||||
name: str = "text-embedding-ada-002"
|
||||
|
||||
@property
|
||||
def ndims(self):
|
||||
# TODO don't hardcode this
|
||||
return 1536
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
# TODO retry, rate limit, token limit
|
||||
openai = self.safe_import("openai")
|
||||
rs = openai.Embedding.create(input=texts, model=self.name)["data"]
|
||||
return [v["embedding"] for v in rs]
|
||||
|
||||
|
||||
@register("open-clip")
|
||||
class OpenClipEmbeddings(EmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the OpenClip API
|
||||
For multi-modal text-to-image search
|
||||
|
||||
https://github.com/mlfoundations/open_clip
|
||||
"""
|
||||
|
||||
name: str = "ViT-B-32"
|
||||
pretrained: str = "laion2b_s34b_b79k"
|
||||
device: str = "cpu"
|
||||
batch_size: int = 64
|
||||
normalize: bool = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
open_clip = self.safe_import("open_clip", "open-clip")
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||
self.name, pretrained=self.pretrained
|
||||
)
|
||||
model.to(self.device)
|
||||
self._model, self._preprocess = model, preprocess
|
||||
self._tokenizer = open_clip.get_tokenizer(self.name)
|
||||
|
||||
@cached_property
|
||||
def ndims(self):
|
||||
return self.generate_text_embeddings("foo").shape[0]
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : Union[str, PIL.Image.Image]
|
||||
The query to embed. A query can be either text or an image.
|
||||
"""
|
||||
if isinstance(query, str):
|
||||
return [self.generate_text_embeddings(query)]
|
||||
else:
|
||||
PIL = self.safe_import("PIL", "pillow")
|
||||
if isinstance(query, PIL.Image.Image):
|
||||
return [self.generate_image_embedding(query)]
|
||||
else:
|
||||
raise TypeError("OpenClip supports str or PIL Image as query")
|
||||
|
||||
def generate_text_embeddings(self, text: str) -> np.ndarray:
|
||||
torch = self.safe_import("torch")
|
||||
text = self.sanitize_input(text)
|
||||
text = self._tokenizer(text)
|
||||
text.to(self.device)
|
||||
with torch.no_grad():
|
||||
text_features = self._model.encode_text(text.to(self.device))
|
||||
if self.normalize:
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
return text_features.cpu().numpy().squeeze()
|
||||
|
||||
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(images, (str, bytes)):
|
||||
images = [images]
|
||||
elif isinstance(images, pa.Array):
|
||||
images = images.to_pylist()
|
||||
elif isinstance(images, pa.ChunkedArray):
|
||||
images = images.combine_chunks().to_pylist()
|
||||
return images
|
||||
|
||||
def compute_source_embeddings(
|
||||
self, images: IMAGES, *args, **kwargs
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given images
|
||||
"""
|
||||
images = self.sanitize_input(images)
|
||||
embeddings = []
|
||||
for i in range(0, len(images), self.batch_size):
|
||||
j = min(i + self.batch_size, len(images))
|
||||
batch = images[i:j]
|
||||
embeddings.extend(self._parallel_get(batch))
|
||||
return embeddings
|
||||
|
||||
def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
|
||||
"""
|
||||
Issue concurrent requests to retrieve the image data
|
||||
"""
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(self.generate_image_embedding, image)
|
||||
for image in images
|
||||
]
|
||||
return [future.result() for future in futures]
|
||||
|
||||
def generate_image_embedding(
|
||||
self, image: Union[str, bytes, "PIL.Image.Image"]
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate the embedding for a single image
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image : Union[str, bytes, PIL.Image.Image]
|
||||
The image to embed. If the image is a str, it is treated as a uri.
|
||||
If the image is bytes, it is treated as the raw image bytes.
|
||||
"""
|
||||
torch = self.safe_import("torch")
|
||||
# TODO handle retry and errors for https
|
||||
image = self._to_pil(image)
|
||||
image = self._preprocess(image).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
return self._encode_and_normalize_image(image)
|
||||
|
||||
def _to_pil(self, image: Union[str, bytes]):
|
||||
PIL = self.safe_import("PIL", "pillow")
|
||||
if isinstance(image, bytes):
|
||||
return PIL.Image.open(io.BytesIO(image))
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
return image
|
||||
elif isinstance(image, str):
|
||||
parsed = urlparse.urlparse(image)
|
||||
# TODO handle drive letter on windows.
|
||||
if parsed.scheme == "file":
|
||||
return PIL.Image.open(parsed.path)
|
||||
elif parsed.scheme == "":
|
||||
return PIL.Image.open(image if os.name == "nt" else parsed.path)
|
||||
elif parsed.scheme.startswith("http"):
|
||||
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
|
||||
else:
|
||||
raise NotImplementedError("Only local and http(s) urls are supported")
|
||||
|
||||
def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"):
|
||||
"""
|
||||
encode a single image tensor and optionally normalize the output
|
||||
"""
|
||||
image_features = self._model.encode_image(image_tensor)
|
||||
if self.normalize:
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
return image_features.cpu().numpy().squeeze()
|
||||
|
||||
|
||||
def url_retrieve(url: str):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
url: str
|
||||
URL to download from
|
||||
"""
|
||||
try:
|
||||
with urllib.request.urlopen(url) as conn:
|
||||
return conn.read()
|
||||
except (socket.gaierror, urllib.error.URLError) as err:
|
||||
raise ConnectionError("could not download {} due to {}".format(url, err))
|
||||
|
||||
@@ -26,6 +26,8 @@ import pyarrow as pa
|
||||
import pydantic
|
||||
import semver
|
||||
|
||||
from .embeddings import EmbeddingFunctionRegistry
|
||||
|
||||
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
|
||||
try:
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
@@ -290,13 +292,49 @@ class LanceModel(pydantic.BaseModel):
|
||||
"""
|
||||
Get the Arrow Schema for this model.
|
||||
"""
|
||||
return pydantic_to_schema(cls)
|
||||
schema = pydantic_to_schema(cls)
|
||||
functions = cls.parse_embedding_functions()
|
||||
if len(functions) > 0:
|
||||
metadata = EmbeddingFunctionRegistry.get_instance().get_table_metadata(
|
||||
functions
|
||||
)
|
||||
schema = schema.with_metadata(metadata)
|
||||
return schema
|
||||
|
||||
@classmethod
|
||||
def field_names(cls) -> List[str]:
|
||||
"""
|
||||
Get the field names of this model.
|
||||
"""
|
||||
return list(cls.safe_get_fields().keys())
|
||||
|
||||
@classmethod
|
||||
def safe_get_fields(cls):
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
return list(cls.__fields__.keys())
|
||||
return list(cls.model_fields.keys())
|
||||
return cls.__fields__
|
||||
return cls.model_fields
|
||||
|
||||
@classmethod
|
||||
def parse_embedding_functions(cls) -> List["EmbeddingFunctionConfig"]:
|
||||
"""
|
||||
Parse the embedding functions from this model.
|
||||
"""
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
|
||||
vec_and_function = []
|
||||
for name, field_info in cls.safe_get_fields().items():
|
||||
func = (field_info.json_schema_extra or {}).get("vector_column_for")
|
||||
if func is not None:
|
||||
vec_and_function.append([name, func])
|
||||
|
||||
configs = []
|
||||
for vec, func in vec_and_function:
|
||||
for source, field_info in cls.safe_get_fields().items():
|
||||
src_func = (field_info.json_schema_extra or {}).get("source_column_for")
|
||||
if src_func == func:
|
||||
configs.append(
|
||||
EmbeddingFunctionConfig(
|
||||
source_column=source, vector_column=vec, function=func
|
||||
)
|
||||
)
|
||||
return configs
|
||||
|
||||
@@ -60,13 +60,15 @@ class LanceQueryBuilder(ABC):
|
||||
def create(
|
||||
cls,
|
||||
table: "lancedb.table.Table",
|
||||
query: Optional[Union[np.ndarray, str]],
|
||||
query: Optional[Union[np.ndarray, str, "PIL.Image.Image"]],
|
||||
query_type: str,
|
||||
vector_column_name: str,
|
||||
) -> LanceQueryBuilder:
|
||||
if query is None:
|
||||
return LanceEmptyQueryBuilder(table)
|
||||
|
||||
# convert "auto" query_type to "vector" or "fts"
|
||||
# and convert the query to vector if needed
|
||||
query, query_type = cls._resolve_query(
|
||||
table, query, query_type, vector_column_name
|
||||
)
|
||||
@@ -90,30 +92,27 @@ class LanceQueryBuilder(ABC):
|
||||
# otherwise raise TypeError
|
||||
if query_type == "fts":
|
||||
if not isinstance(query, str):
|
||||
raise TypeError(
|
||||
f"Query type is 'fts' but query is not a string: {type(query)}"
|
||||
)
|
||||
raise TypeError(f"'fts' queries must be a string: {type(query)}")
|
||||
return query, query_type
|
||||
elif query_type == "vector":
|
||||
# If query_type is vector, then query must be a list or np.ndarray.
|
||||
# otherwise raise TypeError
|
||||
if not isinstance(query, (list, np.ndarray)):
|
||||
raise TypeError(
|
||||
f"Query type is 'vector' but query is not a list or np.ndarray: {type(query)}"
|
||||
)
|
||||
conf = table.embedding_functions.get(vector_column_name)
|
||||
if conf is not None:
|
||||
query = conf.function.compute_query_embeddings(query)[0]
|
||||
else:
|
||||
msg = f"No embedding function for {vector_column_name}"
|
||||
raise ValueError(msg)
|
||||
return query, query_type
|
||||
elif query_type == "auto":
|
||||
if isinstance(query, (list, np.ndarray)):
|
||||
return query, "vector"
|
||||
elif isinstance(query, str):
|
||||
func = table.embedding_functions.get(vector_column_name, None)
|
||||
if func is not None:
|
||||
query = func(query)[0]
|
||||
else:
|
||||
conf = table.embedding_functions.get(vector_column_name)
|
||||
if conf is not None:
|
||||
query = conf.function.compute_query_embeddings(query)[0]
|
||||
return query, "vector"
|
||||
else:
|
||||
return query, "fts"
|
||||
else:
|
||||
raise TypeError("Query must be a list, np.ndarray, or str")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}"
|
||||
@@ -238,7 +237,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
def __init__(
|
||||
self,
|
||||
table: "lancedb.table.Table",
|
||||
query: Union[np.ndarray, list],
|
||||
query: Union[np.ndarray, list, "PIL.Image.Image"],
|
||||
vector_column: str = VECTOR_COLUMN_NAME,
|
||||
):
|
||||
super().__init__(table)
|
||||
|
||||
@@ -18,10 +18,9 @@ from urllib.parse import urlparse
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from lancedb.common import DATA
|
||||
from lancedb.db import DBConnection
|
||||
from lancedb.table import Table, _sanitize_data
|
||||
|
||||
from ..common import DATA
|
||||
from ..db import DBConnection
|
||||
from ..table import Table, _sanitize_data
|
||||
from .arrow import to_ipc_binary
|
||||
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
|
||||
|
||||
|
||||
@@ -28,7 +28,8 @@ from lance.dataset import ReaderLike
|
||||
from lance.vector import vec_to_table
|
||||
|
||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from .embeddings import EmbeddingFunctionModel, EmbeddingFunctionRegistry
|
||||
from .embeddings import EmbeddingFunctionRegistry
|
||||
from .embeddings.functions import EmbeddingFunctionConfig
|
||||
from .pydantic import LanceModel
|
||||
from .query import LanceQueryBuilder, Query
|
||||
from .util import fs_from_uri, safe_import_pandas
|
||||
@@ -81,15 +82,16 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
|
||||
vector column to the table.
|
||||
"""
|
||||
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
|
||||
for vector_col, func in functions.items():
|
||||
if vector_col not in data.column_names:
|
||||
col_data = func(data[func.source_column])
|
||||
for vector_column, conf in functions.items():
|
||||
func = conf.function
|
||||
if vector_column not in data.column_names:
|
||||
col_data = func.compute_source_embeddings(data[conf.source_column])
|
||||
if schema is not None:
|
||||
dtype = schema.field(vector_col).type
|
||||
dtype = schema.field(vector_column).type
|
||||
else:
|
||||
dtype = pa.list_(pa.float32(), len(col_data[0]))
|
||||
data = data.append_column(
|
||||
pa.field(vector_col, type=dtype), pa.array(col_data, type=dtype)
|
||||
pa.field(vector_column, type=dtype), pa.array(col_data, type=dtype)
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -230,7 +232,7 @@ class Table(ABC):
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str]] = None,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
query_type: str = "auto",
|
||||
) -> LanceQueryBuilder:
|
||||
@@ -239,7 +241,7 @@ class Table(ABC):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: str, list, np.ndarray, default None
|
||||
query: str, list, np.ndarray, PIL.Image.Image, default None
|
||||
The query to search for. If None then
|
||||
the select/where/limit clauses are applied to filter
|
||||
the table
|
||||
@@ -249,6 +251,8 @@ class Table(ABC):
|
||||
"vector", "fts", or "auto"
|
||||
If "auto" then the query type is inferred from the query;
|
||||
If `query` is a list/np.ndarray then the query type is "vector";
|
||||
If `query` is a PIL.Image.Image then either do vector search
|
||||
or raise an error if no corresponding embedding function is found.
|
||||
If `query` is a string, then the query type is "vector" if the
|
||||
table has embedding functions else the query type is "fts"
|
||||
|
||||
@@ -524,6 +528,9 @@ class LanceTable(Table):
|
||||
fill_value: float = 0.0,
|
||||
):
|
||||
"""Add data to the table.
|
||||
If vector columns are missing and the table
|
||||
has embedding functions, then the vector columns
|
||||
are automatically computed and added.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -617,12 +624,6 @@ class LanceTable(Table):
|
||||
)
|
||||
self._reset_dataset()
|
||||
|
||||
def _get_embedding_function_for_source_col(self, column_name: str):
|
||||
for k, v in self.embedding_functions.items():
|
||||
if v.source_column == column_name:
|
||||
return v
|
||||
return None
|
||||
|
||||
@cached_property
|
||||
def embedding_functions(self) -> dict:
|
||||
"""
|
||||
@@ -640,7 +641,7 @@ class LanceTable(Table):
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str]] = None,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
query_type: str = "auto",
|
||||
) -> LanceQueryBuilder:
|
||||
@@ -649,7 +650,7 @@ class LanceTable(Table):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: str, list, np.ndarray, or None
|
||||
query: str, list, np.ndarray, a PIL Image or None
|
||||
The query to search for. If None then
|
||||
the select/where/limit clauses are applied to filter
|
||||
the table
|
||||
@@ -658,9 +659,11 @@ class LanceTable(Table):
|
||||
query_type: str, default "auto"
|
||||
"vector", "fts", or "auto"
|
||||
If "auto" then the query type is inferred from the query;
|
||||
If the query is a list/np.ndarray then the query type is "vector";
|
||||
If `query` is a list/np.ndarray then the query type is "vector";
|
||||
If `query` is a PIL.Image.Image then either do vector search
|
||||
or raise an error if no corresponding embedding function is found.
|
||||
If the query is a string, then the query type is "vector" if the
|
||||
table has embedding functions else the query type is "fts"
|
||||
table has embedding functions, else the query type is "fts"
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -684,7 +687,7 @@ class LanceTable(Table):
|
||||
mode="create",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: List[EmbeddingFunctionModel] = None,
|
||||
embedding_functions: List[EmbeddingFunctionConfig] = None,
|
||||
):
|
||||
"""
|
||||
Create a new table.
|
||||
@@ -727,10 +730,16 @@ class LanceTable(Table):
|
||||
"""
|
||||
tbl = LanceTable(db, name)
|
||||
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
||||
# convert LanceModel to pyarrow schema
|
||||
# note that it's possible this contains
|
||||
# embedding function metadata already
|
||||
schema = schema.to_arrow_schema()
|
||||
|
||||
metadata = None
|
||||
if embedding_functions is not None:
|
||||
# If we passed in embedding functions explicitly
|
||||
# then we'll override any schema metadata that
|
||||
# may was implicitly specified by the LanceModel schema
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
metadata = registry.get_table_metadata(embedding_functions)
|
||||
|
||||
|
||||
@@ -44,9 +44,11 @@ classifiers = [
|
||||
repository = "https://github.com/lancedb/lancedb"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio"]
|
||||
tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests"]
|
||||
dev = ["ruff", "pre-commit", "black"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip"]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools", "wheel"]
|
||||
@@ -54,3 +56,10 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--strict-markers"
|
||||
markers = [
|
||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||
"asyncio"
|
||||
]
|
||||
@@ -16,8 +16,12 @@ import lance
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
from lancedb.conftest import MockEmbeddingFunction
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry, with_embeddings
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.embeddings import (
|
||||
EmbeddingFunctionConfig,
|
||||
EmbeddingFunctionRegistry,
|
||||
with_embeddings,
|
||||
)
|
||||
|
||||
|
||||
def mock_embed_func(input_data):
|
||||
@@ -54,8 +58,12 @@ def test_embedding_function(tmp_path):
|
||||
"vector": [np.random.randn(10), np.random.randn(10)],
|
||||
}
|
||||
)
|
||||
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
||||
metadata = registry.get_table_metadata([func])
|
||||
conf = EmbeddingFunctionConfig(
|
||||
source_column="text",
|
||||
vector_column="vector",
|
||||
function=MockTextEmbeddingFunction(),
|
||||
)
|
||||
metadata = registry.get_table_metadata([conf])
|
||||
table = table.replace_schema_metadata(metadata)
|
||||
|
||||
# Write it to disk
|
||||
@@ -65,14 +73,13 @@ def test_embedding_function(tmp_path):
|
||||
ds = lance.dataset(tmp_path / "test.lance")
|
||||
|
||||
# can we get the serialized version back out?
|
||||
functions = registry.parse_functions(ds.schema.metadata)
|
||||
configs = registry.parse_functions(ds.schema.metadata)
|
||||
|
||||
func = functions["vector"]
|
||||
actual = func("hello world")
|
||||
conf = configs["vector"]
|
||||
func = conf.function
|
||||
actual = func.compute_query_embeddings("hello world")
|
||||
|
||||
# We create an instance
|
||||
expected_func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
||||
# And we make sure we can call it
|
||||
expected = expected_func("hello world")
|
||||
expected = func.compute_query_embeddings("hello world")
|
||||
|
||||
assert np.allclose(actual, expected)
|
||||
|
||||
125
python/tests/test_embeddings_slow.py
Normal file
125
python/tests/test_embeddings_slow.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
import lancedb
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
# These are integration tests for embedding functions.
|
||||
# They are slow because they require downloading models
|
||||
# or connection to external api
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
|
||||
def test_sentence_transformer(alias, tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
func = registry.get(alias).create()
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
vector: Vector(func.ndims) = func.VectorField()
|
||||
|
||||
table = db.create_table("words", schema=Words)
|
||||
table.add(
|
||||
pd.DataFrame(
|
||||
{
|
||||
"text": [
|
||||
"hello world",
|
||||
"goodbye world",
|
||||
"fizz",
|
||||
"buzz",
|
||||
"foo",
|
||||
"bar",
|
||||
"baz",
|
||||
]
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
query = "greetings"
|
||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||
|
||||
vec = func.compute_query_embeddings(query)[0]
|
||||
expected = table.search(vec).limit(1).to_pydantic(Words)[0]
|
||||
assert actual.text == expected.text
|
||||
assert actual.text == "hello world"
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_openclip(tmp_path):
|
||||
from PIL import Image
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
func = registry.get("open-clip").create()
|
||||
|
||||
class Images(LanceModel):
|
||||
label: str
|
||||
image_uri: str = func.SourceField()
|
||||
image_bytes: bytes = func.SourceField()
|
||||
vector: Vector(func.ndims) = func.VectorField()
|
||||
vec_from_bytes: Vector(func.ndims) = func.VectorField()
|
||||
|
||||
table = db.create_table("images", schema=Images)
|
||||
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||
]
|
||||
# get each uri as bytes
|
||||
image_bytes = [requests.get(uri).content for uri in uris]
|
||||
table.add(
|
||||
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": image_bytes})
|
||||
)
|
||||
|
||||
# text search
|
||||
actual = table.search("man's best friend").limit(1).to_pydantic(Images)[0]
|
||||
assert actual.label == "dog"
|
||||
frombytes = (
|
||||
table.search("man's best friend", vector_column_name="vec_from_bytes")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == frombytes.label
|
||||
assert np.allclose(actual.vector, frombytes.vector)
|
||||
|
||||
# image search
|
||||
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
|
||||
image_bytes = requests.get(query_image_uri).content
|
||||
query_image = Image.open(io.BytesIO(image_bytes))
|
||||
actual = table.search(query_image).limit(1).to_pydantic(Images)[0]
|
||||
assert actual.label == "dog"
|
||||
other = (
|
||||
table.search(query_image, vector_column_name="vec_from_bytes")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == other.label
|
||||
|
||||
arrow_table = table.search().select(["vector", "vec_from_bytes"]).to_arrow()
|
||||
assert np.allclose(
|
||||
arrow_table["vector"].combine_chunks().values.to_numpy(),
|
||||
arrow_table["vec_from_bytes"].combine_chunks().values.to_numpy(),
|
||||
)
|
||||
@@ -22,8 +22,9 @@ import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
from lancedb.conftest import MockEmbeddingFunction
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
@@ -356,20 +357,23 @@ def test_create_with_embedding_function(db):
|
||||
text: str
|
||||
vector: Vector(10)
|
||||
|
||||
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
||||
func = MockTextEmbeddingFunction()
|
||||
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
||||
df = pd.DataFrame({"text": texts, "vector": func(texts)})
|
||||
df = pd.DataFrame({"text": texts, "vector": func.compute_source_embeddings(texts)})
|
||||
|
||||
conf = EmbeddingFunctionConfig(
|
||||
source_column="text", vector_column="vector", function=func
|
||||
)
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
embedding_functions=[func],
|
||||
embedding_functions=[conf],
|
||||
)
|
||||
table.add(df)
|
||||
|
||||
query_str = "hi how are you?"
|
||||
query_vector = func(query_str)[0]
|
||||
query_vector = func.compute_query_embeddings(query_str)[0]
|
||||
expected = table.search(query_vector).limit(2).to_arrow()
|
||||
|
||||
actual = table.search(query_str).limit(2).to_arrow()
|
||||
@@ -377,17 +381,13 @@ def test_create_with_embedding_function(db):
|
||||
|
||||
|
||||
def test_add_with_embedding_function(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
vector: Vector(10)
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
|
||||
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
embedding_functions=[func],
|
||||
)
|
||||
class MyTable(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
vector: Vector(emb.ndims) = emb.VectorField()
|
||||
|
||||
table = LanceTable.create(db, "my_table", schema=MyTable)
|
||||
|
||||
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
||||
df = pd.DataFrame({"text": texts})
|
||||
@@ -397,7 +397,7 @@ def test_add_with_embedding_function(db):
|
||||
table.add([{"text": t} for t in texts])
|
||||
|
||||
query_str = "hi how are you?"
|
||||
query_vector = func(query_str)[0]
|
||||
query_vector = emb.compute_query_embeddings(query_str)[0]
|
||||
expected = table.search(query_vector).limit(2).to_arrow()
|
||||
|
||||
actual = table.search(query_str).limit(2).to_arrow()
|
||||
|
||||
Reference in New Issue
Block a user