feat(python): add Tensor pydantic type

- [x] Can be used to declare data model
- [ ] Can be used to ingest data
This commit is contained in:
Chang She
2024-02-17 10:29:50 -08:00
parent e0277383a5
commit 96a7c1ab42
5 changed files with 257 additions and 6 deletions

View File

@@ -22,7 +22,13 @@ import pydantic
import pytest
from pydantic import Field
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
from lancedb.pydantic import (
PYDANTIC_VERSION,
LanceModel,
Tensor,
Vector,
pydantic_to_schema,
)
@pytest.mark.skipif(
@@ -244,3 +250,37 @@ def test_lance_model():
t = TestModel()
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
def test_tensor():
class TestModel(LanceModel):
tensor: Tensor((3, 3))
schema = pydantic_to_schema(TestModel)
assert schema == TestModel.to_arrow_schema()
assert TestModel.field_names() == ["tensor"]
if PYDANTIC_VERSION >= (2,):
json_schema = TestModel.model_json_schema()
else:
json_schema = TestModel.schema()
assert json_schema == {
"properties": {
"tensor": {
"items": {
"items": {"type": "number"},
"maxItems": 3,
"minItems": 3,
"type": "array",
},
"maxItems": 3,
"minItems": 3,
"title": "Tensor",
"type": "array",
}
},
"required": ["tensor"],
"title": "TestModel",
"type": "object",
}

View File

@@ -31,7 +31,7 @@ import lancedb
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.db import LanceDBConnection
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from lancedb.pydantic import LanceModel, Vector
from lancedb.pydantic import LanceModel, Tensor, Vector
from lancedb.table import LanceTable
@@ -898,3 +898,18 @@ def test_restore_consistency(tmp_path):
table.add([{"id": 2}])
assert table_fixed.version == table.version - 1
assert table_ref_latest.version == table.version
def test_tensor_type(tmp_path):
# create a model with a tensor column
class MyTable(LanceModel):
tensor: Tensor((256, 256, 3))
db = lancedb.connect(tmp_path)
table = LanceTable.create(db, "my_table", schema=MyTable)
tensor = np.random.rand(256, 256, 3)
table.add([{"tensor": tensor}, {"tensor": tensor.tolist()}])
result = table.search().limit(2).to_pandas()
assert np.allclose(result.tensor[0], result.tensor[1])