mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 15:12:53 +00:00
docs: claim LanceDB supports float16/float32/float64 for multivector (#2040)
This commit is contained in:
@@ -149,6 +149,7 @@ You can index on a column with multivector type and search on it, the query can
|
||||
where `sim` is the similarity function (e.g. cosine).
|
||||
|
||||
For now, only `cosine` metric is supported for multivector search.
|
||||
The vector value type can be `float16`, `float32` or `float64`.
|
||||
|
||||
=== "Python"
|
||||
|
||||
|
||||
@@ -1754,7 +1754,7 @@ class AsyncQuery(AsyncQueryBase):
|
||||
raise ValueError("query_vector can not be None")
|
||||
|
||||
if (
|
||||
isinstance(query_vector, list)
|
||||
isinstance(query_vector, (list, np.ndarray, pa.Array))
|
||||
and len(query_vector) > 0
|
||||
and isinstance(query_vector[0], (list, np.ndarray, pa.Array))
|
||||
):
|
||||
|
||||
@@ -17,6 +17,7 @@ def test_multivector():
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
# float16, float32, and float64 are supported
|
||||
pa.field("vector", pa.list_(pa.list_(pa.float32(), 256))),
|
||||
]
|
||||
)
|
||||
@@ -33,7 +34,7 @@ def test_multivector():
|
||||
tbl.create_index(metric="cosine")
|
||||
|
||||
# query with single vector
|
||||
query = np.random.random(256)
|
||||
query = np.random.random(256).astype(np.float16)
|
||||
tbl.search(query).to_arrow()
|
||||
|
||||
# query with multiple vectors
|
||||
@@ -51,6 +52,7 @@ async def test_multivector_async():
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
# float16, float32, and float64 are supported
|
||||
pa.field("vector", pa.list_(pa.list_(pa.float32(), 256))),
|
||||
]
|
||||
)
|
||||
@@ -72,6 +74,7 @@ async def test_multivector_async():
|
||||
|
||||
# query with multiple vectors
|
||||
query = np.random.random(size=(2, 256))
|
||||
await tbl.query().nearest_to(query).to_arrow()
|
||||
|
||||
# --8<-- [end:async_multivector]
|
||||
await db.drop_table("my_table")
|
||||
|
||||
@@ -69,7 +69,7 @@ async def table_struct_async(tmp_path) -> AsyncTable:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multivec_table() -> lancedb.table.Table:
|
||||
def multivec_table(vector_value_type=pa.float32()) -> lancedb.table.Table:
|
||||
db = lancedb.connect("memory://")
|
||||
# Generate 256 rows of data
|
||||
num_rows = 256
|
||||
@@ -85,7 +85,7 @@ def multivec_table() -> lancedb.table.Table:
|
||||
df = pa.table(
|
||||
{
|
||||
"vector": pa.array(
|
||||
vector_data, type=pa.list_(pa.list_(pa.float32(), list_size=2))
|
||||
vector_data, type=pa.list_(pa.list_(vector_value_type, list_size=2))
|
||||
),
|
||||
"id": pa.array(id_data),
|
||||
"float_field": pa.array(float_field_data),
|
||||
@@ -95,7 +95,7 @@ def multivec_table() -> lancedb.table.Table:
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def multivec_table_async(tmp_path) -> AsyncTable:
|
||||
async def multivec_table_async(vector_value_type=pa.float32()) -> AsyncTable:
|
||||
conn = await lancedb.connect_async(
|
||||
"memory://", read_consistency_interval=timedelta(seconds=0)
|
||||
)
|
||||
@@ -113,7 +113,7 @@ async def multivec_table_async(tmp_path) -> AsyncTable:
|
||||
df = pa.table(
|
||||
{
|
||||
"vector": pa.array(
|
||||
vector_data, type=pa.list_(pa.list_(pa.float32(), list_size=2))
|
||||
vector_data, type=pa.list_(pa.list_(vector_value_type, list_size=2))
|
||||
),
|
||||
"id": pa.array(id_data),
|
||||
"float_field": pa.array(float_field_data),
|
||||
@@ -231,6 +231,9 @@ async def test_distance_range_async(table_async: AsyncTable):
|
||||
assert res["_distance"].to_pylist() == [min_dist, max_dist]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"multivec_table", [pa.float16(), pa.float32(), pa.float64()], indirect=True
|
||||
)
|
||||
def test_multivector(multivec_table: lancedb.table.Table):
|
||||
# create index on multivector
|
||||
multivec_table.create_index(
|
||||
@@ -261,6 +264,9 @@ def test_multivector(multivec_table: lancedb.table.Table):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"multivec_table_async", [pa.float16(), pa.float32(), pa.float64()], indirect=True
|
||||
)
|
||||
async def test_multivector_async(multivec_table_async: AsyncTable):
|
||||
# create index on multivector
|
||||
await multivec_table_async.create_index(
|
||||
|
||||
Reference in New Issue
Block a user