Compare commits

...

1 Commits

Author SHA1 Message Date
Chang She
96a7c1ab42 feat(python): add Tensor pydantic type
- [x] Can be used to declare data model
- [ ] Can be used to ingest data
2024-02-17 10:29:50 -08:00
5 changed files with 257 additions and 6 deletions

View File

@@ -27,6 +27,7 @@ from typing import (
Dict,
Generator,
List,
Tuple,
Type,
Union,
_GenericAlias,
@@ -37,6 +38,11 @@ import pyarrow as pa
import pydantic
import semver
from lancedb.util import safe_import_tf, safe_import_torch
torch = safe_import_torch()
tf = safe_import_tf()
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
try:
from pydantic_core import CoreSchema, core_schema
@@ -79,9 +85,6 @@ def Vector(
) -> Type[FixedSizeListMixin]:
"""Pydantic Vector Type.
!!! warning
Experimental feature.
Parameters
----------
dim : int
@@ -155,6 +158,142 @@ def Vector(
return FixedSizeList
class FixedShapeTensorMixin(ABC):
@staticmethod
@abstractmethod
def shape() -> Tuple[int]:
raise NotImplementedError
@staticmethod
@abstractmethod
def value_arrow_type() -> pa.DataType:
raise NotImplementedError
def Tensor(
shape: Tuple[int], value_type: pa.DataType = pa.float32()
) -> Type[FixedShapeTensorMixin]:
"""Pydantic Tensor Type.
!!! warning
Experimental feature.
Parameters
----------
shape : tuple of int
The shape of the tensor
value_type : pyarrow.DataType, optional
The value type of the vector, by default pa.float32()
Examples
--------
>>> import pydantic
>>> from lancedb.pydantic import LanceModel, Tensor, Vector
...
>>> class MyModel(LanceModel):
... id: int
... url: str
... tensor: Tensor((3, 3))
... embedding: Vector(768)
>>> schema = pydantic_to_schema(MyModel)
>>> assert schema == pa.schema([
... pa.field("id", pa.int64(), False),
... pa.field("url", pa.utf8(), False),
... pa.field("tensor", pa.fixed_shape_tensor(pa.float32(), (3, 3)), False),
... pa.field("embeddings", pa.list_(pa.float32(), 768), False)
... ])
"""
# TODO: make a public parameterized type.
class FixedShapeTensor(FixedShapeTensorMixin):
def __repr__(self):
return f"FixedShapeTensor(shape={shape})"
@staticmethod
def shape() -> Tuple[int]:
return shape
@staticmethod
def value_arrow_type() -> pa.DataType:
return value_type
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
) -> CoreSchema:
return core_schema.no_info_after_validator_function(
np.asarray,
nested_schema(shape, core_schema.float_schema()),
)
@classmethod
def __get_validators__(cls) -> Generator[Callable, None, None]:
yield cls.validate
# For pydantic v1
@classmethod
def validate(cls, v):
if isinstance(v, list):
v = cls._validate_list(v, shape)
elif isinstance(v, np.ndarray):
v = cls._validate_ndarray(v, shape)
elif torch is not None and isinstance(v, torch.Tensor):
v = cls._validate_torch(v, shape)
elif tf is not None and isinstance(v, tf.Tensor):
v = cls._validate_tf(v, shape)
else:
raise TypeError(
"A list of numbers, numpy.ndarray, torch.Tensor, "
f"or tf.Tensor is needed but got {type(v)} instead."
)
return np.asarray(v)
@classmethod
def _validate_list(cls, v, shape):
v = np.asarray(v)
return cls._validate_ndarray(v, shape)
@classmethod
def _validate_ndarray(cls, v, shape):
if v.shape != shape:
raise ValueError(f"Invalid shape {v.shape}, expected {shape}")
return v
@classmethod
def _validate_torch(cls, v, shape):
v = v.detach().cpu().numpy()
return cls._validate_ndarray(v, shape)
@classmethod
def _validate_tf(cls, v, shape):
v = v.numpy()
return cls._validate_ndarray(v, shape)
if PYDANTIC_VERSION < (2, 0):
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any], field):
if field and field.sub_fields:
type_with_potential_subtype = f"np.ndarray[{field.sub_fields[0]}]"
else:
type_with_potential_subtype = "np.ndarray"
field_schema.update({"type": type_with_potential_subtype})
return FixedShapeTensor
def nested_schema(shape, items_schema):
if len(shape) == 0:
return items_schema
else:
return core_schema.list_schema(
min_length=shape[0],
max_length=shape[0],
items_schema=nested_schema(shape[1:], items_schema),
)
def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
"""Convert a field with native Python type to Arrow data type.
@@ -230,6 +369,10 @@ def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
return pa.struct(fields)
elif issubclass(field.annotation, FixedSizeListMixin):
return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim())
elif issubclass(field.annotation, FixedShapeTensorMixin):
return pa.fixed_shape_tensor(
field.annotation.value_arrow_type(), field.annotation.shape()
)
return _py_type_to_arrow_type(field.annotation, field)

View File

@@ -1568,7 +1568,7 @@ def _sanitize_schema(
# is a vector column. This is definitely a bit hacky.
likely_vector_col = (
pa.types.is_fixed_size_list(field.type)
and pa.types.is_float32(field.type.value_type)
and pa.types.is_floating(field.type.value_type)
and field.type.list_size >= 10
)
is_default_vector_col = field.name == VECTOR_COLUMN_NAME
@@ -1581,6 +1581,11 @@ def _sanitize_schema(
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
is_tensor_type = isinstance(field.type, pa.FixedShapeTensorType)
if is_tensor_type and field.name in data.column_names:
data = _sanitize_tensor_column(data, column_name=field.name)
return pa.Table.from_arrays(
[data[name] for name in schema.names], schema=schema
)
@@ -1649,6 +1654,31 @@ def _sanitize_vector_column(
return data
def _sanitize_tensor_column(data: pa.Table, column_name: str) -> pa.Table:
"""
Ensure that the tensor column exists and has type tensor(float32)
Parameters
----------
data: pa.Table
The table to sanitize.
column_name: str
The name of the tensor column.
"""
# ChunkedArray is annoying to work with, so we combine chunks here
tensor_arr = data[column_name].combine_chunks()
typ = data[column_name].type
if not isinstance(typ, pa.FixedShapeTensorType):
raise TypeError(f"Unsupported tensor column type: {tensor_arr.type}")
tensor_arr = ensure_tensor(tensor_arr)
data = data.set_column(
data.column_names.index(column_name), column_name, tensor_arr
)
return data
def ensure_fixed_size_list(vec_arr) -> pa.FixedSizeListArray:
values = vec_arr.values
if not (pa.types.is_float16(values.type) or pa.types.is_float32(values.type)):
@@ -1661,6 +1691,11 @@ def ensure_fixed_size_list(vec_arr) -> pa.FixedSizeListArray:
return vec_arr
def ensure_tensor(tensor_arr) -> pa.TensorArray:
assert 0 == 1
return tensor_arr
def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
"""Sanitize jagged vectors."""
if on_bad_vectors == "error":

View File

@@ -153,6 +153,24 @@ def safe_import_polars():
return None
def safe_import_torch():
try:
import torch
return torch
except ImportError:
return None
def safe_import_tf():
try:
import tensorflow as tf
return tf
except ImportError:
return None
def inf_vector_column_query(schema: pa.Schema) -> str:
"""
Get the vector column name

View File

@@ -22,7 +22,13 @@ import pydantic
import pytest
from pydantic import Field
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
from lancedb.pydantic import (
PYDANTIC_VERSION,
LanceModel,
Tensor,
Vector,
pydantic_to_schema,
)
@pytest.mark.skipif(
@@ -244,3 +250,37 @@ def test_lance_model():
t = TestModel()
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
def test_tensor():
class TestModel(LanceModel):
tensor: Tensor((3, 3))
schema = pydantic_to_schema(TestModel)
assert schema == TestModel.to_arrow_schema()
assert TestModel.field_names() == ["tensor"]
if PYDANTIC_VERSION >= (2,):
json_schema = TestModel.model_json_schema()
else:
json_schema = TestModel.schema()
assert json_schema == {
"properties": {
"tensor": {
"items": {
"items": {"type": "number"},
"maxItems": 3,
"minItems": 3,
"type": "array",
},
"maxItems": 3,
"minItems": 3,
"title": "Tensor",
"type": "array",
}
},
"required": ["tensor"],
"title": "TestModel",
"type": "object",
}

View File

@@ -31,7 +31,7 @@ import lancedb
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.db import LanceDBConnection
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from lancedb.pydantic import LanceModel, Vector
from lancedb.pydantic import LanceModel, Tensor, Vector
from lancedb.table import LanceTable
@@ -898,3 +898,18 @@ def test_restore_consistency(tmp_path):
table.add([{"id": 2}])
assert table_fixed.version == table.version - 1
assert table_ref_latest.version == table.version
def test_tensor_type(tmp_path):
# create a model with a tensor column
class MyTable(LanceModel):
tensor: Tensor((256, 256, 3))
db = lancedb.connect(tmp_path)
table = LanceTable.create(db, "my_table", schema=MyTable)
tensor = np.random.rand(256, 256, 3)
table.add([{"tensor": tensor}, {"tensor": tensor.tolist()}])
result = table.search().limit(2).to_pandas()
assert np.allclose(result.tensor[0], result.tensor[1])