From 8ff5f88916963545968774d7759503b807578eda Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Sun, 16 Jul 2023 11:09:19 -0700 Subject: [PATCH] [Python] Bug fixes in remote API (#314) --- python/lancedb/remote/arrow.py | 22 ++++++++++++++++++++++ python/lancedb/remote/client.py | 12 ++++++++---- python/lancedb/remote/db.py | 11 +++++------ python/lancedb/remote/table.py | 4 ++-- python/lancedb/schema.py | 1 - 5 files changed, 37 insertions(+), 13 deletions(-) create mode 100644 python/lancedb/remote/arrow.py diff --git a/python/lancedb/remote/arrow.py b/python/lancedb/remote/arrow.py new file mode 100644 index 00000000..753087cf --- /dev/null +++ b/python/lancedb/remote/arrow.py @@ -0,0 +1,22 @@ +# Copyright 2023 LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pyarrow as pa + + +def to_ipc_binary(table: pa.Table) -> bytes: + """Serialize a PyArrow Table to IPC binary.""" + sink = pa.BufferOutputStream() + with pa.ipc.new_stream(sink, table.schema) as writer: + writer.write_table(table) + return sink.getvalue().to_pybytes() diff --git a/python/lancedb/remote/client.py b/python/lancedb/remote/client.py index 6d17001d..7d3d965d 100644 --- a/python/lancedb/remote/client.py +++ b/python/lancedb/remote/client.py @@ -93,7 +93,7 @@ class RestfulLanceDBClient: async def post( self, uri: str, - data: Union[Dict[str, Any], BaseModel], + data: Union[Dict[str, Any], BaseModel, bytes], deserialize: Callable = lambda resp: resp.json(), ) -> Dict[str, Any]: """Send a POST request and returns the deserialized response payload. @@ -107,10 +107,14 @@ class RestfulLanceDBClient: """ if isinstance(data, BaseModel): data: Dict[str, Any] = data.dict(exclude_none=True) + if isinstance(data, bytes): + req_kwargs = {"data": data} + else: + req_kwargs = {"json": data} async with self.session.post( uri, - json=data, headers=self.headers, + **req_kwargs, ) as resp: resp: aiohttp.ClientResponse = resp await self._check_status(resp) @@ -119,11 +123,11 @@ class RestfulLanceDBClient: @_check_not_closed async def list_tables(self): """List all tables in the database.""" - json = await self.get("/1/table/", {}) + json = await self.get("/v1/table/", {}) return json["tables"] @_check_not_closed async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: """Query a table.""" - tbl = await self.post(f"/1/table/{table_name}/", query, deserialize=_read_ipc) + tbl = await self.post(f"/v1/table/{table_name}/", query, deserialize=_read_ipc) return VectorQueryResult(tbl) diff --git a/python/lancedb/remote/db.py b/python/lancedb/remote/db.py index 16672daf..a447a805 100644 --- a/python/lancedb/remote/db.py +++ b/python/lancedb/remote/db.py @@ -22,6 +22,7 @@ from lancedb.db import DBConnection from lancedb.schema import schema_to_json from lancedb.table import Table, _sanitize_data +from .arrow import to_ipc_binary from .client import RestfulLanceDBClient @@ -89,10 +90,8 @@ class RemoteDBConnection(DBConnection): from .table import RemoteTable - payload = { - "name": name, - "schema": schema_to_json(data.schema), - "records": data.to_pydict(), - } - self._loop.run_until_complete(self._client.create_table("/table/", payload)) + data = to_ipc_binary(data) + self._loop.run_until_complete( + self._client.post(f"/v1/table/{name}/create", data=data) + ) return RemoteTable(self, name) diff --git a/python/lancedb/remote/table.py b/python/lancedb/remote/table.py index 1d79cecb..8ee631b8 100644 --- a/python/lancedb/remote/table.py +++ b/python/lancedb/remote/table.py @@ -36,10 +36,10 @@ class RemoteTable(Table): def schema(self) -> pa.Schema: """Return the schema of the table.""" resp = self._conn._loop.run_until_complete( - self._conn._client.get(f"/table/{self._name}/describe") + self._conn._client.get(f"/v1/table/{self._name}/describe") ) schema = json_to_schema(resp["schema"]) - raise schema + return schema def to_arrow(self) -> pa.Table: raise NotImplementedError diff --git a/python/lancedb/schema.py b/python/lancedb/schema.py index d89c4301..8d8a77a4 100644 --- a/python/lancedb/schema.py +++ b/python/lancedb/schema.py @@ -13,7 +13,6 @@ """Schema related utilities.""" -import json from typing import Any, Dict, Type import pyarrow as pa