From 2fdcb307eb03094f0f00c784def3918e1145eeb1 Mon Sep 17 00:00:00 2001 From: Chang She <759245+changhiskhan@users.noreply.github.com> Date: Sat, 15 Jul 2023 03:47:42 +0800 Subject: [PATCH] [python] Fix a few minor bugs (#304) --- python/lancedb/db.py | 21 +++++++++------------ python/lancedb/pydantic.py | 8 +++++--- python/lancedb/query.py | 1 + python/tests/test_db.py | 4 ++++ python/tests/test_query.py | 1 + 5 files changed, 20 insertions(+), 15 deletions(-) diff --git a/python/lancedb/db.py b/python/lancedb/db.py index 3ae286bb..721091a4 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -356,18 +356,15 @@ class LanceDBConnection(DBConnection): if mode.lower() not in ["create", "overwrite"]: raise ValueError("mode must be either 'create' or 'overwrite'") - if data is not None: - tbl = LanceTable.create( - self, - name, - data, - schema, - mode=mode, - on_bad_vectors=on_bad_vectors, - fill_value=fill_value, - ) - else: - tbl = LanceTable.open(self, name) + tbl = LanceTable.create( + self, + name, + data, + schema, + mode=mode, + on_bad_vectors=on_bad_vectors, + fill_value=fill_value, + ) return tbl def open_table(self, name: str) -> LanceTable: diff --git a/python/lancedb/pydantic.py b/python/lancedb/pydantic.py index 64bc02cc..2584d075 100644 --- a/python/lancedb/pydantic.py +++ b/python/lancedb/pydantic.py @@ -18,7 +18,7 @@ from __future__ import annotations import inspect import sys import types -from abc import ABC, abstractstaticmethod +from abc import ABC, abstractmethod from typing import Any, List, Type, Union, _GenericAlias import pyarrow as pa @@ -27,11 +27,13 @@ from pydantic_core import CoreSchema, core_schema class FixedSizeListMixin(ABC): - @abstractstaticmethod + @staticmethod + @abstractmethod def dim() -> int: raise NotImplementedError - @abstractstaticmethod + @staticmethod + @abstractmethod def value_arrow_type() -> pa.DataType: raise NotImplementedError diff --git a/python/lancedb/query.py b/python/lancedb/query.py index 9262143b..a96f6682 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -226,6 +226,7 @@ class LanceQueryBuilder: columns=self._columns, nprobes=self._nprobes, refine_factor=self._refine_factor, + vector_column=self._vector_column, ) return self._table._execute_query(query) diff --git a/python/tests/test_db.py b/python/tests/test_db.py index 7e9a96f7..a3cb5ba5 100644 --- a/python/tests/test_db.py +++ b/python/tests/test_db.py @@ -13,6 +13,7 @@ import numpy as np import pandas as pd +import pyarrow as pa import pytest import lancedb @@ -131,6 +132,9 @@ def test_empty_or_nonexistent_table(tmp_path): with pytest.raises(Exception): db.open_table("does_not_exist") + schema = pa.schema([pa.field("a", pa.int32())]) + db.create_table("test", schema=schema) + def test_replace_index(tmp_path): db = lancedb.connect(uri=tmp_path) diff --git a/python/tests/test_query.py b/python/tests/test_query.py index f1f15b26..8e4678e1 100644 --- a/python/tests/test_query.py +++ b/python/tests/test_query.py @@ -119,6 +119,7 @@ def test_query_builder_with_different_vector_column(): columns=["b"], nprobes=20, refine_factor=None, + vector_column="foo_vector", ) )