feat(python): async-sync feature parity on Table (#1914)

### Changes to sync API
* Updated `LanceTable` and `LanceDBConnection` reprs
* Add `storage_options`, `data_storage_version`, and
`enable_v2_manifest_paths` to sync create table API.
* Add `storage_options` to `open_table` in sync API.
* Add `list_indices()` and `index_stats()` to sync API
* `create_table()` will now create only 1 version when data is passed.
Previously it would always create two versions: 1 to create an empty
table and 1 to add data to it.

### Changes to async API
* Add `embedding_functions` to async `create_table()` API.
* Added `head()` to async API

### Refactors
* Refactor index parameters into dataclasses so they are easier to use
from Python
* Moved most tests to use an in-memory DB so we don't need to create so
many temp directories

Closes #1792
Closes #1932

---------

Co-authored-by: Weston Pace <weston.pace@gmail.com>
This commit is contained in:
Will Jones
2024-12-13 12:56:44 -08:00
committed by GitHub
parent d83e5a0208
commit 980aa70e2d
23 changed files with 1296 additions and 1324 deletions

View File

@@ -0,0 +1,32 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from datetime import timedelta
from lancedb.db import AsyncConnection, DBConnection
import lancedb
import pytest
import pytest_asyncio
# Use an in-memory database for most tests.
@pytest.fixture
def mem_db() -> DBConnection:
return lancedb.connect("memory://")
# Use a temporary directory when we need to inspect the database files.
@pytest.fixture
def tmp_db(tmp_path) -> DBConnection:
return lancedb.connect(tmp_path)
@pytest_asyncio.fixture
async def mem_db_async() -> AsyncConnection:
return await lancedb.connect_async("memory://")
@pytest_asyncio.fixture
async def tmp_db_async(tmp_path) -> AsyncConnection:
return await lancedb.connect_async(
tmp_path, read_consistency_interval=timedelta(seconds=0)
)

View File

@@ -98,7 +98,7 @@ def test_ingest_pd(tmp_path):
assert db.open_table("test").name == db["test"].name
def test_ingest_iterator(tmp_path):
def test_ingest_iterator(mem_db: lancedb.DBConnection):
class PydanticSchema(LanceModel):
vector: Vector(2)
item: str
@@ -156,8 +156,7 @@ def test_ingest_iterator(tmp_path):
]
def run_tests(schema):
db = lancedb.connect(tmp_path)
tbl = db.create_table("table2", make_batches(), schema=schema, mode="overwrite")
tbl = mem_db.create_table("table2", make_batches(), schema=schema)
tbl.to_pandas()
assert tbl.search([3.1, 4.1]).limit(1).to_pandas()["_distance"][0] == 0.0
assert tbl.search([5.9, 26.5]).limit(1).to_pandas()["_distance"][0] == 0.0
@@ -165,15 +164,14 @@ def test_ingest_iterator(tmp_path):
tbl.add(make_batches())
assert tbl_len == 50
assert len(tbl) == tbl_len * 2
assert len(tbl.list_versions()) == 3
db.drop_database()
assert len(tbl.list_versions()) == 2
mem_db.drop_database()
run_tests(arrow_schema)
run_tests(PydanticSchema)
def test_table_names(tmp_path):
db = lancedb.connect(tmp_path)
def test_table_names(tmp_db: lancedb.DBConnection):
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -181,10 +179,10 @@ def test_table_names(tmp_path):
"price": [10.0, 20.0],
}
)
db.create_table("test2", data=data)
db.create_table("test1", data=data)
db.create_table("test3", data=data)
assert db.table_names() == ["test1", "test2", "test3"]
tmp_db.create_table("test2", data=data)
tmp_db.create_table("test1", data=data)
tmp_db.create_table("test3", data=data)
assert tmp_db.table_names() == ["test1", "test2", "test3"]
@pytest.mark.asyncio
@@ -209,8 +207,7 @@ async def test_table_names_async(tmp_path):
assert await db.table_names(start_after="test1") == ["test2", "test3"]
def test_create_mode(tmp_path):
db = lancedb.connect(tmp_path)
def test_create_mode(tmp_db: lancedb.DBConnection):
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -218,10 +215,10 @@ def test_create_mode(tmp_path):
"price": [10.0, 20.0],
}
)
db.create_table("test", data=data)
tmp_db.create_table("test", data=data)
with pytest.raises(Exception):
db.create_table("test", data=data)
tmp_db.create_table("test", data=data)
new_data = pd.DataFrame(
{
@@ -230,13 +227,11 @@ def test_create_mode(tmp_path):
"price": [10.0, 20.0],
}
)
tbl = db.create_table("test", data=new_data, mode="overwrite")
tbl = tmp_db.create_table("test", data=new_data, mode="overwrite")
assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"]
def test_create_table_from_iterator(tmp_path):
db = lancedb.connect(tmp_path)
def test_create_table_from_iterator(mem_db: lancedb.DBConnection):
def gen_data():
for _ in range(10):
yield pa.RecordBatch.from_arrays(
@@ -248,14 +243,12 @@ def test_create_table_from_iterator(tmp_path):
["vector", "item", "price"],
)
table = db.create_table("test", data=gen_data())
table = mem_db.create_table("test", data=gen_data())
assert table.count_rows() == 10
@pytest.mark.asyncio
async def test_create_table_from_iterator_async(tmp_path):
db = await lancedb.connect_async(tmp_path)
async def test_create_table_from_iterator_async(mem_db_async: lancedb.AsyncConnection):
def gen_data():
for _ in range(10):
yield pa.RecordBatch.from_arrays(
@@ -267,12 +260,11 @@ async def test_create_table_from_iterator_async(tmp_path):
["vector", "item", "price"],
)
table = await db.create_table("test", data=gen_data())
table = await mem_db_async.create_table("test", data=gen_data())
assert await table.count_rows() == 10
def test_create_exist_ok(tmp_path):
db = lancedb.connect(tmp_path)
def test_create_exist_ok(tmp_db: lancedb.DBConnection):
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -280,13 +272,13 @@ def test_create_exist_ok(tmp_path):
"price": [10.0, 20.0],
}
)
tbl = db.create_table("test", data=data)
tbl = tmp_db.create_table("test", data=data)
with pytest.raises(OSError):
db.create_table("test", data=data)
with pytest.raises(ValueError):
tmp_db.create_table("test", data=data)
# open the table but don't add more rows
tbl2 = db.create_table("test", data=data, exist_ok=True)
tbl2 = tmp_db.create_table("test", data=data, exist_ok=True)
assert tbl.name == tbl2.name
assert tbl.schema == tbl2.schema
assert len(tbl) == len(tbl2)
@@ -298,7 +290,7 @@ def test_create_exist_ok(tmp_path):
pa.field("price", pa.float64()),
]
)
tbl3 = db.create_table("test", schema=schema, exist_ok=True)
tbl3 = tmp_db.create_table("test", schema=schema, exist_ok=True)
assert tbl3.schema == schema
bad_schema = pa.schema(
@@ -310,7 +302,7 @@ def test_create_exist_ok(tmp_path):
]
)
with pytest.raises(ValueError):
db.create_table("test", schema=bad_schema, exist_ok=True)
tmp_db.create_table("test", schema=bad_schema, exist_ok=True)
@pytest.mark.asyncio
@@ -325,26 +317,24 @@ async def test_connect(tmp_path):
@pytest.mark.asyncio
async def test_close(tmp_path):
db = await lancedb.connect_async(tmp_path)
assert db.is_open()
db.close()
assert not db.is_open()
async def test_close(mem_db_async: lancedb.AsyncConnection):
assert mem_db_async.is_open()
mem_db_async.close()
assert not mem_db_async.is_open()
with pytest.raises(RuntimeError, match="is closed"):
await db.table_names()
await mem_db_async.table_names()
@pytest.mark.asyncio
async def test_context_manager(tmp_path):
with await lancedb.connect_async(tmp_path) as db:
async def test_context_manager():
with await lancedb.connect_async("memory://") as db:
assert db.is_open()
assert not db.is_open()
@pytest.mark.asyncio
async def test_create_mode_async(tmp_path):
db = await lancedb.connect_async(tmp_path)
async def test_create_mode_async(tmp_db_async: lancedb.AsyncConnection):
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -352,10 +342,10 @@ async def test_create_mode_async(tmp_path):
"price": [10.0, 20.0],
}
)
await db.create_table("test", data=data)
await tmp_db_async.create_table("test", data=data)
with pytest.raises(ValueError, match="already exists"):
await db.create_table("test", data=data)
await tmp_db_async.create_table("test", data=data)
new_data = pd.DataFrame(
{
@@ -364,15 +354,14 @@ async def test_create_mode_async(tmp_path):
"price": [10.0, 20.0],
}
)
_tbl = await db.create_table("test", data=new_data, mode="overwrite")
_tbl = await tmp_db_async.create_table("test", data=new_data, mode="overwrite")
# MIGRATION: to_pandas() is not available in async
# assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"]
@pytest.mark.asyncio
async def test_create_exist_ok_async(tmp_path):
db = await lancedb.connect_async(tmp_path)
async def test_create_exist_ok_async(tmp_db_async: lancedb.AsyncConnection):
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -380,13 +369,13 @@ async def test_create_exist_ok_async(tmp_path):
"price": [10.0, 20.0],
}
)
tbl = await db.create_table("test", data=data)
tbl = await tmp_db_async.create_table("test", data=data)
with pytest.raises(ValueError, match="already exists"):
await db.create_table("test", data=data)
await tmp_db_async.create_table("test", data=data)
# open the table but don't add more rows
tbl2 = await db.create_table("test", data=data, exist_ok=True)
tbl2 = await tmp_db_async.create_table("test", data=data, exist_ok=True)
assert tbl.name == tbl2.name
assert await tbl.schema() == await tbl2.schema()
@@ -397,7 +386,7 @@ async def test_create_exist_ok_async(tmp_path):
pa.field("price", pa.float64()),
]
)
tbl3 = await db.create_table("test", schema=schema, exist_ok=True)
tbl3 = await tmp_db_async.create_table("test", schema=schema, exist_ok=True)
assert await tbl3.schema() == schema
# Migration: When creating a table, but the table already exists, but
@@ -448,13 +437,12 @@ async def test_create_table_v2_manifest_paths_async(tmp_path):
assert re.match(r"\d{20}\.manifest", manifest)
def test_open_table_sync(tmp_path):
db = lancedb.connect(tmp_path)
db.create_table("test", data=[{"id": 0}])
assert db.open_table("test").count_rows() == 1
assert db.open_table("test", index_cache_size=0).count_rows() == 1
with pytest.raises(FileNotFoundError, match="does not exist"):
db.open_table("does_not_exist")
def test_open_table_sync(tmp_db: lancedb.DBConnection):
tmp_db.create_table("test", data=[{"id": 0}])
assert tmp_db.open_table("test").count_rows() == 1
assert tmp_db.open_table("test", index_cache_size=0).count_rows() == 1
with pytest.raises(ValueError, match="Table 'does_not_exist' was not found"):
tmp_db.open_table("does_not_exist")
@pytest.mark.asyncio
@@ -494,8 +482,7 @@ async def test_open_table(tmp_path):
await db.open_table("does_not_exist")
def test_delete_table(tmp_path):
db = lancedb.connect(tmp_path)
def test_delete_table(tmp_db: lancedb.DBConnection):
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -503,26 +490,25 @@ def test_delete_table(tmp_path):
"price": [10.0, 20.0],
}
)
db.create_table("test", data=data)
tmp_db.create_table("test", data=data)
with pytest.raises(Exception):
db.create_table("test", data=data)
tmp_db.create_table("test", data=data)
assert db.table_names() == ["test"]
assert tmp_db.table_names() == ["test"]
db.drop_table("test")
assert db.table_names() == []
tmp_db.drop_table("test")
assert tmp_db.table_names() == []
db.create_table("test", data=data)
assert db.table_names() == ["test"]
tmp_db.create_table("test", data=data)
assert tmp_db.table_names() == ["test"]
# dropping a table that does not exist should pass
# if ignore_missing=True
db.drop_table("does_not_exist", ignore_missing=True)
tmp_db.drop_table("does_not_exist", ignore_missing=True)
def test_drop_database(tmp_path):
db = lancedb.connect(tmp_path)
def test_drop_database(tmp_db: lancedb.DBConnection):
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
@@ -537,51 +523,50 @@ def test_drop_database(tmp_path):
"price": [12.0, 17.0],
}
)
db.create_table("test", data=data)
tmp_db.create_table("test", data=data)
with pytest.raises(Exception):
db.create_table("test", data=data)
tmp_db.create_table("test", data=data)
assert db.table_names() == ["test"]
assert tmp_db.table_names() == ["test"]
db.create_table("new_test", data=new_data)
db.drop_database()
assert db.table_names() == []
tmp_db.create_table("new_test", data=new_data)
tmp_db.drop_database()
assert tmp_db.table_names() == []
# it should pass when no tables are present
db.create_table("test", data=new_data)
db.drop_table("test")
assert db.table_names() == []
db.drop_database()
assert db.table_names() == []
tmp_db.create_table("test", data=new_data)
tmp_db.drop_table("test")
assert tmp_db.table_names() == []
tmp_db.drop_database()
assert tmp_db.table_names() == []
# creating an empty database with schema
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), list_size=2))])
db.create_table("empty_table", schema=schema)
tmp_db.create_table("empty_table", schema=schema)
# dropping a empty database should pass
db.drop_database()
assert db.table_names() == []
tmp_db.drop_database()
assert tmp_db.table_names() == []
def test_empty_or_nonexistent_table(tmp_path):
db = lancedb.connect(tmp_path)
def test_empty_or_nonexistent_table(mem_db: lancedb.DBConnection):
with pytest.raises(Exception):
db.create_table("test_with_no_data")
mem_db.create_table("test_with_no_data")
with pytest.raises(Exception):
db.open_table("does_not_exist")
mem_db.open_table("does_not_exist")
schema = pa.schema([pa.field("a", pa.int64(), nullable=False)])
test = db.create_table("test", schema=schema)
test = mem_db.create_table("test", schema=schema)
class TestModel(LanceModel):
a: int
test2 = db.create_table("test2", schema=TestModel)
test2 = mem_db.create_table("test2", schema=TestModel)
assert test.schema == test2.schema
@pytest.mark.asyncio
async def test_create_in_v2_mode(tmp_path):
async def test_create_in_v2_mode(mem_db_async: lancedb.AsyncConnection):
def make_data():
for i in range(10):
yield pa.record_batch([pa.array([x for x in range(1024)])], names=["x"])
@@ -591,10 +576,8 @@ async def test_create_in_v2_mode(tmp_path):
schema = pa.schema([pa.field("x", pa.int64())])
db = await lancedb.connect_async(tmp_path)
# Create table in v1 mode
tbl = await db.create_table(
tbl = await mem_db_async.create_table(
"test", data=make_data(), schema=schema, data_storage_version="legacy"
)
@@ -610,7 +593,7 @@ async def test_create_in_v2_mode(tmp_path):
assert not await is_in_v2_mode(tbl)
# Create table in v2 mode
tbl = await db.create_table(
tbl = await mem_db_async.create_table(
"test_v2", data=make_data(), schema=schema, use_legacy_format=False
)
@@ -622,7 +605,7 @@ async def test_create_in_v2_mode(tmp_path):
assert await is_in_v2_mode(tbl)
# Create empty table in v2 mode and add data
tbl = await db.create_table(
tbl = await mem_db_async.create_table(
"test_empty_v2", data=None, schema=schema, use_legacy_format=False
)
await tbl.add(make_table())
@@ -630,7 +613,7 @@ async def test_create_in_v2_mode(tmp_path):
assert await is_in_v2_mode(tbl)
# Create empty table uses v1 mode by default
tbl = await db.create_table(
tbl = await mem_db_async.create_table(
"test_empty_v2_default", data=None, schema=schema, data_storage_version="legacy"
)
await tbl.add(make_table())
@@ -638,18 +621,17 @@ async def test_create_in_v2_mode(tmp_path):
assert not await is_in_v2_mode(tbl)
def test_replace_index(tmp_path):
db = lancedb.connect(uri=tmp_path)
table = db.create_table(
def test_replace_index(mem_db: lancedb.DBConnection):
table = mem_db.create_table(
"test",
[
{"vector": np.random.rand(128), "item": "foo", "price": float(i)}
for i in range(1000)
{"vector": np.random.rand(32), "item": "foo", "price": float(i)}
for i in range(512)
],
)
table.create_index(
num_partitions=2,
num_sub_vectors=4,
num_sub_vectors=2,
)
with pytest.raises(Exception):
@@ -660,27 +642,26 @@ def test_replace_index(tmp_path):
)
table.create_index(
num_partitions=2,
num_sub_vectors=4,
num_partitions=1,
num_sub_vectors=2,
replace=True,
index_cache_size=10,
)
def test_prefilter_with_index(tmp_path):
db = lancedb.connect(uri=tmp_path)
def test_prefilter_with_index(mem_db: lancedb.DBConnection):
data = [
{"vector": np.random.rand(128), "item": "foo", "price": float(i)}
for i in range(1000)
{"vector": np.random.rand(32), "item": "foo", "price": float(i)}
for i in range(512)
]
sample_key = data[100]["vector"]
table = db.create_table(
table = mem_db.create_table(
"test",
data,
)
table.create_index(
num_partitions=2,
num_sub_vectors=4,
num_sub_vectors=2,
)
table = (
table.search(sample_key)
@@ -691,13 +672,12 @@ def test_prefilter_with_index(tmp_path):
assert table.num_rows == 1
def test_create_table_with_invalid_names(tmp_path):
db = lancedb.connect(uri=tmp_path)
def test_create_table_with_invalid_names(tmp_db: lancedb.DBConnection):
data = [{"vector": np.random.rand(128), "item": "foo"} for i in range(10)]
with pytest.raises(ValueError):
db.create_table("foo/bar", data)
tmp_db.create_table("foo/bar", data)
with pytest.raises(ValueError):
db.create_table("foo bar", data)
tmp_db.create_table("foo bar", data)
with pytest.raises(ValueError):
db.create_table("foo$$bar", data)
db.create_table("foo.bar", data)
tmp_db.create_table("foo$$bar", data)
tmp_db.create_table("foo.bar", data)

View File

@@ -15,10 +15,12 @@ import random
from unittest import mock
import lancedb as ldb
from lancedb.db import DBConnection
from lancedb.index import FTS
import numpy as np
import pandas as pd
import pytest
from utils import exception_output
pytest.importorskip("lancedb.fts")
tantivy = pytest.importorskip("tantivy")
@@ -458,3 +460,44 @@ def test_syntax(table):
table.search('the cats OR dogs were not really "pets" at all').phrase_query().limit(
10
).to_list()
def test_language(mem_db: DBConnection):
sentences = [
"Il n'y a que trois routes qui traversent la ville.",
"Je veux prendre la route vers l'est.",
"Je te retrouve au café au bout de la route.",
]
data = [{"text": s} for s in sentences]
table = mem_db.create_table("test", data=data)
with pytest.raises(ValueError) as e:
table.create_fts_index("text", use_tantivy=False, language="klingon")
assert exception_output(e) == (
"ValueError: LanceDB does not support the requested language: 'klingon'\n"
"Supported languages: Arabic, Danish, Dutch, English, Finnish, French, "
"German, Greek, Hungarian, Italian, Norwegian, Portuguese, Romanian, "
"Russian, Spanish, Swedish, Tamil, Turkish"
)
table.create_fts_index(
"text",
use_tantivy=False,
language="French",
stem=True,
ascii_folding=True,
remove_stop_words=True,
)
# Can get "routes" and "route" from the same root
results = table.search("route", query_type="fts").limit(5).to_list()
assert len(results) == 3
# Can find "café", without needing to provide accent
results = table.search("cafe", query_type="fts").limit(5).to_list()
assert len(results) == 1
# Stop words -> no results
results = table.search("la", query_type="fts").limit(5).to_list()
assert len(results) == 0

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import pytest
def exception_output(e_info: pytest.ExceptionInfo):
import traceback
# skip traceback part, since it's not worth checking in tests
lines = traceback.format_exception_only(e_info.type, e_info.value)
return "".join(lines).strip()