ci: fix docs build (#496)

python/python.md contains typos in the class references

---------

Co-authored-by: Chang She <chang@lancedb.com>
This commit is contained in:
Chang She
2023-09-18 16:07:21 -04:00
committed by GitHub
parent bc38abb781
commit c21f9cdda0
9 changed files with 99 additions and 32 deletions

View File

@@ -36,6 +36,5 @@ class MockTextEmbeddingFunction(TextEmbeddingFunction):
emb /= np.linalg.norm(emb)
return emb
@property
def ndims(self):
return 10

View File

@@ -20,19 +20,37 @@ import urllib.error
import urllib.parse as urlparse
import urllib.request
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Dict, List, Optional, Union
import numpy as np
import pyarrow as pa
from cachetools import cached
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PrivateAttr
class EmbeddingFunctionRegistry:
"""
This is a singleton class used to register embedding functions
and fetch them by name. It also handles serializing and deserializing
and fetch them by name. It also handles serializing and deserializing.
You can implement your own embedding function by subclassing EmbeddingFunction
or TextEmbeddingFunction and registering it with the registry.
Examples
--------
>>> registry = EmbeddingFunctionRegistry.get_instance()
>>> @registry.register("my-embedding-function")
... class MyEmbeddingFunction(EmbeddingFunction):
... def ndims(self) -> int:
... return 128
...
... def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
... return self.compute_source_embeddings(query, *args, **kwargs)
...
... def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
... return [np.random.rand(self.ndims()) for _ in range(len(texts))]
...
>>> registry.get("my-embedding-function")
<class 'lancedb.embeddings.functions.MyEmbeddingFunction'>
"""
@classmethod
@@ -130,7 +148,7 @@ class EmbeddingFunctionRegistry:
name = getattr(
func, "__embedding_function_registry_alias__", func.__class__.__name__
)
json_data = func.model_dump()
json_data = func.safe_model_dump()
return {
"name": name,
"model": json_data,
@@ -166,13 +184,16 @@ class EmbeddingFunction(BaseModel, ABC):
"""
An ABC for embedding functions.
The API has two methods:
All concrete embedding functions must implement the following:
1. compute_query_embeddings() which takes a query and returns a list of embeddings
2. get_source_embeddings() which returns a list of embeddings for the source column
For text data, the two will be the same. For multi-modal data, the source column
might be images and the vector column might be text.
3. ndims method which returns the number of dimensions of the vector column
"""
_ndims: int = PrivateAttr()
@classmethod
def create(cls, **kwargs):
"""
@@ -225,7 +246,13 @@ class EmbeddingFunction(BaseModel, ABC):
except ImportError:
raise ImportError(f"Please install {mitigation or module}")
@property
def safe_model_dump(self):
from ..pydantic import PYDANTIC_VERSION
if PYDANTIC_VERSION.major < 2:
return dict(self)
return self.model_dump()
@abstractmethod
def ndims(self):
"""
@@ -235,14 +262,14 @@ class EmbeddingFunction(BaseModel, ABC):
def SourceField(self, **kwargs):
"""
Return a pydantic Field that can automatically indicate
Creates a pydantic Field that can automatically annotate
the source column for this embedding function
"""
return Field(json_schema_extra={"source_column_for": self}, **kwargs)
def VectorField(self, **kwargs):
"""
Return a pydantic Field that can automatically indicate
Creates a pydantic Field that can automatically annotate
the target vector column for this embedding function
"""
return Field(json_schema_extra={"vector_column_for": self}, **kwargs)
@@ -250,8 +277,9 @@ class EmbeddingFunction(BaseModel, ABC):
class EmbeddingFunctionConfig(BaseModel):
"""
This is a dataclass that holds the embedding function
and source column for a vector column
This model encapsulates the configuration for a embedding function
in a lancedb table. It holds the embedding function, the source column,
and the vector column
"""
vector_column: str
@@ -281,6 +309,7 @@ class TextEmbeddingFunction(EmbeddingFunction):
pass
# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8
register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name)
@@ -296,6 +325,10 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
device: str = "cpu"
normalize: bool = True
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._ndims = None
@property
def embedding_model(self):
"""
@@ -305,9 +338,10 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
"""
return self.__class__.get_embedding_model(self.name, self.device)
@cached_property
def ndims(self):
return len(self.generate_embeddings(["foo"])[0])
if self._ndims is None:
self._ndims = len(self.generate_embeddings("foo")[0])
return self._ndims
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
@@ -359,7 +393,6 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
name: str = "text-embedding-ada-002"
@property
def ndims(self):
# TODO don't hardcode this
return 1536
@@ -395,6 +428,9 @@ class OpenClipEmbeddings(EmbeddingFunction):
device: str = "cpu"
batch_size: int = 64
normalize: bool = True
_model = PrivateAttr()
_preprocess = PrivateAttr()
_tokenizer = PrivateAttr()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -405,10 +441,12 @@ class OpenClipEmbeddings(EmbeddingFunction):
model.to(self.device)
self._model, self._preprocess = model, preprocess
self._tokenizer = open_clip.get_tokenizer(self.name)
self._ndims = None
@cached_property
def ndims(self):
return self.generate_text_embeddings("foo").shape[0]
if self._ndims is None:
self._ndims = self.generate_text_embeddings("foo").shape[0]
return self._ndims
def compute_query_embeddings(
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs

View File

@@ -323,14 +323,14 @@ class LanceModel(pydantic.BaseModel):
vec_and_function = []
for name, field_info in cls.safe_get_fields().items():
func = (field_info.json_schema_extra or {}).get("vector_column_for")
func = get_extras(field_info, "vector_column_for")
if func is not None:
vec_and_function.append([name, func])
configs = []
for vec, func in vec_and_function:
for source, field_info in cls.safe_get_fields().items():
src_func = (field_info.json_schema_extra or {}).get("source_column_for")
src_func = get_extras(field_info, "source_column_for")
if src_func == func:
configs.append(
EmbeddingFunctionConfig(
@@ -338,3 +338,12 @@ class LanceModel(pydantic.BaseModel):
)
)
return configs
def get_extras(field_info: pydantic.fields.FieldInfo, key: str) -> Any:
"""
Get the extra metadata from a Pydantic FieldInfo.
"""
if PYDANTIC_VERSION.major >= 2:
return (field_info.json_schema_extra or {}).get(key)
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)

View File

@@ -136,11 +136,9 @@ def test_ingest_iterator(tmp_path):
def run_tests(schema):
db = lancedb.connect(tmp_path)
tbl = db.create_table("table2", make_batches(), schema=schema, mode="overwrite")
tbl.to_pandas()
assert tbl.search([3.1, 4.1]).limit(1).to_df()["_distance"][0] == 0.0
assert tbl.search([5.9, 26.5]).limit(1).to_df()["_distance"][0] == 0.0
tbl_len = len(tbl)
tbl.add(make_batches())
assert tbl_len == 50

View File

@@ -35,7 +35,7 @@ def test_sentence_transformer(alias, tmp_path):
class Words(LanceModel):
text: str = func.SourceField()
vector: Vector(func.ndims) = func.VectorField()
vector: Vector(func.ndims()) = func.VectorField()
table = db.create_table("words", schema=Words)
table.add(
@@ -75,8 +75,8 @@ def test_openclip(tmp_path):
label: str
image_uri: str = func.SourceField()
image_bytes: bytes = func.SourceField()
vector: Vector(func.ndims) = func.VectorField()
vec_from_bytes: Vector(func.ndims) = func.VectorField()
vector: Vector(func.ndims()) = func.VectorField()
vec_from_bytes: Vector(func.ndims()) = func.VectorField()
table = db.create_table("images", schema=Images)
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]

View File

@@ -385,7 +385,7 @@ def test_add_with_embedding_function(db):
class MyTable(LanceModel):
text: str = emb.SourceField()
vector: Vector(emb.ndims) = emb.VectorField()
vector: Vector(emb.ndims()) = emb.VectorField()
table = LanceTable.create(db, "my_table", schema=MyTable)