mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 23:12:58 +00:00
fix: support list of numpy f16 floats as query vector (#1931)
User reported on Discord, when using `table.vector_search([np.float16(1.0), np.float16(2.0), ...])`, it yields `TypeError: 'numpy.float16' object is not iterable`
This commit is contained in:
@@ -1,15 +1,5 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -1644,7 +1634,7 @@ class AsyncQuery(AsyncQueryBase):
|
||||
if (
|
||||
isinstance(query_vector, list)
|
||||
and len(query_vector) > 0
|
||||
and not isinstance(query_vector[0], (float, int))
|
||||
and isinstance(query_vector[0], (list, np.ndarray, pa.Array))
|
||||
):
|
||||
# multiple have been passed
|
||||
query_vectors = [AsyncQuery._query_vec_to_array(v) for v in query_vector]
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import unittest.mock as mock
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import lancedb
|
||||
from lancedb.index import IvfPq
|
||||
@@ -384,3 +385,19 @@ async def test_query_to_list_async(table_async: AsyncTable):
|
||||
assert len(list) == 2
|
||||
assert list[0]["vector"] == [1, 2]
|
||||
assert list[1]["vector"] == [3, 4]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_with_f16(tmp_path: Path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
f16_arr = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
|
||||
|
||||
df = pa.table(
|
||||
{
|
||||
"vector": pa.FixedSizeListArray.from_arrays(f16_arr, 2),
|
||||
"id": pa.array([1, 2]),
|
||||
}
|
||||
)
|
||||
tbl = await db.create_table("test", df)
|
||||
results = await tbl.vector_search([np.float16(1), np.float16(2)]).to_pandas()
|
||||
assert len(results) == 2
|
||||
|
||||
Reference in New Issue
Block a user