feat(python): add exist_ok option to create table (#813)

This mimics CREATE TABLE IF NOT EXISTS behavior.
We add `db.create_table(..., exist_ok=True)` parameter.
By default it is set to False, so trying to create
a table with the same name will raise an exception.
If set to True, then it only opens the table if it
already exists. If you pass in a schema, it will
be checked against the existing table to make sure
you get what you want. If you pass in data, it will
NOT be added to the existing table.
This commit is contained in:
Chang She
2024-01-15 11:09:18 -08:00
committed by Weston Pace
parent 340fd98b42
commit 72b39432e8
4 changed files with 81 additions and 6 deletions

View File

@@ -56,6 +56,7 @@ class DBConnection(EnforceOverrides):
data: Optional[DATA] = None,
schema: Optional[Union[pa.Schema, LanceModel]] = None,
mode: str = "create",
exist_ok: bool = False,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
@@ -86,6 +87,11 @@ class DBConnection(EnforceOverrides):
Can be either "create" or "overwrite".
By default, if the table already exists, an exception is raised.
If you want to overwrite the table, use mode="overwrite".
exist_ok: bool, default False
If a table by the same name already exists, then raise an exception
if exist_ok=False. If exist_ok=True, then open the existing table;
it will not add the provided data but will validate against any
schema that's specified.
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill".
@@ -319,6 +325,7 @@ class LanceDBConnection(DBConnection):
data: Optional[DATA] = None,
schema: Optional[Union[pa.Schema, LanceModel]] = None,
mode: str = "create",
exist_ok: bool = False,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
@@ -338,6 +345,7 @@ class LanceDBConnection(DBConnection):
data,
schema,
mode=mode,
exist_ok=exist_ok,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
embedding_functions=embedding_functions,

View File

@@ -966,6 +966,7 @@ class LanceTable(Table):
data=None,
schema=None,
mode="create",
exist_ok=False,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
embedding_functions: List[EmbeddingFunctionConfig] = None,
@@ -1005,6 +1006,10 @@ class LanceTable(Table):
mode: str, default "create"
The mode to use when writing the data. Valid values are
"create", "overwrite", and "append".
exist_ok: bool, default False
If the table already exists then raise an error if False,
otherwise just open the table, it will not add the provided
data but will validate against any schema that's specified.
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill".
@@ -1055,13 +1060,23 @@ class LanceTable(Table):
schema = schema.with_metadata(metadata)
empty = pa.Table.from_pylist([], schema=schema)
lance.write_dataset(empty, tbl._dataset_uri, schema=schema, mode=mode)
table = LanceTable(db, name)
try:
lance.write_dataset(empty, tbl._dataset_uri, schema=schema, mode=mode)
except OSError as err:
if "Dataset already exists" in str(err) and exist_ok:
if tbl.schema != schema:
raise ValueError(
f"Table {name} already exists with a different schema"
)
return tbl
raise
new_table = LanceTable(db, name)
if data is not None:
table.add(data)
new_table.add(data)
return table
return new_table
@classmethod
def open(cls, db, name):

View File

@@ -190,6 +190,48 @@ def test_create_mode(tmp_path):
assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"]
def test_create_exist_ok(tmp_path):
db = lancedb.connect(tmp_path)
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["foo", "bar"],
"price": [10.0, 20.0],
}
)
tbl = db.create_table("test", data=data)
with pytest.raises(OSError):
db.create_table("test", data=data)
# open the table but don't add more rows
tbl2 = db.create_table("test", data=data, exist_ok=True)
assert tbl.name == tbl2.name
assert tbl.schema == tbl2.schema
assert len(tbl) == len(tbl2)
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), list_size=2)),
pa.field("item", pa.utf8()),
pa.field("price", pa.float64()),
]
)
tbl3 = db.create_table("test", schema=schema, exist_ok=True)
assert tbl3.schema == schema
bad_schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), list_size=2)),
pa.field("item", pa.utf8()),
pa.field("price", pa.float64()),
pa.field("extra", pa.float32()),
]
)
with pytest.raises(ValueError):
db.create_table("test", schema=bad_schema, exist_ok=True)
def test_delete_table(tmp_path):
db = lancedb.connect(tmp_path)
data = pd.DataFrame(