[Python] Pydantic vector field with default value (#474)

Rename `lance.pydantic.vector` to `Vector` and deprecate `vector(dim)`
This commit is contained in:
Lei Xu
2023-09-08 22:35:31 -07:00
committed by GitHub
parent aa7806cf0d
commit b315ea3978
8 changed files with 63 additions and 47 deletions

View File

@@ -46,7 +46,19 @@ class FixedSizeListMixin(ABC):
raise NotImplementedError
def vector(
def vector(dim: int, value_type: pa.DataType = pa.float32()):
# TODO: remove in future release
from warnings import warn
warn(
"lancedb.pydantic.vector() is deprecated, use lancedb.pydantic.Vector instead."
"This function will be removed in future release",
DeprecationWarning,
)
return Vector(dim, value_type)
def Vector(
dim: int, value_type: pa.DataType = pa.float32()
) -> Type[FixedSizeListMixin]:
"""Pydantic Vector Type.
@@ -65,12 +77,12 @@ def vector(
--------
>>> import pydantic
>>> from lancedb.pydantic import vector
>>> from lancedb.pydantic import Vector
...
>>> class MyModel(pydantic.BaseModel):
... id: int
... url: str
... embeddings: vector(768)
... embeddings: Vector(768)
>>> schema = pydantic_to_schema(MyModel)
>>> assert schema == pa.schema([
... pa.field("id", pa.int64(), False),
@@ -258,11 +270,11 @@ class LanceModel(pydantic.BaseModel):
Examples
--------
>>> import lancedb
>>> from lancedb.pydantic import LanceModel, vector
>>> from lancedb.pydantic import LanceModel, Vector
>>>
>>> class TestModel(LanceModel):
... name: str
... vector: vector(2)
... vector: Vector(2)
...
>>> db = lancedb.connect("/tmp")
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema())

View File

@@ -17,7 +17,7 @@ import pyarrow as pa
import pytest
import lancedb
from lancedb.pydantic import LanceModel, vector
from lancedb.pydantic import LanceModel, Vector
def test_basic(tmp_path):
@@ -79,7 +79,7 @@ def test_ingest_pd(tmp_path):
def test_ingest_iterator(tmp_path):
class PydanticSchema(LanceModel):
vector: vector(2)
vector: Vector(2)
item: str
price: float

View File

@@ -19,8 +19,9 @@ from typing import List, Optional
import pyarrow as pa
import pydantic
import pytest
from pydantic import Field
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, pydantic_to_schema, vector
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
@pytest.mark.skipif(
@@ -107,7 +108,7 @@ def test_pydantic_to_arrow_py38():
def test_fixed_size_list_field():
class TestModel(pydantic.BaseModel):
vec: vector(16)
vec: Vector(16)
li: List[int]
data = TestModel(vec=list(range(16)), li=[1, 2, 3])
@@ -154,7 +155,7 @@ def test_fixed_size_list_field():
def test_fixed_size_list_validation():
class TestModel(pydantic.BaseModel):
vec: vector(8)
vec: Vector(8)
with pytest.raises(pydantic.ValidationError):
TestModel(vec=range(9))
@@ -167,9 +168,12 @@ def test_fixed_size_list_validation():
def test_lance_model():
class TestModel(LanceModel):
vec: vector(16)
li: List[int]
vector: Vector(16) = Field(default=[0.0] * 16)
li: List[int] = Field(default=[1, 2, 3])
schema = pydantic_to_schema(TestModel)
assert schema == TestModel.to_arrow_schema()
assert TestModel.field_names() == ["vec", "li"]
assert TestModel.field_names() == ["vector", "li"]
t = TestModel()
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])

View File

@@ -20,7 +20,7 @@ import pyarrow as pa
import pytest
from lancedb.db import LanceDBConnection
from lancedb.pydantic import LanceModel, vector
from lancedb.pydantic import LanceModel, Vector
from lancedb.query import LanceVectorQueryBuilder, Query
from lancedb.table import LanceTable
@@ -67,7 +67,7 @@ def table(tmp_path) -> MockTable:
def test_cast(table):
class TestModel(LanceModel):
vector: vector(2)
vector: Vector(2)
id: int
str_field: str
float_field: float

View File

@@ -24,7 +24,7 @@ import pytest
from lancedb.conftest import MockEmbeddingFunction
from lancedb.db import LanceDBConnection
from lancedb.pydantic import LanceModel, vector
from lancedb.pydantic import LanceModel, Vector
from lancedb.table import LanceTable
@@ -140,7 +140,7 @@ def test_add(db):
def test_add_pydantic_model(db):
class TestModel(LanceModel):
vector: vector(16)
vector: Vector(16)
li: List[int]
data = TestModel(vector=list(range(16)), li=[1, 2, 3])
@@ -354,7 +354,7 @@ def test_update(db):
def test_create_with_embedding_function(db):
class MyTable(LanceModel):
text: str
vector: vector(10)
vector: Vector(10)
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
@@ -379,7 +379,7 @@ def test_create_with_embedding_function(db):
def test_add_with_embedding_function(db):
class MyTable(LanceModel):
text: str
vector: vector(10)
vector: Vector(10)
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
table = LanceTable.create(
@@ -407,8 +407,8 @@ def test_add_with_embedding_function(db):
def test_multiple_vector_columns(db):
class MyTable(LanceModel):
text: str
vector1: vector(10)
vector2: vector(10)
vector1: Vector(10)
vector2: Vector(10)
table = LanceTable.create(
db,