mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 23:12:58 +00:00
feat: support multivector type (#2005)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
@@ -68,6 +68,60 @@ async def table_struct_async(tmp_path) -> AsyncTable:
|
||||
return await conn.create_table("test_struct", table)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multivec_table() -> lancedb.table.Table:
|
||||
db = lancedb.connect("memory://")
|
||||
# Generate 256 rows of data
|
||||
num_rows = 256
|
||||
|
||||
# Generate data for each column
|
||||
vector_data = [
|
||||
[[i, i + 1], [i + 2, i + 3]] for i in range(num_rows)
|
||||
] # Adjust to match nested structure
|
||||
id_data = list(range(1, num_rows + 1))
|
||||
float_field_data = [float(i) for i in range(1, num_rows + 1)]
|
||||
|
||||
# Create the Arrow table
|
||||
df = pa.table(
|
||||
{
|
||||
"vector": pa.array(
|
||||
vector_data, type=pa.list_(pa.list_(pa.float32(), list_size=2))
|
||||
),
|
||||
"id": pa.array(id_data),
|
||||
"float_field": pa.array(float_field_data),
|
||||
}
|
||||
)
|
||||
return db.create_table("test", df)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def multivec_table_async(tmp_path) -> AsyncTable:
|
||||
conn = await lancedb.connect_async(
|
||||
"memory://", read_consistency_interval=timedelta(seconds=0)
|
||||
)
|
||||
# Generate 256 rows of data
|
||||
num_rows = 256
|
||||
|
||||
# Generate data for each column
|
||||
vector_data = [
|
||||
[[i, i + 1], [i + 2, i + 3]] for i in range(num_rows)
|
||||
] # Adjust to match nested structure
|
||||
id_data = list(range(1, num_rows + 1))
|
||||
float_field_data = [float(i) for i in range(1, num_rows + 1)]
|
||||
|
||||
# Create the Arrow table
|
||||
df = pa.table(
|
||||
{
|
||||
"vector": pa.array(
|
||||
vector_data, type=pa.list_(pa.list_(pa.float32(), list_size=2))
|
||||
),
|
||||
"id": pa.array(id_data),
|
||||
"float_field": pa.array(float_field_data),
|
||||
}
|
||||
)
|
||||
return await conn.create_table("test_async", df)
|
||||
|
||||
|
||||
def test_cast(table):
|
||||
class TestModel(LanceModel):
|
||||
vector: Vector(2)
|
||||
@@ -177,6 +231,62 @@ async def test_distance_range_async(table_async: AsyncTable):
|
||||
assert res["_distance"].to_pylist() == [min_dist, max_dist]
|
||||
|
||||
|
||||
def test_multivector(multivec_table: lancedb.table.Table):
|
||||
# create index on multivector
|
||||
multivec_table.create_index(
|
||||
metric="cosine",
|
||||
vector_column_name="vector",
|
||||
index_type="IVF_PQ",
|
||||
num_partitions=1,
|
||||
num_sub_vectors=2,
|
||||
)
|
||||
|
||||
# query with single vector
|
||||
q = [1, 2]
|
||||
rs = multivec_table.search(q).to_arrow()
|
||||
|
||||
# query with multiple vectors
|
||||
q = [[1, 2], [1, 2]]
|
||||
rs2 = multivec_table.search(q).to_arrow()
|
||||
assert len(rs2) == len(rs)
|
||||
for i in range(2):
|
||||
assert rs2["_distance"][i].as_py() == rs["_distance"][i].as_py() * 2
|
||||
|
||||
# can't query with vector that dim not matched
|
||||
with pytest.raises(Exception):
|
||||
multivec_table.search([1, 2, 3]).to_arrow()
|
||||
# can't query with vector list that some dim not matched
|
||||
with pytest.raises(Exception):
|
||||
multivec_table.search([[1, 2], [1, 2, 3]]).to_arrow()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multivector_async(multivec_table_async: AsyncTable):
|
||||
# create index on multivector
|
||||
await multivec_table_async.create_index(
|
||||
"vector",
|
||||
config=IvfPq(distance_type="cosine", num_partitions=1, num_sub_vectors=2),
|
||||
)
|
||||
|
||||
# query with single vector
|
||||
q = [1, 2]
|
||||
rs = await multivec_table_async.query().nearest_to(q).to_arrow()
|
||||
|
||||
# query with multiple vectors
|
||||
q = [[1, 2], [1, 2]]
|
||||
rs2 = await multivec_table_async.query().nearest_to(q).to_arrow()
|
||||
assert len(rs2) == len(rs)
|
||||
for i in range(2):
|
||||
assert rs2["_distance"][i].as_py() == rs["_distance"][i].as_py() * 2
|
||||
|
||||
# can't query with vector that dim not matched
|
||||
with pytest.raises(Exception):
|
||||
await multivec_table_async.query().nearest_to([1, 2, 3]).to_arrow()
|
||||
# can't query with vector list that some dim not matched
|
||||
with pytest.raises(Exception):
|
||||
await multivec_table_async.query().nearest_to([[1, 2], [1, 2, 3]]).to_arrow()
|
||||
|
||||
|
||||
def test_vector_query_with_no_limit(table):
|
||||
with pytest.raises(ValueError):
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(0).select(
|
||||
@@ -448,11 +558,13 @@ async def test_query_to_pandas_flatten_async(table_struct_async: AsyncTable):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_to_polars_async(table_async: AsyncTable):
|
||||
schema = await table_async.schema()
|
||||
num_columns = len(schema.names)
|
||||
df = await table_async.query().to_polars()
|
||||
assert df.shape == (2, 5)
|
||||
assert df.shape == (2, num_columns)
|
||||
|
||||
df = await table_async.query().where("id < 0").to_polars()
|
||||
assert df.shape == (0, 5)
|
||||
assert df.shape == (0, num_columns)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user