docs: claim LanceDB supports float16/float32/float64 for multivector (#2040)

This commit is contained in:
BubbleCal
2025-01-21 07:04:15 +08:00
committed by GitHub
parent f059372137
commit 214d0debf5
4 changed files with 16 additions and 6 deletions

View File

@@ -1754,7 +1754,7 @@ class AsyncQuery(AsyncQueryBase):
raise ValueError("query_vector can not be None")
if (
isinstance(query_vector, list)
isinstance(query_vector, (list, np.ndarray, pa.Array))
and len(query_vector) > 0
and isinstance(query_vector[0], (list, np.ndarray, pa.Array))
):

View File

@@ -17,6 +17,7 @@ def test_multivector():
schema = pa.schema(
[
pa.field("id", pa.int64()),
# float16, float32, and float64 are supported
pa.field("vector", pa.list_(pa.list_(pa.float32(), 256))),
]
)
@@ -33,7 +34,7 @@ def test_multivector():
tbl.create_index(metric="cosine")
# query with single vector
query = np.random.random(256)
query = np.random.random(256).astype(np.float16)
tbl.search(query).to_arrow()
# query with multiple vectors
@@ -51,6 +52,7 @@ async def test_multivector_async():
schema = pa.schema(
[
pa.field("id", pa.int64()),
# float16, float32, and float64 are supported
pa.field("vector", pa.list_(pa.list_(pa.float32(), 256))),
]
)
@@ -72,6 +74,7 @@ async def test_multivector_async():
# query with multiple vectors
query = np.random.random(size=(2, 256))
await tbl.query().nearest_to(query).to_arrow()
# --8<-- [end:async_multivector]
await db.drop_table("my_table")

View File

@@ -69,7 +69,7 @@ async def table_struct_async(tmp_path) -> AsyncTable:
@pytest.fixture
def multivec_table() -> lancedb.table.Table:
def multivec_table(vector_value_type=pa.float32()) -> lancedb.table.Table:
db = lancedb.connect("memory://")
# Generate 256 rows of data
num_rows = 256
@@ -85,7 +85,7 @@ def multivec_table() -> lancedb.table.Table:
df = pa.table(
{
"vector": pa.array(
vector_data, type=pa.list_(pa.list_(pa.float32(), list_size=2))
vector_data, type=pa.list_(pa.list_(vector_value_type, list_size=2))
),
"id": pa.array(id_data),
"float_field": pa.array(float_field_data),
@@ -95,7 +95,7 @@ def multivec_table() -> lancedb.table.Table:
@pytest_asyncio.fixture
async def multivec_table_async(tmp_path) -> AsyncTable:
async def multivec_table_async(vector_value_type=pa.float32()) -> AsyncTable:
conn = await lancedb.connect_async(
"memory://", read_consistency_interval=timedelta(seconds=0)
)
@@ -113,7 +113,7 @@ async def multivec_table_async(tmp_path) -> AsyncTable:
df = pa.table(
{
"vector": pa.array(
vector_data, type=pa.list_(pa.list_(pa.float32(), list_size=2))
vector_data, type=pa.list_(pa.list_(vector_value_type, list_size=2))
),
"id": pa.array(id_data),
"float_field": pa.array(float_field_data),
@@ -231,6 +231,9 @@ async def test_distance_range_async(table_async: AsyncTable):
assert res["_distance"].to_pylist() == [min_dist, max_dist]
@pytest.mark.parametrize(
"multivec_table", [pa.float16(), pa.float32(), pa.float64()], indirect=True
)
def test_multivector(multivec_table: lancedb.table.Table):
# create index on multivector
multivec_table.create_index(
@@ -261,6 +264,9 @@ def test_multivector(multivec_table: lancedb.table.Table):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"multivec_table_async", [pa.float16(), pa.float32(), pa.float64()], indirect=True
)
async def test_multivector_async(multivec_table_async: AsyncTable):
# create index on multivector
await multivec_table_async.create_index(