mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-10 05:42:58 +00:00
stuff
This commit is contained in:
@@ -149,7 +149,6 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
||||
The image to embed. If the image is a str, it is treated as a uri.
|
||||
If the image is bytes, it is treated as the raw image bytes.
|
||||
"""
|
||||
import pdb; pdb.set_trace()
|
||||
torch = attempt_import_or_raise("torch")
|
||||
# TODO handle retry and errors for https
|
||||
image = self._to_pil(image)
|
||||
|
||||
@@ -27,7 +27,9 @@ from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
_GenericAlias,
|
||||
@@ -40,6 +42,7 @@ import semver
|
||||
from lance.arrow import (
|
||||
EncodedImageType,
|
||||
)
|
||||
from lance.util import _check_huggingface
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import core_schema
|
||||
|
||||
@@ -192,6 +195,9 @@ def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
|
||||
elif getattr(py_type, "__origin__", None) in (list, tuple):
|
||||
child = py_type.__args__[0]
|
||||
return pa.list_(_py_type_to_arrow_type(child, field))
|
||||
elif _safe_is_huggingface_image():
|
||||
import datasets
|
||||
|
||||
raise TypeError(
|
||||
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}."
|
||||
)
|
||||
@@ -289,19 +295,18 @@ def EncodedImage():
|
||||
|
||||
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
|
||||
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
|
||||
return [
|
||||
_pydantic_to_field(name, field) for name, field in model.__fields__.items()
|
||||
]
|
||||
|
||||
def _safe_get_fields(model: pydantic.BaseModel):
|
||||
return model.__fields__
|
||||
else:
|
||||
def _safe_get_fields(model: pydantic.BaseModel):
|
||||
return model.model_fields
|
||||
|
||||
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
|
||||
return [
|
||||
_pydantic_to_field(name, field)
|
||||
for name, field in model.model_fields.items()
|
||||
]
|
||||
|
||||
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
|
||||
return [
|
||||
_pydantic_to_field(name, field)
|
||||
for name, field in _safe_get_fields(model).items()
|
||||
]
|
||||
|
||||
|
||||
def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
|
||||
@@ -440,13 +445,7 @@ class LanceModel(pydantic.BaseModel):
|
||||
"""
|
||||
Get the field names of this model.
|
||||
"""
|
||||
return list(cls.safe_get_fields().keys())
|
||||
|
||||
@classmethod
|
||||
def safe_get_fields(cls):
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
return cls.__fields__
|
||||
return cls.model_fields
|
||||
return list(_safe_get_fields(cls).keys())
|
||||
|
||||
@classmethod
|
||||
def parse_embedding_functions(cls) -> List["EmbeddingFunctionConfig"]:
|
||||
@@ -456,14 +455,16 @@ class LanceModel(pydantic.BaseModel):
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
|
||||
vec_and_function = []
|
||||
for name, field_info in cls.safe_get_fields().items():
|
||||
func = get_extras(field_info, "vector_column_for")
|
||||
def get_vector_column(name, field_info):
|
||||
fun = get_extras(field_info, "vector_column_for")
|
||||
if func is not None:
|
||||
vec_and_function.append([name, func])
|
||||
visit_fields(_safe_get_fields(cls).items(), get_vector_column)
|
||||
|
||||
configs = []
|
||||
# find the source columns for each one
|
||||
for vec, func in vec_and_function:
|
||||
for source, field_info in cls.safe_get_fields().items():
|
||||
def get_source_column(source, field_info):
|
||||
src_func = get_extras(field_info, "source_column_for")
|
||||
if src_func is func:
|
||||
# note we can't use == here since the function is a pydantic
|
||||
@@ -476,20 +477,48 @@ class LanceModel(pydantic.BaseModel):
|
||||
source_column=source, vector_column=vec, function=func
|
||||
)
|
||||
)
|
||||
visit_fields(_safe_get_fields(cls).items(), get_source_column)
|
||||
return configs
|
||||
|
||||
|
||||
def visit_fields(fields: Iterable[Tuple[str, FieldInfo]],
|
||||
visitor: Callable[[str, FieldInfo], Any]):
|
||||
"""
|
||||
Visit all the leaf fields in a Pydantic model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fields : Iterable[Tuple(str, FieldInfo)]
|
||||
The fields to visit.
|
||||
visitor : Callable[[str, FieldInfo], Any]
|
||||
The visitor function.
|
||||
"""
|
||||
for name, field_info in fields:
|
||||
# if the field is a pydantic model then
|
||||
# visit all subfields
|
||||
if (isinstance(getattr(field_info, "annotation"), type)
|
||||
and issubclass(field_info.annotation, pydantic.BaseModel)):
|
||||
visit_fields(_safe_get_fields(field_info.annotation).items(),
|
||||
_add_prefix(visitor, name))
|
||||
else:
|
||||
visitor(name, field_info)
|
||||
|
||||
|
||||
def get_extras(field_info: FieldInfo, key: str) -> Any:
|
||||
"""
|
||||
Get the extra metadata from a Pydantic FieldInfo.
|
||||
"""
|
||||
if PYDANTIC_VERSION.major >= 2:
|
||||
return (field_info.json_schema_extra or {}).get(key)
|
||||
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
|
||||
def _add_prefix(visitor: Callable[[str, FieldInfo], Any], prefix: str) -> Callable[[str, FieldInfo], Any]:
|
||||
def prefixed_visitor(name: str, field: FieldInfo):
|
||||
return visitor(f"{prefix}.{name}", field)
|
||||
return prefixed_visitor
|
||||
|
||||
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
|
||||
def get_extras(field_info: FieldInfo, key: str) -> Any:
|
||||
"""
|
||||
Get the extra metadata from a Pydantic FieldInfo.
|
||||
"""
|
||||
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
|
||||
|
||||
|
||||
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a Pydantic model to a dictionary.
|
||||
@@ -498,6 +527,13 @@ if PYDANTIC_VERSION.major < 2:
|
||||
|
||||
else:
|
||||
|
||||
def get_extras(field_info: FieldInfo, key: str) -> Any:
|
||||
"""
|
||||
Get the extra metadata from a Pydantic FieldInfo.
|
||||
"""
|
||||
return (field_info.json_schema_extra or {}).get(key)
|
||||
|
||||
|
||||
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a Pydantic model to a dictionary.
|
||||
|
||||
@@ -75,6 +75,7 @@ def _sanitize_data(
|
||||
on_bad_vectors: str,
|
||||
fill_value: Any,
|
||||
):
|
||||
import pdb; pdb.set_trace()
|
||||
if _check_for_hugging_face(data):
|
||||
# Huggingface datasets
|
||||
import datasets
|
||||
@@ -83,7 +84,6 @@ def _sanitize_data(
|
||||
if schema is None:
|
||||
schema = data.features.arrow_schema
|
||||
data = data.data.to_batches()
|
||||
import pdb; pdb.set_trace()
|
||||
elif isinstance(data, list):
|
||||
# convert to list of dict if data is a bunch of LanceModels
|
||||
if isinstance(data[0], LanceModel):
|
||||
|
||||
Reference in New Issue
Block a user