diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 379d4a4a..2012b765 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -1,15 +1,5 @@ -# Copyright 2023 LanceDB Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors from __future__ import annotations @@ -1644,7 +1634,7 @@ class AsyncQuery(AsyncQueryBase): if ( isinstance(query_vector, list) and len(query_vector) > 0 - and not isinstance(query_vector[0], (float, int)) + and isinstance(query_vector[0], (list, np.ndarray, pa.Array)) ): # multiple have been passed query_vectors = [AsyncQuery._query_vec_to_array(v) for v in query_vector] diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index e19ea798..d1f4bf3e 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -3,6 +3,7 @@ import unittest.mock as mock from datetime import timedelta +from pathlib import Path import lancedb from lancedb.index import IvfPq @@ -384,3 +385,19 @@ async def test_query_to_list_async(table_async: AsyncTable): assert len(list) == 2 assert list[0]["vector"] == [1, 2] assert list[1]["vector"] == [3, 4] + + +@pytest.mark.asyncio +async def test_query_with_f16(tmp_path: Path): + db = await lancedb.connect_async(tmp_path) + f16_arr = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) + + df = pa.table( + { + "vector": pa.FixedSizeListArray.from_arrays(f16_arr, 2), + "id": pa.array([1, 2]), + } + ) + tbl = await db.create_table("test", df) + results = await tbl.vector_search([np.float16(1), np.float16(2)]).to_pandas() + assert len(results) == 2