feat: support multivector type (#2005)

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
BubbleCal
2025-01-14 06:10:40 +08:00
committed by GitHub
parent ce9506db71
commit 66cbf6b6c5
9 changed files with 255 additions and 72 deletions

View File

@@ -68,6 +68,60 @@ async def table_struct_async(tmp_path) -> AsyncTable:
return await conn.create_table("test_struct", table)
@pytest.fixture
def multivec_table() -> lancedb.table.Table:
db = lancedb.connect("memory://")
# Generate 256 rows of data
num_rows = 256
# Generate data for each column
vector_data = [
[[i, i + 1], [i + 2, i + 3]] for i in range(num_rows)
] # Adjust to match nested structure
id_data = list(range(1, num_rows + 1))
float_field_data = [float(i) for i in range(1, num_rows + 1)]
# Create the Arrow table
df = pa.table(
{
"vector": pa.array(
vector_data, type=pa.list_(pa.list_(pa.float32(), list_size=2))
),
"id": pa.array(id_data),
"float_field": pa.array(float_field_data),
}
)
return db.create_table("test", df)
@pytest_asyncio.fixture
async def multivec_table_async(tmp_path) -> AsyncTable:
conn = await lancedb.connect_async(
"memory://", read_consistency_interval=timedelta(seconds=0)
)
# Generate 256 rows of data
num_rows = 256
# Generate data for each column
vector_data = [
[[i, i + 1], [i + 2, i + 3]] for i in range(num_rows)
] # Adjust to match nested structure
id_data = list(range(1, num_rows + 1))
float_field_data = [float(i) for i in range(1, num_rows + 1)]
# Create the Arrow table
df = pa.table(
{
"vector": pa.array(
vector_data, type=pa.list_(pa.list_(pa.float32(), list_size=2))
),
"id": pa.array(id_data),
"float_field": pa.array(float_field_data),
}
)
return await conn.create_table("test_async", df)
def test_cast(table):
class TestModel(LanceModel):
vector: Vector(2)
@@ -177,6 +231,62 @@ async def test_distance_range_async(table_async: AsyncTable):
assert res["_distance"].to_pylist() == [min_dist, max_dist]
def test_multivector(multivec_table: lancedb.table.Table):
# create index on multivector
multivec_table.create_index(
metric="cosine",
vector_column_name="vector",
index_type="IVF_PQ",
num_partitions=1,
num_sub_vectors=2,
)
# query with single vector
q = [1, 2]
rs = multivec_table.search(q).to_arrow()
# query with multiple vectors
q = [[1, 2], [1, 2]]
rs2 = multivec_table.search(q).to_arrow()
assert len(rs2) == len(rs)
for i in range(2):
assert rs2["_distance"][i].as_py() == rs["_distance"][i].as_py() * 2
# can't query with vector that dim not matched
with pytest.raises(Exception):
multivec_table.search([1, 2, 3]).to_arrow()
# can't query with vector list that some dim not matched
with pytest.raises(Exception):
multivec_table.search([[1, 2], [1, 2, 3]]).to_arrow()
@pytest.mark.asyncio
async def test_multivector_async(multivec_table_async: AsyncTable):
# create index on multivector
await multivec_table_async.create_index(
"vector",
config=IvfPq(distance_type="cosine", num_partitions=1, num_sub_vectors=2),
)
# query with single vector
q = [1, 2]
rs = await multivec_table_async.query().nearest_to(q).to_arrow()
# query with multiple vectors
q = [[1, 2], [1, 2]]
rs2 = await multivec_table_async.query().nearest_to(q).to_arrow()
assert len(rs2) == len(rs)
for i in range(2):
assert rs2["_distance"][i].as_py() == rs["_distance"][i].as_py() * 2
# can't query with vector that dim not matched
with pytest.raises(Exception):
await multivec_table_async.query().nearest_to([1, 2, 3]).to_arrow()
# can't query with vector list that some dim not matched
with pytest.raises(Exception):
await multivec_table_async.query().nearest_to([[1, 2], [1, 2, 3]]).to_arrow()
def test_vector_query_with_no_limit(table):
with pytest.raises(ValueError):
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(0).select(
@@ -448,11 +558,13 @@ async def test_query_to_pandas_flatten_async(table_struct_async: AsyncTable):
@pytest.mark.asyncio
async def test_query_to_polars_async(table_async: AsyncTable):
schema = await table_async.schema()
num_columns = len(schema.names)
df = await table_async.query().to_polars()
assert df.shape == (2, 5)
assert df.shape == (2, num_columns)
df = await table_async.query().where("id < 0").to_polars()
assert df.shape == (0, 5)
assert df.shape == (0, num_columns)
@pytest.mark.asyncio