diff --git a/docs/src/basic.md b/docs/src/basic.md index 89ce1b6d..1151dd1c 100644 --- a/docs/src/basic.md +++ b/docs/src/basic.md @@ -25,6 +25,10 @@ tbl = db.create_table("my_table", Under the hood, LanceDB is converting the input data into an Apache Arrow table and persisting it to disk in [Lance format](github.com/eto-ai/lance). +If the table already exists, LanceDB will raise an error by default. +If you want to overwrite the table, you can pass in `mode="overwrite"` +to the `create_table` method. + You can also pass in a pandas DataFrame directly: ```python import pandas as pd diff --git a/python/lancedb/db.py b/python/lancedb/db.py index cf408e21..f905aae3 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -55,7 +55,11 @@ class LanceDBConnection: return self.open_table(name) def create_table( - self, name: str, data: DATA = None, schema: pa.Schema = None + self, + name: str, + data: DATA = None, + schema: pa.Schema = None, + mode: str = "create", ) -> LanceTable: """Create a table in the database. @@ -67,6 +71,10 @@ class LanceDBConnection: The data to insert into the table. schema: pyarrow.Schema; optional The schema of the table. + mode: str; default "create" + The mode to use when creating the table. + By default, if the table already exists, an exception is raised. + If you want to overwrite the table, use mode="overwrite". Note ---- @@ -78,7 +86,7 @@ class LanceDBConnection: A LanceTable object representing the table. """ if data is not None: - tbl = LanceTable.create(self, name, data, schema) + tbl = LanceTable.create(self, name, data, schema, mode=mode) else: tbl = LanceTable(self, name) return tbl diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 5a79306c..1680a38c 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -156,10 +156,10 @@ class LanceTable: return LanceQueryBuilder(self, query) @classmethod - def create(cls, db, name, data, schema=None): + def create(cls, db, name, data, schema=None, mode="create"): tbl = LanceTable(db, name) data = _sanitize_data(data, schema) - lance.write_dataset(data, tbl._dataset_uri, mode="create") + lance.write_dataset(data, tbl._dataset_uri, mode=mode) return tbl diff --git a/python/tests/test_db.py b/python/tests/test_db.py index 207dd430..0a6aec53 100644 --- a/python/tests/test_db.py +++ b/python/tests/test_db.py @@ -13,6 +13,7 @@ import lancedb import pandas as pd +import pytest def test_basic(tmp_path): @@ -49,7 +50,13 @@ def test_ingest_pd(tmp_path): assert db.uri == str(tmp_path) assert db.table_names() == [] - data = pd.DataFrame({"vector": [[3.1, 4.1], [5.9, 26.5]], "item": ["foo", "bar"], "price": [10.0, 20.0]}) + data = pd.DataFrame( + { + "vector": [[3.1, 4.1], [5.9, 26.5]], + "item": ["foo", "bar"], + "price": [10.0, 20.0], + } + ) table = db.create_table("test", data=data) rs = table.search([100, 100]).limit(1).to_df() assert len(rs) == 1 @@ -64,3 +71,28 @@ def test_ingest_pd(tmp_path): assert len(db) == 1 assert db.open_table("test").name == db["test"].name + + +def test_create_mode(tmp_path): + db = lancedb.connect(tmp_path) + data = pd.DataFrame( + { + "vector": [[3.1, 4.1], [5.9, 26.5]], + "item": ["foo", "bar"], + "price": [10.0, 20.0], + } + ) + db.create_table("test", data=data) + + with pytest.raises(Exception): + db.create_table("test", data=data) + + new_data = pd.DataFrame( + { + "vector": [[3.1, 4.1], [5.9, 26.5]], + "item": ["fizz", "buzz"], + "price": [10.0, 20.0], + } + ) + tbl = db.create_table("test", data=new_data, mode="overwrite") + assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"]