[Python] Bug fixes in remote API (#314)

This commit is contained in:
Lei Xu
2023-07-16 11:09:19 -07:00
committed by GitHub
parent 028a6e433d
commit 8ff5f88916
5 changed files with 37 additions and 13 deletions

View File

@@ -0,0 +1,22 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pyarrow as pa
def to_ipc_binary(table: pa.Table) -> bytes:
"""Serialize a PyArrow Table to IPC binary."""
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, table.schema) as writer:
writer.write_table(table)
return sink.getvalue().to_pybytes()

View File

@@ -93,7 +93,7 @@ class RestfulLanceDBClient:
async def post(
self,
uri: str,
data: Union[Dict[str, Any], BaseModel],
data: Union[Dict[str, Any], BaseModel, bytes],
deserialize: Callable = lambda resp: resp.json(),
) -> Dict[str, Any]:
"""Send a POST request and returns the deserialized response payload.
@@ -107,10 +107,14 @@ class RestfulLanceDBClient:
"""
if isinstance(data, BaseModel):
data: Dict[str, Any] = data.dict(exclude_none=True)
if isinstance(data, bytes):
req_kwargs = {"data": data}
else:
req_kwargs = {"json": data}
async with self.session.post(
uri,
json=data,
headers=self.headers,
**req_kwargs,
) as resp:
resp: aiohttp.ClientResponse = resp
await self._check_status(resp)
@@ -119,11 +123,11 @@ class RestfulLanceDBClient:
@_check_not_closed
async def list_tables(self):
"""List all tables in the database."""
json = await self.get("/1/table/", {})
json = await self.get("/v1/table/", {})
return json["tables"]
@_check_not_closed
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
"""Query a table."""
tbl = await self.post(f"/1/table/{table_name}/", query, deserialize=_read_ipc)
tbl = await self.post(f"/v1/table/{table_name}/", query, deserialize=_read_ipc)
return VectorQueryResult(tbl)

View File

@@ -22,6 +22,7 @@ from lancedb.db import DBConnection
from lancedb.schema import schema_to_json
from lancedb.table import Table, _sanitize_data
from .arrow import to_ipc_binary
from .client import RestfulLanceDBClient
@@ -89,10 +90,8 @@ class RemoteDBConnection(DBConnection):
from .table import RemoteTable
payload = {
"name": name,
"schema": schema_to_json(data.schema),
"records": data.to_pydict(),
}
self._loop.run_until_complete(self._client.create_table("/table/", payload))
data = to_ipc_binary(data)
self._loop.run_until_complete(
self._client.post(f"/v1/table/{name}/create", data=data)
)
return RemoteTable(self, name)

View File

@@ -36,10 +36,10 @@ class RemoteTable(Table):
def schema(self) -> pa.Schema:
"""Return the schema of the table."""
resp = self._conn._loop.run_until_complete(
self._conn._client.get(f"/table/{self._name}/describe")
self._conn._client.get(f"/v1/table/{self._name}/describe")
)
schema = json_to_schema(resp["schema"])
raise schema
return schema
def to_arrow(self) -> pa.Table:
raise NotImplementedError

View File

@@ -13,7 +13,6 @@
"""Schema related utilities."""
import json
from typing import Any, Dict, Type
import pyarrow as pa