diff --git a/docs/src/guides/tables.md b/docs/src/guides/tables.md index 5808e49a..9cf1ba0c 100644 --- a/docs/src/guides/tables.md +++ b/docs/src/guides/tables.md @@ -31,13 +31,23 @@ This guide will show how to create tables, insert data into them, and update the ``` !!! info "Note" - If the table already exists, LanceDB will raise an error by default. If you want to overwrite the table, you can pass in mode="overwrite" to the createTable function. + If the table already exists, LanceDB will raise an error by default. + + `create_table` supports an optional `exist_ok` parameter. When set to True + and the table exists, then it simply opens the existing table. The data you + passed in will NOT be appended to the table in that case. + + ```python + db.create_table("name", data, exist_ok=True) + ``` + + Sometimes you want to make sure that you start fresh. If you want to + overwrite the table, you can pass in mode="overwrite" to the createTable function. ```python db.create_table("name", data, mode="overwrite") ``` - ### From pandas DataFrame ```python diff --git a/python/lancedb/db.py b/python/lancedb/db.py index e2204164..eb09a538 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -56,6 +56,7 @@ class DBConnection(EnforceOverrides): data: Optional[DATA] = None, schema: Optional[Union[pa.Schema, LanceModel]] = None, mode: str = "create", + exist_ok: bool = False, on_bad_vectors: str = "error", fill_value: float = 0.0, embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None, @@ -86,6 +87,11 @@ class DBConnection(EnforceOverrides): Can be either "create" or "overwrite". By default, if the table already exists, an exception is raised. If you want to overwrite the table, use mode="overwrite". + exist_ok: bool, default False + If a table by the same name already exists, then raise an exception + if exist_ok=False. If exist_ok=True, then open the existing table; + it will not add the provided data but will validate against any + schema that's specified. on_bad_vectors: str, default "error" What to do if any of the vectors are not the same size or contains NaNs. One of "error", "drop", "fill". @@ -319,6 +325,7 @@ class LanceDBConnection(DBConnection): data: Optional[DATA] = None, schema: Optional[Union[pa.Schema, LanceModel]] = None, mode: str = "create", + exist_ok: bool = False, on_bad_vectors: str = "error", fill_value: float = 0.0, embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None, @@ -338,6 +345,7 @@ class LanceDBConnection(DBConnection): data, schema, mode=mode, + exist_ok=exist_ok, on_bad_vectors=on_bad_vectors, fill_value=fill_value, embedding_functions=embedding_functions, diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 2e01469e..f914838d 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -966,6 +966,7 @@ class LanceTable(Table): data=None, schema=None, mode="create", + exist_ok=False, on_bad_vectors: str = "error", fill_value: float = 0.0, embedding_functions: List[EmbeddingFunctionConfig] = None, @@ -1005,6 +1006,10 @@ class LanceTable(Table): mode: str, default "create" The mode to use when writing the data. Valid values are "create", "overwrite", and "append". + exist_ok: bool, default False + If the table already exists then raise an error if False, + otherwise just open the table, it will not add the provided + data but will validate against any schema that's specified. on_bad_vectors: str, default "error" What to do if any of the vectors are not the same size or contains NaNs. One of "error", "drop", "fill". @@ -1055,13 +1060,23 @@ class LanceTable(Table): schema = schema.with_metadata(metadata) empty = pa.Table.from_pylist([], schema=schema) - lance.write_dataset(empty, tbl._dataset_uri, schema=schema, mode=mode) - table = LanceTable(db, name) + try: + lance.write_dataset(empty, tbl._dataset_uri, schema=schema, mode=mode) + except OSError as err: + if "Dataset already exists" in str(err) and exist_ok: + if tbl.schema != schema: + raise ValueError( + f"Table {name} already exists with a different schema" + ) + return tbl + raise + + new_table = LanceTable(db, name) if data is not None: - table.add(data) + new_table.add(data) - return table + return new_table @classmethod def open(cls, db, name): diff --git a/python/tests/test_db.py b/python/tests/test_db.py index a7afa56e..700b34d3 100644 --- a/python/tests/test_db.py +++ b/python/tests/test_db.py @@ -190,6 +190,48 @@ def test_create_mode(tmp_path): assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"] +def test_create_exist_ok(tmp_path): + db = lancedb.connect(tmp_path) + data = pd.DataFrame( + { + "vector": [[3.1, 4.1], [5.9, 26.5]], + "item": ["foo", "bar"], + "price": [10.0, 20.0], + } + ) + tbl = db.create_table("test", data=data) + + with pytest.raises(OSError): + db.create_table("test", data=data) + + # open the table but don't add more rows + tbl2 = db.create_table("test", data=data, exist_ok=True) + assert tbl.name == tbl2.name + assert tbl.schema == tbl2.schema + assert len(tbl) == len(tbl2) + + schema = pa.schema( + [ + pa.field("vector", pa.list_(pa.float32(), list_size=2)), + pa.field("item", pa.utf8()), + pa.field("price", pa.float64()), + ] + ) + tbl3 = db.create_table("test", schema=schema, exist_ok=True) + assert tbl3.schema == schema + + bad_schema = pa.schema( + [ + pa.field("vector", pa.list_(pa.float32(), list_size=2)), + pa.field("item", pa.utf8()), + pa.field("price", pa.float64()), + pa.field("extra", pa.float32()), + ] + ) + with pytest.raises(ValueError): + db.create_table("test", schema=bad_schema, exist_ok=True) + + def test_delete_table(tmp_path): db = lancedb.connect(tmp_path) data = pd.DataFrame(