mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-06 11:52:57 +00:00
fix: support list of numpy f16 floats as query vector (#1931)
User reported on Discord, when using `table.vector_search([np.float16(1.0), np.float16(2.0), ...])`, it yields `TypeError: 'numpy.float16' object is not iterable`
This commit is contained in:
@@ -1,15 +1,5 @@
|
|||||||
# Copyright 2023 LanceDB Developers
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
#
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -1644,7 +1634,7 @@ class AsyncQuery(AsyncQueryBase):
|
|||||||
if (
|
if (
|
||||||
isinstance(query_vector, list)
|
isinstance(query_vector, list)
|
||||||
and len(query_vector) > 0
|
and len(query_vector) > 0
|
||||||
and not isinstance(query_vector[0], (float, int))
|
and isinstance(query_vector[0], (list, np.ndarray, pa.Array))
|
||||||
):
|
):
|
||||||
# multiple have been passed
|
# multiple have been passed
|
||||||
query_vectors = [AsyncQuery._query_vec_to_array(v) for v in query_vector]
|
query_vectors = [AsyncQuery._query_vec_to_array(v) for v in query_vector]
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import unittest.mock as mock
|
import unittest.mock as mock
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import lancedb
|
import lancedb
|
||||||
from lancedb.index import IvfPq
|
from lancedb.index import IvfPq
|
||||||
@@ -384,3 +385,19 @@ async def test_query_to_list_async(table_async: AsyncTable):
|
|||||||
assert len(list) == 2
|
assert len(list) == 2
|
||||||
assert list[0]["vector"] == [1, 2]
|
assert list[0]["vector"] == [1, 2]
|
||||||
assert list[1]["vector"] == [3, 4]
|
assert list[1]["vector"] == [3, 4]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_with_f16(tmp_path: Path):
|
||||||
|
db = await lancedb.connect_async(tmp_path)
|
||||||
|
f16_arr = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16)
|
||||||
|
|
||||||
|
df = pa.table(
|
||||||
|
{
|
||||||
|
"vector": pa.FixedSizeListArray.from_arrays(f16_arr, 2),
|
||||||
|
"id": pa.array([1, 2]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tbl = await db.create_table("test", df)
|
||||||
|
results = await tbl.vector_search([np.float16(1), np.float16(2)]).to_pandas()
|
||||||
|
assert len(results) == 2
|
||||||
|
|||||||
Reference in New Issue
Block a user