feat: add a basic async python client starting point (#1014)

This changes `lancedb` from a "pure python" setuptools project to a
maturin project and adds a rust lancedb dependency.

The async python client is extremely minimal (only `connect` and
`Connection.table_names` are supported). The purpose of this PR is to
get the infrastructure in place for building out the rest of the async
client.

Although this is not technically a breaking change (no APIs are
changing) it is still a considerable change in the way the wheels are
built because they now include the native shared library.
This commit is contained in:
Weston Pace
2024-02-27 04:52:02 -08:00
committed by GitHub
parent 5af74b5aca
commit a6bcbd007b
82 changed files with 1029 additions and 153 deletions

View File

@@ -1,32 +0,0 @@
from click.testing import CliRunner
from lancedb.cli.cli import cli
from lancedb.utils import CONFIG
def test_entry():
runner = CliRunner()
result = runner.invoke(cli)
assert result.exit_code == 0 # Main check
assert "lancedb" in result.output.lower() # lazy check
def test_diagnostics():
runner = CliRunner()
result = runner.invoke(cli, ["diagnostics", "--disabled"])
assert result.exit_code == 0 # Main check
assert not CONFIG["diagnostics"]
result = runner.invoke(cli, ["diagnostics", "--enabled"])
assert result.exit_code == 0 # Main check
assert CONFIG["diagnostics"]
def test_config():
runner = CliRunner()
result = runner.invoke(cli, ["config"])
assert result.exit_code == 0 # Main check
cfg = CONFIG.copy()
cfg.pop("uuid")
for item in cfg: # check for keys only as formatting is subject to change
assert item in result.output

View File

@@ -1,77 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pandas as pd
import pytest
from lancedb.context import contextualize
@pytest.fixture
def raw_df() -> pd.DataFrame:
return pd.DataFrame(
{
"token": [
"The",
"quick",
"brown",
"fox",
"jumped",
"over",
"the",
"lazy",
"dog",
"I",
"love",
"sandwiches",
],
"document_id": [1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2],
}
)
def test_contextualizer(raw_df: pd.DataFrame):
result = (
contextualize(raw_df)
.window(6)
.stride(3)
.text_col("token")
.groupby("document_id")
.to_pandas()["token"]
.to_list()
)
assert result == [
"The quick brown fox jumped over",
"fox jumped over the lazy dog",
"the lazy dog",
"I love sandwiches",
]
def test_contextualizer_with_threshold(raw_df: pd.DataFrame):
result = (
contextualize(raw_df)
.window(6)
.stride(3)
.text_col("token")
.groupby("document_id")
.min_window_size(4)
.to_pandas()["token"]
.to_list()
)
assert result == [
"The quick brown fox jumped over",
"fox jumped over the lazy dog",
]

View File

@@ -1,372 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
import lancedb
from lancedb.pydantic import LanceModel, Vector
def test_basic(tmp_path):
db = lancedb.connect(tmp_path)
assert db.uri == str(tmp_path)
assert db.table_names() == []
table = db.create_table(
"test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
rs = table.search([100, 100]).limit(1).to_pandas()
assert len(rs) == 1
assert rs["item"].iloc[0] == "bar"
rs = table.search([100, 100]).where("price < 15").limit(2).to_pandas()
assert len(rs) == 1
assert rs["item"].iloc[0] == "foo"
assert db.table_names() == ["test"]
assert "test" in db
assert len(db) == 1
assert db.open_table("test").name == db["test"].name
def test_ingest_pd(tmp_path):
db = lancedb.connect(tmp_path)
assert db.uri == str(tmp_path)
assert db.table_names() == []
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["foo", "bar"],
"price": [10.0, 20.0],
}
)
table = db.create_table("test", data=data)
rs = table.search([100, 100]).limit(1).to_pandas()
assert len(rs) == 1
assert rs["item"].iloc[0] == "bar"
rs = table.search([100, 100]).where("price < 15").limit(2).to_pandas()
assert len(rs) == 1
assert rs["item"].iloc[0] == "foo"
assert db.table_names() == ["test"]
assert "test" in db
assert len(db) == 1
assert db.open_table("test").name == db["test"].name
def test_ingest_iterator(tmp_path):
class PydanticSchema(LanceModel):
vector: Vector(2)
item: str
price: float
arrow_schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2)),
pa.field("item", pa.utf8()),
pa.field("price", pa.float32()),
]
)
def make_batches():
for _ in range(5):
yield from [
# pandas
pd.DataFrame(
{
"vector": [[3.1, 4.1], [1, 1]],
"item": ["foo", "bar"],
"price": [10.0, 20.0],
}
),
# pylist
[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
# recordbatch
pa.RecordBatch.from_arrays(
[
pa.array([[3.1, 4.1], [5.9, 26.5]], pa.list_(pa.float32(), 2)),
pa.array(["foo", "bar"]),
pa.array([10.0, 20.0]),
],
["vector", "item", "price"],
),
# pa Table
pa.Table.from_arrays(
[
pa.array([[3.1, 4.1], [5.9, 26.5]], pa.list_(pa.float32(), 2)),
pa.array(["foo", "bar"]),
pa.array([10.0, 20.0]),
],
["vector", "item", "price"],
),
# pydantic list
[
PydanticSchema(vector=[3.1, 4.1], item="foo", price=10.0),
PydanticSchema(vector=[5.9, 26.5], item="bar", price=20.0),
],
# TODO: test pydict separately. it is unique column number and
# name constraints
]
def run_tests(schema):
db = lancedb.connect(tmp_path)
tbl = db.create_table("table2", make_batches(), schema=schema, mode="overwrite")
tbl.to_pandas()
assert tbl.search([3.1, 4.1]).limit(1).to_pandas()["_distance"][0] == 0.0
assert tbl.search([5.9, 26.5]).limit(1).to_pandas()["_distance"][0] == 0.0
tbl_len = len(tbl)
tbl.add(make_batches())
assert tbl_len == 50
assert len(tbl) == tbl_len * 2
assert len(tbl.list_versions()) == 3
db.drop_database()
run_tests(arrow_schema)
run_tests(PydanticSchema)
def test_table_names(tmp_path):
db = lancedb.connect(tmp_path)
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["foo", "bar"],
"price": [10.0, 20.0],
}
)
db.create_table("test2", data=data)
db.create_table("test1", data=data)
db.create_table("test3", data=data)
assert db.table_names() == ["test1", "test2", "test3"]
def test_create_mode(tmp_path):
db = lancedb.connect(tmp_path)
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["foo", "bar"],
"price": [10.0, 20.0],
}
)
db.create_table("test", data=data)
with pytest.raises(Exception):
db.create_table("test", data=data)
new_data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["fizz", "buzz"],
"price": [10.0, 20.0],
}
)
tbl = db.create_table("test", data=new_data, mode="overwrite")
assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"]
def test_create_exist_ok(tmp_path):
db = lancedb.connect(tmp_path)
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["foo", "bar"],
"price": [10.0, 20.0],
}
)
tbl = db.create_table("test", data=data)
with pytest.raises(OSError):
db.create_table("test", data=data)
# open the table but don't add more rows
tbl2 = db.create_table("test", data=data, exist_ok=True)
assert tbl.name == tbl2.name
assert tbl.schema == tbl2.schema
assert len(tbl) == len(tbl2)
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), list_size=2)),
pa.field("item", pa.utf8()),
pa.field("price", pa.float64()),
]
)
tbl3 = db.create_table("test", schema=schema, exist_ok=True)
assert tbl3.schema == schema
bad_schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), list_size=2)),
pa.field("item", pa.utf8()),
pa.field("price", pa.float64()),
pa.field("extra", pa.float32()),
]
)
with pytest.raises(ValueError):
db.create_table("test", schema=bad_schema, exist_ok=True)
def test_delete_table(tmp_path):
db = lancedb.connect(tmp_path)
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["foo", "bar"],
"price": [10.0, 20.0],
}
)
db.create_table("test", data=data)
with pytest.raises(Exception):
db.create_table("test", data=data)
assert db.table_names() == ["test"]
db.drop_table("test")
assert db.table_names() == []
db.create_table("test", data=data)
assert db.table_names() == ["test"]
# dropping a table that does not exist should pass
# if ignore_missing=True
db.drop_table("does_not_exist", ignore_missing=True)
def test_drop_database(tmp_path):
db = lancedb.connect(tmp_path)
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["foo", "bar"],
"price": [10.0, 20.0],
}
)
new_data = pd.DataFrame(
{
"vector": [[5.1, 4.1], [5.9, 10.5]],
"item": ["kiwi", "avocado"],
"price": [12.0, 17.0],
}
)
db.create_table("test", data=data)
with pytest.raises(Exception):
db.create_table("test", data=data)
assert db.table_names() == ["test"]
db.create_table("new_test", data=new_data)
db.drop_database()
assert db.table_names() == []
# it should pass when no tables are present
db.create_table("test", data=new_data)
db.drop_table("test")
assert db.table_names() == []
db.drop_database()
assert db.table_names() == []
# creating an empty database with schema
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), list_size=2))])
db.create_table("empty_table", schema=schema)
# dropping a empty database should pass
db.drop_database()
assert db.table_names() == []
def test_empty_or_nonexistent_table(tmp_path):
db = lancedb.connect(tmp_path)
with pytest.raises(Exception):
db.create_table("test_with_no_data")
with pytest.raises(Exception):
db.open_table("does_not_exist")
schema = pa.schema([pa.field("a", pa.int64(), nullable=False)])
test = db.create_table("test", schema=schema)
class TestModel(LanceModel):
a: int
test2 = db.create_table("test2", schema=TestModel)
assert test.schema == test2.schema
def test_replace_index(tmp_path):
db = lancedb.connect(uri=tmp_path)
table = db.create_table(
"test",
[
{"vector": np.random.rand(128), "item": "foo", "price": float(i)}
for i in range(1000)
],
)
table.create_index(
num_partitions=2,
num_sub_vectors=4,
)
with pytest.raises(Exception):
table.create_index(
num_partitions=2,
num_sub_vectors=4,
replace=False,
)
table.create_index(
num_partitions=2,
num_sub_vectors=4,
replace=True,
index_cache_size=10,
)
def test_prefilter_with_index(tmp_path):
db = lancedb.connect(uri=tmp_path)
data = [
{"vector": np.random.rand(128), "item": "foo", "price": float(i)}
for i in range(1000)
]
sample_key = data[100]["vector"]
table = db.create_table(
"test",
data,
)
table.create_index(
num_partitions=2,
num_sub_vectors=4,
)
table = (
table.search(sample_key)
.where("price == 500", prefilter=True)
.limit(5)
.to_arrow()
)
assert table.num_rows == 1

View File

@@ -1,27 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
from lancedb import LanceDBConnection
# TODO: setup integ test mark and script
@pytest.mark.skip(reason="Need to set up a local server")
def test_against_local_server():
conn = LanceDBConnection("lancedb+http://localhost:10024")
table = conn.open_table("sift1m_ivf1024_pq16")
df = table.search(np.random.rand(128)).to_pandas()
assert len(df) == 10

View File

@@ -1,115 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import lance
import numpy as np
import pyarrow as pa
import pytest
import lancedb
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.embeddings import (
EmbeddingFunctionConfig,
EmbeddingFunctionRegistry,
with_embeddings,
)
from lancedb.pydantic import LanceModel, Vector
def mock_embed_func(input_data):
return [np.random.randn(128).tolist() for _ in range(len(input_data))]
def test_with_embeddings():
for wrap_api in [True, False]:
if wrap_api and sys.version_info.minor >= 11:
# ratelimiter package doesn't work on 3.11
continue
data = pa.Table.from_arrays(
[
pa.array(["foo", "bar"]),
pa.array([10.0, 20.0]),
],
names=["text", "price"],
)
data = with_embeddings(mock_embed_func, data, wrap_api=wrap_api)
assert data.num_columns == 3
assert data.num_rows == 2
assert data.column_names == ["text", "price", "vector"]
assert data.column("text").to_pylist() == ["foo", "bar"]
assert data.column("price").to_pylist() == [10.0, 20.0]
def test_embedding_function(tmp_path):
registry = EmbeddingFunctionRegistry.get_instance()
# let's create a table
table = pa.table(
{
"text": pa.array(["hello world", "goodbye world"]),
"vector": [np.random.randn(10), np.random.randn(10)],
}
)
conf = EmbeddingFunctionConfig(
source_column="text",
vector_column="vector",
function=MockTextEmbeddingFunction(),
)
metadata = registry.get_table_metadata([conf])
table = table.replace_schema_metadata(metadata)
# Write it to disk
lance.write_dataset(table, tmp_path / "test.lance")
# Load this back
ds = lance.dataset(tmp_path / "test.lance")
# can we get the serialized version back out?
configs = registry.parse_functions(ds.schema.metadata)
conf = configs["vector"]
func = conf.function
actual = func.compute_query_embeddings("hello world")
# And we make sure we can call it
expected = func.compute_query_embeddings("hello world")
assert np.allclose(actual, expected)
@pytest.mark.slow
def test_embedding_function_rate_limit(tmp_path):
def _get_schema_from_model(model):
class Schema(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
return Schema
db = lancedb.connect(tmp_path)
registry = EmbeddingFunctionRegistry.get_instance()
model = registry.get("test-rate-limited").create(max_retries=0)
schema = _get_schema_from_model(model)
table = db.create_table("test", schema=schema, mode="overwrite")
table.add([{"text": "hello world"}])
with pytest.raises(Exception):
table.add([{"text": "hello world"}])
assert len(table) == 1
model = registry.get("test-rate-limited").create()
schema = _get_schema_from_model(model)
table = db.create_table("test", schema=schema, mode="overwrite")
table.add([{"text": "hello world"}])
table.add([{"text": "hello world"}])
assert len(table) == 2

View File

@@ -1,421 +0,0 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import io
import os
import numpy as np
import pandas as pd
import pytest
import requests
import lancedb
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
# These are integration tests for embedding functions.
# They are slow because they require downloading models
# or connection to external api
try:
if importlib.util.find_spec("mlx.core") is not None:
_mlx = True
else:
_mlx = None
except Exception:
_mlx = None
try:
if importlib.util.find_spec("imagebind") is not None:
_imagebind = True
else:
_imagebind = None
except Exception:
_imagebind = None
@pytest.mark.slow
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
def test_basic_text_embeddings(alias, tmp_path):
db = lancedb.connect(tmp_path)
registry = get_registry()
func = registry.get(alias).create(max_retries=0)
func2 = registry.get(alias).create(max_retries=0)
class Words(LanceModel):
text: str = func.SourceField()
text2: str = func2.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
vector2: Vector(func2.ndims()) = func2.VectorField()
table = db.create_table("words", schema=Words)
table.add(
pd.DataFrame(
{
"text": [
"hello world",
"goodbye world",
"fizz",
"buzz",
"foo",
"bar",
"baz",
],
"text2": [
"to be or not to be",
"that is the question",
"for whether tis nobler",
"in the mind to suffer",
"the slings and arrows",
"of outrageous fortune",
"or to take arms",
],
}
)
)
query = "greetings"
actual = (
table.search(query, vector_column_name="vector").limit(1).to_pydantic(Words)[0]
)
vec = func.compute_query_embeddings(query)[0]
expected = (
table.search(vec, vector_column_name="vector").limit(1).to_pydantic(Words)[0]
)
assert actual.text == expected.text
assert actual.text == "hello world"
assert not np.allclose(actual.vector, actual.vector2)
actual = (
table.search(query, vector_column_name="vector2").limit(1).to_pydantic(Words)[0]
)
assert actual.text != "hello world"
assert not np.allclose(actual.vector, actual.vector2)
@pytest.mark.slow
def test_openclip(tmp_path):
from PIL import Image
db = lancedb.connect(tmp_path)
registry = get_registry()
func = registry.get("open-clip").create(max_retries=0)
class Images(LanceModel):
label: str
image_uri: str = func.SourceField()
image_bytes: bytes = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
vec_from_bytes: Vector(func.ndims()) = func.VectorField()
table = db.create_table("images", schema=Images)
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
# get each uri as bytes
image_bytes = [requests.get(uri).content for uri in uris]
table.add(
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": image_bytes})
)
# text search
actual = (
table.search("man's best friend", vector_column_name="vector")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == "dog"
frombytes = (
table.search("man's best friend", vector_column_name="vec_from_bytes")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == frombytes.label
assert np.allclose(actual.vector, frombytes.vector)
# image search
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
image_bytes = requests.get(query_image_uri).content
query_image = Image.open(io.BytesIO(image_bytes))
actual = (
table.search(query_image, vector_column_name="vector")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == "dog"
other = (
table.search(query_image, vector_column_name="vec_from_bytes")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == other.label
arrow_table = table.search().select(["vector", "vec_from_bytes"]).to_arrow()
assert np.allclose(
arrow_table["vector"].combine_chunks().values.to_numpy(),
arrow_table["vec_from_bytes"].combine_chunks().values.to_numpy(),
)
@pytest.mark.skipif(
_imagebind is None,
reason="skip if imagebind not installed.",
)
@pytest.mark.slow
def test_imagebind(tmp_path):
import os
import shutil
import tempfile
import pandas as pd
import requests
import lancedb.embeddings.imagebind
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
with tempfile.TemporaryDirectory() as temp_dir:
print(f"Created temporary directory {temp_dir}")
def download_images(image_uris):
downloaded_image_paths = []
for uri in image_uris:
try:
response = requests.get(uri, stream=True)
if response.status_code == 200:
# Extract image name from URI
image_name = os.path.basename(uri)
image_path = os.path.join(temp_dir, image_name)
with open(image_path, "wb") as out_file:
shutil.copyfileobj(response.raw, out_file)
downloaded_image_paths.append(image_path)
except Exception as e: # noqa: PERF203
print(f"Failed to download {uri}. Error: {e}")
return temp_dir, downloaded_image_paths
db = lancedb.connect(tmp_path)
registry = get_registry()
func = registry.get("imagebind").create(max_retries=0)
class Images(LanceModel):
label: str
image_uri: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
table = db.create_table("images", schema=Images)
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
temp_dir, downloaded_images = download_images(uris)
table.add(pd.DataFrame({"label": labels, "image_uri": downloaded_images}))
# text search
actual = (
table.search("man's best friend", vector_column_name="vector")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == "dog"
# image search
query_image_uri = [
"https://live.staticflickr.com/65535/33336453970_491665f66e_h.jpg"
]
temp_dir, downloaded_images = download_images(query_image_uri)
query_image_uri = downloaded_images[0]
actual = (
table.search(query_image_uri, vector_column_name="vector")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == "dog"
if os.path.isdir(temp_dir):
shutil.rmtree(temp_dir)
print(f"Deleted temporary directory {temp_dir}")
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
) # also skip if cohere not installed
def test_cohere_embedding_function():
cohere = (
get_registry()
.get("cohere")
.create(name="embed-multilingual-v2.0", max_retries=0)
)
class TextModel(LanceModel):
text: str = cohere.SourceField()
vector: Vector(cohere.ndims()) = cohere.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect("~/lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == cohere.ndims()
@pytest.mark.slow
def test_instructor_embedding(tmp_path):
model = get_registry().get("instructor").create(max_retries=0)
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("GOOGLE_API_KEY") is None, reason="GOOGLE_API_KEY not set"
)
def test_gemini_embedding(tmp_path):
model = get_registry().get("gemini-text").create(max_retries=0)
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
@pytest.mark.skipif(
_mlx is None,
reason="mlx tests only required for apple users.",
)
@pytest.mark.slow
def test_gte_embedding(tmp_path):
import lancedb.embeddings.gte
model = get_registry().get("gte-text").create()
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
def aws_setup():
try:
import boto3
sts = boto3.client("sts")
sts.get_caller_identity()
return True
except Exception:
return False
@pytest.mark.slow
@pytest.mark.skipif(
not aws_setup(), reason="AWS credentials not set or libraries not installed"
)
def test_bedrock_embedding(tmp_path):
for name in [
"amazon.titan-embed-text-v1",
"cohere.embed-english-v3",
"cohere.embed-multilingual-v3",
]:
model = get_registry().get("bedrock-text").create(max_retries=0, name=name)
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
)
def test_openai_embedding(tmp_path):
def _get_table(model):
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
return tbl
model = get_registry().get("openai").create(max_retries=0)
tbl = _get_table(model)
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
model = (
get_registry()
.get("openai")
.create(max_retries=0, name="text-embedding-3-large")
)
tbl = _get_table(model)
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
model = (
get_registry()
.get("openai")
.create(max_retries=0, name="text-embedding-3-large", dim=1024)
)
tbl = _get_table(model)
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"

View File

@@ -1,184 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
from unittest import mock
import numpy as np
import pandas as pd
import pytest
import tantivy
import lancedb as ldb
import lancedb.fts
@pytest.fixture
def table(tmp_path) -> ldb.table.LanceTable:
db = ldb.connect(tmp_path)
vectors = [np.random.randn(128) for _ in range(100)]
nouns = ("puppy", "car", "rabbit", "girl", "monkey")
verbs = ("runs", "hits", "jumps", "drives", "barfs")
adv = ("crazily.", "dutifully.", "foolishly.", "merrily.", "occasionally.")
adj = ("adorable", "clueless", "dirty", "odd", "stupid")
text = [
" ".join(
[
nouns[random.randrange(0, 5)],
verbs[random.randrange(0, 5)],
adv[random.randrange(0, 5)],
adj[random.randrange(0, 5)],
]
)
for _ in range(100)
]
table = db.create_table(
"test",
data=pd.DataFrame(
{
"vector": vectors,
"id": [i % 2 for i in range(100)],
"text": text,
"text2": text,
"nested": [{"text": t} for t in text],
}
),
)
return table
def test_create_index(tmp_path):
index = ldb.fts.create_index(str(tmp_path / "index"), ["text"])
assert isinstance(index, tantivy.Index)
assert os.path.exists(str(tmp_path / "index"))
def test_populate_index(tmp_path, table):
index = ldb.fts.create_index(str(tmp_path / "index"), ["text"])
assert ldb.fts.populate_index(index, table, ["text"]) == len(table)
def test_search_index(tmp_path, table):
index = ldb.fts.create_index(str(tmp_path / "index"), ["text"])
ldb.fts.populate_index(index, table, ["text"])
index.reload()
results = ldb.fts.search_index(index, query="puppy", limit=10)
assert len(results) == 2
assert len(results[0]) == 10 # row_ids
assert len(results[1]) == 10 # _distance
def test_create_index_from_table(tmp_path, table):
table.create_fts_index("text")
df = table.search("puppy").limit(10).select(["text"]).to_pandas()
assert len(df) <= 10
assert "text" in df.columns
# Check whether it can be updated
table.add(
[
{
"vector": np.random.randn(128),
"id": 101,
"text": "gorilla",
"text2": "gorilla",
"nested": {"text": "gorilla"},
}
]
)
with pytest.raises(ValueError, match="already exists"):
table.create_fts_index("text")
table.create_fts_index("text", replace=True)
assert len(table.search("gorilla").limit(1).to_pandas()) == 1
def test_create_index_multiple_columns(tmp_path, table):
table.create_fts_index(["text", "text2"])
df = table.search("puppy").limit(10).to_pandas()
assert len(df) == 10
assert "text" in df.columns
assert "text2" in df.columns
def test_empty_rs(tmp_path, table, mocker):
table.create_fts_index(["text", "text2"])
mocker.patch("lancedb.fts.search_index", return_value=([], []))
df = table.search("puppy").limit(10).to_pandas()
assert len(df) == 0
def test_nested_schema(tmp_path, table):
table.create_fts_index("nested.text")
rs = table.search("puppy").limit(10).to_list()
assert len(rs) == 10
def test_search_index_with_filter(table):
table.create_fts_index("text")
orig_import = __import__
def import_mock(name, *args):
if name == "duckdb":
raise ImportError
return orig_import(name, *args)
# no duckdb
with mock.patch("builtins.__import__", side_effect=import_mock):
rs = table.search("puppy").where("id=1").limit(10).to_list()
for r in rs:
assert r["id"] == 1
# yes duckdb
rs2 = table.search("puppy").where("id=1").limit(10).to_list()
for r in rs2:
assert r["id"] == 1
assert rs == rs2
def test_null_input(table):
table.add(
[
{
"vector": np.random.randn(128),
"id": 101,
"text": None,
"text2": None,
"nested": {"text": None},
}
]
)
table.create_fts_index("text")
def test_syntax(table):
# https://github.com/lancedb/lancedb/issues/769
table.create_fts_index("text")
with pytest.raises(ValueError, match="Syntax Error"):
table.search("they could have been dogs OR cats").limit(10).to_list()
table.search("they could have been dogs OR cats").phrase_query().limit(10).to_list()
# this should work
table.search('"they could have been dogs OR cats"').limit(10).to_list()
# this should work too
table.search('''"the cats OR dogs were not really 'pets' at all"''').limit(
10
).to_list()
table.search('the cats OR dogs were not really "pets" at all').phrase_query().limit(
10
).to_list()
table.search('the cats OR dogs were not really "pets" at all').phrase_query().limit(
10
).to_list()

View File

@@ -1,51 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import lancedb
# You need to setup AWS credentials an a base path to run this test. Example
# AWS_PROFILE=default TEST_S3_BASE_URL=s3://my_bucket/dataset pytest tests/test_io.py
@pytest.mark.skipif(
(os.environ.get("TEST_S3_BASE_URL") is None),
reason="please setup s3 base url",
)
def test_s3_io():
db = lancedb.connect(os.environ.get("TEST_S3_BASE_URL"))
assert db.table_names() == []
table = db.create_table(
"test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
rs = table.search([100, 100]).limit(1).to_pandas()
assert len(rs) == 1
assert rs["item"].iloc[0] == "bar"
rs = table.search([100, 100]).where("price < 15").limit(2).to_pandas()
assert len(rs) == 1
assert rs["item"].iloc[0] == "foo"
assert db.table_names() == ["test"]
assert "test" in db
assert len(db) == 1
assert db.open_table("test").name == db["test"].name

View File

@@ -1,246 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import sys
from datetime import date, datetime
from typing import List, Optional, Tuple
import pyarrow as pa
import pydantic
import pytest
from pydantic import Field
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
@pytest.mark.skipif(
sys.version_info < (3, 9),
reason="using native type alias requires python3.9 or higher",
)
def test_pydantic_to_arrow():
class StructModel(pydantic.BaseModel):
a: str
b: Optional[float]
class TestModel(pydantic.BaseModel):
id: int
s: str
vec: list[float]
li: list[int]
lili: list[list[float]]
litu: list[tuple[float, float]]
opt: Optional[str] = None
st: StructModel
dt: date
dtt: datetime
dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"})
# d: dict
# TODO: test we can actually convert the model into data.
# m = TestModel(
# id=1,
# s="hello",
# vec=[1.0, 2.0, 3.0],
# li=[2, 3, 4],
# lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]],
# litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)],
# st=StructModel(a="a", b=1.0),
# dt=date.today(),
# dtt=datetime.now(),
# dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
# )
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.int64(), False),
pa.field("s", pa.utf8(), False),
pa.field("vec", pa.list_(pa.float64()), False),
pa.field("li", pa.list_(pa.int64()), False),
pa.field("lili", pa.list_(pa.list_(pa.float64())), False),
pa.field("litu", pa.list_(pa.list_(pa.float64())), False),
pa.field("opt", pa.utf8(), True),
pa.field(
"st",
pa.struct(
[pa.field("a", pa.utf8(), False), pa.field("b", pa.float64(), True)]
),
False,
),
pa.field("dt", pa.date32(), False),
pa.field("dtt", pa.timestamp("us"), False),
pa.field("dt_with_tz", pa.timestamp("us", tz="Asia/Shanghai"), False),
]
)
assert schema == expect_schema
@pytest.mark.skipif(
sys.version_info < (3, 10),
reason="using | type syntax requires python3.10 or higher",
)
def test_optional_types_py310():
class TestModel(pydantic.BaseModel):
a: str | None
b: None | str
c: Optional[str]
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("a", pa.utf8(), True),
pa.field("b", pa.utf8(), True),
pa.field("c", pa.utf8(), True),
]
)
assert schema == expect_schema
@pytest.mark.skipif(
sys.version_info > (3, 8),
reason="using native type alias requires python3.9 or higher",
)
def test_pydantic_to_arrow_py38():
class StructModel(pydantic.BaseModel):
a: str
b: Optional[float]
class TestModel(pydantic.BaseModel):
id: int
s: str
vec: List[float]
li: List[int]
lili: List[List[float]]
litu: List[Tuple[float, float]]
opt: Optional[str] = None
st: StructModel
dt: date
dtt: datetime
dt_with_tz: datetime = Field(json_schema_extra={"tz": "Asia/Shanghai"})
# d: dict
# TODO: test we can actually convert the model to Arrow data.
# m = TestModel(
# id=1,
# s="hello",
# vec=[1.0, 2.0, 3.0],
# li=[2, 3, 4],
# lili=[[2.5, 1.5], [3.5, 4.5], [5.5, 6.5]],
# litu=[(2.5, 1.5), (3.5, 4.5), (5.5, 6.5)],
# st=StructModel(a="a", b=1.0),
# dt=date.today(),
# dtt=datetime.now(),
# dt_with_tz=datetime.now(pytz.timezone("Asia/Shanghai")),
# )
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.int64(), False),
pa.field("s", pa.utf8(), False),
pa.field("vec", pa.list_(pa.float64()), False),
pa.field("li", pa.list_(pa.int64()), False),
pa.field("lili", pa.list_(pa.list_(pa.float64())), False),
pa.field("litu", pa.list_(pa.list_(pa.float64())), False),
pa.field("opt", pa.utf8(), True),
pa.field(
"st",
pa.struct(
[pa.field("a", pa.utf8(), False), pa.field("b", pa.float64(), True)]
),
False,
),
pa.field("dt", pa.date32(), False),
pa.field("dtt", pa.timestamp("us"), False),
pa.field("dt_with_tz", pa.timestamp("us", tz="Asia/Shanghai"), False),
]
)
assert schema == expect_schema
def test_fixed_size_list_field():
class TestModel(pydantic.BaseModel):
vec: Vector(16)
li: List[int]
data = TestModel(vec=list(range(16)), li=[1, 2, 3])
if PYDANTIC_VERSION >= (2,):
assert json.loads(data.model_dump_json()) == {
"vec": list(range(16)),
"li": [1, 2, 3],
}
else:
assert data.dict() == {
"vec": list(range(16)),
"li": [1, 2, 3],
}
schema = pydantic_to_schema(TestModel)
assert schema == pa.schema(
[
pa.field("vec", pa.list_(pa.float32(), 16), False),
pa.field("li", pa.list_(pa.int64()), False),
]
)
if PYDANTIC_VERSION >= (2,):
json_schema = TestModel.model_json_schema()
else:
json_schema = TestModel.schema()
assert json_schema == {
"properties": {
"vec": {
"items": {"type": "number"},
"maxItems": 16,
"minItems": 16,
"title": "Vec",
"type": "array",
},
"li": {"items": {"type": "integer"}, "title": "Li", "type": "array"},
},
"required": ["vec", "li"],
"title": "TestModel",
"type": "object",
}
def test_fixed_size_list_validation():
class TestModel(pydantic.BaseModel):
vec: Vector(8)
with pytest.raises(pydantic.ValidationError):
TestModel(vec=range(9))
with pytest.raises(pydantic.ValidationError):
TestModel(vec=range(7))
TestModel(vec=range(8))
def test_lance_model():
class TestModel(LanceModel):
vector: Vector(16) = Field(default=[0.0] * 16)
li: List[int] = Field(default=[1, 2, 3])
schema = pydantic_to_schema(TestModel)
assert schema == TestModel.to_arrow_schema()
assert TestModel.field_names() == ["vector", "li"]
t = TestModel()
assert t == TestModel(vec=[0.0] * 16, li=[1, 2, 3])

View File

@@ -1,176 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest.mock as mock
import lance
import numpy as np
import pandas.testing as tm
import pyarrow as pa
import pytest
from lancedb.db import LanceDBConnection
from lancedb.pydantic import LanceModel, Vector
from lancedb.query import LanceVectorQueryBuilder, Query
from lancedb.table import LanceTable
class MockTable:
def __init__(self, tmp_path):
self.uri = tmp_path
self._conn = LanceDBConnection(self.uri)
def to_lance(self):
return lance.dataset(self.uri)
def _execute_query(self, query):
ds = self.to_lance()
return ds.to_table(
columns=query.columns,
filter=query.filter,
prefilter=query.prefilter,
nearest={
"column": query.vector_column,
"q": query.vector,
"k": query.k,
"metric": query.metric,
"nprobes": query.nprobes,
"refine_factor": query.refine_factor,
},
)
@pytest.fixture
def table(tmp_path) -> MockTable:
df = pa.table(
{
"vector": pa.array(
[[1, 2], [3, 4]], type=pa.list_(pa.float32(), list_size=2)
),
"id": pa.array([1, 2]),
"str_field": pa.array(["a", "b"]),
"float_field": pa.array([1.0, 2.0]),
}
)
lance.write_dataset(df, tmp_path)
return MockTable(tmp_path)
def test_cast(table):
class TestModel(LanceModel):
vector: Vector(2)
id: int
str_field: str
float_field: float
q = LanceVectorQueryBuilder(table, [0, 0], "vector").limit(1)
results = q.to_pydantic(TestModel)
assert len(results) == 1
r0 = results[0]
assert isinstance(r0, TestModel)
assert r0.id == 1
assert r0.vector == [1, 2]
assert r0.str_field == "a"
assert r0.float_field == 1.0
def test_query_builder(table):
rs = (
LanceVectorQueryBuilder(table, [0, 0], "vector")
.limit(1)
.select(["id"])
.to_list()
)
assert rs[0]["id"] == 1
assert all(np.array(rs[0]["vector"]) == [1, 2])
def test_query_builder_with_filter(table):
rs = LanceVectorQueryBuilder(table, [0, 0], "vector").where("id = 2").to_list()
assert rs[0]["id"] == 2
assert all(np.array(rs[0]["vector"]) == [3, 4])
def test_query_builder_with_prefilter(table):
df = (
LanceVectorQueryBuilder(table, [0, 0], "vector")
.where("id = 2")
.limit(1)
.to_pandas()
)
assert len(df) == 0
df = (
LanceVectorQueryBuilder(table, [0, 0], "vector")
.where("id = 2", prefilter=True)
.limit(1)
.to_pandas()
)
assert df["id"].values[0] == 2
assert all(df["vector"].values[0] == [3, 4])
def test_query_builder_with_metric(table):
query = [4, 8]
vector_column_name = "vector"
df_default = LanceVectorQueryBuilder(table, query, vector_column_name).to_pandas()
df_l2 = (
LanceVectorQueryBuilder(table, query, vector_column_name)
.metric("L2")
.to_pandas()
)
tm.assert_frame_equal(df_default, df_l2)
df_cosine = (
LanceVectorQueryBuilder(table, query, vector_column_name)
.metric("cosine")
.limit(1)
.to_pandas()
)
assert df_cosine._distance[0] == pytest.approx(
cosine_distance(query, df_cosine.vector[0]),
abs=1e-6,
)
assert 0 <= df_cosine._distance[0] <= 1
def test_query_builder_with_different_vector_column():
table = mock.MagicMock(spec=LanceTable)
query = [4, 8]
vector_column_name = "foo_vector"
builder = (
LanceVectorQueryBuilder(table, query, vector_column_name)
.metric("cosine")
.where("b < 10")
.select(["b"])
.limit(2)
)
ds = mock.Mock()
table.to_lance.return_value = ds
builder.to_arrow()
table._execute_query.assert_called_once_with(
Query(
vector=query,
filter="b < 10",
k=2,
metric="cosine",
columns=["b"],
nprobes=20,
refine_factor=None,
vector_column="foo_vector",
)
)
def cosine_distance(vec1, vec2):
return 1 - np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))

View File

@@ -1,95 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import attrs
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from aiohttp import web
from lancedb.remote.client import RestfulLanceDBClient, VectorQuery
@attrs.define
class MockLanceDBServer:
runner: web.AppRunner = attrs.field(init=False)
site: web.TCPSite = attrs.field(init=False)
async def query_handler(self, request: web.Request) -> web.Response:
table_name = request.match_info["table_name"]
assert table_name == "test_table"
await request.json()
# TODO: do some matching
vecs = pd.Series([np.random.rand(128) for x in range(10)], name="vector")
ids = pd.Series(range(10), name="id")
df = pd.DataFrame([vecs, ids]).T
batch = pa.RecordBatch.from_pandas(
df,
schema=pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 128)),
pa.field("id", pa.int64()),
]
),
)
sink = pa.BufferOutputStream()
with pa.ipc.new_file(sink, batch.schema) as writer:
writer.write_batch(batch)
return web.Response(body=sink.getvalue().to_pybytes())
async def setup(self):
app = web.Application()
app.add_routes([web.post("/table/{table_name}", self.query_handler)])
self.runner = web.AppRunner(app)
await self.runner.setup()
self.site = web.TCPSite(self.runner, "localhost", 8111)
async def start(self):
await self.site.start()
async def stop(self):
await self.runner.cleanup()
@pytest.mark.skip(reason="flaky somehow, fix later")
@pytest.mark.asyncio
async def test_e2e_with_mock_server():
mock_server = MockLanceDBServer()
await mock_server.setup()
await mock_server.start()
try:
client = RestfulLanceDBClient("lancedb+http://localhost:8111")
df = (
await client.query(
"test_table",
VectorQuery(
vector=np.random.rand(128).tolist(),
k=10,
_metric="L2",
columns=["id", "vector"],
),
)
).to_pandas()
assert "vector" in df.columns
assert "id" in df.columns
finally:
# make sure we don't leak resources
await mock_server.stop()

View File

@@ -1,42 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pyarrow as pa
import lancedb
from lancedb.remote.client import VectorQuery, VectorQueryResult
class FakeLanceDBClient:
def close(self):
pass
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
assert table_name == "test"
t = pa.schema([]).empty_table()
return VectorQueryResult(t)
def post(self, path: str):
pass
def mount_retry_adapter_for_table(self, table_name: str):
pass
def test_remote_db():
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
setattr(conn, "_client", FakeLanceDBClient())
table = conn["test"]
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
table.search([1.0, 2.0]).to_pandas()

View File

@@ -1,259 +0,0 @@
import os
import numpy as np
import pytest
import lancedb
from lancedb.conftest import MockTextEmbeddingFunction # noqa
from lancedb.embeddings import EmbeddingFunctionRegistry
from lancedb.pydantic import LanceModel, Vector
from lancedb.rerankers import (
CohereReranker,
ColbertReranker,
CrossEncoderReranker,
OpenaiReranker,
)
from lancedb.table import LanceTable
def get_test_table(tmp_path):
db = lancedb.connect(tmp_path)
# Create a LanceDB table schema with a vector and a text column
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
class MyTable(LanceModel):
text: str = emb.SourceField()
vector: Vector(emb.ndims()) = emb.VectorField()
# Initialize the table using the schema
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
)
# Need to test with a bunch of phrases to make sure sorting is consistent
phrases = [
"great kid don't get cocky",
"now that's a name I haven't heard in a long time",
"if you strike me down I shall become more powerful than you imagine",
"I find your lack of faith disturbing",
"I've got a bad feeling about this",
"never tell me the odds",
"I am your father",
"somebody has to save our skins",
"New strategy R2 let the wookiee win",
"Arrrrggghhhhhhh",
"I see a mansard roof through the trees",
"I see a salty message written in the eves",
"the ground beneath my feet",
"the hot garbage and concrete",
"and now the tops of buildings",
"everybody with a worried mind could never forgive the sight",
"of wicked snakes inside a place you thought was dignified",
"I don't wanna live like this",
"but I don't wanna die",
"The templars want control",
"the brotherhood of assassins want freedom",
"if only they could both see the world as it really is",
"there would be peace",
"but the war goes on",
"altair's legacy was a warning",
"Kratos had a son",
"he was a god",
"the god of war",
"but his son was mortal",
"there hasn't been a good battlefield game since 2142",
"I wish they would make another one",
"campains are not as good as they used to be",
"Multiplayer and open world games have destroyed the single player experience",
"Maybe the future is console games",
"I don't know",
]
# Add the phrases and vectors to the table
table.add([{"text": p} for p in phrases])
# Create a fts index
table.create_fts_index("text")
return table, MyTable
def test_linear_combination(tmp_path):
table, schema = get_test_table(tmp_path)
# The default reranker
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score")
.to_pydantic(schema)
)
result2 = ( # noqa
table.search("Our father who art in heaven.", query_type="hybrid")
.rerank(normalize="rank")
.to_pydantic(schema)
)
result3 = table.search(
"Our father who art in heaven..", query_type="hybrid"
).to_pydantic(schema)
assert result1 == result3 # 2 & 3 should be the same as they use score as score
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
.limit(30)
.rerank(normalize="score")
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
@pytest.mark.skipif(
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
)
def test_cohere_reranker(tmp_path):
pytest.importorskip("cohere")
table, schema = get_test_table(tmp_path)
# The default reranker
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=CohereReranker())
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=CohereReranker())
.to_pydantic(schema)
)
assert result1 == result2
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
.limit(30)
.rerank(reranker=CohereReranker())
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
def test_cross_encoder_reranker(tmp_path):
pytest.importorskip("sentence_transformers")
table, schema = get_test_table(tmp_path)
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=CrossEncoderReranker())
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=CrossEncoderReranker())
.to_pydantic(schema)
)
assert result1 == result2
# test explicit hybrid query
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query), query_type="hybrid")
.limit(30)
.rerank(reranker=CrossEncoderReranker())
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
def test_colbert_reranker(tmp_path):
pytest.importorskip("transformers")
table, schema = get_test_table(tmp_path)
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=ColbertReranker())
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=ColbertReranker())
.to_pydantic(schema)
)
assert result1 == result2
# test explicit hybrid query
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
.limit(30)
.rerank(reranker=ColbertReranker())
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
)
def test_openai_reranker(tmp_path):
pytest.importorskip("openai")
table, schema = get_test_table(tmp_path)
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=OpenaiReranker())
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=OpenaiReranker())
.to_pydantic(schema)
)
assert result1 == result2
# test explicit hybrid query
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
.limit(30)
.rerank(reranker=OpenaiReranker())
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)

View File

@@ -1,926 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from copy import copy
from datetime import date, datetime, timedelta
from pathlib import Path
from time import sleep
from typing import List
from unittest.mock import PropertyMock, patch
import lance
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import pytest
from pydantic import BaseModel
import lancedb
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.db import LanceDBConnection
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from lancedb.pydantic import LanceModel, Vector
from lancedb.table import LanceTable
class MockDB:
def __init__(self, uri: Path):
self.uri = uri
self.read_consistency_interval = None
@functools.cached_property
def is_managed_remote(self) -> bool:
return False
@pytest.fixture
def db(tmp_path) -> MockDB:
return MockDB(tmp_path)
def test_basic(db):
ds = LanceTable.create(
db,
"test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
).to_lance()
table = LanceTable(db, "test")
assert table.name == "test"
assert table.schema == ds.schema
assert table.to_lance().to_table() == ds.to_table()
def test_create_table(db):
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2)),
pa.field("item", pa.string()),
pa.field("price", pa.float32()),
]
)
expected = pa.Table.from_arrays(
[
pa.FixedSizeListArray.from_arrays(pa.array([3.1, 4.1, 5.9, 26.5]), 2),
pa.array(["foo", "bar"]),
pa.array([10.0, 20.0]),
],
schema=schema,
)
data = [
[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
]
]
df = pd.DataFrame(data[0])
data.append(df)
data.append(pa.Table.from_pandas(df, schema=schema))
for i, d in enumerate(data):
tbl = (
LanceTable.create(db, f"test_{i}", data=d, schema=schema)
.to_lance()
.to_table()
)
assert expected == tbl
def test_empty_table(db):
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2)),
pa.field("item", pa.string()),
pa.field("price", pa.float32()),
]
)
tbl = LanceTable.create(db, "test", schema=schema)
data = [
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
]
tbl.add(data=data)
def test_add(db):
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2)),
pa.field("item", pa.string()),
pa.field("price", pa.float64()),
]
)
table = LanceTable.create(
db,
"test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
_add(table, schema)
table = LanceTable.create(db, "test2", schema=schema)
table.add(
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
_add(table, schema)
def test_add_pydantic_model(db):
# https://github.com/lancedb/lancedb/issues/562
class Metadata(BaseModel):
source: str
timestamp: datetime
class Document(BaseModel):
content: str
meta: Metadata
class LanceSchema(LanceModel):
id: str
vector: Vector(2)
li: List[int]
payload: Document
tbl = LanceTable.create(db, "mytable", schema=LanceSchema, mode="overwrite")
assert tbl.schema == LanceSchema.to_arrow_schema()
# add works
expected = LanceSchema(
id="id",
vector=[0.0, 0.0],
li=[1, 2, 3],
payload=Document(
content="foo", meta=Metadata(source="bar", timestamp=datetime.now())
),
)
tbl.add([expected])
result = tbl.search([0.0, 0.0]).limit(1).to_pydantic(LanceSchema)[0]
assert result == expected
flattened = tbl.search([0.0, 0.0]).limit(1).to_pandas(flatten=1)
assert len(flattened.columns) == 6 # _distance is automatically added
really_flattened = tbl.search([0.0, 0.0]).limit(1).to_pandas(flatten=True)
assert len(really_flattened.columns) == 7
def test_polars(db):
data = {
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["foo", "bar"],
"price": [10.0, 20.0],
}
# Ingest polars dataframe
table = LanceTable.create(db, "test", data=pl.DataFrame(data))
assert len(table) == 2
result = table.to_pandas()
assert np.allclose(result["vector"].tolist(), data["vector"])
assert result["item"].tolist() == data["item"]
assert np.allclose(result["price"].tolist(), data["price"])
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2)),
pa.field("item", pa.large_string()),
pa.field("price", pa.float64()),
]
)
assert table.schema == schema
# search results to polars dataframe
q = [3.1, 4.1]
result = table.search(q).limit(1).to_polars()
assert np.allclose(result["vector"][0], q)
assert result["item"][0] == "foo"
assert np.allclose(result["price"][0], 10.0)
# enter table to polars dataframe
result = table.to_polars()
assert np.allclose(result.collect()["vector"].to_list(), data["vector"])
# make sure filtering isn't broken
filtered_result = result.filter(pl.col("item").is_in(["foo", "bar"])).collect()
assert len(filtered_result) == 2
def _add(table, schema):
# table = LanceTable(db, "test")
assert len(table) == 2
table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])
assert len(table) == 3
expected = pa.Table.from_arrays(
[
pa.FixedSizeListArray.from_arrays(
pa.array([3.1, 4.1, 5.9, 26.5, 6.3, 100.5]), 2
),
pa.array(["foo", "bar", "new"]),
pa.array([10.0, 20.0, 30.0]),
],
schema=schema,
)
assert expected == table.to_arrow()
def test_versioning(db):
table = LanceTable.create(
db,
"test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
assert len(table.list_versions()) == 2
assert table.version == 2
table.add([{"vector": [6.3, 100.5], "item": "new", "price": 30.0}])
assert len(table.list_versions()) == 3
assert table.version == 3
assert len(table) == 3
table.checkout(2)
assert table.version == 2
assert len(table) == 2
def test_create_index_method():
with patch.object(
LanceTable, "_dataset_mut", new_callable=PropertyMock
) as mock_dataset:
# Setup mock responses
mock_dataset.return_value.create_index.return_value = None
# Create a LanceTable object
connection = LanceDBConnection(uri="mock.uri")
table = LanceTable(connection, "test_table")
# Call the create_index method
table.create_index(
metric="L2",
num_partitions=256,
num_sub_vectors=96,
vector_column_name="vector",
replace=True,
index_cache_size=256,
)
# Check that the _dataset.create_index method was called
# with the right parameters
mock_dataset.return_value.create_index.assert_called_once_with(
column="vector",
index_type="IVF_PQ",
metric="L2",
num_partitions=256,
num_sub_vectors=96,
replace=True,
accelerator=None,
index_cache_size=256,
)
def test_add_with_nans(db):
# by default we raise an error on bad input vectors
bad_data = [
{"vector": [np.nan], "item": "bar", "price": 20.0},
{"vector": [5], "item": "bar", "price": 20.0},
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
{"vector": [np.nan, 5.0], "item": "bar", "price": 20.0},
]
for row in bad_data:
with pytest.raises(ValueError):
LanceTable.create(
db,
"error_test",
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, row],
)
table = LanceTable.create(
db,
"drop_test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [np.nan], "item": "bar", "price": 20.0},
{"vector": [5], "item": "bar", "price": 20.0},
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
],
on_bad_vectors="drop",
)
assert len(table) == 1
# We can fill bad input with some value
table = LanceTable.create(
db,
"fill_test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [np.nan], "item": "bar", "price": 20.0},
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
],
on_bad_vectors="fill",
fill_value=0.0,
)
assert len(table) == 3
arrow_tbl = table.to_lance().to_table(filter="item == 'bar'")
v = arrow_tbl["vector"].to_pylist()[0]
assert np.allclose(v, np.array([0.0, 0.0]))
def test_restore(db):
table = LanceTable.create(
db,
"my_table",
data=[{"vector": [1.1, 0.9], "type": "vector"}],
)
table.add([{"vector": [0.5, 0.2], "type": "vector"}])
table.restore(2)
assert len(table.list_versions()) == 4
assert len(table) == 1
expected = table.to_arrow()
table.checkout(2)
table.restore()
assert len(table.list_versions()) == 5
assert table.to_arrow() == expected
table.restore(5) # latest version should be no-op
assert len(table.list_versions()) == 5
with pytest.raises(ValueError):
table.restore(6)
with pytest.raises(ValueError):
table.restore(0)
def test_merge(db, tmp_path):
table = LanceTable.create(
db,
"my_table",
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
)
other_table = pa.table({"document": ["foo", "bar"], "id": [0, 1]})
table.merge(other_table, left_on="id")
assert len(table.list_versions()) == 3
expected = pa.table(
{"vector": [[1.1, 0.9], [1.2, 1.9]], "id": [0, 1], "document": ["foo", "bar"]},
schema=table.schema,
)
assert table.to_arrow() == expected
other_dataset = lance.write_dataset(other_table, tmp_path / "other_table.lance")
table.restore(1)
table.merge(other_dataset, left_on="id")
def test_delete(db):
table = LanceTable.create(
db,
"my_table",
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
)
assert len(table) == 2
assert len(table.list_versions()) == 2
table.delete("id=0")
assert len(table.list_versions()) == 3
assert table.version == 3
assert len(table) == 1
assert table.to_pandas()["id"].tolist() == [1]
def test_update(db):
table = LanceTable.create(
db,
"my_table",
data=[{"vector": [1.1, 0.9], "id": 0}, {"vector": [1.2, 1.9], "id": 1}],
)
assert len(table) == 2
assert len(table.list_versions()) == 2
table.update(where="id=0", values={"vector": [1.1, 1.1]})
assert len(table.list_versions()) == 3
assert table.version == 3
assert len(table) == 2
v = table.to_arrow()["vector"].combine_chunks()
v = v.values.to_numpy().reshape(2, 2)
assert np.allclose(v, np.array([[1.2, 1.9], [1.1, 1.1]]))
def test_update_types(db):
table = LanceTable.create(
db,
"my_table",
data=[
{
"id": 0,
"str": "foo",
"float": 1.1,
"timestamp": datetime(2021, 1, 1),
"date": date(2021, 1, 1),
"vector1": [1.0, 0.0],
"vector2": [1.0, 1.0],
}
],
)
# Update with SQL
table.update(
values_sql=dict(
id="1",
str="'bar'",
float="2.2",
timestamp="TIMESTAMP '2021-01-02 00:00:00'",
date="DATE '2021-01-02'",
vector1="[2.0, 2.0]",
vector2="[3.0, 3.0]",
)
)
actual = table.to_arrow().to_pylist()[0]
expected = dict(
id=1,
str="bar",
float=2.2,
timestamp=datetime(2021, 1, 2),
date=date(2021, 1, 2),
vector1=[2.0, 2.0],
vector2=[3.0, 3.0],
)
assert actual == expected
# Update with values
table.update(
values=dict(
id=2,
str="baz",
float=3.3,
timestamp=datetime(2021, 1, 3),
date=date(2021, 1, 3),
vector1=[3.0, 3.0],
vector2=np.array([4.0, 4.0]),
)
)
actual = table.to_arrow().to_pylist()[0]
expected = dict(
id=2,
str="baz",
float=3.3,
timestamp=datetime(2021, 1, 3),
date=date(2021, 1, 3),
vector1=[3.0, 3.0],
vector2=[4.0, 4.0],
)
assert actual == expected
def test_merge_insert(db):
table = LanceTable.create(
db,
"my_table",
data=pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]}),
)
assert len(table) == 3
version = table.version
new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]})
# upsert
table.merge_insert(
"a"
).when_matched_update_all().when_not_matched_insert_all().execute(new_data)
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "x", "y", "z"]})
assert table.to_arrow().sort_by("a") == expected
table.restore(version)
# conditional update
table.merge_insert("a").when_matched_update_all(where="target.b = 'b'").execute(
new_data
)
expected = pa.table({"a": [1, 2, 3], "b": ["a", "x", "c"]})
assert table.to_arrow().sort_by("a") == expected
table.restore(version)
# insert-if-not-exists
table.merge_insert("a").when_not_matched_insert_all().execute(new_data)
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "b", "c", "z"]})
assert table.to_arrow().sort_by("a") == expected
table.restore(version)
new_data = pa.table({"a": [2, 4], "b": ["x", "z"]})
# replace-range
table.merge_insert(
"a"
).when_matched_update_all().when_not_matched_insert_all().when_not_matched_by_source_delete(
"a > 2"
).execute(new_data)
expected = pa.table({"a": [1, 2, 4], "b": ["a", "x", "z"]})
assert table.to_arrow().sort_by("a") == expected
table.restore(version)
# replace-range no condition
table.merge_insert(
"a"
).when_matched_update_all().when_not_matched_insert_all().when_not_matched_by_source_delete().execute(
new_data
)
expected = pa.table({"a": [2, 4], "b": ["x", "z"]})
assert table.to_arrow().sort_by("a") == expected
def test_create_with_embedding_function(db):
class MyTable(LanceModel):
text: str
vector: Vector(10)
func = MockTextEmbeddingFunction()
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
df = pd.DataFrame({"text": texts, "vector": func.compute_source_embeddings(texts)})
conf = EmbeddingFunctionConfig(
source_column="text", vector_column="vector", function=func
)
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
embedding_functions=[conf],
)
table.add(df)
query_str = "hi how are you?"
query_vector = func.compute_query_embeddings(query_str)[0]
expected = table.search(query_vector).limit(2).to_arrow()
actual = table.search(query_str).limit(2).to_arrow()
assert actual == expected
def test_create_f16_table(db):
class MyTable(LanceModel):
text: str
vector: Vector(128, value_type=pa.float16())
df = pd.DataFrame(
{
"text": [f"s-{i}" for i in range(10000)],
"vector": [np.random.randn(128).astype(np.float16) for _ in range(10000)],
}
)
table = LanceTable.create(
db,
"f16_tbl",
schema=MyTable,
)
table.add(df)
table.create_index(num_partitions=2, num_sub_vectors=8)
query = df.vector.iloc[2]
expected = table.search(query).limit(2).to_arrow()
assert "s-2" in expected["text"].to_pylist()
def test_add_with_embedding_function(db):
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
class MyTable(LanceModel):
text: str = emb.SourceField()
vector: Vector(emb.ndims()) = emb.VectorField()
table = LanceTable.create(db, "my_table", schema=MyTable)
texts = ["hello world", "goodbye world", "foo bar baz fizz buzz"]
df = pd.DataFrame({"text": texts})
table.add(df)
texts = ["the quick brown fox", "jumped over the lazy dog"]
table.add([{"text": t} for t in texts])
query_str = "hi how are you?"
query_vector = emb.compute_query_embeddings(query_str)[0]
expected = table.search(query_vector).limit(2).to_arrow()
actual = table.search(query_str).limit(2).to_arrow()
assert actual == expected
def test_multiple_vector_columns(db):
class MyTable(LanceModel):
text: str
vector1: Vector(10)
vector2: Vector(10)
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
)
v1 = np.random.randn(10)
v2 = np.random.randn(10)
data = [
{"vector1": v1, "vector2": v2, "text": "foo"},
{"vector1": v2, "vector2": v1, "text": "bar"},
]
df = pd.DataFrame(data)
table.add(df)
q = np.random.randn(10)
result1 = table.search(q, vector_column_name="vector1").limit(1).to_pandas()
result2 = table.search(q, vector_column_name="vector2").limit(1).to_pandas()
assert result1["text"].iloc[0] != result2["text"].iloc[0]
def test_create_scalar_index(db):
vec_array = pa.array(
[[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]], pa.list_(pa.float32(), 2)
)
test_data = pa.Table.from_pydict(
{"x": ["c", "b", "a", "e", "b"], "y": [1, 2, 3, 4, 5], "vector": vec_array}
)
table = LanceTable.create(
db,
"my_table",
data=test_data,
)
table.create_scalar_index("x")
indices = table.to_lance().list_indices()
assert len(indices) == 1
scalar_index = indices[0]
assert scalar_index["type"] == "Scalar"
# Confirm that prefiltering still works with the scalar index column
results = table.search().where("x = 'c'").to_arrow()
assert results == test_data.slice(0, 1)
results = table.search([5, 5]).to_arrow()
assert results["_distance"][0].as_py() == 0
results = table.search([5, 5]).where("x != 'b'").to_arrow()
assert results["_distance"][0].as_py() > 0
def test_empty_query(db):
table = LanceTable.create(
db,
"my_table",
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
)
df = table.search().select(["id"]).where("text='bar'").limit(1).to_pandas()
val = df.id.iloc[0]
assert val == 1
table = LanceTable.create(db, "my_table2", data=[{"id": i} for i in range(100)])
df = table.search().select(["id"]).to_pandas()
assert len(df) == 10
df = table.search().select(["id"]).limit(None).to_pandas()
assert len(df) == 100
df = table.search().select(["id"]).limit(-1).to_pandas()
assert len(df) == 100
def test_search_with_schema_inf_single_vector(db):
class MyTable(LanceModel):
text: str
vector_col: Vector(10)
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
)
v1 = np.random.randn(10)
v2 = np.random.randn(10)
data = [
{"vector_col": v1, "text": "foo"},
{"vector_col": v2, "text": "bar"},
]
df = pd.DataFrame(data)
table.add(df)
q = np.random.randn(10)
result1 = table.search(q, vector_column_name="vector_col").limit(1).to_pandas()
result2 = table.search(q).limit(1).to_pandas()
assert result1["text"].iloc[0] == result2["text"].iloc[0]
def test_search_with_schema_inf_multiple_vector(db):
class MyTable(LanceModel):
text: str
vector1: Vector(10)
vector2: Vector(10)
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
)
v1 = np.random.randn(10)
v2 = np.random.randn(10)
data = [
{"vector1": v1, "vector2": v2, "text": "foo"},
{"vector1": v2, "vector2": v1, "text": "bar"},
]
df = pd.DataFrame(data)
table.add(df)
q = np.random.randn(10)
with pytest.raises(ValueError):
table.search(q).limit(1).to_pandas()
def test_compact_cleanup(db):
table = LanceTable.create(
db,
"my_table",
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
)
table.add([{"text": "baz", "id": 2}])
assert len(table) == 3
assert table.version == 3
stats = table.compact_files()
assert len(table) == 3
# Compact_files bump 2 versions.
assert table.version == 5
assert stats.fragments_removed > 0
assert stats.fragments_added == 1
stats = table.cleanup_old_versions()
assert stats.bytes_removed == 0
stats = table.cleanup_old_versions(older_than=timedelta(0), delete_unverified=True)
assert stats.bytes_removed > 0
assert table.version == 5
with pytest.raises(Exception, match="Version 3 no longer exists"):
table.checkout(3)
def test_count_rows(db):
table = LanceTable.create(
db,
"my_table",
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
)
assert len(table) == 2
assert table.count_rows() == 2
assert table.count_rows(filter="text='bar'") == 1
def test_hybrid_search(db, tmp_path):
db = MockDB(str(tmp_path))
# Create a LanceDB table schema with a vector and a text column
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
class MyTable(LanceModel):
text: str = emb.SourceField()
vector: Vector(emb.ndims()) = emb.VectorField()
# Initialize the table using the schema
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
)
# Create a list of 10 unique english phrases
phrases = [
"great kid don't get cocky",
"now that's a name I haven't heard in a long time",
"if you strike me down I shall become more powerful than you imagine",
"I find your lack of faith disturbing",
"I've got a bad feeling about this",
"never tell me the odds",
"I am your father",
"somebody has to save our skins",
"New strategy R2 let the wookiee win",
"Arrrrggghhhhhhh",
]
# Add the phrases and vectors to the table
table.add([{"text": p} for p in phrases])
# Create a fts index
table.create_fts_index("text")
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score")
.to_pydantic(MyTable)
)
result2 = ( # noqa
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="rank")
.to_pydantic(MyTable)
)
result3 = table.search(
"Our father who art in heaven", query_type="hybrid"
).to_pydantic(MyTable)
assert result1 == result3
@pytest.mark.parametrize(
"consistency_interval", [None, timedelta(seconds=0), timedelta(seconds=0.1)]
)
def test_consistency(tmp_path, consistency_interval):
db = lancedb.connect(tmp_path)
table = LanceTable.create(db, "my_table", data=[{"id": 0}])
db2 = lancedb.connect(tmp_path, read_consistency_interval=consistency_interval)
table2 = db2.open_table("my_table")
assert table2.version == table.version
table.add([{"id": 1}])
if consistency_interval is None:
assert table2.version == table.version - 1
table2.checkout_latest()
assert table2.version == table.version
elif consistency_interval == timedelta(seconds=0):
assert table2.version == table.version
else:
# (consistency_interval == timedelta(seconds=0.1)
assert table2.version == table.version - 1
sleep(0.1)
assert table2.version == table.version
def test_restore_consistency(tmp_path):
db = lancedb.connect(tmp_path)
table = LanceTable.create(db, "my_table", data=[{"id": 0}])
db2 = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0))
table2 = db2.open_table("my_table")
assert table2.version == table.version
# If we call checkout, it should lose consistency
table_fixed = copy(table2)
table_fixed.checkout(table.version)
# But if we call checkout_latest, it should be consistent again
table_ref_latest = copy(table_fixed)
table_ref_latest.checkout_latest()
table.add([{"id": 2}])
assert table_fixed.version == table.version - 1
assert table_ref_latest.version == table.version
# Schema evolution
def test_add_columns(tmp_path):
db = lancedb.connect(tmp_path)
data = pa.table({"id": [0, 1]})
table = LanceTable.create(db, "my_table", data=data)
table.add_columns({"new_col": "id + 2"})
assert table.to_arrow().column_names == ["id", "new_col"]
assert table.to_arrow()["new_col"].to_pylist() == [2, 3]
def test_alter_columns(tmp_path):
db = lancedb.connect(tmp_path)
data = pa.table({"id": [0, 1]})
table = LanceTable.create(db, "my_table", data=data)
table.alter_columns({"path": "id", "rename": "new_id"})
assert table.to_arrow().column_names == ["new_id"]
def test_drop_columns(tmp_path):
db = lancedb.connect(tmp_path)
data = pa.table({"id": [0, 1], "category": ["a", "b"]})
table = LanceTable.create(db, "my_table", data=data)
table.drop_columns(["category"])
assert table.to_arrow().column_names == ["id"]

View File

@@ -1,61 +0,0 @@
import json
import pytest
import lancedb
from lancedb.utils.events import _Events
@pytest.fixture(autouse=True)
def request_log_path(tmp_path):
return tmp_path / "request.json"
def mock_register_event(name: str, **kwargs):
if _Events._instance is None:
_Events._instance = _Events()
_Events._instance.enabled = True
_Events._instance.rate_limit = 0
_Events._instance(name, **kwargs)
def test_event_reporting(monkeypatch, request_log_path, tmp_path) -> None:
def mock_request(**kwargs):
json_data = kwargs.get("json", {})
with open(request_log_path, "w") as f:
json.dump(json_data, f)
monkeypatch.setattr(
lancedb.table, "register_event", mock_register_event
) # Force enable registering events and strip exception handling
monkeypatch.setattr(lancedb.utils.events, "threaded_request", mock_request)
db = lancedb.connect(tmp_path)
db.create_table(
"test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
mode="overwrite",
)
assert request_log_path.exists() # test if event was registered
with open(request_log_path, "r") as f:
json_data = json.load(f)
# TODO: don't hardcode these here. Instead create a module level json scehma in
# lancedb.utils.events for better evolvability
batch_keys = ["api_key", "distinct_id", "batch"]
event_keys = ["event", "properties", "timestamp", "distinct_id"]
property_keys = ["cli", "install", "platforms", "version", "session_id"]
assert all([key in json_data for key in batch_keys])
assert all([key in json_data["batch"][0] for key in event_keys])
assert all([key in json_data["batch"][0]["properties"] for key in property_keys])
# cleanup & reset
monkeypatch.undo()
_Events._instance = None

View File

@@ -1,87 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import pytest
from lancedb.util import get_uri_scheme, join_uri
def test_normalize_uri():
uris = [
"relative/path",
"/absolute/path",
"file:///absolute/path",
"s3://bucket/path",
"gs://bucket/path",
"c:\\windows\\path",
]
schemes = ["file", "file", "file", "s3", "gs", "file"]
for uri, expected_scheme in zip(uris, schemes):
parsed_scheme = get_uri_scheme(uri)
assert parsed_scheme == expected_scheme
def test_join_uri_remote():
schemes = ["s3", "az", "gs"]
for scheme in schemes:
expected = f"{scheme}://bucket/path/to/table.lance"
base_uri = f"{scheme}://bucket/path/to/"
parts = ["table.lance"]
assert join_uri(base_uri, *parts) == expected
base_uri = f"{scheme}://bucket"
parts = ["path", "to", "table.lance"]
assert join_uri(base_uri, *parts) == expected
# skip this test if on windows
@pytest.mark.skipif(os.name == "nt", reason="Windows paths are not POSIX")
def test_join_uri_posix():
for base in [
# relative path
"relative/path",
"relative/path/",
# an absolute path
"/absolute/path",
"/absolute/path/",
# a file URI
"file:///absolute/path",
"file:///absolute/path/",
]:
joined = join_uri(base, "table.lance")
assert joined == str(pathlib.Path(base) / "table.lance")
joined = join_uri(pathlib.Path(base), "table.lance")
assert joined == pathlib.Path(base) / "table.lance"
# skip this test if not on windows
@pytest.mark.skipif(os.name != "nt", reason="Windows paths are not POSIX")
def test_local_join_uri_windows():
# https://learn.microsoft.com/en-us/dotnet/standard/io/file-path-formats
for base in [
# windows relative path
"relative\\path",
"relative\\path\\",
# windows absolute path from current drive
"c:\\absolute\\path",
# relative path from root of current drive
"\\relative\\path",
]:
joined = join_uri(base, "table.lance")
assert joined == str(pathlib.Path(base) / "table.lance")
joined = join_uri(pathlib.Path(base), "table.lance")
assert joined == pathlib.Path(base) / "table.lance"