diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 6b0b6b6e..84ec462f 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -66,3 +66,33 @@ jobs: run: black --check --diff --no-color --quiet . - name: Run tests run: pytest -m "not slow" -x -v --durations=30 tests + pydantic1x: + timeout-minutes: 30 + runs-on: "ubuntu-22.04" + defaults: + run: + shell: bash + working-directory: python + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + lfs: true + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: 3.9 + - name: Install lancedb + run: | + pip install "pydantic<2" + pip install -e .[tests] + pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985 + pip install pytest pytest-mock black isort + - name: Black + run: black --check --diff --no-color --quiet . + - name: isort + run: isort --check --diff --quiet . + - name: Run tests + run: pytest -m "not slow" -x -v --durations=30 tests + # - name: doctest + # run: pytest --doctest-modules lancedb \ No newline at end of file diff --git a/docs/src/python/python.md b/docs/src/python/python.md index f11caa87..a72f9bd6 100644 --- a/docs/src/python/python.md +++ b/docs/src/python/python.md @@ -28,14 +28,6 @@ pip install lancedb ::: lancedb.embeddings.with_embeddings -::: lancedb.embeddings.functions.EmbeddingFunctionRegistry - -::: lancedb.embeddings.functions.EmbeddingFunctionModel - -::: lancedb.embeddings.functions.TextEmbeddingFunctionModel - -::: lancedb.embeddings.functions.SentenceTransformerEmbeddingFunction - ## Context ::: lancedb.context.contextualize diff --git a/python/lancedb/conftest.py b/python/lancedb/conftest.py index 716c6de8..a88e967f 100644 --- a/python/lancedb/conftest.py +++ b/python/lancedb/conftest.py @@ -36,6 +36,5 @@ class MockTextEmbeddingFunction(TextEmbeddingFunction): emb /= np.linalg.norm(emb) return emb - @property def ndims(self): return 10 diff --git a/python/lancedb/embeddings/functions.py b/python/lancedb/embeddings/functions.py index e8683695..4fb763c5 100644 --- a/python/lancedb/embeddings/functions.py +++ b/python/lancedb/embeddings/functions.py @@ -20,19 +20,37 @@ import urllib.error import urllib.parse as urlparse import urllib.request from abc import ABC, abstractmethod -from functools import cached_property from typing import Dict, List, Optional, Union import numpy as np import pyarrow as pa from cachetools import cached -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, PrivateAttr class EmbeddingFunctionRegistry: """ This is a singleton class used to register embedding functions - and fetch them by name. It also handles serializing and deserializing + and fetch them by name. It also handles serializing and deserializing. + You can implement your own embedding function by subclassing EmbeddingFunction + or TextEmbeddingFunction and registering it with the registry. + + Examples + -------- + >>> registry = EmbeddingFunctionRegistry.get_instance() + >>> @registry.register("my-embedding-function") + ... class MyEmbeddingFunction(EmbeddingFunction): + ... def ndims(self) -> int: + ... return 128 + ... + ... def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]: + ... return self.compute_source_embeddings(query, *args, **kwargs) + ... + ... def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]: + ... return [np.random.rand(self.ndims()) for _ in range(len(texts))] + ... + >>> registry.get("my-embedding-function") + """ @classmethod @@ -130,7 +148,7 @@ class EmbeddingFunctionRegistry: name = getattr( func, "__embedding_function_registry_alias__", func.__class__.__name__ ) - json_data = func.model_dump() + json_data = func.safe_model_dump() return { "name": name, "model": json_data, @@ -166,13 +184,16 @@ class EmbeddingFunction(BaseModel, ABC): """ An ABC for embedding functions. - The API has two methods: + All concrete embedding functions must implement the following: 1. compute_query_embeddings() which takes a query and returns a list of embeddings 2. get_source_embeddings() which returns a list of embeddings for the source column For text data, the two will be the same. For multi-modal data, the source column might be images and the vector column might be text. + 3. ndims method which returns the number of dimensions of the vector column """ + _ndims: int = PrivateAttr() + @classmethod def create(cls, **kwargs): """ @@ -225,7 +246,13 @@ class EmbeddingFunction(BaseModel, ABC): except ImportError: raise ImportError(f"Please install {mitigation or module}") - @property + def safe_model_dump(self): + from ..pydantic import PYDANTIC_VERSION + + if PYDANTIC_VERSION.major < 2: + return dict(self) + return self.model_dump() + @abstractmethod def ndims(self): """ @@ -235,14 +262,14 @@ class EmbeddingFunction(BaseModel, ABC): def SourceField(self, **kwargs): """ - Return a pydantic Field that can automatically indicate + Creates a pydantic Field that can automatically annotate the source column for this embedding function """ return Field(json_schema_extra={"source_column_for": self}, **kwargs) def VectorField(self, **kwargs): """ - Return a pydantic Field that can automatically indicate + Creates a pydantic Field that can automatically annotate the target vector column for this embedding function """ return Field(json_schema_extra={"vector_column_for": self}, **kwargs) @@ -250,8 +277,9 @@ class EmbeddingFunction(BaseModel, ABC): class EmbeddingFunctionConfig(BaseModel): """ - This is a dataclass that holds the embedding function - and source column for a vector column + This model encapsulates the configuration for a embedding function + in a lancedb table. It holds the embedding function, the source column, + and the vector column """ vector_column: str @@ -281,6 +309,7 @@ class TextEmbeddingFunction(EmbeddingFunction): pass +# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8 register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name) @@ -296,6 +325,10 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction): device: str = "cpu" normalize: bool = True + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._ndims = None + @property def embedding_model(self): """ @@ -305,9 +338,10 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction): """ return self.__class__.get_embedding_model(self.name, self.device) - @cached_property def ndims(self): - return len(self.generate_embeddings(["foo"])[0]) + if self._ndims is None: + self._ndims = len(self.generate_embeddings("foo")[0]) + return self._ndims def generate_embeddings( self, texts: Union[List[str], np.ndarray] @@ -359,7 +393,6 @@ class OpenAIEmbeddings(TextEmbeddingFunction): name: str = "text-embedding-ada-002" - @property def ndims(self): # TODO don't hardcode this return 1536 @@ -395,6 +428,9 @@ class OpenClipEmbeddings(EmbeddingFunction): device: str = "cpu" batch_size: int = 64 normalize: bool = True + _model = PrivateAttr() + _preprocess = PrivateAttr() + _tokenizer = PrivateAttr() def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -405,10 +441,12 @@ class OpenClipEmbeddings(EmbeddingFunction): model.to(self.device) self._model, self._preprocess = model, preprocess self._tokenizer = open_clip.get_tokenizer(self.name) + self._ndims = None - @cached_property def ndims(self): - return self.generate_text_embeddings("foo").shape[0] + if self._ndims is None: + self._ndims = self.generate_text_embeddings("foo").shape[0] + return self._ndims def compute_query_embeddings( self, query: Union[str, "PIL.Image.Image"], *args, **kwargs diff --git a/python/lancedb/pydantic.py b/python/lancedb/pydantic.py index 46b882e7..958b8a83 100644 --- a/python/lancedb/pydantic.py +++ b/python/lancedb/pydantic.py @@ -323,14 +323,14 @@ class LanceModel(pydantic.BaseModel): vec_and_function = [] for name, field_info in cls.safe_get_fields().items(): - func = (field_info.json_schema_extra or {}).get("vector_column_for") + func = get_extras(field_info, "vector_column_for") if func is not None: vec_and_function.append([name, func]) configs = [] for vec, func in vec_and_function: for source, field_info in cls.safe_get_fields().items(): - src_func = (field_info.json_schema_extra or {}).get("source_column_for") + src_func = get_extras(field_info, "source_column_for") if src_func == func: configs.append( EmbeddingFunctionConfig( @@ -338,3 +338,12 @@ class LanceModel(pydantic.BaseModel): ) ) return configs + + +def get_extras(field_info: pydantic.fields.FieldInfo, key: str) -> Any: + """ + Get the extra metadata from a Pydantic FieldInfo. + """ + if PYDANTIC_VERSION.major >= 2: + return (field_info.json_schema_extra or {}).get(key) + return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key) diff --git a/python/tests/test_db.py b/python/tests/test_db.py index 86c1ae05..bae8848a 100644 --- a/python/tests/test_db.py +++ b/python/tests/test_db.py @@ -136,11 +136,9 @@ def test_ingest_iterator(tmp_path): def run_tests(schema): db = lancedb.connect(tmp_path) tbl = db.create_table("table2", make_batches(), schema=schema, mode="overwrite") - tbl.to_pandas() assert tbl.search([3.1, 4.1]).limit(1).to_df()["_distance"][0] == 0.0 assert tbl.search([5.9, 26.5]).limit(1).to_df()["_distance"][0] == 0.0 - tbl_len = len(tbl) tbl.add(make_batches()) assert tbl_len == 50 diff --git a/python/tests/test_embeddings_slow.py b/python/tests/test_embeddings_slow.py index 8960f587..92692d8c 100644 --- a/python/tests/test_embeddings_slow.py +++ b/python/tests/test_embeddings_slow.py @@ -35,7 +35,7 @@ def test_sentence_transformer(alias, tmp_path): class Words(LanceModel): text: str = func.SourceField() - vector: Vector(func.ndims) = func.VectorField() + vector: Vector(func.ndims()) = func.VectorField() table = db.create_table("words", schema=Words) table.add( @@ -75,8 +75,8 @@ def test_openclip(tmp_path): label: str image_uri: str = func.SourceField() image_bytes: bytes = func.SourceField() - vector: Vector(func.ndims) = func.VectorField() - vec_from_bytes: Vector(func.ndims) = func.VectorField() + vector: Vector(func.ndims()) = func.VectorField() + vec_from_bytes: Vector(func.ndims()) = func.VectorField() table = db.create_table("images", schema=Images) labels = ["cat", "cat", "dog", "dog", "horse", "horse"] diff --git a/python/tests/test_table.py b/python/tests/test_table.py index 92129e9c..c1655601 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -385,7 +385,7 @@ def test_add_with_embedding_function(db): class MyTable(LanceModel): text: str = emb.SourceField() - vector: Vector(emb.ndims) = emb.VectorField() + vector: Vector(emb.ndims()) = emb.VectorField() table = LanceTable.create(db, "my_table", schema=MyTable) diff --git a/rust/vectordb/src/database.rs b/rust/vectordb/src/database.rs index 92c30d72..3bc57d1d 100644 --- a/rust/vectordb/src/database.rs +++ b/rust/vectordb/src/database.rs @@ -231,6 +231,7 @@ impl Database { #[cfg(test)] mod tests { use std::fs::create_dir_all; + use tempfile::tempdir; use crate::database::Database;