diff --git a/python/python/lancedb/remote/client.py b/python/python/lancedb/remote/client.py index 4975e39d..5ad9a2d0 100644 --- a/python/python/lancedb/remote/client.py +++ b/python/python/lancedb/remote/client.py @@ -79,6 +79,13 @@ class RestfulLanceDBClient: or f"https://{self.db_name}.{self.region}.api.lancedb.com" ) + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + return False # Do not suppress exceptions + def close(self): self.session.close() self.closed = True diff --git a/python/python/tests/test_remote_client.py b/python/python/tests/test_remote_client.py index e9a2b19f..f5874953 100644 --- a/python/python/tests/test_remote_client.py +++ b/python/python/tests/test_remote_client.py @@ -74,21 +74,23 @@ async def test_e2e_with_mock_server(): await mock_server.start() try: - client = RestfulLanceDBClient("lancedb+http://localhost:8111") - df = ( - await client.query( - "test_table", - VectorQuery( - vector=np.random.rand(128).tolist(), - k=10, - _metric="L2", - columns=["id", "vector"], - ), - ) - ).to_pandas() + with RestfulLanceDBClient("lancedb+http://localhost:8111") as client: + df = ( + await client.query( + "test_table", + VectorQuery( + vector=np.random.rand(128).tolist(), + k=10, + _metric="L2", + columns=["id", "vector"], + ), + ) + ).to_pandas() - assert "vector" in df.columns - assert "id" in df.columns + assert "vector" in df.columns + assert "id" in df.columns + + assert client.closed finally: # make sure we don't leak resources await mock_server.stop()