feat(python): check vector query is not None (#1847)

Fix the type hints of `nearest_to` method, and raise `ValueError` when
the input is None
This commit is contained in:
Lei Xu
2024-11-18 14:15:22 -08:00
committed by GitHub
parent cc72050206
commit 267aa83bf8
3 changed files with 15 additions and 15 deletions

View File

@@ -1495,7 +1495,8 @@ class AsyncQuery(AsyncQueryBase):
return pa.array(vec)
def nearest_to(
self, query_vector: Optional[Union[VEC, Tuple, List[VEC]]] = None
self,
query_vector: Union[VEC, Tuple, List[VEC]],
) -> AsyncVectorQuery:
"""
Find the nearest vectors to the given query vector.
@@ -1542,6 +1543,9 @@ class AsyncQuery(AsyncQueryBase):
will be added to the results. This column will contain the index of the
query vector that the result is nearest to.
"""
if query_vector is None:
raise ValueError("query_vector can not be None")
if (
isinstance(query_vector, list)
and len(query_vector) > 0
@@ -1618,7 +1622,7 @@ class AsyncVectorQuery(AsyncQueryBase):
"""
Set the number of partitions to search (probe)
This argument is only used when the vector column has an IVF PQ index.
This argument is only used when the vector column has an IVF-based index.
If there is no index then this value is ignored.
The IVF stage of IVF PQ divides the input into partitions (clusters) of

View File

@@ -2697,7 +2697,7 @@ class AsyncTable:
def vector_search(
self,
query_vector: Optional[Union[VEC, Tuple]] = None,
query_vector: Union[VEC, Tuple],
) -> AsyncVectorQuery:
"""
Search the table with a given query vector.

View File

@@ -1,15 +1,5 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import unittest.mock as mock
from datetime import timedelta
@@ -342,6 +332,12 @@ async def test_query_to_pandas_async(table_async: AsyncTable):
assert df.shape == (0, 4)
@pytest.mark.asyncio
async def test_none_query(table_async: AsyncTable):
with pytest.raises(ValueError):
await table_async.query().nearest_to(None).to_arrow()
@pytest.mark.asyncio
async def test_fast_search_async(tmp_path):
db = await lancedb.connect_async(tmp_path)