[Python] Add records to remote (#315)

This commit is contained in:
Lei Xu
2023-07-16 13:24:38 -07:00
committed by GitHub
parent 8ff5f88916
commit 7a57cddb2c
3 changed files with 45 additions and 8 deletions

View File

@@ -13,7 +13,7 @@
import functools
from typing import Any, Callable, Dict, Union
from typing import Any, Callable, Dict, Optional, Union
import aiohttp
import attr
@@ -24,6 +24,8 @@ from lancedb.common import Credential
from lancedb.remote import VectorQuery, VectorQueryResult
from lancedb.remote.errors import LanceDBClientError
ARROW_STREAM_CONTENT_TYPE = "application/vnd.apache.arrow.stream"
def _check_not_closed(f):
@functools.wraps(f)
@@ -59,9 +61,12 @@ class RestfulLanceDBClient:
@functools.cached_property
def headers(self) -> Dict[str, str]:
return {
headers = {
"x-api-key": self.api_key,
}
if self.region == "local": # Local test mode
headers["Host"] = f"{self.db_name}.{self.region}.api.lancedb.com"
return headers
@staticmethod
async def _check_status(resp: aiohttp.ClientResponse):
@@ -94,6 +99,8 @@ class RestfulLanceDBClient:
self,
uri: str,
data: Union[Dict[str, Any], BaseModel, bytes],
params: Optional[Dict[str, Any]] = None,
content_type: Optional[str] = None,
deserialize: Callable = lambda resp: resp.json(),
) -> Dict[str, Any]:
"""Send a POST request and returns the deserialized response payload.
@@ -111,9 +118,14 @@ class RestfulLanceDBClient:
req_kwargs = {"data": data}
else:
req_kwargs = {"json": data}
headers = self.headers.copy()
if content_type is not None:
headers["content-type"] = content_type
async with self.session.post(
uri,
headers=self.headers,
headers=headers,
params=params,
**req_kwargs,
) as resp:
resp: aiohttp.ClientResponse = resp

View File

@@ -12,6 +12,7 @@
# limitations under the License.
import asyncio
import uuid
from typing import List
from urllib.parse import urlparse
@@ -23,7 +24,7 @@ from lancedb.schema import schema_to_json
from lancedb.table import Table, _sanitize_data
from .arrow import to_ipc_binary
from .client import RestfulLanceDBClient
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
class RemoteDBConnection(DBConnection):
@@ -73,7 +74,6 @@ class RemoteDBConnection(DBConnection):
name: str,
data: DATA = None,
schema: pa.Schema = None,
mode: str = "create",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
) -> Table:
@@ -91,7 +91,14 @@ class RemoteDBConnection(DBConnection):
from .table import RemoteTable
data = to_ipc_binary(data)
request_id = uuid.uuid4().hex
self._loop.run_until_complete(
self._client.post(f"/v1/table/{name}/create", data=data)
self._client.post(
f"/v1/table/{name}/create",
data=data,
params={"request_id": request_id},
content_type=ARROW_STREAM_CONTENT_TYPE,
)
)
return RemoteTable(self, name)

View File

@@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import uuid
from functools import cached_property
from typing import Union
@@ -20,7 +21,9 @@ from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
from ..query import LanceQueryBuilder, Query
from ..schema import json_to_schema
from ..table import Query, Table
from ..table import Query, Table, _sanitize_data
from .arrow import to_ipc_binary
from .client import ARROW_STREAM_CONTENT_TYPE
from .db import RemoteDBConnection
@@ -61,7 +64,22 @@ class RemoteTable(Table):
on_bad_vectors: str = "error",
fill_value: float = 0.0,
) -> int:
raise NotImplementedError
data = _sanitize_data(
data, self.schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
payload = to_ipc_binary(data)
request_id = uuid.uuid4().hex
self._conn._loop.run_until_complete(
self._conn._client.post(
f"/v1/table/{self._name}/insert",
data=payload,
params={"request_id": request_id, "mode": mode},
content_type=ARROW_STREAM_CONTENT_TYPE,
)
)
return len(data)
def search(
self, query: Union[VEC, str], vector_column: str = VECTOR_COLUMN_NAME