feat(python): support optional vector field in pydantic model (#1097)

The LanceDB embeddings registry allows users to annotate the pydantic
model used as table schema with the desired embedding function, e.g.:

```python
class Schema(LanceModel):
    id: str
    vector: Vector(openai.ndims()) = openai.VectorField()
    text: str = openai.SourceField()
```

Tables created like this does not require embeddings to be calculated by
the user explicitly, e.g. this works:

```python
table.add([{"id": "foo", "text": "rust all the things"}])
```

However, trying to construct pydantic model instances without vector
doesn't because it's a required field.

Instead, you need add a default value:

```python
class Schema(LanceModel):
    id: str
    vector: Vector(openai.ndims()) = openai.VectorField(default=None)
    text: str = openai.SourceField()
```

then this completes without errors:
```python
table.add([Schema(id="foo", text="rust all the things")])
```

However, all of the vectors are filled with zeros. Instead in
add_vector_col we have to add an additional check so that the embedding
generation is called.
This commit is contained in:
Chang She
2024-03-13 14:35:08 -07:00
committed by Weston Pace
parent 723defbe7e
commit 377832e532
2 changed files with 46 additions and 4 deletions

View File

@@ -117,7 +117,8 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
for vector_column, conf in functions.items():
func = conf.function
if vector_column not in data.column_names:
no_vector_column = vector_column not in data.column_names
if no_vector_column or pc.all(pc.is_null(data[vector_column])).as_py():
col_data = func.compute_source_embeddings_with_retry(
data[conf.source_column]
)
@@ -125,9 +126,16 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
dtype = schema.field(vector_column).type
else:
dtype = pa.list_(pa.float32(), len(col_data[0]))
data = data.append_column(
pa.field(vector_column, type=dtype), pa.array(col_data, type=dtype)
)
if no_vector_column:
data = data.append_column(
pa.field(vector_column, type=dtype), pa.array(col_data, type=dtype)
)
else:
data = data.set_column(
data.column_names.index(vector_column),
pa.field(vector_column, type=dtype),
pa.array(col_data, type=dtype),
)
return data

View File

@@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from typing import List, Union
import lance
import lancedb
@@ -23,6 +24,8 @@ from lancedb.embeddings import (
EmbeddingFunctionRegistry,
with_embeddings,
)
from lancedb.embeddings.base import TextEmbeddingFunction
from lancedb.embeddings.registry import get_registry, register
from lancedb.pydantic import LanceModel, Vector
@@ -112,3 +115,34 @@ def test_embedding_function_rate_limit(tmp_path):
table.add([{"text": "hello world"}])
table.add([{"text": "hello world"}])
assert len(table) == 2
def test_add_optional_vector(tmp_path):
@register("mock-embedding")
class MockEmbeddingFunction(TextEmbeddingFunction):
def ndims(self):
return 128
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
"""
Generate the embeddings for the given texts
"""
return [np.random.randn(self.ndims()).tolist() for _ in range(len(texts))]
registry = get_registry()
model = registry.get("mock-embedding").create()
class LanceSchema(LanceModel):
id: str
vector: Vector(model.ndims()) = model.VectorField(default=None)
text: str = model.SourceField()
db = lancedb.connect(tmp_path)
tbl = db.create_table("optional_vector", schema=LanceSchema)
# add works
expected = LanceSchema(id="id", text="text")
tbl.add([expected])
assert not (np.abs(tbl.to_pandas()["vector"][0]) < 1e-6).all()