mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-05 19:32:56 +00:00
feat: support multivector type (#2005)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
@@ -1741,12 +1741,14 @@ class AsyncQuery(AsyncQueryBase):
|
||||
a default `limit` of 10 will be used.
|
||||
|
||||
Typically, a single vector is passed in as the query. However, you can also
|
||||
pass in multiple vectors. This can be useful if you want to find the nearest
|
||||
vectors to multiple query vectors. This is not expected to be faster than
|
||||
making multiple queries concurrently; it is just a convenience method.
|
||||
If multiple vectors are passed in then an additional column `query_index`
|
||||
will be added to the results. This column will contain the index of the
|
||||
query vector that the result is nearest to.
|
||||
pass in multiple vectors. When multiple vectors are passed in, if the vector
|
||||
column is with multivector type, then the vectors will be treated as a single
|
||||
query. Or the vectors will be treated as multiple queries, this can be useful
|
||||
if you want to find the nearest vectors to multiple query vectors.
|
||||
This is not expected to be faster than making multiple queries concurrently;
|
||||
it is just a convenience method. If multiple vectors are passed in then
|
||||
an additional column `query_index` will be added to the results. This column
|
||||
will contain the index of the query vector that the result is nearest to.
|
||||
"""
|
||||
if query_vector is None:
|
||||
raise ValueError("query_vector can not be None")
|
||||
|
||||
@@ -2856,6 +2856,8 @@ class AsyncTable:
|
||||
async_query = async_query.with_row_id()
|
||||
|
||||
if query.vector:
|
||||
# we need the schema to get the vector column type
|
||||
# to determine whether the vectors is batch queries or not
|
||||
async_query = (
|
||||
async_query.nearest_to(query.vector)
|
||||
.distance_type(query.metric)
|
||||
|
||||
@@ -223,7 +223,7 @@ def inf_vector_column_query(schema: pa.Schema) -> str:
|
||||
vector_col_count = 0
|
||||
for field_name in schema.names:
|
||||
field = schema.field(field_name)
|
||||
if pa.types.is_fixed_size_list(field.type):
|
||||
if is_vector_column(field.type):
|
||||
vector_col_count += 1
|
||||
if vector_col_count > 1:
|
||||
raise ValueError(
|
||||
@@ -231,7 +231,6 @@ def inf_vector_column_query(schema: pa.Schema) -> str:
|
||||
"Please specify the vector column name "
|
||||
"for vector search"
|
||||
)
|
||||
break
|
||||
elif vector_col_count == 1:
|
||||
vector_col_name = field_name
|
||||
if vector_col_count == 0:
|
||||
@@ -242,6 +241,29 @@ def inf_vector_column_query(schema: pa.Schema) -> str:
|
||||
return vector_col_name
|
||||
|
||||
|
||||
def is_vector_column(data_type: pa.DataType) -> bool:
|
||||
"""
|
||||
Check if the column is a vector column.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data_type : pa.DataType
|
||||
The data type of the column.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool: True if the column is a vector column.
|
||||
"""
|
||||
if pa.types.is_fixed_size_list(data_type) and (
|
||||
pa.types.is_floating(data_type.value_type)
|
||||
or pa.types.is_uint8(data_type.value_type)
|
||||
):
|
||||
return True
|
||||
elif pa.types.is_list(data_type):
|
||||
return is_vector_column(data_type.value_type)
|
||||
return False
|
||||
|
||||
|
||||
def infer_vector_column_name(
|
||||
schema: pa.Schema,
|
||||
query_type: str,
|
||||
|
||||
@@ -68,6 +68,60 @@ async def table_struct_async(tmp_path) -> AsyncTable:
|
||||
return await conn.create_table("test_struct", table)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multivec_table() -> lancedb.table.Table:
|
||||
db = lancedb.connect("memory://")
|
||||
# Generate 256 rows of data
|
||||
num_rows = 256
|
||||
|
||||
# Generate data for each column
|
||||
vector_data = [
|
||||
[[i, i + 1], [i + 2, i + 3]] for i in range(num_rows)
|
||||
] # Adjust to match nested structure
|
||||
id_data = list(range(1, num_rows + 1))
|
||||
float_field_data = [float(i) for i in range(1, num_rows + 1)]
|
||||
|
||||
# Create the Arrow table
|
||||
df = pa.table(
|
||||
{
|
||||
"vector": pa.array(
|
||||
vector_data, type=pa.list_(pa.list_(pa.float32(), list_size=2))
|
||||
),
|
||||
"id": pa.array(id_data),
|
||||
"float_field": pa.array(float_field_data),
|
||||
}
|
||||
)
|
||||
return db.create_table("test", df)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def multivec_table_async(tmp_path) -> AsyncTable:
|
||||
conn = await lancedb.connect_async(
|
||||
"memory://", read_consistency_interval=timedelta(seconds=0)
|
||||
)
|
||||
# Generate 256 rows of data
|
||||
num_rows = 256
|
||||
|
||||
# Generate data for each column
|
||||
vector_data = [
|
||||
[[i, i + 1], [i + 2, i + 3]] for i in range(num_rows)
|
||||
] # Adjust to match nested structure
|
||||
id_data = list(range(1, num_rows + 1))
|
||||
float_field_data = [float(i) for i in range(1, num_rows + 1)]
|
||||
|
||||
# Create the Arrow table
|
||||
df = pa.table(
|
||||
{
|
||||
"vector": pa.array(
|
||||
vector_data, type=pa.list_(pa.list_(pa.float32(), list_size=2))
|
||||
),
|
||||
"id": pa.array(id_data),
|
||||
"float_field": pa.array(float_field_data),
|
||||
}
|
||||
)
|
||||
return await conn.create_table("test_async", df)
|
||||
|
||||
|
||||
def test_cast(table):
|
||||
class TestModel(LanceModel):
|
||||
vector: Vector(2)
|
||||
@@ -177,6 +231,62 @@ async def test_distance_range_async(table_async: AsyncTable):
|
||||
assert res["_distance"].to_pylist() == [min_dist, max_dist]
|
||||
|
||||
|
||||
def test_multivector(multivec_table: lancedb.table.Table):
|
||||
# create index on multivector
|
||||
multivec_table.create_index(
|
||||
metric="cosine",
|
||||
vector_column_name="vector",
|
||||
index_type="IVF_PQ",
|
||||
num_partitions=1,
|
||||
num_sub_vectors=2,
|
||||
)
|
||||
|
||||
# query with single vector
|
||||
q = [1, 2]
|
||||
rs = multivec_table.search(q).to_arrow()
|
||||
|
||||
# query with multiple vectors
|
||||
q = [[1, 2], [1, 2]]
|
||||
rs2 = multivec_table.search(q).to_arrow()
|
||||
assert len(rs2) == len(rs)
|
||||
for i in range(2):
|
||||
assert rs2["_distance"][i].as_py() == rs["_distance"][i].as_py() * 2
|
||||
|
||||
# can't query with vector that dim not matched
|
||||
with pytest.raises(Exception):
|
||||
multivec_table.search([1, 2, 3]).to_arrow()
|
||||
# can't query with vector list that some dim not matched
|
||||
with pytest.raises(Exception):
|
||||
multivec_table.search([[1, 2], [1, 2, 3]]).to_arrow()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multivector_async(multivec_table_async: AsyncTable):
|
||||
# create index on multivector
|
||||
await multivec_table_async.create_index(
|
||||
"vector",
|
||||
config=IvfPq(distance_type="cosine", num_partitions=1, num_sub_vectors=2),
|
||||
)
|
||||
|
||||
# query with single vector
|
||||
q = [1, 2]
|
||||
rs = await multivec_table_async.query().nearest_to(q).to_arrow()
|
||||
|
||||
# query with multiple vectors
|
||||
q = [[1, 2], [1, 2]]
|
||||
rs2 = await multivec_table_async.query().nearest_to(q).to_arrow()
|
||||
assert len(rs2) == len(rs)
|
||||
for i in range(2):
|
||||
assert rs2["_distance"][i].as_py() == rs["_distance"][i].as_py() * 2
|
||||
|
||||
# can't query with vector that dim not matched
|
||||
with pytest.raises(Exception):
|
||||
await multivec_table_async.query().nearest_to([1, 2, 3]).to_arrow()
|
||||
# can't query with vector list that some dim not matched
|
||||
with pytest.raises(Exception):
|
||||
await multivec_table_async.query().nearest_to([[1, 2], [1, 2, 3]]).to_arrow()
|
||||
|
||||
|
||||
def test_vector_query_with_no_limit(table):
|
||||
with pytest.raises(ValueError):
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(0).select(
|
||||
@@ -448,11 +558,13 @@ async def test_query_to_pandas_flatten_async(table_struct_async: AsyncTable):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_to_polars_async(table_async: AsyncTable):
|
||||
schema = await table_async.schema()
|
||||
num_columns = len(schema.names)
|
||||
df = await table_async.query().to_polars()
|
||||
assert df.shape == (2, 5)
|
||||
assert df.shape == (2, num_columns)
|
||||
|
||||
df = await table_async.query().where("id < 0").to_polars()
|
||||
assert df.shape == (0, 5)
|
||||
assert df.shape == (0, num_columns)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user