This commit is contained in:
Chang She
2024-03-12 18:45:21 -07:00
parent c1dfad675a
commit 86c9bc0d2d
3 changed files with 65 additions and 30 deletions

View File

@@ -149,7 +149,6 @@ class OpenClipEmbeddings(EmbeddingFunction):
The image to embed. If the image is a str, it is treated as a uri.
If the image is bytes, it is treated as the raw image bytes.
"""
import pdb; pdb.set_trace()
torch = attempt_import_or_raise("torch")
# TODO handle retry and errors for https
image = self._to_pil(image)

View File

@@ -27,7 +27,9 @@ from typing import (
Callable,
Dict,
Generator,
Iterable,
List,
Tuple,
Type,
Union,
_GenericAlias,
@@ -40,6 +42,7 @@ import semver
from lance.arrow import (
EncodedImageType,
)
from lance.util import _check_huggingface
from pydantic.fields import FieldInfo
from pydantic_core import core_schema
@@ -192,6 +195,9 @@ def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
elif getattr(py_type, "__origin__", None) in (list, tuple):
child = py_type.__args__[0]
return pa.list_(_py_type_to_arrow_type(child, field))
elif _safe_is_huggingface_image():
import datasets
raise TypeError(
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}."
)
@@ -289,19 +295,18 @@ def EncodedImage():
if PYDANTIC_VERSION.major < 2:
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
return [
_pydantic_to_field(name, field) for name, field in model.__fields__.items()
]
def _safe_get_fields(model: pydantic.BaseModel):
return model.__fields__
else:
def _safe_get_fields(model: pydantic.BaseModel):
return model.model_fields
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
return [
_pydantic_to_field(name, field)
for name, field in model.model_fields.items()
]
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
return [
_pydantic_to_field(name, field)
for name, field in _safe_get_fields(model).items()
]
def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
@@ -440,13 +445,7 @@ class LanceModel(pydantic.BaseModel):
"""
Get the field names of this model.
"""
return list(cls.safe_get_fields().keys())
@classmethod
def safe_get_fields(cls):
if PYDANTIC_VERSION.major < 2:
return cls.__fields__
return cls.model_fields
return list(_safe_get_fields(cls).keys())
@classmethod
def parse_embedding_functions(cls) -> List["EmbeddingFunctionConfig"]:
@@ -456,14 +455,16 @@ class LanceModel(pydantic.BaseModel):
from .embeddings import EmbeddingFunctionConfig
vec_and_function = []
for name, field_info in cls.safe_get_fields().items():
func = get_extras(field_info, "vector_column_for")
def get_vector_column(name, field_info):
fun = get_extras(field_info, "vector_column_for")
if func is not None:
vec_and_function.append([name, func])
visit_fields(_safe_get_fields(cls).items(), get_vector_column)
configs = []
# find the source columns for each one
for vec, func in vec_and_function:
for source, field_info in cls.safe_get_fields().items():
def get_source_column(source, field_info):
src_func = get_extras(field_info, "source_column_for")
if src_func is func:
# note we can't use == here since the function is a pydantic
@@ -476,20 +477,48 @@ class LanceModel(pydantic.BaseModel):
source_column=source, vector_column=vec, function=func
)
)
visit_fields(_safe_get_fields(cls).items(), get_source_column)
return configs
def visit_fields(fields: Iterable[Tuple[str, FieldInfo]],
visitor: Callable[[str, FieldInfo], Any]):
"""
Visit all the leaf fields in a Pydantic model.
Parameters
----------
fields : Iterable[Tuple(str, FieldInfo)]
The fields to visit.
visitor : Callable[[str, FieldInfo], Any]
The visitor function.
"""
for name, field_info in fields:
# if the field is a pydantic model then
# visit all subfields
if (isinstance(getattr(field_info, "annotation"), type)
and issubclass(field_info.annotation, pydantic.BaseModel)):
visit_fields(_safe_get_fields(field_info.annotation).items(),
_add_prefix(visitor, name))
else:
visitor(name, field_info)
def get_extras(field_info: FieldInfo, key: str) -> Any:
"""
Get the extra metadata from a Pydantic FieldInfo.
"""
if PYDANTIC_VERSION.major >= 2:
return (field_info.json_schema_extra or {}).get(key)
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
def _add_prefix(visitor: Callable[[str, FieldInfo], Any], prefix: str) -> Callable[[str, FieldInfo], Any]:
def prefixed_visitor(name: str, field: FieldInfo):
return visitor(f"{prefix}.{name}", field)
return prefixed_visitor
if PYDANTIC_VERSION.major < 2:
def get_extras(field_info: FieldInfo, key: str) -> Any:
"""
Get the extra metadata from a Pydantic FieldInfo.
"""
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
"""
Convert a Pydantic model to a dictionary.
@@ -498,6 +527,13 @@ if PYDANTIC_VERSION.major < 2:
else:
def get_extras(field_info: FieldInfo, key: str) -> Any:
"""
Get the extra metadata from a Pydantic FieldInfo.
"""
return (field_info.json_schema_extra or {}).get(key)
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
"""
Convert a Pydantic model to a dictionary.

View File

@@ -75,6 +75,7 @@ def _sanitize_data(
on_bad_vectors: str,
fill_value: Any,
):
import pdb; pdb.set_trace()
if _check_for_hugging_face(data):
# Huggingface datasets
import datasets
@@ -83,7 +84,6 @@ def _sanitize_data(
if schema is None:
schema = data.features.arrow_schema
data = data.data.to_batches()
import pdb; pdb.set_trace()
elif isinstance(data, list):
# convert to list of dict if data is a bunch of LanceModels
if isinstance(data[0], LanceModel):