docs: claim LanceDB supports float16/float32/float64 for multivector (#2040)

This commit is contained in:
BubbleCal
2025-01-21 07:04:15 +08:00
committed by GitHub
parent f059372137
commit 214d0debf5
4 changed files with 16 additions and 6 deletions

View File

@@ -69,7 +69,7 @@ async def table_struct_async(tmp_path) -> AsyncTable:
@pytest.fixture
def multivec_table() -> lancedb.table.Table:
def multivec_table(vector_value_type=pa.float32()) -> lancedb.table.Table:
db = lancedb.connect("memory://")
# Generate 256 rows of data
num_rows = 256
@@ -85,7 +85,7 @@ def multivec_table() -> lancedb.table.Table:
df = pa.table(
{
"vector": pa.array(
vector_data, type=pa.list_(pa.list_(pa.float32(), list_size=2))
vector_data, type=pa.list_(pa.list_(vector_value_type, list_size=2))
),
"id": pa.array(id_data),
"float_field": pa.array(float_field_data),
@@ -95,7 +95,7 @@ def multivec_table() -> lancedb.table.Table:
@pytest_asyncio.fixture
async def multivec_table_async(tmp_path) -> AsyncTable:
async def multivec_table_async(vector_value_type=pa.float32()) -> AsyncTable:
conn = await lancedb.connect_async(
"memory://", read_consistency_interval=timedelta(seconds=0)
)
@@ -113,7 +113,7 @@ async def multivec_table_async(tmp_path) -> AsyncTable:
df = pa.table(
{
"vector": pa.array(
vector_data, type=pa.list_(pa.list_(pa.float32(), list_size=2))
vector_data, type=pa.list_(pa.list_(vector_value_type, list_size=2))
),
"id": pa.array(id_data),
"float_field": pa.array(float_field_data),
@@ -231,6 +231,9 @@ async def test_distance_range_async(table_async: AsyncTable):
assert res["_distance"].to_pylist() == [min_dist, max_dist]
@pytest.mark.parametrize(
"multivec_table", [pa.float16(), pa.float32(), pa.float64()], indirect=True
)
def test_multivector(multivec_table: lancedb.table.Table):
# create index on multivector
multivec_table.create_index(
@@ -261,6 +264,9 @@ def test_multivector(multivec_table: lancedb.table.Table):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"multivec_table_async", [pa.float16(), pa.float32(), pa.float64()], indirect=True
)
async def test_multivector_async(multivec_table_async: AsyncTable):
# create index on multivector
await multivec_table_async.create_index(