[Python] List tables from remote service (#262)

This commit is contained in:
Lei Xu
2023-07-09 23:58:03 -07:00
committed by GitHub
parent 97364a2514
commit 9ef846929b
3 changed files with 74 additions and 26 deletions

View File

@@ -13,11 +13,12 @@
import functools
from typing import Dict
from typing import Any, Callable, Dict, Union
import aiohttp
import attr
import pyarrow as pa
from pydantic import BaseModel
from lancedb.common import Credential
from lancedb.remote import VectorQuery, VectorQueryResult
@@ -34,6 +35,12 @@ def _check_not_closed(f):
return wrapped
async def _read_ipc(resp: aiohttp.ClientResponse) -> pa.Table:
resp_body = await resp.read()
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
return reader.read_all()
@attr.define(slots=False)
class RestfulLanceDBClient:
db_name: str
@@ -56,28 +63,67 @@ class RestfulLanceDBClient:
"x-api-key": self.api_key,
}
@staticmethod
async def _check_status(resp: aiohttp.ClientResponse):
if resp.status == 404:
raise LanceDBClientError(f"Not found: {await resp.text()}")
elif 400 <= resp.status < 500:
raise LanceDBClientError(
f"Bad Request: {resp.status}, error: {await resp.text()}"
)
elif 500 <= resp.status < 600:
raise LanceDBClientError(
f"Internal Server Error: {resp.status}, error: {await resp.text()}"
)
elif resp.status != 200:
raise LanceDBClientError(
f"Unknown Error: {resp.status}, error: {await resp.text()}"
)
@_check_not_closed
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
async def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None):
"""Send a GET request and returns the deserialized response payload."""
if isinstance(params, BaseModel):
params: Dict[str, Any] = params.dict(exclude_none=True)
async with self.session.get(uri, params=params, headers=self.headers) as resp:
await self._check_status(resp)
return await resp.json()
@_check_not_closed
async def post(
self,
uri: str,
data: Union[Dict[str, Any], BaseModel],
deserialize: Callable = lambda resp: resp.json(),
) -> Dict[str, Any]:
"""Send a POST request and returns the deserialized response payload.
Parameters
----------
uri : str
The uri to send the POST request to.
data: Union[Dict[str, Any], BaseModel]
"""
if isinstance(data, BaseModel):
data: Dict[str, Any] = data.dict(exclude_none=True)
async with self.session.post(
f"/1/table/{table_name}/",
json=query.dict(exclude_none=True),
uri,
json=data,
headers=self.headers,
) as resp:
resp: aiohttp.ClientResponse = resp
if 400 <= resp.status < 500:
raise LanceDBClientError(
f"Bad Request: {resp.status}, error: {await resp.text()}"
)
if 500 <= resp.status < 600:
raise LanceDBClientError(
f"Internal Server Error: {resp.status}, error: {await resp.text()}"
)
if resp.status != 200:
raise LanceDBClientError(
f"Unknown Error: {resp.status}, error: {await resp.text()}"
)
await self._check_status(resp)
return await deserialize(resp)
resp_body = await resp.read()
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
tbl = reader.read_all()
@_check_not_closed
async def list_tables(self):
"""List all tables in the database."""
json = await self.get("/1/table/", {})
return json["tables"]
@_check_not_closed
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
"""Query a table."""
tbl = await self.post(f"/1/table/{table_name}/", query, deserialize=_read_ipc)
return VectorQueryResult(tbl)

View File

@@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from typing import List
from urllib.parse import urlparse
@@ -34,12 +35,18 @@ class RemoteDBConnection(DBConnection):
self.db_name = parsed.netloc
self.api_key = api_key
self._client = RestfulLanceDBClient(self.db_name, region, api_key)
try:
self._loop = asyncio.get_running_loop()
except RuntimeError:
self._loop = asyncio.get_event_loop()
def __repr__(self) -> str:
return f"RemoveConnect(name={self.db_name})"
def table_names(self) -> List[str]:
raise NotImplementedError
"""List the names of all tables in the database."""
result = self._loop.run_until_complete(self._client.list_tables())
return result
def open_table(self, name: str) -> Table:
"""Open a Lance Table in the database.

View File

@@ -11,7 +11,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from typing import Union
import pyarrow as pa
@@ -62,9 +61,5 @@ class RemoteTable(Table):
return LanceQueryBuilder(self, query, vector_column)
def _execute_query(self, query: Query) -> pa.Table:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.get_event_loop()
result = self._conn._client.query(self._name, query)
return loop.run_until_complete(result).to_arrow()
return self._conn._loop.run_until_complete(result).to_arrow()