fix: support list of numpy f16 floats as query vector (#1931)

User reported on Discord, when using
`table.vector_search([np.float16(1.0), np.float16(2.0), ...])`, it
yields `TypeError: 'numpy.float16' object is not iterable`
This commit is contained in:
Lei Xu
2024-12-10 16:17:28 -08:00
committed by GitHub
parent 3324e7d525
commit 347515aa51
2 changed files with 20 additions and 13 deletions

View File

@@ -1,15 +1,5 @@
# Copyright 2023 LanceDB Developers # SPDX-License-Identifier: Apache-2.0
# # SPDX-FileCopyrightText: Copyright The LanceDB Authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations from __future__ import annotations
@@ -1644,7 +1634,7 @@ class AsyncQuery(AsyncQueryBase):
if ( if (
isinstance(query_vector, list) isinstance(query_vector, list)
and len(query_vector) > 0 and len(query_vector) > 0
and not isinstance(query_vector[0], (float, int)) and isinstance(query_vector[0], (list, np.ndarray, pa.Array))
): ):
# multiple have been passed # multiple have been passed
query_vectors = [AsyncQuery._query_vec_to_array(v) for v in query_vector] query_vectors = [AsyncQuery._query_vec_to_array(v) for v in query_vector]

View File

@@ -3,6 +3,7 @@
import unittest.mock as mock import unittest.mock as mock
from datetime import timedelta from datetime import timedelta
from pathlib import Path
import lancedb import lancedb
from lancedb.index import IvfPq from lancedb.index import IvfPq
@@ -384,3 +385,19 @@ async def test_query_to_list_async(table_async: AsyncTable):
assert len(list) == 2 assert len(list) == 2
assert list[0]["vector"] == [1, 2] assert list[0]["vector"] == [1, 2]
assert list[1]["vector"] == [3, 4] assert list[1]["vector"] == [3, 4]
@pytest.mark.asyncio
async def test_query_with_f16(tmp_path: Path):
db = await lancedb.connect_async(tmp_path)
f16_arr = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
df = pa.table(
{
"vector": pa.FixedSizeListArray.from_arrays(f16_arr, 2),
"id": pa.array([1, 2]),
}
)
tbl = await db.create_table("test", df)
results = await tbl.vector_search([np.float16(1), np.float16(2)]).to_pandas()
assert len(results) == 2