diff --git a/python/lancedb/remote/client.py b/python/lancedb/remote/client.py index 7d3d965d..17edc2e4 100644 --- a/python/lancedb/remote/client.py +++ b/python/lancedb/remote/client.py @@ -13,7 +13,7 @@ import functools -from typing import Any, Callable, Dict, Union +from typing import Any, Callable, Dict, Optional, Union import aiohttp import attr @@ -24,6 +24,8 @@ from lancedb.common import Credential from lancedb.remote import VectorQuery, VectorQueryResult from lancedb.remote.errors import LanceDBClientError +ARROW_STREAM_CONTENT_TYPE = "application/vnd.apache.arrow.stream" + def _check_not_closed(f): @functools.wraps(f) @@ -59,9 +61,12 @@ class RestfulLanceDBClient: @functools.cached_property def headers(self) -> Dict[str, str]: - return { + headers = { "x-api-key": self.api_key, } + if self.region == "local": # Local test mode + headers["Host"] = f"{self.db_name}.{self.region}.api.lancedb.com" + return headers @staticmethod async def _check_status(resp: aiohttp.ClientResponse): @@ -94,6 +99,8 @@ class RestfulLanceDBClient: self, uri: str, data: Union[Dict[str, Any], BaseModel, bytes], + params: Optional[Dict[str, Any]] = None, + content_type: Optional[str] = None, deserialize: Callable = lambda resp: resp.json(), ) -> Dict[str, Any]: """Send a POST request and returns the deserialized response payload. @@ -111,9 +118,14 @@ class RestfulLanceDBClient: req_kwargs = {"data": data} else: req_kwargs = {"json": data} + + headers = self.headers.copy() + if content_type is not None: + headers["content-type"] = content_type async with self.session.post( uri, - headers=self.headers, + headers=headers, + params=params, **req_kwargs, ) as resp: resp: aiohttp.ClientResponse = resp diff --git a/python/lancedb/remote/db.py b/python/lancedb/remote/db.py index a447a805..f8a93e9c 100644 --- a/python/lancedb/remote/db.py +++ b/python/lancedb/remote/db.py @@ -12,6 +12,7 @@ # limitations under the License. import asyncio +import uuid from typing import List from urllib.parse import urlparse @@ -23,7 +24,7 @@ from lancedb.schema import schema_to_json from lancedb.table import Table, _sanitize_data from .arrow import to_ipc_binary -from .client import RestfulLanceDBClient +from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient class RemoteDBConnection(DBConnection): @@ -73,7 +74,6 @@ class RemoteDBConnection(DBConnection): name: str, data: DATA = None, schema: pa.Schema = None, - mode: str = "create", on_bad_vectors: str = "error", fill_value: float = 0.0, ) -> Table: @@ -91,7 +91,14 @@ class RemoteDBConnection(DBConnection): from .table import RemoteTable data = to_ipc_binary(data) + request_id = uuid.uuid4().hex + self._loop.run_until_complete( - self._client.post(f"/v1/table/{name}/create", data=data) + self._client.post( + f"/v1/table/{name}/create", + data=data, + params={"request_id": request_id}, + content_type=ARROW_STREAM_CONTENT_TYPE, + ) ) return RemoteTable(self, name) diff --git a/python/lancedb/remote/table.py b/python/lancedb/remote/table.py index 8ee631b8..3c7070a2 100644 --- a/python/lancedb/remote/table.py +++ b/python/lancedb/remote/table.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import uuid from functools import cached_property from typing import Union @@ -20,7 +21,9 @@ from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME from ..query import LanceQueryBuilder, Query from ..schema import json_to_schema -from ..table import Query, Table +from ..table import Query, Table, _sanitize_data +from .arrow import to_ipc_binary +from .client import ARROW_STREAM_CONTENT_TYPE from .db import RemoteDBConnection @@ -61,7 +64,22 @@ class RemoteTable(Table): on_bad_vectors: str = "error", fill_value: float = 0.0, ) -> int: - raise NotImplementedError + data = _sanitize_data( + data, self.schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value + ) + payload = to_ipc_binary(data) + + request_id = uuid.uuid4().hex + + self._conn._loop.run_until_complete( + self._conn._client.post( + f"/v1/table/{self._name}/insert", + data=payload, + params={"request_id": request_id, "mode": mode}, + content_type=ARROW_STREAM_CONTENT_TYPE, + ) + ) + return len(data) def search( self, query: Union[VEC, str], vector_column: str = VECTOR_COLUMN_NAME