mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 22:59:57 +00:00
This mimics CREATE TABLE IF NOT EXISTS behavior. We add `db.create_table(..., exist_ok=True)` parameter. By default it is set to False, so trying to create a table with the same name will raise an exception. If set to True, then it only opens the table if it already exists. If you pass in a schema, it will be checked against the existing table to make sure you get what you want. If you pass in data, it will NOT be added to the existing table.
372 lines
10 KiB
Python
372 lines
10 KiB
Python
# Copyright 2023 LanceDB Developers
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import pyarrow as pa
|
|
import pytest
|
|
|
|
import lancedb
|
|
from lancedb.pydantic import LanceModel, Vector
|
|
|
|
|
|
def test_basic(tmp_path):
|
|
db = lancedb.connect(tmp_path)
|
|
|
|
assert db.uri == str(tmp_path)
|
|
assert db.table_names() == []
|
|
|
|
table = db.create_table(
|
|
"test",
|
|
data=[
|
|
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
|
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
|
],
|
|
)
|
|
rs = table.search([100, 100]).limit(1).to_pandas()
|
|
assert len(rs) == 1
|
|
assert rs["item"].iloc[0] == "bar"
|
|
|
|
rs = table.search([100, 100]).where("price < 15").limit(2).to_pandas()
|
|
assert len(rs) == 1
|
|
assert rs["item"].iloc[0] == "foo"
|
|
|
|
assert db.table_names() == ["test"]
|
|
assert "test" in db
|
|
assert len(db) == 1
|
|
|
|
assert db.open_table("test").name == db["test"].name
|
|
|
|
|
|
def test_ingest_pd(tmp_path):
|
|
db = lancedb.connect(tmp_path)
|
|
|
|
assert db.uri == str(tmp_path)
|
|
assert db.table_names() == []
|
|
|
|
data = pd.DataFrame(
|
|
{
|
|
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
|
"item": ["foo", "bar"],
|
|
"price": [10.0, 20.0],
|
|
}
|
|
)
|
|
table = db.create_table("test", data=data)
|
|
rs = table.search([100, 100]).limit(1).to_pandas()
|
|
assert len(rs) == 1
|
|
assert rs["item"].iloc[0] == "bar"
|
|
|
|
rs = table.search([100, 100]).where("price < 15").limit(2).to_pandas()
|
|
assert len(rs) == 1
|
|
assert rs["item"].iloc[0] == "foo"
|
|
|
|
assert db.table_names() == ["test"]
|
|
assert "test" in db
|
|
assert len(db) == 1
|
|
|
|
assert db.open_table("test").name == db["test"].name
|
|
|
|
|
|
def test_ingest_iterator(tmp_path):
|
|
class PydanticSchema(LanceModel):
|
|
vector: Vector(2)
|
|
item: str
|
|
price: float
|
|
|
|
arrow_schema = pa.schema(
|
|
[
|
|
pa.field("vector", pa.list_(pa.float32(), 2)),
|
|
pa.field("item", pa.utf8()),
|
|
pa.field("price", pa.float32()),
|
|
]
|
|
)
|
|
|
|
def make_batches():
|
|
for _ in range(5):
|
|
yield from [
|
|
# pandas
|
|
pd.DataFrame(
|
|
{
|
|
"vector": [[3.1, 4.1], [1, 1]],
|
|
"item": ["foo", "bar"],
|
|
"price": [10.0, 20.0],
|
|
}
|
|
),
|
|
# pylist
|
|
[
|
|
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
|
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
|
],
|
|
# recordbatch
|
|
pa.RecordBatch.from_arrays(
|
|
[
|
|
pa.array([[3.1, 4.1], [5.9, 26.5]], pa.list_(pa.float32(), 2)),
|
|
pa.array(["foo", "bar"]),
|
|
pa.array([10.0, 20.0]),
|
|
],
|
|
["vector", "item", "price"],
|
|
),
|
|
# pa Table
|
|
pa.Table.from_arrays(
|
|
[
|
|
pa.array([[3.1, 4.1], [5.9, 26.5]], pa.list_(pa.float32(), 2)),
|
|
pa.array(["foo", "bar"]),
|
|
pa.array([10.0, 20.0]),
|
|
],
|
|
["vector", "item", "price"],
|
|
),
|
|
# pydantic list
|
|
[
|
|
PydanticSchema(vector=[3.1, 4.1], item="foo", price=10.0),
|
|
PydanticSchema(vector=[5.9, 26.5], item="bar", price=20.0),
|
|
],
|
|
# TODO: test pydict separately. it is unique column number and names contraint
|
|
]
|
|
|
|
def run_tests(schema):
|
|
db = lancedb.connect(tmp_path)
|
|
tbl = db.create_table("table2", make_batches(), schema=schema, mode="overwrite")
|
|
tbl.to_pandas()
|
|
assert tbl.search([3.1, 4.1]).limit(1).to_pandas()["_distance"][0] == 0.0
|
|
assert tbl.search([5.9, 26.5]).limit(1).to_pandas()["_distance"][0] == 0.0
|
|
tbl_len = len(tbl)
|
|
tbl.add(make_batches())
|
|
assert tbl_len == 50
|
|
assert len(tbl) == tbl_len * 2
|
|
assert len(tbl.list_versions()) == 3
|
|
db.drop_database()
|
|
|
|
run_tests(arrow_schema)
|
|
run_tests(PydanticSchema)
|
|
|
|
|
|
def test_table_names(tmp_path):
|
|
db = lancedb.connect(tmp_path)
|
|
data = pd.DataFrame(
|
|
{
|
|
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
|
"item": ["foo", "bar"],
|
|
"price": [10.0, 20.0],
|
|
}
|
|
)
|
|
db.create_table("test2", data=data)
|
|
db.create_table("test1", data=data)
|
|
db.create_table("test3", data=data)
|
|
assert db.table_names() == ["test1", "test2", "test3"]
|
|
|
|
|
|
def test_create_mode(tmp_path):
|
|
db = lancedb.connect(tmp_path)
|
|
data = pd.DataFrame(
|
|
{
|
|
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
|
"item": ["foo", "bar"],
|
|
"price": [10.0, 20.0],
|
|
}
|
|
)
|
|
db.create_table("test", data=data)
|
|
|
|
with pytest.raises(Exception):
|
|
db.create_table("test", data=data)
|
|
|
|
new_data = pd.DataFrame(
|
|
{
|
|
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
|
"item": ["fizz", "buzz"],
|
|
"price": [10.0, 20.0],
|
|
}
|
|
)
|
|
tbl = db.create_table("test", data=new_data, mode="overwrite")
|
|
assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"]
|
|
|
|
|
|
def test_create_exist_ok(tmp_path):
|
|
db = lancedb.connect(tmp_path)
|
|
data = pd.DataFrame(
|
|
{
|
|
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
|
"item": ["foo", "bar"],
|
|
"price": [10.0, 20.0],
|
|
}
|
|
)
|
|
tbl = db.create_table("test", data=data)
|
|
|
|
with pytest.raises(OSError):
|
|
db.create_table("test", data=data)
|
|
|
|
# open the table but don't add more rows
|
|
tbl2 = db.create_table("test", data=data, exist_ok=True)
|
|
assert tbl.name == tbl2.name
|
|
assert tbl.schema == tbl2.schema
|
|
assert len(tbl) == len(tbl2)
|
|
|
|
schema = pa.schema(
|
|
[
|
|
pa.field("vector", pa.list_(pa.float32(), list_size=2)),
|
|
pa.field("item", pa.utf8()),
|
|
pa.field("price", pa.float64()),
|
|
]
|
|
)
|
|
tbl3 = db.create_table("test", schema=schema, exist_ok=True)
|
|
assert tbl3.schema == schema
|
|
|
|
bad_schema = pa.schema(
|
|
[
|
|
pa.field("vector", pa.list_(pa.float32(), list_size=2)),
|
|
pa.field("item", pa.utf8()),
|
|
pa.field("price", pa.float64()),
|
|
pa.field("extra", pa.float32()),
|
|
]
|
|
)
|
|
with pytest.raises(ValueError):
|
|
db.create_table("test", schema=bad_schema, exist_ok=True)
|
|
|
|
|
|
def test_delete_table(tmp_path):
|
|
db = lancedb.connect(tmp_path)
|
|
data = pd.DataFrame(
|
|
{
|
|
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
|
"item": ["foo", "bar"],
|
|
"price": [10.0, 20.0],
|
|
}
|
|
)
|
|
db.create_table("test", data=data)
|
|
|
|
with pytest.raises(Exception):
|
|
db.create_table("test", data=data)
|
|
|
|
assert db.table_names() == ["test"]
|
|
|
|
db.drop_table("test")
|
|
assert db.table_names() == []
|
|
|
|
db.create_table("test", data=data)
|
|
assert db.table_names() == ["test"]
|
|
|
|
# dropping a table that does not exist should pass
|
|
# if ignore_missing=True
|
|
db.drop_table("does_not_exist", ignore_missing=True)
|
|
|
|
|
|
def test_drop_database(tmp_path):
|
|
db = lancedb.connect(tmp_path)
|
|
data = pd.DataFrame(
|
|
{
|
|
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
|
"item": ["foo", "bar"],
|
|
"price": [10.0, 20.0],
|
|
}
|
|
)
|
|
new_data = pd.DataFrame(
|
|
{
|
|
"vector": [[5.1, 4.1], [5.9, 10.5]],
|
|
"item": ["kiwi", "avocado"],
|
|
"price": [12.0, 17.0],
|
|
}
|
|
)
|
|
db.create_table("test", data=data)
|
|
with pytest.raises(Exception):
|
|
db.create_table("test", data=data)
|
|
|
|
assert db.table_names() == ["test"]
|
|
|
|
db.create_table("new_test", data=new_data)
|
|
db.drop_database()
|
|
assert db.table_names() == []
|
|
|
|
# it should pass when no tables are present
|
|
db.create_table("test", data=new_data)
|
|
db.drop_table("test")
|
|
assert db.table_names() == []
|
|
db.drop_database()
|
|
assert db.table_names() == []
|
|
|
|
# creating an empty database with schema
|
|
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), list_size=2))])
|
|
db.create_table("empty_table", schema=schema)
|
|
# dropping a empty database should pass
|
|
db.drop_database()
|
|
assert db.table_names() == []
|
|
|
|
|
|
def test_empty_or_nonexistent_table(tmp_path):
|
|
db = lancedb.connect(tmp_path)
|
|
with pytest.raises(Exception):
|
|
db.create_table("test_with_no_data")
|
|
|
|
with pytest.raises(Exception):
|
|
db.open_table("does_not_exist")
|
|
|
|
schema = pa.schema([pa.field("a", pa.int64(), nullable=False)])
|
|
test = db.create_table("test", schema=schema)
|
|
|
|
class TestModel(LanceModel):
|
|
a: int
|
|
|
|
test2 = db.create_table("test2", schema=TestModel)
|
|
assert test.schema == test2.schema
|
|
|
|
|
|
def test_replace_index(tmp_path):
|
|
db = lancedb.connect(uri=tmp_path)
|
|
table = db.create_table(
|
|
"test",
|
|
[
|
|
{"vector": np.random.rand(128), "item": "foo", "price": float(i)}
|
|
for i in range(1000)
|
|
],
|
|
)
|
|
table.create_index(
|
|
num_partitions=2,
|
|
num_sub_vectors=4,
|
|
)
|
|
|
|
with pytest.raises(Exception):
|
|
table.create_index(
|
|
num_partitions=2,
|
|
num_sub_vectors=4,
|
|
replace=False,
|
|
)
|
|
|
|
table.create_index(
|
|
num_partitions=2,
|
|
num_sub_vectors=4,
|
|
replace=True,
|
|
index_cache_size=10,
|
|
)
|
|
|
|
|
|
def test_prefilter_with_index(tmp_path):
|
|
db = lancedb.connect(uri=tmp_path)
|
|
data = [
|
|
{"vector": np.random.rand(128), "item": "foo", "price": float(i)}
|
|
for i in range(1000)
|
|
]
|
|
sample_key = data[100]["vector"]
|
|
table = db.create_table(
|
|
"test",
|
|
data,
|
|
)
|
|
table.create_index(
|
|
num_partitions=2,
|
|
num_sub_vectors=4,
|
|
)
|
|
table = (
|
|
table.search(sample_key)
|
|
.where("price == 500", prefilter=True)
|
|
.limit(5)
|
|
.to_arrow()
|
|
)
|
|
assert table.num_rows == 1
|