feat(python): async-sync feature parity on Table (#1914)

### Changes to sync API
* Updated `LanceTable` and `LanceDBConnection` reprs
* Add `storage_options`, `data_storage_version`, and
`enable_v2_manifest_paths` to sync create table API.
* Add `storage_options` to `open_table` in sync API.
* Add `list_indices()` and `index_stats()` to sync API
* `create_table()` will now create only 1 version when data is passed.
Previously it would always create two versions: 1 to create an empty
table and 1 to add data to it.

### Changes to async API
* Add `embedding_functions` to async `create_table()` API.
* Added `head()` to async API

### Refactors
* Refactor index parameters into dataclasses so they are easier to use
from Python
* Moved most tests to use an in-memory DB so we don't need to create so
many temp directories

Closes #1792
Closes #1932

---------

Co-authored-by: Weston Pace <weston.pace@gmail.com>
This commit is contained in:
Will Jones
2024-12-13 12:56:44 -08:00
committed by GitHub
parent d83e5a0208
commit 980aa70e2d
23 changed files with 1296 additions and 1324 deletions

View File

@@ -98,7 +98,7 @@ def test_ingest_pd(tmp_path):
assert db.open_table("test").name == db["test"].name
def test_ingest_iterator(tmp_path):
def test_ingest_iterator(mem_db: lancedb.DBConnection):
class PydanticSchema(LanceModel):
vector: Vector(2)
item: str
@@ -156,8 +156,7 @@ def test_ingest_iterator(tmp_path):
]
def run_tests(schema):
db = lancedb.connect(tmp_path)
tbl = db.create_table("table2", make_batches(), schema=schema, mode="overwrite")
tbl = mem_db.create_table("table2", make_batches(), schema=schema)
tbl.to_pandas()
assert tbl.search([3.1, 4.1]).limit(1).to_pandas()["_distance"][0] == 0.0
assert tbl.search([5.9, 26.5]).limit(1).to_pandas()["_distance"][0] == 0.0
@@ -165,15 +164,14 @@ def test_ingest_iterator(tmp_path):
tbl.add(make_batches())
assert tbl_len == 50
assert len(tbl) == tbl_len * 2
assert len(tbl.list_versions()) == 3
db.drop_database()
assert len(tbl.list_versions()) == 2
mem_db.drop_database()
run_tests(arrow_schema)
run_tests(PydanticSchema)
def test_table_names(tmp_path):
db = lancedb.connect(tmp_path)
def test_table_names(tmp_db: lancedb.DBConnection):
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -181,10 +179,10 @@ def test_table_names(tmp_path):
"price": [10.0, 20.0],
}
)
db.create_table("test2", data=data)
db.create_table("test1", data=data)
db.create_table("test3", data=data)
assert db.table_names() == ["test1", "test2", "test3"]
tmp_db.create_table("test2", data=data)
tmp_db.create_table("test1", data=data)
tmp_db.create_table("test3", data=data)
assert tmp_db.table_names() == ["test1", "test2", "test3"]
@pytest.mark.asyncio
@@ -209,8 +207,7 @@ async def test_table_names_async(tmp_path):
assert await db.table_names(start_after="test1") == ["test2", "test3"]
def test_create_mode(tmp_path):
db = lancedb.connect(tmp_path)
def test_create_mode(tmp_db: lancedb.DBConnection):
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -218,10 +215,10 @@ def test_create_mode(tmp_path):
"price": [10.0, 20.0],
}
)
db.create_table("test", data=data)
tmp_db.create_table("test", data=data)
with pytest.raises(Exception):
db.create_table("test", data=data)
tmp_db.create_table("test", data=data)
new_data = pd.DataFrame(
{
@@ -230,13 +227,11 @@ def test_create_mode(tmp_path):
"price": [10.0, 20.0],
}
)
tbl = db.create_table("test", data=new_data, mode="overwrite")
tbl = tmp_db.create_table("test", data=new_data, mode="overwrite")
assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"]
def test_create_table_from_iterator(tmp_path):
db = lancedb.connect(tmp_path)
def test_create_table_from_iterator(mem_db: lancedb.DBConnection):
def gen_data():
for _ in range(10):
yield pa.RecordBatch.from_arrays(
@@ -248,14 +243,12 @@ def test_create_table_from_iterator(tmp_path):
["vector", "item", "price"],
)
table = db.create_table("test", data=gen_data())
table = mem_db.create_table("test", data=gen_data())
assert table.count_rows() == 10
@pytest.mark.asyncio
async def test_create_table_from_iterator_async(tmp_path):
db = await lancedb.connect_async(tmp_path)
async def test_create_table_from_iterator_async(mem_db_async: lancedb.AsyncConnection):
def gen_data():
for _ in range(10):
yield pa.RecordBatch.from_arrays(
@@ -267,12 +260,11 @@ async def test_create_table_from_iterator_async(tmp_path):
["vector", "item", "price"],
)
table = await db.create_table("test", data=gen_data())
table = await mem_db_async.create_table("test", data=gen_data())
assert await table.count_rows() == 10
def test_create_exist_ok(tmp_path):
db = lancedb.connect(tmp_path)
def test_create_exist_ok(tmp_db: lancedb.DBConnection):
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -280,13 +272,13 @@ def test_create_exist_ok(tmp_path):
"price": [10.0, 20.0],
}
)
tbl = db.create_table("test", data=data)
tbl = tmp_db.create_table("test", data=data)
with pytest.raises(OSError):
db.create_table("test", data=data)
with pytest.raises(ValueError):
tmp_db.create_table("test", data=data)
# open the table but don't add more rows
tbl2 = db.create_table("test", data=data, exist_ok=True)
tbl2 = tmp_db.create_table("test", data=data, exist_ok=True)
assert tbl.name == tbl2.name
assert tbl.schema == tbl2.schema
assert len(tbl) == len(tbl2)
@@ -298,7 +290,7 @@ def test_create_exist_ok(tmp_path):
pa.field("price", pa.float64()),
]
)
tbl3 = db.create_table("test", schema=schema, exist_ok=True)
tbl3 = tmp_db.create_table("test", schema=schema, exist_ok=True)
assert tbl3.schema == schema
bad_schema = pa.schema(
@@ -310,7 +302,7 @@ def test_create_exist_ok(tmp_path):
]
)
with pytest.raises(ValueError):
db.create_table("test", schema=bad_schema, exist_ok=True)
tmp_db.create_table("test", schema=bad_schema, exist_ok=True)
@pytest.mark.asyncio
@@ -325,26 +317,24 @@ async def test_connect(tmp_path):
@pytest.mark.asyncio
async def test_close(tmp_path):
db = await lancedb.connect_async(tmp_path)
assert db.is_open()
db.close()
assert not db.is_open()
async def test_close(mem_db_async: lancedb.AsyncConnection):
assert mem_db_async.is_open()
mem_db_async.close()
assert not mem_db_async.is_open()
with pytest.raises(RuntimeError, match="is closed"):
await db.table_names()
await mem_db_async.table_names()
@pytest.mark.asyncio
async def test_context_manager(tmp_path):
with await lancedb.connect_async(tmp_path) as db:
async def test_context_manager():
with await lancedb.connect_async("memory://") as db:
assert db.is_open()
assert not db.is_open()
@pytest.mark.asyncio
async def test_create_mode_async(tmp_path):
db = await lancedb.connect_async(tmp_path)
async def test_create_mode_async(tmp_db_async: lancedb.AsyncConnection):
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -352,10 +342,10 @@ async def test_create_mode_async(tmp_path):
"price": [10.0, 20.0],
}
)
await db.create_table("test", data=data)
await tmp_db_async.create_table("test", data=data)
with pytest.raises(ValueError, match="already exists"):
await db.create_table("test", data=data)
await tmp_db_async.create_table("test", data=data)
new_data = pd.DataFrame(
{
@@ -364,15 +354,14 @@ async def test_create_mode_async(tmp_path):
"price": [10.0, 20.0],
}
)
_tbl = await db.create_table("test", data=new_data, mode="overwrite")
_tbl = await tmp_db_async.create_table("test", data=new_data, mode="overwrite")
# MIGRATION: to_pandas() is not available in async
# assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"]
@pytest.mark.asyncio
async def test_create_exist_ok_async(tmp_path):
db = await lancedb.connect_async(tmp_path)
async def test_create_exist_ok_async(tmp_db_async: lancedb.AsyncConnection):
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -380,13 +369,13 @@ async def test_create_exist_ok_async(tmp_path):
"price": [10.0, 20.0],
}
)
tbl = await db.create_table("test", data=data)
tbl = await tmp_db_async.create_table("test", data=data)
with pytest.raises(ValueError, match="already exists"):
await db.create_table("test", data=data)
await tmp_db_async.create_table("test", data=data)
# open the table but don't add more rows
tbl2 = await db.create_table("test", data=data, exist_ok=True)
tbl2 = await tmp_db_async.create_table("test", data=data, exist_ok=True)
assert tbl.name == tbl2.name
assert await tbl.schema() == await tbl2.schema()
@@ -397,7 +386,7 @@ async def test_create_exist_ok_async(tmp_path):
pa.field("price", pa.float64()),
]
)
tbl3 = await db.create_table("test", schema=schema, exist_ok=True)
tbl3 = await tmp_db_async.create_table("test", schema=schema, exist_ok=True)
assert await tbl3.schema() == schema
# Migration: When creating a table, but the table already exists, but
@@ -448,13 +437,12 @@ async def test_create_table_v2_manifest_paths_async(tmp_path):
assert re.match(r"\d{20}\.manifest", manifest)
def test_open_table_sync(tmp_path):
db = lancedb.connect(tmp_path)
db.create_table("test", data=[{"id": 0}])
assert db.open_table("test").count_rows() == 1
assert db.open_table("test", index_cache_size=0).count_rows() == 1
with pytest.raises(FileNotFoundError, match="does not exist"):
db.open_table("does_not_exist")
def test_open_table_sync(tmp_db: lancedb.DBConnection):
tmp_db.create_table("test", data=[{"id": 0}])
assert tmp_db.open_table("test").count_rows() == 1
assert tmp_db.open_table("test", index_cache_size=0).count_rows() == 1
with pytest.raises(ValueError, match="Table 'does_not_exist' was not found"):
tmp_db.open_table("does_not_exist")
@pytest.mark.asyncio
@@ -494,8 +482,7 @@ async def test_open_table(tmp_path):
await db.open_table("does_not_exist")
def test_delete_table(tmp_path):
db = lancedb.connect(tmp_path)
def test_delete_table(tmp_db: lancedb.DBConnection):
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -503,26 +490,25 @@ def test_delete_table(tmp_path):
"price": [10.0, 20.0],
}
)
db.create_table("test", data=data)
tmp_db.create_table("test", data=data)
with pytest.raises(Exception):
db.create_table("test", data=data)
tmp_db.create_table("test", data=data)
assert db.table_names() == ["test"]
assert tmp_db.table_names() == ["test"]
db.drop_table("test")
assert db.table_names() == []
tmp_db.drop_table("test")
assert tmp_db.table_names() == []
db.create_table("test", data=data)
assert db.table_names() == ["test"]
tmp_db.create_table("test", data=data)
assert tmp_db.table_names() == ["test"]
# dropping a table that does not exist should pass
# if ignore_missing=True
db.drop_table("does_not_exist", ignore_missing=True)
tmp_db.drop_table("does_not_exist", ignore_missing=True)
def test_drop_database(tmp_path):
db = lancedb.connect(tmp_path)
def test_drop_database(tmp_db: lancedb.DBConnection):
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -537,51 +523,50 @@ def test_drop_database(tmp_path):
"price": [12.0, 17.0],
}
)
db.create_table("test", data=data)
tmp_db.create_table("test", data=data)
with pytest.raises(Exception):
db.create_table("test", data=data)
tmp_db.create_table("test", data=data)
assert db.table_names() == ["test"]
assert tmp_db.table_names() == ["test"]
db.create_table("new_test", data=new_data)
db.drop_database()
assert db.table_names() == []
tmp_db.create_table("new_test", data=new_data)
tmp_db.drop_database()
assert tmp_db.table_names() == []
# it should pass when no tables are present
db.create_table("test", data=new_data)
db.drop_table("test")
assert db.table_names() == []
db.drop_database()
assert db.table_names() == []
tmp_db.create_table("test", data=new_data)
tmp_db.drop_table("test")
assert tmp_db.table_names() == []
tmp_db.drop_database()
assert tmp_db.table_names() == []
# creating an empty database with schema
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), list_size=2))])
db.create_table("empty_table", schema=schema)
tmp_db.create_table("empty_table", schema=schema)
# dropping a empty database should pass
db.drop_database()
assert db.table_names() == []
tmp_db.drop_database()
assert tmp_db.table_names() == []
def test_empty_or_nonexistent_table(tmp_path):
db = lancedb.connect(tmp_path)
def test_empty_or_nonexistent_table(mem_db: lancedb.DBConnection):
with pytest.raises(Exception):
db.create_table("test_with_no_data")
mem_db.create_table("test_with_no_data")
with pytest.raises(Exception):
db.open_table("does_not_exist")
mem_db.open_table("does_not_exist")
schema = pa.schema([pa.field("a", pa.int64(), nullable=False)])
test = db.create_table("test", schema=schema)
test = mem_db.create_table("test", schema=schema)
class TestModel(LanceModel):
a: int
test2 = db.create_table("test2", schema=TestModel)
test2 = mem_db.create_table("test2", schema=TestModel)
assert test.schema == test2.schema
@pytest.mark.asyncio
async def test_create_in_v2_mode(tmp_path):
async def test_create_in_v2_mode(mem_db_async: lancedb.AsyncConnection):
def make_data():
for i in range(10):
yield pa.record_batch([pa.array([x for x in range(1024)])], names=["x"])
@@ -591,10 +576,8 @@ async def test_create_in_v2_mode(tmp_path):
schema = pa.schema([pa.field("x", pa.int64())])
db = await lancedb.connect_async(tmp_path)
# Create table in v1 mode
tbl = await db.create_table(
tbl = await mem_db_async.create_table(
"test", data=make_data(), schema=schema, data_storage_version="legacy"
)
@@ -610,7 +593,7 @@ async def test_create_in_v2_mode(tmp_path):
assert not await is_in_v2_mode(tbl)
# Create table in v2 mode
tbl = await db.create_table(
tbl = await mem_db_async.create_table(
"test_v2", data=make_data(), schema=schema, use_legacy_format=False
)
@@ -622,7 +605,7 @@ async def test_create_in_v2_mode(tmp_path):
assert await is_in_v2_mode(tbl)
# Create empty table in v2 mode and add data
tbl = await db.create_table(
tbl = await mem_db_async.create_table(
"test_empty_v2", data=None, schema=schema, use_legacy_format=False
)
await tbl.add(make_table())
@@ -630,7 +613,7 @@ async def test_create_in_v2_mode(tmp_path):
assert await is_in_v2_mode(tbl)
# Create empty table uses v1 mode by default
tbl = await db.create_table(
tbl = await mem_db_async.create_table(
"test_empty_v2_default", data=None, schema=schema, data_storage_version="legacy"
)
await tbl.add(make_table())
@@ -638,18 +621,17 @@ async def test_create_in_v2_mode(tmp_path):
assert not await is_in_v2_mode(tbl)
def test_replace_index(tmp_path):
db = lancedb.connect(uri=tmp_path)
table = db.create_table(
def test_replace_index(mem_db: lancedb.DBConnection):
table = mem_db.create_table(
"test",
[
{"vector": np.random.rand(128), "item": "foo", "price": float(i)}
for i in range(1000)
{"vector": np.random.rand(32), "item": "foo", "price": float(i)}
for i in range(512)
],
)
table.create_index(
num_partitions=2,
num_sub_vectors=4,
num_sub_vectors=2,
)
with pytest.raises(Exception):
@@ -660,27 +642,26 @@ def test_replace_index(tmp_path):
)
table.create_index(
num_partitions=2,
num_sub_vectors=4,
num_partitions=1,
num_sub_vectors=2,
replace=True,
index_cache_size=10,
)
def test_prefilter_with_index(tmp_path):
db = lancedb.connect(uri=tmp_path)
def test_prefilter_with_index(mem_db: lancedb.DBConnection):
data = [
{"vector": np.random.rand(128), "item": "foo", "price": float(i)}
for i in range(1000)
{"vector": np.random.rand(32), "item": "foo", "price": float(i)}
for i in range(512)
]
sample_key = data[100]["vector"]
table = db.create_table(
table = mem_db.create_table(
"test",
data,
)
table.create_index(
num_partitions=2,
num_sub_vectors=4,
num_sub_vectors=2,
)
table = (
table.search(sample_key)
@@ -691,13 +672,12 @@ def test_prefilter_with_index(tmp_path):
assert table.num_rows == 1
def test_create_table_with_invalid_names(tmp_path):
db = lancedb.connect(uri=tmp_path)
def test_create_table_with_invalid_names(tmp_db: lancedb.DBConnection):
data = [{"vector": np.random.rand(128), "item": "foo"} for i in range(10)]
with pytest.raises(ValueError):
db.create_table("foo/bar", data)
tmp_db.create_table("foo/bar", data)
with pytest.raises(ValueError):
db.create_table("foo bar", data)
tmp_db.create_table("foo bar", data)
with pytest.raises(ValueError):
db.create_table("foo$$bar", data)
db.create_table("foo.bar", data)
tmp_db.create_table("foo$$bar", data)
tmp_db.create_table("foo.bar", data)