mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 15:12:53 +00:00
python: Bug fixes / tests (#1210)
closes #1194 #1172 #1124 #1208 @wjones127 : `if query_type != "fts":` is needed because both fts and vector search create `LanceQueryBuilder` which has `vector_column_name` as a required attribute.
This commit is contained in:
@@ -57,6 +57,7 @@ tests = [
|
||||
"duckdb",
|
||||
"pytz",
|
||||
"polars>=0.19",
|
||||
"tantivy"
|
||||
]
|
||||
dev = ["ruff", "pre-commit"]
|
||||
docs = [
|
||||
|
||||
@@ -78,6 +78,9 @@ class BedRockText(TextEmbeddingFunction):
|
||||
|
||||
class Config:
|
||||
keep_untouched = (cached_property,)
|
||||
else:
|
||||
model_config = dict()
|
||||
model_config["ignored_types"] = (cached_property,)
|
||||
|
||||
def ndims(self):
|
||||
# return len(self._generate_embedding("test"))
|
||||
|
||||
@@ -94,6 +94,9 @@ class GeminiText(TextEmbeddingFunction):
|
||||
|
||||
class Config:
|
||||
keep_untouched = (cached_property,)
|
||||
else:
|
||||
model_config = dict()
|
||||
model_config["ignored_types"] = (cached_property,)
|
||||
|
||||
def ndims(self):
|
||||
# TODO: fix hardcoding
|
||||
|
||||
@@ -22,6 +22,8 @@ from .base import EmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import AUDIO, IMAGES, TEXT
|
||||
|
||||
from lancedb.pydantic import PYDANTIC_VERSION
|
||||
|
||||
|
||||
@register("imagebind")
|
||||
class ImageBindEmbeddings(EmbeddingFunction):
|
||||
@@ -38,8 +40,13 @@ class ImageBindEmbeddings(EmbeddingFunction):
|
||||
device: str = "cpu"
|
||||
normalize: bool = False
|
||||
|
||||
class Config:
|
||||
keep_untouched = (cached_property,)
|
||||
if PYDANTIC_VERSION < (2, 0): # Pydantic 1.x compat
|
||||
|
||||
class Config:
|
||||
keep_untouched = (cached_property,)
|
||||
else:
|
||||
model_config = dict()
|
||||
model_config["ignored_types"] = (cached_property,)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -17,6 +17,7 @@ from typing import List, Any
|
||||
import numpy as np
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
from lancedb.pydantic import PYDANTIC_VERSION
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import EmbeddingFunction
|
||||
@@ -53,8 +54,13 @@ class TransformersEmbeddingFunction(EmbeddingFunction):
|
||||
self._tokenizer = transformers.AutoTokenizer.from_pretrained(self.name)
|
||||
self._model = transformers.AutoModel.from_pretrained(self.name)
|
||||
|
||||
class Config:
|
||||
keep_untouched = (cached_property,)
|
||||
if PYDANTIC_VERSION < (2, 0): # Pydantic 1.x compat
|
||||
|
||||
class Config:
|
||||
keep_untouched = (cached_property,)
|
||||
else:
|
||||
model_config = dict()
|
||||
model_config["ignored_types"] = (cached_property,)
|
||||
|
||||
def ndims(self):
|
||||
self._ndims = self._model.config.hidden_size
|
||||
|
||||
@@ -95,6 +95,9 @@ def _sanitize_data(
|
||||
data.data.to_batches(), schema, metadata, on_bad_vectors, fill_value
|
||||
)
|
||||
|
||||
if isinstance(data, LanceModel):
|
||||
raise ValueError("Cannot add a single LanceModel to a table. Use a list.")
|
||||
|
||||
if isinstance(data, list):
|
||||
# convert to list of dict if data is a bunch of LanceModels
|
||||
if isinstance(data[0], LanceModel):
|
||||
@@ -1403,7 +1406,14 @@ class LanceTable(Table):
|
||||
vector and the returned vector.
|
||||
"""
|
||||
if vector_column_name is None and query is not None:
|
||||
vector_column_name = inf_vector_column_query(self.schema)
|
||||
try:
|
||||
vector_column_name = inf_vector_column_query(self.schema)
|
||||
except Exception as e:
|
||||
if query_type == "fts":
|
||||
vector_column_name = ""
|
||||
else:
|
||||
raise e
|
||||
|
||||
return LanceQueryBuilder.create(
|
||||
self,
|
||||
query,
|
||||
|
||||
@@ -28,13 +28,25 @@ def test_basic(tmp_path):
|
||||
assert db.uri == str(tmp_path)
|
||||
assert db.table_names() == []
|
||||
|
||||
class SimpleModel(LanceModel):
|
||||
item: str
|
||||
price: float
|
||||
vector: Vector(2)
|
||||
|
||||
table = db.create_table(
|
||||
"test",
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
],
|
||||
schema=SimpleModel,
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Cannot add a single LanceModel to a table. Use a list."
|
||||
):
|
||||
table.add(SimpleModel(item="baz", price=30.0, vector=[1.0, 2.0]))
|
||||
|
||||
rs = table.search([100, 100]).limit(1).to_pandas()
|
||||
assert len(rs) == 1
|
||||
assert rs["item"].iloc[0] == "bar"
|
||||
@@ -43,6 +55,11 @@ def test_basic(tmp_path):
|
||||
assert len(rs) == 1
|
||||
assert rs["item"].iloc[0] == "foo"
|
||||
|
||||
table.create_fts_index(["item"])
|
||||
rs = table.search("bar", query_type="fts").to_pandas()
|
||||
assert len(rs) == 1
|
||||
assert rs["item"].iloc[0] == "bar"
|
||||
|
||||
assert db.table_names() == ["test"]
|
||||
assert "test" in db
|
||||
assert len(db) == 1
|
||||
|
||||
Reference in New Issue
Block a user