mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-06 20:02:58 +00:00
ci: fix docs build (#496)
python/python.md contains typos in the class references --------- Co-authored-by: Chang She <chang@lancedb.com>
This commit is contained in:
@@ -36,6 +36,5 @@ class MockTextEmbeddingFunction(TextEmbeddingFunction):
|
||||
emb /= np.linalg.norm(emb)
|
||||
return emb
|
||||
|
||||
@property
|
||||
def ndims(self):
|
||||
return 10
|
||||
|
||||
@@ -20,19 +20,37 @@ import urllib.error
|
||||
import urllib.parse as urlparse
|
||||
import urllib.request
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from cachetools import cached
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
|
||||
class EmbeddingFunctionRegistry:
|
||||
"""
|
||||
This is a singleton class used to register embedding functions
|
||||
and fetch them by name. It also handles serializing and deserializing
|
||||
and fetch them by name. It also handles serializing and deserializing.
|
||||
You can implement your own embedding function by subclassing EmbeddingFunction
|
||||
or TextEmbeddingFunction and registering it with the registry.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> registry = EmbeddingFunctionRegistry.get_instance()
|
||||
>>> @registry.register("my-embedding-function")
|
||||
... class MyEmbeddingFunction(EmbeddingFunction):
|
||||
... def ndims(self) -> int:
|
||||
... return 128
|
||||
...
|
||||
... def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
... return self.compute_source_embeddings(query, *args, **kwargs)
|
||||
...
|
||||
... def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
... return [np.random.rand(self.ndims()) for _ in range(len(texts))]
|
||||
...
|
||||
>>> registry.get("my-embedding-function")
|
||||
<class 'lancedb.embeddings.functions.MyEmbeddingFunction'>
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@@ -130,7 +148,7 @@ class EmbeddingFunctionRegistry:
|
||||
name = getattr(
|
||||
func, "__embedding_function_registry_alias__", func.__class__.__name__
|
||||
)
|
||||
json_data = func.model_dump()
|
||||
json_data = func.safe_model_dump()
|
||||
return {
|
||||
"name": name,
|
||||
"model": json_data,
|
||||
@@ -166,13 +184,16 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
"""
|
||||
An ABC for embedding functions.
|
||||
|
||||
The API has two methods:
|
||||
All concrete embedding functions must implement the following:
|
||||
1. compute_query_embeddings() which takes a query and returns a list of embeddings
|
||||
2. get_source_embeddings() which returns a list of embeddings for the source column
|
||||
For text data, the two will be the same. For multi-modal data, the source column
|
||||
might be images and the vector column might be text.
|
||||
3. ndims method which returns the number of dimensions of the vector column
|
||||
"""
|
||||
|
||||
_ndims: int = PrivateAttr()
|
||||
|
||||
@classmethod
|
||||
def create(cls, **kwargs):
|
||||
"""
|
||||
@@ -225,7 +246,13 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
except ImportError:
|
||||
raise ImportError(f"Please install {mitigation or module}")
|
||||
|
||||
@property
|
||||
def safe_model_dump(self):
|
||||
from ..pydantic import PYDANTIC_VERSION
|
||||
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
return dict(self)
|
||||
return self.model_dump()
|
||||
|
||||
@abstractmethod
|
||||
def ndims(self):
|
||||
"""
|
||||
@@ -235,14 +262,14 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
|
||||
def SourceField(self, **kwargs):
|
||||
"""
|
||||
Return a pydantic Field that can automatically indicate
|
||||
Creates a pydantic Field that can automatically annotate
|
||||
the source column for this embedding function
|
||||
"""
|
||||
return Field(json_schema_extra={"source_column_for": self}, **kwargs)
|
||||
|
||||
def VectorField(self, **kwargs):
|
||||
"""
|
||||
Return a pydantic Field that can automatically indicate
|
||||
Creates a pydantic Field that can automatically annotate
|
||||
the target vector column for this embedding function
|
||||
"""
|
||||
return Field(json_schema_extra={"vector_column_for": self}, **kwargs)
|
||||
@@ -250,8 +277,9 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
|
||||
class EmbeddingFunctionConfig(BaseModel):
|
||||
"""
|
||||
This is a dataclass that holds the embedding function
|
||||
and source column for a vector column
|
||||
This model encapsulates the configuration for a embedding function
|
||||
in a lancedb table. It holds the embedding function, the source column,
|
||||
and the vector column
|
||||
"""
|
||||
|
||||
vector_column: str
|
||||
@@ -281,6 +309,7 @@ class TextEmbeddingFunction(EmbeddingFunction):
|
||||
pass
|
||||
|
||||
|
||||
# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8
|
||||
register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name)
|
||||
|
||||
|
||||
@@ -296,6 +325,10 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
device: str = "cpu"
|
||||
normalize: bool = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._ndims = None
|
||||
|
||||
@property
|
||||
def embedding_model(self):
|
||||
"""
|
||||
@@ -305,9 +338,10 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
return self.__class__.get_embedding_model(self.name, self.device)
|
||||
|
||||
@cached_property
|
||||
def ndims(self):
|
||||
return len(self.generate_embeddings(["foo"])[0])
|
||||
if self._ndims is None:
|
||||
self._ndims = len(self.generate_embeddings("foo")[0])
|
||||
return self._ndims
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
@@ -359,7 +393,6 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
|
||||
name: str = "text-embedding-ada-002"
|
||||
|
||||
@property
|
||||
def ndims(self):
|
||||
# TODO don't hardcode this
|
||||
return 1536
|
||||
@@ -395,6 +428,9 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
||||
device: str = "cpu"
|
||||
batch_size: int = 64
|
||||
normalize: bool = True
|
||||
_model = PrivateAttr()
|
||||
_preprocess = PrivateAttr()
|
||||
_tokenizer = PrivateAttr()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -405,10 +441,12 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
||||
model.to(self.device)
|
||||
self._model, self._preprocess = model, preprocess
|
||||
self._tokenizer = open_clip.get_tokenizer(self.name)
|
||||
self._ndims = None
|
||||
|
||||
@cached_property
|
||||
def ndims(self):
|
||||
return self.generate_text_embeddings("foo").shape[0]
|
||||
if self._ndims is None:
|
||||
self._ndims = self.generate_text_embeddings("foo").shape[0]
|
||||
return self._ndims
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
|
||||
|
||||
@@ -323,14 +323,14 @@ class LanceModel(pydantic.BaseModel):
|
||||
|
||||
vec_and_function = []
|
||||
for name, field_info in cls.safe_get_fields().items():
|
||||
func = (field_info.json_schema_extra or {}).get("vector_column_for")
|
||||
func = get_extras(field_info, "vector_column_for")
|
||||
if func is not None:
|
||||
vec_and_function.append([name, func])
|
||||
|
||||
configs = []
|
||||
for vec, func in vec_and_function:
|
||||
for source, field_info in cls.safe_get_fields().items():
|
||||
src_func = (field_info.json_schema_extra or {}).get("source_column_for")
|
||||
src_func = get_extras(field_info, "source_column_for")
|
||||
if src_func == func:
|
||||
configs.append(
|
||||
EmbeddingFunctionConfig(
|
||||
@@ -338,3 +338,12 @@ class LanceModel(pydantic.BaseModel):
|
||||
)
|
||||
)
|
||||
return configs
|
||||
|
||||
|
||||
def get_extras(field_info: pydantic.fields.FieldInfo, key: str) -> Any:
|
||||
"""
|
||||
Get the extra metadata from a Pydantic FieldInfo.
|
||||
"""
|
||||
if PYDANTIC_VERSION.major >= 2:
|
||||
return (field_info.json_schema_extra or {}).get(key)
|
||||
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
|
||||
|
||||
@@ -136,11 +136,9 @@ def test_ingest_iterator(tmp_path):
|
||||
def run_tests(schema):
|
||||
db = lancedb.connect(tmp_path)
|
||||
tbl = db.create_table("table2", make_batches(), schema=schema, mode="overwrite")
|
||||
|
||||
tbl.to_pandas()
|
||||
assert tbl.search([3.1, 4.1]).limit(1).to_df()["_distance"][0] == 0.0
|
||||
assert tbl.search([5.9, 26.5]).limit(1).to_df()["_distance"][0] == 0.0
|
||||
|
||||
tbl_len = len(tbl)
|
||||
tbl.add(make_batches())
|
||||
assert tbl_len == 50
|
||||
|
||||
@@ -35,7 +35,7 @@ def test_sentence_transformer(alias, tmp_path):
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
vector: Vector(func.ndims) = func.VectorField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = db.create_table("words", schema=Words)
|
||||
table.add(
|
||||
@@ -75,8 +75,8 @@ def test_openclip(tmp_path):
|
||||
label: str
|
||||
image_uri: str = func.SourceField()
|
||||
image_bytes: bytes = func.SourceField()
|
||||
vector: Vector(func.ndims) = func.VectorField()
|
||||
vec_from_bytes: Vector(func.ndims) = func.VectorField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
vec_from_bytes: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = db.create_table("images", schema=Images)
|
||||
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
||||
|
||||
@@ -385,7 +385,7 @@ def test_add_with_embedding_function(db):
|
||||
|
||||
class MyTable(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
vector: Vector(emb.ndims) = emb.VectorField()
|
||||
vector: Vector(emb.ndims()) = emb.VectorField()
|
||||
|
||||
table = LanceTable.create(db, "my_table", schema=MyTable)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user