mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-05 03:12:57 +00:00
[Python] Pydantic vector field with default value (#474)
Rename `lance.pydantic.vector` to `Vector` and deprecate `vector(dim)`
This commit is contained in:
@@ -46,7 +46,19 @@ class FixedSizeListMixin(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def vector(
|
||||
def vector(dim: int, value_type: pa.DataType = pa.float32()):
|
||||
# TODO: remove in future release
|
||||
from warnings import warn
|
||||
|
||||
warn(
|
||||
"lancedb.pydantic.vector() is deprecated, use lancedb.pydantic.Vector instead."
|
||||
"This function will be removed in future release",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return Vector(dim, value_type)
|
||||
|
||||
|
||||
def Vector(
|
||||
dim: int, value_type: pa.DataType = pa.float32()
|
||||
) -> Type[FixedSizeListMixin]:
|
||||
"""Pydantic Vector Type.
|
||||
@@ -65,12 +77,12 @@ def vector(
|
||||
--------
|
||||
|
||||
>>> import pydantic
|
||||
>>> from lancedb.pydantic import vector
|
||||
>>> from lancedb.pydantic import Vector
|
||||
...
|
||||
>>> class MyModel(pydantic.BaseModel):
|
||||
... id: int
|
||||
... url: str
|
||||
... embeddings: vector(768)
|
||||
... embeddings: Vector(768)
|
||||
>>> schema = pydantic_to_schema(MyModel)
|
||||
>>> assert schema == pa.schema([
|
||||
... pa.field("id", pa.int64(), False),
|
||||
@@ -258,11 +270,11 @@ class LanceModel(pydantic.BaseModel):
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> from lancedb.pydantic import LanceModel, vector
|
||||
>>> from lancedb.pydantic import LanceModel, Vector
|
||||
>>>
|
||||
>>> class TestModel(LanceModel):
|
||||
... name: str
|
||||
... vector: vector(2)
|
||||
... vector: Vector(2)
|
||||
...
|
||||
>>> db = lancedb.connect("/tmp")
|
||||
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema())
|
||||
|
||||
@@ -17,7 +17,7 @@ import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, vector
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
|
||||
def test_basic(tmp_path):
|
||||
@@ -79,7 +79,7 @@ def test_ingest_pd(tmp_path):
|
||||
|
||||
def test_ingest_iterator(tmp_path):
|
||||
class PydanticSchema(LanceModel):
|
||||
vector: vector(2)
|
||||
vector: Vector(2)
|
||||
item: str
|
||||
price: float
|
||||
|
||||
|
||||
@@ -19,8 +19,9 @@ from typing import List, Optional
|
||||
import pyarrow as pa
|
||||
import pydantic
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, pydantic_to_schema, vector
|
||||
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -107,7 +108,7 @@ def test_pydantic_to_arrow_py38():
|
||||
|
||||
def test_fixed_size_list_field():
|
||||
class TestModel(pydantic.BaseModel):
|
||||
vec: vector(16)
|
||||
vec: Vector(16)
|
||||
li: List[int]
|
||||
|
||||
data = TestModel(vec=list(range(16)), li=[1, 2, 3])
|
||||
@@ -154,7 +155,7 @@ def test_fixed_size_list_field():
|
||||
|
||||
def test_fixed_size_list_validation():
|
||||
class TestModel(pydantic.BaseModel):
|
||||
vec: vector(8)
|
||||
vec: Vector(8)
|
||||
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
TestModel(vec=range(9))
|
||||
@@ -167,9 +168,12 @@ def test_fixed_size_list_validation():
|
||||
|
||||
def test_lance_model():
|
||||
class TestModel(LanceModel):
|
||||
vec: vector(16)
|
||||
li: List[int]
|
||||
vector: Vector(16) = Field(default=[0.0] * 16)
|
||||
li: List[int] = Field(default=[1, 2, 3])
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
assert schema == TestModel.to_arrow_schema()
|
||||
assert TestModel.field_names() == ["vec", "li"]
|
||||
assert TestModel.field_names() == ["vector", "li"]
|
||||
|
||||
t = TestModel()
|
||||
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
|
||||
|
||||
@@ -20,7 +20,7 @@ import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.pydantic import LanceModel, vector
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.query import LanceVectorQueryBuilder, Query
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
@@ -67,7 +67,7 @@ def table(tmp_path) -> MockTable:
|
||||
|
||||
def test_cast(table):
|
||||
class TestModel(LanceModel):
|
||||
vector: vector(2)
|
||||
vector: Vector(2)
|
||||
id: int
|
||||
str_field: str
|
||||
float_field: float
|
||||
|
||||
@@ -24,7 +24,7 @@ import pytest
|
||||
|
||||
from lancedb.conftest import MockEmbeddingFunction
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.pydantic import LanceModel, vector
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
|
||||
@@ -140,7 +140,7 @@ def test_add(db):
|
||||
|
||||
def test_add_pydantic_model(db):
|
||||
class TestModel(LanceModel):
|
||||
vector: vector(16)
|
||||
vector: Vector(16)
|
||||
li: List[int]
|
||||
|
||||
data = TestModel(vector=list(range(16)), li=[1, 2, 3])
|
||||
@@ -354,7 +354,7 @@ def test_update(db):
|
||||
def test_create_with_embedding_function(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
vector: vector(10)
|
||||
vector: Vector(10)
|
||||
|
||||
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
||||
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
|
||||
@@ -379,7 +379,7 @@ def test_create_with_embedding_function(db):
|
||||
def test_add_with_embedding_function(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
vector: vector(10)
|
||||
vector: Vector(10)
|
||||
|
||||
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
|
||||
table = LanceTable.create(
|
||||
@@ -407,8 +407,8 @@ def test_add_with_embedding_function(db):
|
||||
def test_multiple_vector_columns(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
vector1: vector(10)
|
||||
vector2: vector(10)
|
||||
vector1: Vector(10)
|
||||
vector2: Vector(10)
|
||||
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
|
||||
Reference in New Issue
Block a user