diff --git a/python/lancedb/remote/client.py b/python/lancedb/remote/client.py index 3cccba40..6d17001d 100644 --- a/python/lancedb/remote/client.py +++ b/python/lancedb/remote/client.py @@ -13,11 +13,12 @@ import functools -from typing import Dict +from typing import Any, Callable, Dict, Union import aiohttp import attr import pyarrow as pa +from pydantic import BaseModel from lancedb.common import Credential from lancedb.remote import VectorQuery, VectorQueryResult @@ -34,6 +35,12 @@ def _check_not_closed(f): return wrapped +async def _read_ipc(resp: aiohttp.ClientResponse) -> pa.Table: + resp_body = await resp.read() + with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader: + return reader.read_all() + + @attr.define(slots=False) class RestfulLanceDBClient: db_name: str @@ -56,28 +63,67 @@ class RestfulLanceDBClient: "x-api-key": self.api_key, } + @staticmethod + async def _check_status(resp: aiohttp.ClientResponse): + if resp.status == 404: + raise LanceDBClientError(f"Not found: {await resp.text()}") + elif 400 <= resp.status < 500: + raise LanceDBClientError( + f"Bad Request: {resp.status}, error: {await resp.text()}" + ) + elif 500 <= resp.status < 600: + raise LanceDBClientError( + f"Internal Server Error: {resp.status}, error: {await resp.text()}" + ) + elif resp.status != 200: + raise LanceDBClientError( + f"Unknown Error: {resp.status}, error: {await resp.text()}" + ) + @_check_not_closed - async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: + async def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None): + """Send a GET request and returns the deserialized response payload.""" + if isinstance(params, BaseModel): + params: Dict[str, Any] = params.dict(exclude_none=True) + async with self.session.get(uri, params=params, headers=self.headers) as resp: + await self._check_status(resp) + return await resp.json() + + @_check_not_closed + async def post( + self, + uri: str, + data: Union[Dict[str, Any], BaseModel], + deserialize: Callable = lambda resp: resp.json(), + ) -> Dict[str, Any]: + """Send a POST request and returns the deserialized response payload. + + Parameters + ---------- + uri : str + The uri to send the POST request to. + data: Union[Dict[str, Any], BaseModel] + + """ + if isinstance(data, BaseModel): + data: Dict[str, Any] = data.dict(exclude_none=True) async with self.session.post( - f"/1/table/{table_name}/", - json=query.dict(exclude_none=True), + uri, + json=data, headers=self.headers, ) as resp: resp: aiohttp.ClientResponse = resp - if 400 <= resp.status < 500: - raise LanceDBClientError( - f"Bad Request: {resp.status}, error: {await resp.text()}" - ) - if 500 <= resp.status < 600: - raise LanceDBClientError( - f"Internal Server Error: {resp.status}, error: {await resp.text()}" - ) - if resp.status != 200: - raise LanceDBClientError( - f"Unknown Error: {resp.status}, error: {await resp.text()}" - ) + await self._check_status(resp) + return await deserialize(resp) - resp_body = await resp.read() - with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader: - tbl = reader.read_all() + @_check_not_closed + async def list_tables(self): + """List all tables in the database.""" + json = await self.get("/1/table/", {}) + return json["tables"] + + @_check_not_closed + async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: + """Query a table.""" + tbl = await self.post(f"/1/table/{table_name}/", query, deserialize=_read_ipc) return VectorQueryResult(tbl) diff --git a/python/lancedb/remote/db.py b/python/lancedb/remote/db.py index 7b721662..d4ddc7bc 100644 --- a/python/lancedb/remote/db.py +++ b/python/lancedb/remote/db.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from typing import List from urllib.parse import urlparse @@ -34,12 +35,18 @@ class RemoteDBConnection(DBConnection): self.db_name = parsed.netloc self.api_key = api_key self._client = RestfulLanceDBClient(self.db_name, region, api_key) + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + self._loop = asyncio.get_event_loop() def __repr__(self) -> str: return f"RemoveConnect(name={self.db_name})" def table_names(self) -> List[str]: - raise NotImplementedError + """List the names of all tables in the database.""" + result = self._loop.run_until_complete(self._client.list_tables()) + return result def open_table(self, name: str) -> Table: """Open a Lance Table in the database. diff --git a/python/lancedb/remote/table.py b/python/lancedb/remote/table.py index e4152e6b..08d7a055 100644 --- a/python/lancedb/remote/table.py +++ b/python/lancedb/remote/table.py @@ -11,7 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio from typing import Union import pyarrow as pa @@ -62,9 +61,5 @@ class RemoteTable(Table): return LanceQueryBuilder(self, query, vector_column) def _execute_query(self, query: Query) -> pa.Table: - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = asyncio.get_event_loop() result = self._conn._client.query(self._name, query) - return loop.run_until_complete(result).to_arrow() + return self._conn._loop.run_until_complete(result).to_arrow()