feat: support multivector type (#2005)

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
BubbleCal
2025-01-14 06:10:40 +08:00
committed by GitHub
parent ce9506db71
commit 66cbf6b6c5
9 changed files with 255 additions and 72 deletions

View File

@@ -1741,12 +1741,14 @@ class AsyncQuery(AsyncQueryBase):
a default `limit` of 10 will be used.
Typically, a single vector is passed in as the query. However, you can also
pass in multiple vectors. This can be useful if you want to find the nearest
vectors to multiple query vectors. This is not expected to be faster than
making multiple queries concurrently; it is just a convenience method.
If multiple vectors are passed in then an additional column `query_index`
will be added to the results. This column will contain the index of the
query vector that the result is nearest to.
pass in multiple vectors. When multiple vectors are passed in, if the vector
column is with multivector type, then the vectors will be treated as a single
query. Or the vectors will be treated as multiple queries, this can be useful
if you want to find the nearest vectors to multiple query vectors.
This is not expected to be faster than making multiple queries concurrently;
it is just a convenience method. If multiple vectors are passed in then
an additional column `query_index` will be added to the results. This column
will contain the index of the query vector that the result is nearest to.
"""
if query_vector is None:
raise ValueError("query_vector can not be None")

View File

@@ -2856,6 +2856,8 @@ class AsyncTable:
async_query = async_query.with_row_id()
if query.vector:
# we need the schema to get the vector column type
# to determine whether the vectors is batch queries or not
async_query = (
async_query.nearest_to(query.vector)
.distance_type(query.metric)

View File

@@ -223,7 +223,7 @@ def inf_vector_column_query(schema: pa.Schema) -> str:
vector_col_count = 0
for field_name in schema.names:
field = schema.field(field_name)
if pa.types.is_fixed_size_list(field.type):
if is_vector_column(field.type):
vector_col_count += 1
if vector_col_count > 1:
raise ValueError(
@@ -231,7 +231,6 @@ def inf_vector_column_query(schema: pa.Schema) -> str:
"Please specify the vector column name "
"for vector search"
)
break
elif vector_col_count == 1:
vector_col_name = field_name
if vector_col_count == 0:
@@ -242,6 +241,29 @@ def inf_vector_column_query(schema: pa.Schema) -> str:
return vector_col_name
def is_vector_column(data_type: pa.DataType) -> bool:
"""
Check if the column is a vector column.
Parameters
----------
data_type : pa.DataType
The data type of the column.
Returns
-------
bool: True if the column is a vector column.
"""
if pa.types.is_fixed_size_list(data_type) and (
pa.types.is_floating(data_type.value_type)
or pa.types.is_uint8(data_type.value_type)
):
return True
elif pa.types.is_list(data_type):
return is_vector_column(data_type.value_type)
return False
def infer_vector_column_name(
schema: pa.Schema,
query_type: str,

View File

@@ -68,6 +68,60 @@ async def table_struct_async(tmp_path) -> AsyncTable:
return await conn.create_table("test_struct", table)
@pytest.fixture
def multivec_table() -> lancedb.table.Table:
db = lancedb.connect("memory://")
# Generate 256 rows of data
num_rows = 256
# Generate data for each column
vector_data = [
[[i, i + 1], [i + 2, i + 3]] for i in range(num_rows)
] # Adjust to match nested structure
id_data = list(range(1, num_rows + 1))
float_field_data = [float(i) for i in range(1, num_rows + 1)]
# Create the Arrow table
df = pa.table(
{
"vector": pa.array(
vector_data, type=pa.list_(pa.list_(pa.float32(), list_size=2))
),
"id": pa.array(id_data),
"float_field": pa.array(float_field_data),
}
)
return db.create_table("test", df)
@pytest_asyncio.fixture
async def multivec_table_async(tmp_path) -> AsyncTable:
conn = await lancedb.connect_async(
"memory://", read_consistency_interval=timedelta(seconds=0)
)
# Generate 256 rows of data
num_rows = 256
# Generate data for each column
vector_data = [
[[i, i + 1], [i + 2, i + 3]] for i in range(num_rows)
] # Adjust to match nested structure
id_data = list(range(1, num_rows + 1))
float_field_data = [float(i) for i in range(1, num_rows + 1)]
# Create the Arrow table
df = pa.table(
{
"vector": pa.array(
vector_data, type=pa.list_(pa.list_(pa.float32(), list_size=2))
),
"id": pa.array(id_data),
"float_field": pa.array(float_field_data),
}
)
return await conn.create_table("test_async", df)
def test_cast(table):
class TestModel(LanceModel):
vector: Vector(2)
@@ -177,6 +231,62 @@ async def test_distance_range_async(table_async: AsyncTable):
assert res["_distance"].to_pylist() == [min_dist, max_dist]
def test_multivector(multivec_table: lancedb.table.Table):
# create index on multivector
multivec_table.create_index(
metric="cosine",
vector_column_name="vector",
index_type="IVF_PQ",
num_partitions=1,
num_sub_vectors=2,
)
# query with single vector
q = [1, 2]
rs = multivec_table.search(q).to_arrow()
# query with multiple vectors
q = [[1, 2], [1, 2]]
rs2 = multivec_table.search(q).to_arrow()
assert len(rs2) == len(rs)
for i in range(2):
assert rs2["_distance"][i].as_py() == rs["_distance"][i].as_py() * 2
# can't query with vector that dim not matched
with pytest.raises(Exception):
multivec_table.search([1, 2, 3]).to_arrow()
# can't query with vector list that some dim not matched
with pytest.raises(Exception):
multivec_table.search([[1, 2], [1, 2, 3]]).to_arrow()
@pytest.mark.asyncio
async def test_multivector_async(multivec_table_async: AsyncTable):
# create index on multivector
await multivec_table_async.create_index(
"vector",
config=IvfPq(distance_type="cosine", num_partitions=1, num_sub_vectors=2),
)
# query with single vector
q = [1, 2]
rs = await multivec_table_async.query().nearest_to(q).to_arrow()
# query with multiple vectors
q = [[1, 2], [1, 2]]
rs2 = await multivec_table_async.query().nearest_to(q).to_arrow()
assert len(rs2) == len(rs)
for i in range(2):
assert rs2["_distance"][i].as_py() == rs["_distance"][i].as_py() * 2
# can't query with vector that dim not matched
with pytest.raises(Exception):
await multivec_table_async.query().nearest_to([1, 2, 3]).to_arrow()
# can't query with vector list that some dim not matched
with pytest.raises(Exception):
await multivec_table_async.query().nearest_to([[1, 2], [1, 2, 3]]).to_arrow()
def test_vector_query_with_no_limit(table):
with pytest.raises(ValueError):
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(0).select(
@@ -448,11 +558,13 @@ async def test_query_to_pandas_flatten_async(table_struct_async: AsyncTable):
@pytest.mark.asyncio
async def test_query_to_polars_async(table_async: AsyncTable):
schema = await table_async.schema()
num_columns = len(schema.names)
df = await table_async.query().to_polars()
assert df.shape == (2, 5)
assert df.shape == (2, num_columns)
df = await table_async.query().where("id < 0").to_polars()
assert df.shape == (0, 5)
assert df.shape == (0, num_columns)
@pytest.mark.asyncio