fix: support list of numpy f16 floats as query vector (#1931)

User reported on Discord, when using
`table.vector_search([np.float16(1.0), np.float16(2.0), ...])`, it
yields `TypeError: 'numpy.float16' object is not iterable`
This commit is contained in:
Lei Xu
2024-12-10 16:17:28 -08:00
committed by GitHub
parent 3324e7d525
commit 347515aa51
2 changed files with 20 additions and 13 deletions

View File

@@ -3,6 +3,7 @@
import unittest.mock as mock
from datetime import timedelta
from pathlib import Path
import lancedb
from lancedb.index import IvfPq
@@ -384,3 +385,19 @@ async def test_query_to_list_async(table_async: AsyncTable):
assert len(list) == 2
assert list[0]["vector"] == [1, 2]
assert list[1]["vector"] == [3, 4]
@pytest.mark.asyncio
async def test_query_with_f16(tmp_path: Path):
db = await lancedb.connect_async(tmp_path)
f16_arr = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
df = pa.table(
{
"vector": pa.FixedSizeListArray.from_arrays(f16_arr, 2),
"id": pa.array([1, 2]),
}
)
tbl = await db.create_table("test", df)
results = await tbl.vector_search([np.float16(1), np.float16(2)]).to_pandas()
assert len(results) == 2