diff --git a/python/python/lancedb/embeddings/open_clip.py b/python/python/lancedb/embeddings/open_clip.py index 3df1d1fb..811bfacc 100644 --- a/python/python/lancedb/embeddings/open_clip.py +++ b/python/python/lancedb/embeddings/open_clip.py @@ -149,7 +149,6 @@ class OpenClipEmbeddings(EmbeddingFunction): The image to embed. If the image is a str, it is treated as a uri. If the image is bytes, it is treated as the raw image bytes. """ - import pdb; pdb.set_trace() torch = attempt_import_or_raise("torch") # TODO handle retry and errors for https image = self._to_pil(image) diff --git a/python/python/lancedb/pydantic.py b/python/python/lancedb/pydantic.py index 696ea7ef..66300401 100644 --- a/python/python/lancedb/pydantic.py +++ b/python/python/lancedb/pydantic.py @@ -27,7 +27,9 @@ from typing import ( Callable, Dict, Generator, + Iterable, List, + Tuple, Type, Union, _GenericAlias, @@ -40,6 +42,7 @@ import semver from lance.arrow import ( EncodedImageType, ) +from lance.util import _check_huggingface from pydantic.fields import FieldInfo from pydantic_core import core_schema @@ -192,6 +195,9 @@ def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType: elif getattr(py_type, "__origin__", None) in (list, tuple): child = py_type.__args__[0] return pa.list_(_py_type_to_arrow_type(child, field)) + elif _safe_is_huggingface_image(): + import datasets + raise TypeError( f"Converting Pydantic type to Arrow Type: unsupported type {py_type}." ) @@ -289,19 +295,18 @@ def EncodedImage(): if PYDANTIC_VERSION.major < 2: - - def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]: - return [ - _pydantic_to_field(name, field) for name, field in model.__fields__.items() - ] - + def _safe_get_fields(model: pydantic.BaseModel): + return model.__fields__ else: + def _safe_get_fields(model: pydantic.BaseModel): + return model.model_fields - def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]: - return [ - _pydantic_to_field(name, field) - for name, field in model.model_fields.items() - ] + +def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]: + return [ + _pydantic_to_field(name, field) + for name, field in _safe_get_fields(model).items() + ] def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType: @@ -440,13 +445,7 @@ class LanceModel(pydantic.BaseModel): """ Get the field names of this model. """ - return list(cls.safe_get_fields().keys()) - - @classmethod - def safe_get_fields(cls): - if PYDANTIC_VERSION.major < 2: - return cls.__fields__ - return cls.model_fields + return list(_safe_get_fields(cls).keys()) @classmethod def parse_embedding_functions(cls) -> List["EmbeddingFunctionConfig"]: @@ -456,14 +455,16 @@ class LanceModel(pydantic.BaseModel): from .embeddings import EmbeddingFunctionConfig vec_and_function = [] - for name, field_info in cls.safe_get_fields().items(): - func = get_extras(field_info, "vector_column_for") + def get_vector_column(name, field_info): + fun = get_extras(field_info, "vector_column_for") if func is not None: vec_and_function.append([name, func]) + visit_fields(_safe_get_fields(cls).items(), get_vector_column) configs = [] + # find the source columns for each one for vec, func in vec_and_function: - for source, field_info in cls.safe_get_fields().items(): + def get_source_column(source, field_info): src_func = get_extras(field_info, "source_column_for") if src_func is func: # note we can't use == here since the function is a pydantic @@ -476,20 +477,48 @@ class LanceModel(pydantic.BaseModel): source_column=source, vector_column=vec, function=func ) ) + visit_fields(_safe_get_fields(cls).items(), get_source_column) return configs + + +def visit_fields(fields: Iterable[Tuple[str, FieldInfo]], + visitor: Callable[[str, FieldInfo], Any]): + """ + Visit all the leaf fields in a Pydantic model. + + Parameters + ---------- + fields : Iterable[Tuple(str, FieldInfo)] + The fields to visit. + visitor : Callable[[str, FieldInfo], Any] + The visitor function. + """ + for name, field_info in fields: + # if the field is a pydantic model then + # visit all subfields + if (isinstance(getattr(field_info, "annotation"), type) + and issubclass(field_info.annotation, pydantic.BaseModel)): + visit_fields(_safe_get_fields(field_info.annotation).items(), + _add_prefix(visitor, name)) + else: + visitor(name, field_info) -def get_extras(field_info: FieldInfo, key: str) -> Any: - """ - Get the extra metadata from a Pydantic FieldInfo. - """ - if PYDANTIC_VERSION.major >= 2: - return (field_info.json_schema_extra or {}).get(key) - return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key) +def _add_prefix(visitor: Callable[[str, FieldInfo], Any], prefix: str) -> Callable[[str, FieldInfo], Any]: + def prefixed_visitor(name: str, field: FieldInfo): + return visitor(f"{prefix}.{name}", field) + return prefixed_visitor if PYDANTIC_VERSION.major < 2: + def get_extras(field_info: FieldInfo, key: str) -> Any: + """ + Get the extra metadata from a Pydantic FieldInfo. + """ + return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key) + + def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]: """ Convert a Pydantic model to a dictionary. @@ -498,6 +527,13 @@ if PYDANTIC_VERSION.major < 2: else: + def get_extras(field_info: FieldInfo, key: str) -> Any: + """ + Get the extra metadata from a Pydantic FieldInfo. + """ + return (field_info.json_schema_extra or {}).get(key) + + def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]: """ Convert a Pydantic model to a dictionary. diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index c4297480..cfef5960 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -75,6 +75,7 @@ def _sanitize_data( on_bad_vectors: str, fill_value: Any, ): + import pdb; pdb.set_trace() if _check_for_hugging_face(data): # Huggingface datasets import datasets @@ -83,7 +84,6 @@ def _sanitize_data( if schema is None: schema = data.features.arrow_schema data = data.data.to_batches() - import pdb; pdb.set_trace() elif isinstance(data, list): # convert to list of dict if data is a bunch of LanceModels if isinstance(data[0], LanceModel):