mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-24 22:09:58 +00:00
This PR adds ColPali support with ColPaliEmbeddings class (tagged "colpali") using ColQwen2.5 for multi-vector text/image embeddings. Also added MultiVector Pydantic type to handle the vector lists. I've added some integration test for the embedding model and some unit test for the new Pydantic type. Could be a template for other ColPali variants as well. or until transformers🤗 starts supporting it. Still `TODO`: - [ ] Documentation - [ ] Add an example _Could also allow Image as query, but didn't work well when testing it._ [ColPali-Engine](https://github.com/illuin-tech/colpali) version: 0.3.9.dev17+g3faee24 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced support for ColPali-based multimodal multi-vector embeddings for both text and images. - Added a new embedding class for generating multi-vector embeddings, configurable for various model and processing options. - Added a new Pydantic type for multi-vector embeddings, supporting validation and schema generation for lists of fixed-dimension vectors. - **Bug Fixes** - Ensured proper asynchronous index creation in query tests for improved reliability. - **Tests** - Added integration tests for ColPali embeddings, including text-to-image search and validation of multi-vector fields. - Added comprehensive tests for the new multi-vector Pydantic type, covering schema, validation, and default value behavior. - **Chores** - Updated optional dependencies to include the ColPali engine. - Added utility to check for availability of flash attention support. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
415 lines
12 KiB
Python
415 lines
12 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
|
|
import json
|
|
import sys
|
|
from datetime import date, datetime
|
|
from typing import List, Optional, Tuple
|
|
|
|
import pyarrow as pa
|
|
import pydantic
|
|
import pytest
|
|
from lancedb.pydantic import (
|
|
PYDANTIC_VERSION,
|
|
LanceModel,
|
|
Vector,
|
|
pydantic_to_schema,
|
|
MultiVector,
|
|
)
|
|
from pydantic import BaseModel
|
|
from pydantic import Field
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
sys.version_info < (3, 9),
|
|
reason="using native type alias requires python3.9 or higher",
|
|
)
|
|
def test_pydantic_to_arrow():
|
|
class StructModel(pydantic.BaseModel):
|
|
a: str
|
|
b: Optional[float]
|
|
|
|
class TestModel(pydantic.BaseModel):
|
|
id: int
|
|
s: str
|
|
vec: list[float]
|
|
li: list[int]
|
|
lili: list[list[float]]
|
|
litu: list[tuple[float, float]]
|
|
opt: Optional[str] = None
|
|
st: StructModel
|
|
dt: date
|
|
dtt: datetime
|
|
dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"})
|
|
# d: dict
|
|
|
|
# TODO: test we can actually convert the model into data.
|
|
# m = TestModel(
|
|
# id=1,
|
|
# s="hello",
|
|
# vec=[1.0, 2.0, 3.0],
|
|
# li=[2, 3, 4],
|
|
# lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]],
|
|
# litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)],
|
|
# st=StructModel(a="a", b=1.0),
|
|
# dt=date.today(),
|
|
# dtt=datetime.now(),
|
|
# dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
|
|
# )
|
|
|
|
schema = pydantic_to_schema(TestModel)
|
|
|
|
expect_schema = pa.schema(
|
|
[
|
|
pa.field("id", pa.int64(), False),
|
|
pa.field("s", pa.utf8(), False),
|
|
pa.field("vec", pa.list_(pa.float64()), False),
|
|
pa.field("li", pa.list_(pa.int64()), False),
|
|
pa.field("lili", pa.list_(pa.list_(pa.float64())), False),
|
|
pa.field("litu", pa.list_(pa.list_(pa.float64())), False),
|
|
pa.field("opt", pa.utf8(), True),
|
|
pa.field(
|
|
"st",
|
|
pa.struct(
|
|
[pa.field("a", pa.utf8(), False), pa.field("b", pa.float64(), True)]
|
|
),
|
|
False,
|
|
),
|
|
pa.field("dt", pa.date32(), False),
|
|
pa.field("dtt", pa.timestamp("us"), False),
|
|
pa.field("dt_with_tz", pa.timestamp("us", tz="Asia/Shanghai"), False),
|
|
]
|
|
)
|
|
assert schema == expect_schema
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
sys.version_info < (3, 10),
|
|
reason="using | type syntax requires python3.10 or higher",
|
|
)
|
|
def test_optional_types_py310():
|
|
class TestModel(pydantic.BaseModel):
|
|
a: str | None
|
|
b: None | str
|
|
c: Optional[str]
|
|
|
|
schema = pydantic_to_schema(TestModel)
|
|
|
|
expect_schema = pa.schema(
|
|
[
|
|
pa.field("a", pa.utf8(), True),
|
|
pa.field("b", pa.utf8(), True),
|
|
pa.field("c", pa.utf8(), True),
|
|
]
|
|
)
|
|
assert schema == expect_schema
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
sys.version_info > (3, 8),
|
|
reason="using native type alias requires python3.9 or higher",
|
|
)
|
|
def test_pydantic_to_arrow_py38():
|
|
class StructModel(pydantic.BaseModel):
|
|
a: str
|
|
b: Optional[float]
|
|
|
|
class TestModel(pydantic.BaseModel):
|
|
id: int
|
|
s: str
|
|
vec: List[float]
|
|
li: List[int]
|
|
lili: List[List[float]]
|
|
litu: List[Tuple[float, float]]
|
|
opt: Optional[str] = None
|
|
st: StructModel
|
|
dt: date
|
|
dtt: datetime
|
|
dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"})
|
|
# d: dict
|
|
|
|
# TODO: test we can actually convert the model to Arrow data.
|
|
# m = TestModel(
|
|
# id=1,
|
|
# s="hello",
|
|
# vec=[1.0, 2.0, 3.0],
|
|
# li=[2, 3, 4],
|
|
# lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]],
|
|
# litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)],
|
|
# st=StructModel(a="a", b=1.0),
|
|
# dt=date.today(),
|
|
# dtt=datetime.now(),
|
|
# dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
|
|
# )
|
|
|
|
schema = pydantic_to_schema(TestModel)
|
|
|
|
expect_schema = pa.schema(
|
|
[
|
|
pa.field("id", pa.int64(), False),
|
|
pa.field("s", pa.utf8(), False),
|
|
pa.field("vec", pa.list_(pa.float64()), False),
|
|
pa.field("li", pa.list_(pa.int64()), False),
|
|
pa.field("lili", pa.list_(pa.list_(pa.float64())), False),
|
|
pa.field("litu", pa.list_(pa.list_(pa.float64())), False),
|
|
pa.field("opt", pa.utf8(), True),
|
|
pa.field(
|
|
"st",
|
|
pa.struct(
|
|
[pa.field("a", pa.utf8(), False), pa.field("b", pa.float64(), True)]
|
|
),
|
|
False,
|
|
),
|
|
pa.field("dt", pa.date32(), False),
|
|
pa.field("dtt", pa.timestamp("us"), False),
|
|
pa.field("dt_with_tz", pa.timestamp("us", tz="Asia/Shanghai"), False),
|
|
]
|
|
)
|
|
assert schema == expect_schema
|
|
|
|
|
|
def test_nullable_vector():
|
|
class NullableModel(pydantic.BaseModel):
|
|
vec: Vector(16, nullable=False)
|
|
|
|
schema = pydantic_to_schema(NullableModel)
|
|
assert schema == pa.schema([pa.field("vec", pa.list_(pa.float32(), 16), False)])
|
|
|
|
class DefaultModel(pydantic.BaseModel):
|
|
vec: Vector(16)
|
|
|
|
schema = pydantic_to_schema(DefaultModel)
|
|
assert schema == pa.schema([pa.field("vec", pa.list_(pa.float32(), 16), True)])
|
|
|
|
class NotNullableModel(pydantic.BaseModel):
|
|
vec: Vector(16)
|
|
|
|
schema = pydantic_to_schema(NotNullableModel)
|
|
assert schema == pa.schema([pa.field("vec", pa.list_(pa.float32(), 16), True)])
|
|
|
|
|
|
def test_fixed_size_list_field():
|
|
class TestModel(pydantic.BaseModel):
|
|
vec: Vector(16)
|
|
li: List[int]
|
|
|
|
data = TestModel(vec=list(range(16)), li=[1, 2, 3])
|
|
if PYDANTIC_VERSION.major >= 2:
|
|
assert json.loads(data.model_dump_json()) == {
|
|
"vec": list(range(16)),
|
|
"li": [1, 2, 3],
|
|
}
|
|
else:
|
|
assert data.dict() == {
|
|
"vec": list(range(16)),
|
|
"li": [1, 2, 3],
|
|
}
|
|
|
|
schema = pydantic_to_schema(TestModel)
|
|
assert schema == pa.schema(
|
|
[
|
|
pa.field("vec", pa.list_(pa.float32(), 16)),
|
|
pa.field("li", pa.list_(pa.int64()), False),
|
|
]
|
|
)
|
|
|
|
if PYDANTIC_VERSION.major >= 2:
|
|
json_schema = TestModel.model_json_schema()
|
|
else:
|
|
json_schema = TestModel.schema()
|
|
|
|
assert json_schema == {
|
|
"properties": {
|
|
"vec": {
|
|
"items": {"type": "number"},
|
|
"maxItems": 16,
|
|
"minItems": 16,
|
|
"title": "Vec",
|
|
"type": "array",
|
|
},
|
|
"li": {"items": {"type": "integer"}, "title": "Li", "type": "array"},
|
|
},
|
|
"required": ["vec", "li"],
|
|
"title": "TestModel",
|
|
"type": "object",
|
|
}
|
|
|
|
|
|
def test_fixed_size_list_validation():
|
|
class TestModel(pydantic.BaseModel):
|
|
vec: Vector(8)
|
|
|
|
with pytest.raises(pydantic.ValidationError):
|
|
TestModel(vec=range(9))
|
|
|
|
with pytest.raises(pydantic.ValidationError):
|
|
TestModel(vec=range(7))
|
|
|
|
TestModel(vec=range(8))
|
|
|
|
|
|
def test_lance_model():
|
|
class TestModel(LanceModel):
|
|
vector: Vector(16) = Field(default=[0.0] * 16)
|
|
li: List[int] = Field(default=[1, 2, 3])
|
|
|
|
schema = pydantic_to_schema(TestModel)
|
|
assert schema == TestModel.to_arrow_schema()
|
|
assert TestModel.field_names() == ["vector", "li"]
|
|
|
|
t = TestModel()
|
|
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])
|
|
|
|
|
|
def test_optional_nested_model():
|
|
class WAMedia(BaseModel):
|
|
url: str
|
|
mimetype: str
|
|
filename: Optional[str]
|
|
error: Optional[str]
|
|
data: bytes
|
|
|
|
class WALocation(BaseModel):
|
|
description: Optional[str]
|
|
latitude: str
|
|
longitude: str
|
|
|
|
class ReplyToMessage(BaseModel):
|
|
id: str
|
|
participant: str
|
|
body: str
|
|
|
|
class Message(BaseModel):
|
|
id: str
|
|
timestamp: int
|
|
from_: str
|
|
fromMe: bool
|
|
to: str
|
|
body: str
|
|
hasMedia: Optional[bool]
|
|
media: WAMedia
|
|
mediaUrl: Optional[str]
|
|
ack: Optional[int]
|
|
ackName: Optional[str]
|
|
author: Optional[str]
|
|
location: Optional[WALocation]
|
|
vCards: Optional[List[str]]
|
|
replyTo: Optional[ReplyToMessage]
|
|
|
|
class AnyEvent(LanceModel):
|
|
id: str
|
|
session: str
|
|
metadata: Optional[str] = None
|
|
engine: str
|
|
event: str
|
|
|
|
class MessageEvent(AnyEvent):
|
|
payload: Message
|
|
|
|
schema = pydantic_to_schema(MessageEvent)
|
|
|
|
payload = schema.field("payload")
|
|
assert payload.type == pa.struct(
|
|
[
|
|
pa.field("id", pa.utf8(), False),
|
|
pa.field("timestamp", pa.int64(), False),
|
|
pa.field("from_", pa.utf8(), False),
|
|
pa.field("fromMe", pa.bool_(), False),
|
|
pa.field("to", pa.utf8(), False),
|
|
pa.field("body", pa.utf8(), False),
|
|
pa.field("hasMedia", pa.bool_(), True),
|
|
pa.field(
|
|
"media",
|
|
pa.struct(
|
|
[
|
|
pa.field("url", pa.utf8(), False),
|
|
pa.field("mimetype", pa.utf8(), False),
|
|
pa.field("filename", pa.utf8(), True),
|
|
pa.field("error", pa.utf8(), True),
|
|
pa.field("data", pa.binary(), False),
|
|
]
|
|
),
|
|
False,
|
|
),
|
|
pa.field("mediaUrl", pa.utf8(), True),
|
|
pa.field("ack", pa.int64(), True),
|
|
pa.field("ackName", pa.utf8(), True),
|
|
pa.field("author", pa.utf8(), True),
|
|
pa.field(
|
|
"location",
|
|
pa.struct(
|
|
[
|
|
pa.field("description", pa.utf8(), True),
|
|
pa.field("latitude", pa.utf8(), False),
|
|
pa.field("longitude", pa.utf8(), False),
|
|
]
|
|
),
|
|
True, # Optional
|
|
),
|
|
pa.field("vCards", pa.list_(pa.utf8()), True),
|
|
pa.field(
|
|
"replyTo",
|
|
pa.struct(
|
|
[
|
|
pa.field("id", pa.utf8(), False),
|
|
pa.field("participant", pa.utf8(), False),
|
|
pa.field("body", pa.utf8(), False),
|
|
]
|
|
),
|
|
True,
|
|
),
|
|
]
|
|
)
|
|
|
|
|
|
def test_multi_vector():
|
|
class TestModel(pydantic.BaseModel):
|
|
vec: MultiVector(8)
|
|
|
|
schema = pydantic_to_schema(TestModel)
|
|
assert schema == pa.schema(
|
|
[pa.field("vec", pa.list_(pa.list_(pa.float32(), 8)), True)]
|
|
)
|
|
|
|
with pytest.raises(pydantic.ValidationError):
|
|
TestModel(vec=[[1.0] * 7])
|
|
|
|
with pytest.raises(pydantic.ValidationError):
|
|
TestModel(vec=[[1.0] * 9])
|
|
|
|
TestModel(vec=[[1.0] * 8])
|
|
TestModel(vec=[[1.0] * 8, [2.0] * 8])
|
|
|
|
TestModel(vec=[])
|
|
|
|
|
|
def test_multi_vector_nullable():
|
|
class NullableModel(pydantic.BaseModel):
|
|
vec: MultiVector(16, nullable=False)
|
|
|
|
schema = pydantic_to_schema(NullableModel)
|
|
assert schema == pa.schema(
|
|
[pa.field("vec", pa.list_(pa.list_(pa.float32(), 16)), False)]
|
|
)
|
|
|
|
class DefaultModel(pydantic.BaseModel):
|
|
vec: MultiVector(16)
|
|
|
|
schema = pydantic_to_schema(DefaultModel)
|
|
assert schema == pa.schema(
|
|
[pa.field("vec", pa.list_(pa.list_(pa.float32(), 16)), True)]
|
|
)
|
|
|
|
|
|
def test_multi_vector_in_lance_model():
|
|
class TestModel(LanceModel):
|
|
id: int
|
|
vectors: MultiVector(16) = Field(default=[[0.0] * 16])
|
|
|
|
schema = pydantic_to_schema(TestModel)
|
|
assert schema == TestModel.to_arrow_schema()
|
|
assert TestModel.field_names() == ["id", "vectors"]
|
|
|
|
t = TestModel(id=1)
|
|
assert t.vectors == [[0.0] * 16]
|