diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py index d5d294ac..4bf66332 100644 --- a/python/python/lancedb/db.py +++ b/python/python/lancedb/db.py @@ -509,7 +509,7 @@ class AsyncConnection(object): return self._inner.__repr__() def __enter__(self): - self + return self def __exit__(self, *_): self.close() @@ -779,7 +779,7 @@ class AsyncConnection(object): name: str, storage_options: Optional[Dict[str, str]] = None, index_cache_size: Optional[int] = None, - ) -> Table: + ) -> AsyncTable: """Open a Lance Table in the database. Parameters diff --git a/python/python/tests/test_db.py b/python/python/tests/test_db.py index 40fc4998..027a14ef 100644 --- a/python/python/tests/test_db.py +++ b/python/python/tests/test_db.py @@ -296,6 +296,13 @@ async def test_close(tmp_path): await db.table_names() +@pytest.mark.asyncio +async def test_context_manager(tmp_path): + with await lancedb.connect_async(tmp_path) as db: + assert db.is_open() + assert not db.is_open() + + @pytest.mark.asyncio async def test_create_mode_async(tmp_path): db = await lancedb.connect_async(tmp_path)