mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 22:59:57 +00:00
fix: support list of numpy f16 floats as query vector (#1931)
User reported on Discord, when using `table.vector_search([np.float16(1.0), np.float16(2.0), ...])`, it yields `TypeError: 'numpy.float16' object is not iterable`
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
import unittest.mock as mock
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import lancedb
|
||||
from lancedb.index import IvfPq
|
||||
@@ -384,3 +385,19 @@ async def test_query_to_list_async(table_async: AsyncTable):
|
||||
assert len(list) == 2
|
||||
assert list[0]["vector"] == [1, 2]
|
||||
assert list[1]["vector"] == [3, 4]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_with_f16(tmp_path: Path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
f16_arr = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
|
||||
|
||||
df = pa.table(
|
||||
{
|
||||
"vector": pa.FixedSizeListArray.from_arrays(f16_arr, 2),
|
||||
"id": pa.array([1, 2]),
|
||||
}
|
||||
)
|
||||
tbl = await db.create_table("test", df)
|
||||
results = await tbl.vector_search([np.float16(1), np.float16(2)]).to_pandas()
|
||||
assert len(results) == 2
|
||||
|
||||
Reference in New Issue
Block a user