mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-24 05:49:57 +00:00
[Python] List tables from remote service (#262)
This commit is contained in:
@@ -13,11 +13,12 @@
|
||||
|
||||
|
||||
import functools
|
||||
from typing import Dict
|
||||
from typing import Any, Callable, Dict, Union
|
||||
|
||||
import aiohttp
|
||||
import attr
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from lancedb.common import Credential
|
||||
from lancedb.remote import VectorQuery, VectorQueryResult
|
||||
@@ -34,6 +35,12 @@ def _check_not_closed(f):
|
||||
return wrapped
|
||||
|
||||
|
||||
async def _read_ipc(resp: aiohttp.ClientResponse) -> pa.Table:
|
||||
resp_body = await resp.read()
|
||||
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
|
||||
return reader.read_all()
|
||||
|
||||
|
||||
@attr.define(slots=False)
|
||||
class RestfulLanceDBClient:
|
||||
db_name: str
|
||||
@@ -56,28 +63,67 @@ class RestfulLanceDBClient:
|
||||
"x-api-key": self.api_key,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def _check_status(resp: aiohttp.ClientResponse):
|
||||
if resp.status == 404:
|
||||
raise LanceDBClientError(f"Not found: {await resp.text()}")
|
||||
elif 400 <= resp.status < 500:
|
||||
raise LanceDBClientError(
|
||||
f"Bad Request: {resp.status}, error: {await resp.text()}"
|
||||
)
|
||||
elif 500 <= resp.status < 600:
|
||||
raise LanceDBClientError(
|
||||
f"Internal Server Error: {resp.status}, error: {await resp.text()}"
|
||||
)
|
||||
elif resp.status != 200:
|
||||
raise LanceDBClientError(
|
||||
f"Unknown Error: {resp.status}, error: {await resp.text()}"
|
||||
)
|
||||
|
||||
@_check_not_closed
|
||||
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
async def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None):
|
||||
"""Send a GET request and returns the deserialized response payload."""
|
||||
if isinstance(params, BaseModel):
|
||||
params: Dict[str, Any] = params.dict(exclude_none=True)
|
||||
async with self.session.get(uri, params=params, headers=self.headers) as resp:
|
||||
await self._check_status(resp)
|
||||
return await resp.json()
|
||||
|
||||
@_check_not_closed
|
||||
async def post(
|
||||
self,
|
||||
uri: str,
|
||||
data: Union[Dict[str, Any], BaseModel],
|
||||
deserialize: Callable = lambda resp: resp.json(),
|
||||
) -> Dict[str, Any]:
|
||||
"""Send a POST request and returns the deserialized response payload.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uri : str
|
||||
The uri to send the POST request to.
|
||||
data: Union[Dict[str, Any], BaseModel]
|
||||
|
||||
"""
|
||||
if isinstance(data, BaseModel):
|
||||
data: Dict[str, Any] = data.dict(exclude_none=True)
|
||||
async with self.session.post(
|
||||
f"/1/table/{table_name}/",
|
||||
json=query.dict(exclude_none=True),
|
||||
uri,
|
||||
json=data,
|
||||
headers=self.headers,
|
||||
) as resp:
|
||||
resp: aiohttp.ClientResponse = resp
|
||||
if 400 <= resp.status < 500:
|
||||
raise LanceDBClientError(
|
||||
f"Bad Request: {resp.status}, error: {await resp.text()}"
|
||||
)
|
||||
if 500 <= resp.status < 600:
|
||||
raise LanceDBClientError(
|
||||
f"Internal Server Error: {resp.status}, error: {await resp.text()}"
|
||||
)
|
||||
if resp.status != 200:
|
||||
raise LanceDBClientError(
|
||||
f"Unknown Error: {resp.status}, error: {await resp.text()}"
|
||||
)
|
||||
await self._check_status(resp)
|
||||
return await deserialize(resp)
|
||||
|
||||
resp_body = await resp.read()
|
||||
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
|
||||
tbl = reader.read_all()
|
||||
@_check_not_closed
|
||||
async def list_tables(self):
|
||||
"""List all tables in the database."""
|
||||
json = await self.get("/1/table/", {})
|
||||
return json["tables"]
|
||||
|
||||
@_check_not_closed
|
||||
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
"""Query a table."""
|
||||
tbl = await self.post(f"/1/table/{table_name}/", query, deserialize=_read_ipc)
|
||||
return VectorQueryResult(tbl)
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -34,12 +35,18 @@ class RemoteDBConnection(DBConnection):
|
||||
self.db_name = parsed.netloc
|
||||
self.api_key = api_key
|
||||
self._client = RestfulLanceDBClient(self.db_name, region, api_key)
|
||||
try:
|
||||
self._loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoveConnect(name={self.db_name})"
|
||||
|
||||
def table_names(self) -> List[str]:
|
||||
raise NotImplementedError
|
||||
"""List the names of all tables in the database."""
|
||||
result = self._loop.run_until_complete(self._client.list_tables())
|
||||
return result
|
||||
|
||||
def open_table(self, name: str) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
from typing import Union
|
||||
|
||||
import pyarrow as pa
|
||||
@@ -62,9 +61,5 @@ class RemoteTable(Table):
|
||||
return LanceQueryBuilder(self, query, vector_column)
|
||||
|
||||
def _execute_query(self, query: Query) -> pa.Table:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.get_event_loop()
|
||||
result = self._conn._client.query(self._name, query)
|
||||
return loop.run_until_complete(result).to_arrow()
|
||||
return self._conn._loop.run_until_complete(result).to_arrow()
|
||||
|
||||
Reference in New Issue
Block a user