[Python] Pydantic vector field with default value (#474)

Rename `lance.pydantic.vector` to `Vector` and deprecate `vector(dim)`
This commit is contained in:
Lei Xu
2023-09-08 22:35:31 -07:00
committed by GitHub
parent aa7806cf0d
commit b315ea3978
8 changed files with 63 additions and 47 deletions

View File

@@ -49,11 +49,11 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
db.create_table("table2", data)
db["table2"].head()
db["table2"].head()
```
!!! info "Note"
Data is converted to Arrow before being written to disk. For maximum control over how data is saved, either provide the PyArrow schema to convert to or else provide a PyArrow Table directly.
```python
custom_schema = pa.schema([
pa.field("vector", pa.list_(pa.float32(), 2)),
@@ -66,7 +66,7 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
### From PyArrow Tables
You can also create LanceDB tables directly from pyarrow tables
```python
table = pa.Table.from_arrays(
[
@@ -87,15 +87,15 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
LanceDB supports to create Apache Arrow Schema from a Pydantic BaseModel via pydantic_to_schema() method.
```python
from lancedb.pydantic import vector, LanceModel
from lancedb.pydantic import Vector, LanceModel
class Content(LanceModel):
movie_id: int
vector: vector(128)
vector: Vector(128)
genres: str
title: str
imdb_id: int
@property
def imdb_url(self) -> str:
return f"https://www.imdb.com/title/tt{self.imdb_id}"
@@ -103,7 +103,7 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
import pyarrow as pa
db = lancedb.connect("~/.lancedb")
table_name = "movielens_small"
table = db.create_table(table_name, schema=Content.to_arrow_schema())
table = db.create_table(table_name, schema=Content)
```
### Using Iterators / Writing Large Datasets
@@ -113,7 +113,7 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
LanceDB additionally supports pyarrow's `RecordBatch` Iterators or other generators producing supported data types.
Here's an example using using `RecordBatch` iterator for creating tables.
```python
import pyarrow as pa
@@ -142,11 +142,11 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
## Creating Empty Table
You can also create empty tables in python. Initialize it with schema and later ingest data into it.
```python
import lancedb
import pyarrow as pa
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2)),
@@ -168,8 +168,8 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
from lancedb.pydantic import LanceModel, vector
class Model(LanceModel):
vector: vector(2)
vector: Vector(2)
tbl = db.create_table("table5", schema=Model.to_arrow_schema())
```
@@ -249,7 +249,7 @@ After a table has been created, you can always add more data to it using
You can also add a large dataset batch in one go using Iterator of any supported data types.
### Adding to table using Iterator
```python
import pandas as pd
@@ -261,10 +261,10 @@ After a table has been created, you can always add more data to it using
"item": ["foo", "bar"],
"price": [10.0, 20.0],
})
tbl.add(make_batches())
```
The other arguments accepted:
| Name | Type | Description | Default |
@@ -274,7 +274,7 @@ After a table has been created, you can always add more data to it using
| on_bad_vectors | str | What to do if any of the vectors are not the same size or contains NaNs. One of "error", "drop", "fill". | drop |
| fill value | float | The value to use when filling vectors: Only used if on_bad_vectors="fill". | 0.0 |
=== "Javascript/Typescript"
```javascript
@@ -312,7 +312,7 @@ Use the `delete()` method on tables to delete rows from a table. To choose which
# x vector
# 0 1 [1.0, 2.0]
# 1 3 [5.0, 6.0]
```
```
### Delete from a list of values
@@ -325,7 +325,7 @@ Use the `delete()` method on tables to delete rows from a table. To choose which
# x vector
# 0 3 [5.0, 6.0]
```
=== "Javascript/Typescript"
```javascript

View File

@@ -249,11 +249,11 @@
}
],
"source": [
"from lancedb.pydantic import vector, LanceModel\n",
"from lancedb.pydantic import Vector, LanceModel\n",
"\n",
"class Content(LanceModel):\n",
" movie_id: int\n",
" vector: vector(128)\n",
" vector: Vector(128)\n",
" genres: str\n",
" title: str\n",
" imdb_id: int\n",
@@ -359,7 +359,7 @@
"import pandas as pd\n",
"\n",
"class PydanticSchema(LanceModel):\n",
" vector: vector(2)\n",
" vector: Vector(2)\n",
" item: str\n",
" price: float\n",
"\n",
@@ -394,10 +394,10 @@
"outputs": [],
"source": [
"import lancedb\n",
"from lancedb.pydantic import LanceModel, vector\n",
"from lancedb.pydantic import LanceModel, Vector\n",
"\n",
"class Model(LanceModel):\n",
" vector: vector(2)\n",
" vector: Vector(2)\n",
"\n",
"tbl = db.create_table(\"table6\", schema=Model.to_arrow_schema())"
]

View File

@@ -13,10 +13,10 @@ via [pydantic_to_schema()](python.md##lancedb.pydantic.pydantic_to_schema) metho
## Vector Field
LanceDB provides a [`vector(dim)`](python.md#lancedb.pydantic.vector) method to define a
LanceDB provides a [`Vector(dim)`](python.md#lancedb.pydantic.Vector) method to define a
vector Field in a Pydantic Model.
::: lancedb.pydantic.vector
::: lancedb.pydantic.Vector
## Type Conversion
@@ -33,4 +33,4 @@ Current supported type conversions:
| `str` | `pyarrow.utf8()` |
| `list` | `pyarrow.List` |
| `BaseModel` | `pyarrow.Struct` |
| `vector(n)` | `pyarrow.FixedSizeList(float32, n)` |
| `Vector(n)` | `pyarrow.FixedSizeList(float32, n)` |

View File

@@ -46,7 +46,19 @@ class FixedSizeListMixin(ABC):
raise NotImplementedError
def vector(
def vector(dim: int, value_type: pa.DataType = pa.float32()):
# TODO: remove in future release
from warnings import warn
warn(
"lancedb.pydantic.vector() is deprecated, use lancedb.pydantic.Vector instead."
"This function will be removed in future release",
DeprecationWarning,
)
return Vector(dim, value_type)
def Vector(
dim: int, value_type: pa.DataType = pa.float32()
) -> Type[FixedSizeListMixin]:
"""Pydantic Vector Type.
@@ -65,12 +77,12 @@ def vector(
--------
>>> import pydantic
>>> from lancedb.pydantic import vector
>>> from lancedb.pydantic import Vector
...
>>> class MyModel(pydantic.BaseModel):
... id: int
... url: str
... embeddings: vector(768)
... embeddings: Vector(768)
>>> schema = pydantic_to_schema(MyModel)
>>> assert schema == pa.schema([
... pa.field("id", pa.int64(), False),
@@ -258,11 +270,11 @@ class LanceModel(pydantic.BaseModel):
Examples
--------
>>> import lancedb
>>> from lancedb.pydantic import LanceModel, vector
>>> from lancedb.pydantic import LanceModel, Vector
>>>
>>> class TestModel(LanceModel):
... name: str
... vector: vector(2)
... vector: Vector(2)
...
>>> db = lancedb.connect("/tmp")
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema())

View File

@@ -17,7 +17,7 @@ import pyarrow as pa
import pytest
import lancedb
from lancedb.pydantic import LanceModel, vector
from lancedb.pydantic import LanceModel, Vector
def test_basic(tmp_path):
@@ -79,7 +79,7 @@ def test_ingest_pd(tmp_path):
def test_ingest_iterator(tmp_path):
class PydanticSchema(LanceModel):
vector: vector(2)
vector: Vector(2)
item: str
price: float

View File

@@ -19,8 +19,9 @@ from typing import List, Optional
import pyarrow as pa
import pydantic
import pytest
from pydantic import Field
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, pydantic_to_schema, vector
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
@pytest.mark.skipif(
@@ -107,7 +108,7 @@ def test_pydantic_to_arrow_py38():
def test_fixed_size_list_field():
class TestModel(pydantic.BaseModel):
vec: vector(16)
vec: Vector(16)
li: List[int]
data = TestModel(vec=list(range(16)), li=[1, 2, 3])
@@ -154,7 +155,7 @@ def test_fixed_size_list_field():
def test_fixed_size_list_validation():
class TestModel(pydantic.BaseModel):
vec: vector(8)
vec: Vector(8)
with pytest.raises(pydantic.ValidationError):
TestModel(vec=range(9))
@@ -167,9 +168,12 @@ def test_fixed_size_list_validation():
def test_lance_model():
class TestModel(LanceModel):
vec: vector(16)
li: List[int]
vector: Vector(16) = Field(default=[0.0] * 16)
li: List[int] = Field(default=[1, 2, 3])
schema = pydantic_to_schema(TestModel)
assert schema == TestModel.to_arrow_schema()
assert TestModel.field_names() == ["vec", "li"]
assert TestModel.field_names() == ["vector", "li"]
t = TestModel()
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])

View File

@@ -20,7 +20,7 @@ import pyarrow as pa
import pytest
from lancedb.db import LanceDBConnection
from lancedb.pydantic import LanceModel, vector
from lancedb.pydantic import LanceModel, Vector
from lancedb.query import LanceVectorQueryBuilder, Query
from lancedb.table import LanceTable
@@ -67,7 +67,7 @@ def table(tmp_path) -> MockTable:
def test_cast(table):
class TestModel(LanceModel):
vector: vector(2)
vector: Vector(2)
id: int
str_field: str
float_field: float

View File

@@ -24,7 +24,7 @@ import pytest
from lancedb.conftest import MockEmbeddingFunction
from lancedb.db import LanceDBConnection
from lancedb.pydantic import LanceModel, vector
from lancedb.pydantic import LanceModel, Vector
from lancedb.table import LanceTable
@@ -140,7 +140,7 @@ def test_add(db):
def test_add_pydantic_model(db):
class TestModel(LanceModel):
vector: vector(16)
vector: Vector(16)
li: List[int]
data = TestModel(vector=list(range(16)), li=[1, 2, 3])
@@ -354,7 +354,7 @@ def test_update(db):
def test_create_with_embedding_function(db):
class MyTable(LanceModel):
text: str
vector: vector(10)
vector: Vector(10)
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
@@ -379,7 +379,7 @@ def test_create_with_embedding_function(db):
def test_add_with_embedding_function(db):
class MyTable(LanceModel):
text: str
vector: vector(10)
vector: Vector(10)
func = MockEmbeddingFunction(source_column="text", vector_column="vector")
table = LanceTable.create(
@@ -407,8 +407,8 @@ def test_add_with_embedding_function(db):
def test_multiple_vector_columns(db):
class MyTable(LanceModel):
text: str
vector1: vector(10)
vector2: vector(10)
vector1: Vector(10)
vector2: Vector(10)
table = LanceTable.create(
db,