mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-07 12:22:59 +00:00
[Python] Bug fixes in remote API (#314)
This commit is contained in:
22
python/lancedb/remote/arrow.py
Normal file
22
python/lancedb/remote/arrow.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
def to_ipc_binary(table: pa.Table) -> bytes:
|
||||
"""Serialize a PyArrow Table to IPC binary."""
|
||||
sink = pa.BufferOutputStream()
|
||||
with pa.ipc.new_stream(sink, table.schema) as writer:
|
||||
writer.write_table(table)
|
||||
return sink.getvalue().to_pybytes()
|
||||
@@ -93,7 +93,7 @@ class RestfulLanceDBClient:
|
||||
async def post(
|
||||
self,
|
||||
uri: str,
|
||||
data: Union[Dict[str, Any], BaseModel],
|
||||
data: Union[Dict[str, Any], BaseModel, bytes],
|
||||
deserialize: Callable = lambda resp: resp.json(),
|
||||
) -> Dict[str, Any]:
|
||||
"""Send a POST request and returns the deserialized response payload.
|
||||
@@ -107,10 +107,14 @@ class RestfulLanceDBClient:
|
||||
"""
|
||||
if isinstance(data, BaseModel):
|
||||
data: Dict[str, Any] = data.dict(exclude_none=True)
|
||||
if isinstance(data, bytes):
|
||||
req_kwargs = {"data": data}
|
||||
else:
|
||||
req_kwargs = {"json": data}
|
||||
async with self.session.post(
|
||||
uri,
|
||||
json=data,
|
||||
headers=self.headers,
|
||||
**req_kwargs,
|
||||
) as resp:
|
||||
resp: aiohttp.ClientResponse = resp
|
||||
await self._check_status(resp)
|
||||
@@ -119,11 +123,11 @@ class RestfulLanceDBClient:
|
||||
@_check_not_closed
|
||||
async def list_tables(self):
|
||||
"""List all tables in the database."""
|
||||
json = await self.get("/1/table/", {})
|
||||
json = await self.get("/v1/table/", {})
|
||||
return json["tables"]
|
||||
|
||||
@_check_not_closed
|
||||
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
"""Query a table."""
|
||||
tbl = await self.post(f"/1/table/{table_name}/", query, deserialize=_read_ipc)
|
||||
tbl = await self.post(f"/v1/table/{table_name}/", query, deserialize=_read_ipc)
|
||||
return VectorQueryResult(tbl)
|
||||
|
||||
@@ -22,6 +22,7 @@ from lancedb.db import DBConnection
|
||||
from lancedb.schema import schema_to_json
|
||||
from lancedb.table import Table, _sanitize_data
|
||||
|
||||
from .arrow import to_ipc_binary
|
||||
from .client import RestfulLanceDBClient
|
||||
|
||||
|
||||
@@ -89,10 +90,8 @@ class RemoteDBConnection(DBConnection):
|
||||
|
||||
from .table import RemoteTable
|
||||
|
||||
payload = {
|
||||
"name": name,
|
||||
"schema": schema_to_json(data.schema),
|
||||
"records": data.to_pydict(),
|
||||
}
|
||||
self._loop.run_until_complete(self._client.create_table("/table/", payload))
|
||||
data = to_ipc_binary(data)
|
||||
self._loop.run_until_complete(
|
||||
self._client.post(f"/v1/table/{name}/create", data=data)
|
||||
)
|
||||
return RemoteTable(self, name)
|
||||
|
||||
@@ -36,10 +36,10 @@ class RemoteTable(Table):
|
||||
def schema(self) -> pa.Schema:
|
||||
"""Return the schema of the table."""
|
||||
resp = self._conn._loop.run_until_complete(
|
||||
self._conn._client.get(f"/table/{self._name}/describe")
|
||||
self._conn._client.get(f"/v1/table/{self._name}/describe")
|
||||
)
|
||||
schema = json_to_schema(resp["schema"])
|
||||
raise schema
|
||||
return schema
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
|
||||
"""Schema related utilities."""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, Type
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
Reference in New Issue
Block a user