mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-24 05:49:57 +00:00
Compare commits
1 Commits
add-python
...
changhiskh
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
96a7c1ab42 |
@@ -27,6 +27,7 @@ from typing import (
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
_GenericAlias,
|
||||
@@ -37,6 +38,11 @@ import pyarrow as pa
|
||||
import pydantic
|
||||
import semver
|
||||
|
||||
from lancedb.util import safe_import_tf, safe_import_torch
|
||||
|
||||
torch = safe_import_torch()
|
||||
tf = safe_import_tf()
|
||||
|
||||
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
|
||||
try:
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
@@ -79,9 +85,6 @@ def Vector(
|
||||
) -> Type[FixedSizeListMixin]:
|
||||
"""Pydantic Vector Type.
|
||||
|
||||
!!! warning
|
||||
Experimental feature.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dim : int
|
||||
@@ -155,6 +158,142 @@ def Vector(
|
||||
return FixedSizeList
|
||||
|
||||
|
||||
class FixedShapeTensorMixin(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def shape() -> Tuple[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def value_arrow_type() -> pa.DataType:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def Tensor(
|
||||
shape: Tuple[int], value_type: pa.DataType = pa.float32()
|
||||
) -> Type[FixedShapeTensorMixin]:
|
||||
"""Pydantic Tensor Type.
|
||||
|
||||
!!! warning
|
||||
Experimental feature.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shape : tuple of int
|
||||
The shape of the tensor
|
||||
value_type : pyarrow.DataType, optional
|
||||
The value type of the vector, by default pa.float32()
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> import pydantic
|
||||
>>> from lancedb.pydantic import LanceModel, Tensor, Vector
|
||||
...
|
||||
>>> class MyModel(LanceModel):
|
||||
... id: int
|
||||
... url: str
|
||||
... tensor: Tensor((3, 3))
|
||||
... embedding: Vector(768)
|
||||
>>> schema = pydantic_to_schema(MyModel)
|
||||
>>> assert schema == pa.schema([
|
||||
... pa.field("id", pa.int64(), False),
|
||||
... pa.field("url", pa.utf8(), False),
|
||||
... pa.field("tensor", pa.fixed_shape_tensor(pa.float32(), (3, 3)), False),
|
||||
... pa.field("embeddings", pa.list_(pa.float32(), 768), False)
|
||||
... ])
|
||||
"""
|
||||
|
||||
# TODO: make a public parameterized type.
|
||||
class FixedShapeTensor(FixedShapeTensorMixin):
|
||||
def __repr__(self):
|
||||
return f"FixedShapeTensor(shape={shape})"
|
||||
|
||||
@staticmethod
|
||||
def shape() -> Tuple[int]:
|
||||
return shape
|
||||
|
||||
@staticmethod
|
||||
def value_arrow_type() -> pa.DataType:
|
||||
return value_type
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
return core_schema.no_info_after_validator_function(
|
||||
np.asarray,
|
||||
nested_schema(shape, core_schema.float_schema()),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Generator[Callable, None, None]:
|
||||
yield cls.validate
|
||||
|
||||
# For pydantic v1
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
if isinstance(v, list):
|
||||
v = cls._validate_list(v, shape)
|
||||
elif isinstance(v, np.ndarray):
|
||||
v = cls._validate_ndarray(v, shape)
|
||||
elif torch is not None and isinstance(v, torch.Tensor):
|
||||
v = cls._validate_torch(v, shape)
|
||||
elif tf is not None and isinstance(v, tf.Tensor):
|
||||
v = cls._validate_tf(v, shape)
|
||||
else:
|
||||
raise TypeError(
|
||||
"A list of numbers, numpy.ndarray, torch.Tensor, "
|
||||
f"or tf.Tensor is needed but got {type(v)} instead."
|
||||
)
|
||||
return np.asarray(v)
|
||||
|
||||
@classmethod
|
||||
def _validate_list(cls, v, shape):
|
||||
v = np.asarray(v)
|
||||
return cls._validate_ndarray(v, shape)
|
||||
|
||||
@classmethod
|
||||
def _validate_ndarray(cls, v, shape):
|
||||
if v.shape != shape:
|
||||
raise ValueError(f"Invalid shape {v.shape}, expected {shape}")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def _validate_torch(cls, v, shape):
|
||||
v = v.detach().cpu().numpy()
|
||||
return cls._validate_ndarray(v, shape)
|
||||
|
||||
@classmethod
|
||||
def _validate_tf(cls, v, shape):
|
||||
v = v.numpy()
|
||||
return cls._validate_ndarray(v, shape)
|
||||
|
||||
if PYDANTIC_VERSION < (2, 0):
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any], field):
|
||||
if field and field.sub_fields:
|
||||
type_with_potential_subtype = f"np.ndarray[{field.sub_fields[0]}]"
|
||||
else:
|
||||
type_with_potential_subtype = "np.ndarray"
|
||||
field_schema.update({"type": type_with_potential_subtype})
|
||||
|
||||
return FixedShapeTensor
|
||||
|
||||
|
||||
def nested_schema(shape, items_schema):
|
||||
if len(shape) == 0:
|
||||
return items_schema
|
||||
else:
|
||||
return core_schema.list_schema(
|
||||
min_length=shape[0],
|
||||
max_length=shape[0],
|
||||
items_schema=nested_schema(shape[1:], items_schema),
|
||||
)
|
||||
|
||||
|
||||
def _py_type_to_arrow_type(py_type: Type[Any], field: FieldInfo) -> pa.DataType:
|
||||
"""Convert a field with native Python type to Arrow data type.
|
||||
|
||||
@@ -230,6 +369,10 @@ def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
|
||||
return pa.struct(fields)
|
||||
elif issubclass(field.annotation, FixedSizeListMixin):
|
||||
return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim())
|
||||
elif issubclass(field.annotation, FixedShapeTensorMixin):
|
||||
return pa.fixed_shape_tensor(
|
||||
field.annotation.value_arrow_type(), field.annotation.shape()
|
||||
)
|
||||
return _py_type_to_arrow_type(field.annotation, field)
|
||||
|
||||
|
||||
|
||||
@@ -1568,7 +1568,7 @@ def _sanitize_schema(
|
||||
# is a vector column. This is definitely a bit hacky.
|
||||
likely_vector_col = (
|
||||
pa.types.is_fixed_size_list(field.type)
|
||||
and pa.types.is_float32(field.type.value_type)
|
||||
and pa.types.is_floating(field.type.value_type)
|
||||
and field.type.list_size >= 10
|
||||
)
|
||||
is_default_vector_col = field.name == VECTOR_COLUMN_NAME
|
||||
@@ -1581,6 +1581,11 @@ def _sanitize_schema(
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
|
||||
is_tensor_type = isinstance(field.type, pa.FixedShapeTensorType)
|
||||
if is_tensor_type and field.name in data.column_names:
|
||||
data = _sanitize_tensor_column(data, column_name=field.name)
|
||||
|
||||
return pa.Table.from_arrays(
|
||||
[data[name] for name in schema.names], schema=schema
|
||||
)
|
||||
@@ -1649,6 +1654,31 @@ def _sanitize_vector_column(
|
||||
return data
|
||||
|
||||
|
||||
def _sanitize_tensor_column(data: pa.Table, column_name: str) -> pa.Table:
|
||||
"""
|
||||
Ensure that the tensor column exists and has type tensor(float32)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: pa.Table
|
||||
The table to sanitize.
|
||||
column_name: str
|
||||
The name of the tensor column.
|
||||
"""
|
||||
# ChunkedArray is annoying to work with, so we combine chunks here
|
||||
tensor_arr = data[column_name].combine_chunks()
|
||||
typ = data[column_name].type
|
||||
if not isinstance(typ, pa.FixedShapeTensorType):
|
||||
raise TypeError(f"Unsupported tensor column type: {tensor_arr.type}")
|
||||
|
||||
tensor_arr = ensure_tensor(tensor_arr)
|
||||
data = data.set_column(
|
||||
data.column_names.index(column_name), column_name, tensor_arr
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def ensure_fixed_size_list(vec_arr) -> pa.FixedSizeListArray:
|
||||
values = vec_arr.values
|
||||
if not (pa.types.is_float16(values.type) or pa.types.is_float32(values.type)):
|
||||
@@ -1661,6 +1691,11 @@ def ensure_fixed_size_list(vec_arr) -> pa.FixedSizeListArray:
|
||||
return vec_arr
|
||||
|
||||
|
||||
def ensure_tensor(tensor_arr) -> pa.TensorArray:
|
||||
assert 0 == 1
|
||||
return tensor_arr
|
||||
|
||||
|
||||
def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
|
||||
"""Sanitize jagged vectors."""
|
||||
if on_bad_vectors == "error":
|
||||
|
||||
@@ -153,6 +153,24 @@ def safe_import_polars():
|
||||
return None
|
||||
|
||||
|
||||
def safe_import_torch():
|
||||
try:
|
||||
import torch
|
||||
|
||||
return torch
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def safe_import_tf():
|
||||
try:
|
||||
import tensorflow as tf
|
||||
|
||||
return tf
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def inf_vector_column_query(schema: pa.Schema) -> str:
|
||||
"""
|
||||
Get the vector column name
|
||||
|
||||
@@ -22,7 +22,13 @@ import pydantic
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
|
||||
from lancedb.pydantic import (
|
||||
PYDANTIC_VERSION,
|
||||
LanceModel,
|
||||
Tensor,
|
||||
Vector,
|
||||
pydantic_to_schema,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -244,3 +250,37 @@ def test_lance_model():
|
||||
|
||||
t = TestModel()
|
||||
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
|
||||
|
||||
|
||||
def test_tensor():
|
||||
class TestModel(LanceModel):
|
||||
tensor: Tensor((3, 3))
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
assert schema == TestModel.to_arrow_schema()
|
||||
assert TestModel.field_names() == ["tensor"]
|
||||
|
||||
if PYDANTIC_VERSION >= (2,):
|
||||
json_schema = TestModel.model_json_schema()
|
||||
else:
|
||||
json_schema = TestModel.schema()
|
||||
|
||||
assert json_schema == {
|
||||
"properties": {
|
||||
"tensor": {
|
||||
"items": {
|
||||
"items": {"type": "number"},
|
||||
"maxItems": 3,
|
||||
"minItems": 3,
|
||||
"type": "array",
|
||||
},
|
||||
"maxItems": 3,
|
||||
"minItems": 3,
|
||||
"title": "Tensor",
|
||||
"type": "array",
|
||||
}
|
||||
},
|
||||
"required": ["tensor"],
|
||||
"title": "TestModel",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ import lancedb
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.pydantic import LanceModel, Tensor, Vector
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
|
||||
@@ -898,3 +898,18 @@ def test_restore_consistency(tmp_path):
|
||||
table.add([{"id": 2}])
|
||||
assert table_fixed.version == table.version - 1
|
||||
assert table_ref_latest.version == table.version
|
||||
|
||||
|
||||
def test_tensor_type(tmp_path):
|
||||
# create a model with a tensor column
|
||||
class MyTable(LanceModel):
|
||||
tensor: Tensor((256, 256, 3))
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
table = LanceTable.create(db, "my_table", schema=MyTable)
|
||||
|
||||
tensor = np.random.rand(256, 256, 3)
|
||||
table.add([{"tensor": tensor}, {"tensor": tensor.tolist()}])
|
||||
|
||||
result = table.search().limit(2).to_pandas()
|
||||
assert np.allclose(result.tensor[0], result.tensor[1])
|
||||
|
||||
Reference in New Issue
Block a user