diff --git a/python/lancedb/pydantic.py b/python/lancedb/pydantic.py index 89bf08d6..490b212f 100644 --- a/python/lancedb/pydantic.py +++ b/python/lancedb/pydantic.py @@ -27,6 +27,7 @@ from typing import ( Dict, Generator, List, + Tuple, Type, Union, _GenericAlias, @@ -37,6 +38,11 @@ import pyarrow as pa import pydantic import semver +from lancedb.util import safe_import_tf, safe_import_torch + +torch = safe_import_torch() +tf = safe_import_tf() + PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__) try: from pydantic_core import CoreSchema, core_schema @@ -79,9 +85,6 @@ def Vector( ) -> Type[FixedSizeListMixin]: """Pydantic Vector Type. - !!! warning - Experimental feature. - Parameters ---------- dim : int @@ -155,6 +158,142 @@ def Vector( return FixedSizeList +class FixedShapeTensorMixin(ABC): + @staticmethod + @abstractmethod + def shape() -> Tuple[int]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def value_arrow_type() -> pa.DataType: + raise NotImplementedError + + +def Tensor( + shape: Tuple[int], value_type: pa.DataType = pa.float32() +) -> Type[FixedShapeTensorMixin]: + """Pydantic Tensor Type. + + !!! warning + Experimental feature. + + Parameters + ---------- + shape : tuple of int + The shape of the tensor + value_type : pyarrow.DataType, optional + The value type of the vector, by default pa.float32() + + Examples + -------- + + >>> import pydantic + >>> from lancedb.pydantic import LanceModel, Tensor, Vector + ... + >>> class MyModel(LanceModel): + ... id: int + ... url: str + ... tensor: Tensor((3, 3)) + ... embedding: Vector(768) + >>> schema = pydantic_to_schema(MyModel) + >>> assert schema == pa.schema([ + ... pa.field("id", pa.int64(), False), + ... pa.field("url", pa.utf8(), False), + ... pa.field("tensor", pa.fixed_shape_tensor(pa.float32(), (3, 3)), False), + ... pa.field("embeddings", pa.list_(pa.float32(), 768), False) + ... ]) + """ + + # TODO: make a public parameterized type. + class FixedShapeTensor(FixedShapeTensorMixin): + def __repr__(self): + return f"FixedShapeTensor(shape={shape})" + + @staticmethod + def shape() -> Tuple[int]: + return shape + + @staticmethod + def value_arrow_type() -> pa.DataType: + return value_type + + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler + ) -> CoreSchema: + return core_schema.no_info_after_validator_function( + np.asarray, + nested_schema(shape, core_schema.float_schema()), + ) + + @classmethod + def __get_validators__(cls) -> Generator[Callable, None, None]: + yield cls.validate + + # For pydantic v1 + @classmethod + def validate(cls, v): + if isinstance(v, list): + v = cls._validate_list(v, shape) + elif isinstance(v, np.ndarray): + v = cls._validate_ndarray(v, shape) + elif torch is not None and isinstance(v, torch.Tensor): + v = cls._validate_torch(v, shape) + elif tf is not None and isinstance(v, tf.Tensor): + v = cls._validate_tf(v, shape) + else: + raise TypeError( + "A list of numbers, numpy.ndarray, torch.Tensor, " + f"or tf.Tensor is needed but got {type(v)} instead." + ) + return np.asarray(v) + + @classmethod + def _validate_list(cls, v, shape): + v = np.asarray(v) + return cls._validate_ndarray(v, shape) + + @classmethod + def _validate_ndarray(cls, v, shape): + if v.shape != shape: + raise ValueError(f"Invalid shape {v.shape}, expected {shape}") + return v + + @classmethod + def _validate_torch(cls, v, shape): + v = v.detach().cpu().numpy() + return cls._validate_ndarray(v, shape) + + @classmethod + def _validate_tf(cls, v, shape): + v = v.numpy() + return cls._validate_ndarray(v, shape) + + if PYDANTIC_VERSION < (2, 0): + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any], field): + if field and field.sub_fields: + type_with_potential_subtype = f"np.ndarray[{field.sub_fields[0]}]" + else: + type_with_potential_subtype = "np.ndarray" + field_schema.update({"type": type_with_potential_subtype}) + + return FixedShapeTensor + + +def nested_schema(shape, items_schema): + if len(shape) == 0: + return items_schema + else: + return core_schema.list_schema( + min_length=shape[0], + max_length=shape[0], + items_schema=nested_schema(shape[1:], items_schema), + ) + + def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType: """Convert a field with native Python type to Arrow data type. @@ -230,6 +369,10 @@ def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType: return pa.struct(fields) elif issubclass(field.annotation, FixedSizeListMixin): return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim()) + elif issubclass(field.annotation, FixedShapeTensorMixin): + return pa.fixed_shape_tensor( + field.annotation.value_arrow_type(), field.annotation.shape() + ) return _py_type_to_arrow_type(field.annotation, field) diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 04bad713..fbbfe409 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -1568,7 +1568,7 @@ def _sanitize_schema( # is a vector column. This is definitely a bit hacky. likely_vector_col = ( pa.types.is_fixed_size_list(field.type) - and pa.types.is_float32(field.type.value_type) + and pa.types.is_floating(field.type.value_type) and field.type.list_size >= 10 ) is_default_vector_col = field.name == VECTOR_COLUMN_NAME @@ -1581,6 +1581,11 @@ def _sanitize_schema( on_bad_vectors=on_bad_vectors, fill_value=fill_value, ) + + is_tensor_type = isinstance(field.type, pa.FixedShapeTensorType) + if is_tensor_type and field.name in data.column_names: + data = _sanitize_tensor_column(data, column_name=field.name) + return pa.Table.from_arrays( [data[name] for name in schema.names], schema=schema ) @@ -1649,6 +1654,31 @@ def _sanitize_vector_column( return data +def _sanitize_tensor_column(data: pa.Table, column_name: str) -> pa.Table: + """ + Ensure that the tensor column exists and has type tensor(float32) + + Parameters + ---------- + data: pa.Table + The table to sanitize. + column_name: str + The name of the tensor column. + """ + # ChunkedArray is annoying to work with, so we combine chunks here + tensor_arr = data[column_name].combine_chunks() + typ = data[column_name].type + if not isinstance(typ, pa.FixedShapeTensorType): + raise TypeError(f"Unsupported tensor column type: {tensor_arr.type}") + + tensor_arr = ensure_tensor(tensor_arr) + data = data.set_column( + data.column_names.index(column_name), column_name, tensor_arr + ) + + return data + + def ensure_fixed_size_list(vec_arr) -> pa.FixedSizeListArray: values = vec_arr.values if not (pa.types.is_float16(values.type) or pa.types.is_float32(values.type)): @@ -1661,6 +1691,11 @@ def ensure_fixed_size_list(vec_arr) -> pa.FixedSizeListArray: return vec_arr +def ensure_tensor(tensor_arr) -> pa.TensorArray: + assert 0 == 1 + return tensor_arr + + def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_name): """Sanitize jagged vectors.""" if on_bad_vectors == "error": diff --git a/python/lancedb/util.py b/python/lancedb/util.py index 14f9e530..dcea76d7 100644 --- a/python/lancedb/util.py +++ b/python/lancedb/util.py @@ -153,6 +153,24 @@ def safe_import_polars(): return None +def safe_import_torch(): + try: + import torch + + return torch + except ImportError: + return None + + +def safe_import_tf(): + try: + import tensorflow as tf + + return tf + except ImportError: + return None + + def inf_vector_column_query(schema: pa.Schema) -> str: """ Get the vector column name diff --git a/python/tests/test_pydantic.py b/python/tests/test_pydantic.py index b37373ee..6ad14488 100644 --- a/python/tests/test_pydantic.py +++ b/python/tests/test_pydantic.py @@ -22,7 +22,13 @@ import pydantic import pytest from pydantic import Field -from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema +from lancedb.pydantic import ( + PYDANTIC_VERSION, + LanceModel, + Tensor, + Vector, + pydantic_to_schema, +) @pytest.mark.skipif( @@ -244,3 +250,37 @@ def test_lance_model(): t = TestModel() assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3]) + + +def test_tensor(): + class TestModel(LanceModel): + tensor: Tensor((3, 3)) + + schema = pydantic_to_schema(TestModel) + assert schema == TestModel.to_arrow_schema() + assert TestModel.field_names() == ["tensor"] + + if PYDANTIC_VERSION >= (2,): + json_schema = TestModel.model_json_schema() + else: + json_schema = TestModel.schema() + + assert json_schema == { + "properties": { + "tensor": { + "items": { + "items": {"type": "number"}, + "maxItems": 3, + "minItems": 3, + "type": "array", + }, + "maxItems": 3, + "minItems": 3, + "title": "Tensor", + "type": "array", + } + }, + "required": ["tensor"], + "title": "TestModel", + "type": "object", + } diff --git a/python/tests/test_table.py b/python/tests/test_table.py index 9282a5c6..7337968b 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -31,7 +31,7 @@ import lancedb from lancedb.conftest import MockTextEmbeddingFunction from lancedb.db import LanceDBConnection from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry -from lancedb.pydantic import LanceModel, Vector +from lancedb.pydantic import LanceModel, Tensor, Vector from lancedb.table import LanceTable @@ -898,3 +898,18 @@ def test_restore_consistency(tmp_path): table.add([{"id": 2}]) assert table_fixed.version == table.version - 1 assert table_ref_latest.version == table.version + + +def test_tensor_type(tmp_path): + # create a model with a tensor column + class MyTable(LanceModel): + tensor: Tensor((256, 256, 3)) + + db = lancedb.connect(tmp_path) + table = LanceTable.create(db, "my_table", schema=MyTable) + + tensor = np.random.rand(256, 256, 3) + table.add([{"tensor": tensor}, {"tensor": tensor.tolist()}]) + + result = table.search().limit(2).to_pandas() + assert np.allclose(result.tensor[0], result.tensor[1])