diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py index 18d4076c..278ed0b2 100644 --- a/python/python/lancedb/db.py +++ b/python/python/lancedb/db.py @@ -18,12 +18,12 @@ from pathlib import Path from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Union from lancedb.embeddings.registry import EmbeddingFunctionRegistry -from overrides import EnforceOverrides, override +from overrides import EnforceOverrides, override # type: ignore from lancedb.common import data_to_reader, sanitize_uri, validate_schema from lancedb.background_loop import LOOP -from ._lancedb import connect as lancedb_connect +from ._lancedb import connect as lancedb_connect # type: ignore from .table import ( AsyncTable, LanceTable, @@ -503,13 +503,7 @@ class LanceDBConnection(DBConnection): ignore_missing: bool, default False If True, ignore if the table does not exist. """ - try: - LOOP.run(self._conn.drop_table(name)) - except ValueError as e: - if not ignore_missing: - raise e - if f"Table '{name}' was not found" not in str(e): - raise e + LOOP.run(self._conn.drop_table(name, ignore_missing=ignore_missing)) @override def drop_database(self): @@ -886,15 +880,23 @@ class AsyncConnection(object): """ await self._inner.rename_table(old_name, new_name) - async def drop_table(self, name: str): + async def drop_table(self, name: str, *, ignore_missing: bool = False): """Drop a table from the database. Parameters ---------- name: str The name of the table. + ignore_missing: bool, default False + If True, ignore if the table does not exist. """ - await self._inner.drop_table(name) + try: + await self._inner.drop_table(name) + except ValueError as e: + if not ignore_missing: + raise e + if f"Table '{name}' was not found" not in str(e): + raise e async def drop_database(self): """ diff --git a/python/python/tests/test_db.py b/python/python/tests/test_db.py index 394956c8..176c0110 100644 --- a/python/python/tests/test_db.py +++ b/python/python/tests/test_db.py @@ -508,6 +508,32 @@ def test_delete_table(tmp_db: lancedb.DBConnection): tmp_db.drop_table("does_not_exist", ignore_missing=True) +@pytest.mark.asyncio +async def test_delete_table_async(tmp_db: lancedb.DBConnection): + data = pd.DataFrame( + { + "vector": [[3.1, 4.1], [5.9, 26.5]], + "item": ["foo", "bar"], + "price": [10.0, 20.0], + } + ) + + tmp_db.create_table("test", data=data) + + with pytest.raises(Exception): + tmp_db.create_table("test", data=data) + + assert tmp_db.table_names() == ["test"] + + tmp_db.drop_table("test") + assert tmp_db.table_names() == [] + + tmp_db.create_table("test", data=data) + assert tmp_db.table_names() == ["test"] + + tmp_db.drop_table("does_not_exist", ignore_missing=True) + + def test_drop_database(tmp_db: lancedb.DBConnection): data = pd.DataFrame( {