diff --git a/python/python/lancedb/remote/db.py b/python/python/lancedb/remote/db.py index 6c1fb0a4c..41ec1ba3c 100644 --- a/python/python/lancedb/remote/db.py +++ b/python/python/lancedb/remote/db.py @@ -384,6 +384,7 @@ class RemoteDBConnection(DBConnection): on_bad_vectors: str = "error", fill_value: float = 0.0, mode: Optional[str] = None, + exist_ok: bool = False, embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None, *, namespace: Optional[List[str]] = None, @@ -412,6 +413,12 @@ class RemoteDBConnection(DBConnection): - pyarrow.Schema - [LanceModel][lancedb.pydantic.LanceModel] + mode: str, default "create" + The mode to use when creating the table. + Can be either "create", "overwrite", or "exist_ok". + exist_ok: bool, default False + If exist_ok is True, and mode is None or "create", mode will be changed + to "exist_ok". on_bad_vectors: str, default "error" What to do if any of the vectors are not the same size or contains NaNs. One of "error", "drop", "fill". @@ -483,6 +490,11 @@ class RemoteDBConnection(DBConnection): LanceTable(table4) """ + if exist_ok: + if mode == "create": + mode = "exist_ok" + elif not mode: + mode = "exist_ok" if namespace is None: namespace = [] validate_table_name(name) diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index 759b67214..566e1fba4 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -168,6 +168,42 @@ def test_table_len_sync(): assert len(table) == 1 +def test_create_table_exist_ok(): + def handler(request): + if request.path == "/v1/table/test/create/?mode=exist_ok": + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write(b"{}") + else: + request.send_response(404) + request.end_headers() + + with mock_lancedb_connection(handler) as db: + table = db.create_table("test", [{"id": 1}], exist_ok=True) + assert table is not None + + with mock_lancedb_connection(handler) as db: + table = db.create_table("test", [{"id": 1}], mode="create", exist_ok=True) + assert table is not None + + +def test_create_table_exist_ok_with_mode_overwrite(): + def handler(request): + if request.path == "/v1/table/test/create/?mode=overwrite": + request.send_response(200) + request.send_header("Content-Type", "application/json") + request.end_headers() + request.wfile.write(b"{}") + else: + request.send_response(404) + request.end_headers() + + with mock_lancedb_connection(handler) as db: + table = db.create_table("test", [{"id": 1}], mode="overwrite", exist_ok=True) + assert table is not None + + @pytest.mark.asyncio async def test_http_error(): request_id_holder = {"request_id": None}