mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-14 15:52:57 +00:00
feat(python): add Tensor pydantic type
- [x] Can be used to declare data model - [ ] Can be used to ingest data
This commit is contained in:
@@ -22,7 +22,13 @@ import pydantic
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
|
||||
from lancedb.pydantic import (
|
||||
PYDANTIC_VERSION,
|
||||
LanceModel,
|
||||
Tensor,
|
||||
Vector,
|
||||
pydantic_to_schema,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -244,3 +250,37 @@ def test_lance_model():
|
||||
|
||||
t = TestModel()
|
||||
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
|
||||
|
||||
|
||||
def test_tensor():
|
||||
class TestModel(LanceModel):
|
||||
tensor: Tensor((3, 3))
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
assert schema == TestModel.to_arrow_schema()
|
||||
assert TestModel.field_names() == ["tensor"]
|
||||
|
||||
if PYDANTIC_VERSION >= (2,):
|
||||
json_schema = TestModel.model_json_schema()
|
||||
else:
|
||||
json_schema = TestModel.schema()
|
||||
|
||||
assert json_schema == {
|
||||
"properties": {
|
||||
"tensor": {
|
||||
"items": {
|
||||
"items": {"type": "number"},
|
||||
"maxItems": 3,
|
||||
"minItems": 3,
|
||||
"type": "array",
|
||||
},
|
||||
"maxItems": 3,
|
||||
"minItems": 3,
|
||||
"title": "Tensor",
|
||||
"type": "array",
|
||||
}
|
||||
},
|
||||
"required": ["tensor"],
|
||||
"title": "TestModel",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ import lancedb
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.pydantic import LanceModel, Tensor, Vector
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
|
||||
@@ -898,3 +898,18 @@ def test_restore_consistency(tmp_path):
|
||||
table.add([{"id": 2}])
|
||||
assert table_fixed.version == table.version - 1
|
||||
assert table_ref_latest.version == table.version
|
||||
|
||||
|
||||
def test_tensor_type(tmp_path):
|
||||
# create a model with a tensor column
|
||||
class MyTable(LanceModel):
|
||||
tensor: Tensor((256, 256, 3))
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
table = LanceTable.create(db, "my_table", schema=MyTable)
|
||||
|
||||
tensor = np.random.rand(256, 256, 3)
|
||||
table.add([{"tensor": tensor}, {"tensor": tensor.tolist()}])
|
||||
|
||||
result = table.search().limit(2).to_pandas()
|
||||
assert np.allclose(result.tensor[0], result.tensor[1])
|
||||
|
||||
Reference in New Issue
Block a user