Add mode to overwrite table if already exists

This commit is contained in:
Chang She
2023-04-19 16:19:18 -07:00
parent ec197b1855
commit d7c5793803
4 changed files with 49 additions and 5 deletions

View File

@@ -25,6 +25,10 @@ tbl = db.create_table("my_table",
Under the hood, LanceDB is converting the input data into an Apache Arrow table
and persisting it to disk in [Lance format](github.com/eto-ai/lance).
If the table already exists, LanceDB will raise an error by default.
If you want to overwrite the table, you can pass in `mode="overwrite"`
to the `create_table` method.
You can also pass in a pandas DataFrame directly:
```python
import pandas as pd

View File

@@ -55,7 +55,11 @@ class LanceDBConnection:
return self.open_table(name)
def create_table(
self, name: str, data: DATA = None, schema: pa.Schema = None
self,
name: str,
data: DATA = None,
schema: pa.Schema = None,
mode: str = "create",
) -> LanceTable:
"""Create a table in the database.
@@ -67,6 +71,10 @@ class LanceDBConnection:
The data to insert into the table.
schema: pyarrow.Schema; optional
The schema of the table.
mode: str; default "create"
The mode to use when creating the table.
By default, if the table already exists, an exception is raised.
If you want to overwrite the table, use mode="overwrite".
Note
----
@@ -78,7 +86,7 @@ class LanceDBConnection:
A LanceTable object representing the table.
"""
if data is not None:
tbl = LanceTable.create(self, name, data, schema)
tbl = LanceTable.create(self, name, data, schema, mode=mode)
else:
tbl = LanceTable(self, name)
return tbl

View File

@@ -156,10 +156,10 @@ class LanceTable:
return LanceQueryBuilder(self, query)
@classmethod
def create(cls, db, name, data, schema=None):
def create(cls, db, name, data, schema=None, mode="create"):
tbl = LanceTable(db, name)
data = _sanitize_data(data, schema)
lance.write_dataset(data, tbl._dataset_uri, mode="create")
lance.write_dataset(data, tbl._dataset_uri, mode=mode)
return tbl

View File

@@ -13,6 +13,7 @@
import lancedb
import pandas as pd
import pytest
def test_basic(tmp_path):
@@ -49,7 +50,13 @@ def test_ingest_pd(tmp_path):
assert db.uri == str(tmp_path)
assert db.table_names() == []
data = pd.DataFrame({"vector": [[3.1, 4.1], [5.9, 26.5]], "item": ["foo", "bar"], "price": [10.0, 20.0]})
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["foo", "bar"],
"price": [10.0, 20.0],
}
)
table = db.create_table("test", data=data)
rs = table.search([100, 100]).limit(1).to_df()
assert len(rs) == 1
@@ -64,3 +71,28 @@ def test_ingest_pd(tmp_path):
assert len(db) == 1
assert db.open_table("test").name == db["test"].name
def test_create_mode(tmp_path):
db = lancedb.connect(tmp_path)
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["foo", "bar"],
"price": [10.0, 20.0],
}
)
db.create_table("test", data=data)
with pytest.raises(Exception):
db.create_table("test", data=data)
new_data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["fizz", "buzz"],
"price": [10.0, 20.0],
}
)
tbl = db.create_table("test", data=new_data, mode="overwrite")
assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"]