python: Bug fixes / tests (#1210)

closes #1194 #1172 #1124 #1208 


@wjones127 : `if query_type != "fts":` is needed because both fts and
vector search create `LanceQueryBuilder` which has `vector_column_name`
as a required attribute.
This commit is contained in:
Raghav Dixit
2024-04-10 13:17:14 -04:00
committed by GitHub
parent 1d23af213b
commit a6aa67baed
7 changed files with 52 additions and 5 deletions

View File

@@ -57,6 +57,7 @@ tests = [
"duckdb",
"pytz",
"polars>=0.19",
"tantivy"
]
dev = ["ruff", "pre-commit"]
docs = [

View File

@@ -78,6 +78,9 @@ class BedRockText(TextEmbeddingFunction):
class Config:
keep_untouched = (cached_property,)
else:
model_config = dict()
model_config["ignored_types"] = (cached_property,)
def ndims(self):
# return len(self._generate_embedding("test"))

View File

@@ -94,6 +94,9 @@ class GeminiText(TextEmbeddingFunction):
class Config:
keep_untouched = (cached_property,)
else:
model_config = dict()
model_config["ignored_types"] = (cached_property,)
def ndims(self):
# TODO: fix hardcoding

View File

@@ -22,6 +22,8 @@ from .base import EmbeddingFunction
from .registry import register
from .utils import AUDIO, IMAGES, TEXT
from lancedb.pydantic import PYDANTIC_VERSION
@register("imagebind")
class ImageBindEmbeddings(EmbeddingFunction):
@@ -38,8 +40,13 @@ class ImageBindEmbeddings(EmbeddingFunction):
device: str = "cpu"
normalize: bool = False
class Config:
keep_untouched = (cached_property,)
if PYDANTIC_VERSION < (2, 0): # Pydantic 1.x compat
class Config:
keep_untouched = (cached_property,)
else:
model_config = dict()
model_config["ignored_types"] = (cached_property,)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

View File

@@ -17,6 +17,7 @@ from typing import List, Any
import numpy as np
from pydantic import PrivateAttr
from lancedb.pydantic import PYDANTIC_VERSION
from ..util import attempt_import_or_raise
from .base import EmbeddingFunction
@@ -53,8 +54,13 @@ class TransformersEmbeddingFunction(EmbeddingFunction):
self._tokenizer = transformers.AutoTokenizer.from_pretrained(self.name)
self._model = transformers.AutoModel.from_pretrained(self.name)
class Config:
keep_untouched = (cached_property,)
if PYDANTIC_VERSION < (2, 0): # Pydantic 1.x compat
class Config:
keep_untouched = (cached_property,)
else:
model_config = dict()
model_config["ignored_types"] = (cached_property,)
def ndims(self):
self._ndims = self._model.config.hidden_size

View File

@@ -95,6 +95,9 @@ def _sanitize_data(
data.data.to_batches(), schema, metadata, on_bad_vectors, fill_value
)
if isinstance(data, LanceModel):
raise ValueError("Cannot add a single LanceModel to a table. Use a list.")
if isinstance(data, list):
# convert to list of dict if data is a bunch of LanceModels
if isinstance(data[0], LanceModel):
@@ -1403,7 +1406,14 @@ class LanceTable(Table):
vector and the returned vector.
"""
if vector_column_name is None and query is not None:
vector_column_name = inf_vector_column_query(self.schema)
try:
vector_column_name = inf_vector_column_query(self.schema)
except Exception as e:
if query_type == "fts":
vector_column_name = ""
else:
raise e
return LanceQueryBuilder.create(
self,
query,

View File

@@ -28,13 +28,25 @@ def test_basic(tmp_path):
assert db.uri == str(tmp_path)
assert db.table_names() == []
class SimpleModel(LanceModel):
item: str
price: float
vector: Vector(2)
table = db.create_table(
"test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
schema=SimpleModel,
)
with pytest.raises(
ValueError, match="Cannot add a single LanceModel to a table. Use a list."
):
table.add(SimpleModel(item="baz", price=30.0, vector=[1.0, 2.0]))
rs = table.search([100, 100]).limit(1).to_pandas()
assert len(rs) == 1
assert rs["item"].iloc[0] == "bar"
@@ -43,6 +55,11 @@ def test_basic(tmp_path):
assert len(rs) == 1
assert rs["item"].iloc[0] == "foo"
table.create_fts_index(["item"])
rs = table.search("bar", query_type="fts").to_pandas()
assert len(rs) == 1
assert rs["item"].iloc[0] == "bar"
assert db.table_names() == ["test"]
assert "test" in db
assert len(db) == 1