mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-22 21:09:58 +00:00
chore: switch over to requtes for remote client
This commit is contained in:
@@ -15,7 +15,8 @@
|
||||
import functools
|
||||
from typing import Any, Callable, Dict, Iterable, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
import urllib.parse
|
||||
import attrs
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel
|
||||
@@ -37,8 +38,8 @@ def _check_not_closed(f):
|
||||
return wrapped
|
||||
|
||||
|
||||
async def _read_ipc(resp: aiohttp.ClientResponse) -> pa.Table:
|
||||
resp_body = await resp.read()
|
||||
def _read_ipc(resp: requests.Response) -> pa.Table:
|
||||
resp_body = resp.raw.read()
|
||||
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
|
||||
return reader.read_all()
|
||||
|
||||
@@ -53,15 +54,24 @@ class RestfulLanceDBClient:
|
||||
closed: bool = attrs.field(default=False, init=False)
|
||||
|
||||
@functools.cached_property
|
||||
def session(self) -> aiohttp.ClientSession:
|
||||
url = (
|
||||
def session(self) -> requests.Session:
|
||||
session = requests.session()
|
||||
session.stream = True
|
||||
|
||||
return session
|
||||
|
||||
@functools.cached_property
|
||||
def url(self) -> str:
|
||||
return (
|
||||
self.host_override
|
||||
or f"https://{self.db_name}.{self.region}.api.lancedb.com"
|
||||
)
|
||||
return aiohttp.ClientSession(url)
|
||||
|
||||
async def close(self):
|
||||
await self.session.close()
|
||||
def _get_request_url(self, uri: str) -> str:
|
||||
return urllib.parse.urljoin(self.url, uri)
|
||||
|
||||
def close(self):
|
||||
self.session.close()
|
||||
self.closed = True
|
||||
|
||||
@functools.cached_property
|
||||
@@ -75,39 +85,25 @@ class RestfulLanceDBClient:
|
||||
headers["x-lancedb-database"] = self.db_name
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
async def _check_status(resp: aiohttp.ClientResponse):
|
||||
if resp.status == 404:
|
||||
raise LanceDBClientError(f"Not found: {await resp.text()}")
|
||||
elif 400 <= resp.status < 500:
|
||||
raise LanceDBClientError(
|
||||
f"Bad Request: {resp.status}, error: {await resp.text()}"
|
||||
)
|
||||
elif 500 <= resp.status < 600:
|
||||
raise LanceDBClientError(
|
||||
f"Internal Server Error: {resp.status}, error: {await resp.text()}"
|
||||
)
|
||||
elif resp.status != 200:
|
||||
raise LanceDBClientError(
|
||||
f"Unknown Error: {resp.status}, error: {await resp.text()}"
|
||||
)
|
||||
|
||||
@_check_not_closed
|
||||
async def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None):
|
||||
def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None):
|
||||
"""Send a GET request and returns the deserialized response payload."""
|
||||
if isinstance(params, BaseModel):
|
||||
params: Dict[str, Any] = params.dict(exclude_none=True)
|
||||
async with self.session.get(
|
||||
uri,
|
||||
|
||||
resp = self.session.get(
|
||||
self._get_request_url(uri),
|
||||
params=params,
|
||||
headers=self.headers,
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
) as resp:
|
||||
await self._check_status(resp)
|
||||
return await resp.json()
|
||||
# 5s connect timeout, 30s read timeout
|
||||
timeout=(5.0, 30.0),
|
||||
)
|
||||
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
@_check_not_closed
|
||||
async def post(
|
||||
def post(
|
||||
self,
|
||||
uri: str,
|
||||
data: Optional[Union[Dict[str, Any], BaseModel, bytes]] = None,
|
||||
@@ -139,31 +135,31 @@ class RestfulLanceDBClient:
|
||||
headers["content-type"] = content_type
|
||||
if request_id is not None:
|
||||
headers["x-request-id"] = request_id
|
||||
async with self.session.post(
|
||||
uri,
|
||||
headers=headers,
|
||||
|
||||
resp = self.session.post(
|
||||
self._get_request_url(uri),
|
||||
params=params,
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
headers=self.headers,
|
||||
# 5s connect timeout, 30s read timeout
|
||||
timeout=(5.0, 30.0),
|
||||
**req_kwargs,
|
||||
) as resp:
|
||||
resp: aiohttp.ClientResponse = resp
|
||||
await self._check_status(resp)
|
||||
return await deserialize(resp)
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
return deserialize(resp)
|
||||
|
||||
@_check_not_closed
|
||||
async def list_tables(
|
||||
def list_tables(
|
||||
self, limit: int, page_token: Optional[str] = None
|
||||
) -> Iterable[str]:
|
||||
"""List all tables in the database."""
|
||||
if page_token is None:
|
||||
page_token = ""
|
||||
json = await self.get("/v1/table/", {"limit": limit, "page_token": page_token})
|
||||
json = self.get("/v1/table/", {"limit": limit, "page_token": page_token})
|
||||
return json["tables"]
|
||||
|
||||
@_check_not_closed
|
||||
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
"""Query a table."""
|
||||
tbl = await self.post(
|
||||
f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc
|
||||
)
|
||||
tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc)
|
||||
return VectorQueryResult(tbl)
|
||||
|
||||
@@ -50,10 +50,6 @@ class RemoteDBConnection(DBConnection):
|
||||
self._client = RestfulLanceDBClient(
|
||||
self.db_name, region, api_key, host_override
|
||||
)
|
||||
try:
|
||||
self._loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoteConnect(name={self.db_name})"
|
||||
@@ -76,15 +72,13 @@ class RemoteDBConnection(DBConnection):
|
||||
An iterator of table names.
|
||||
"""
|
||||
while True:
|
||||
result = self._loop.run_until_complete(
|
||||
self._client.list_tables(limit, page_token)
|
||||
)
|
||||
if len(result) > 0:
|
||||
page_token = result[len(result) - 1]
|
||||
else:
|
||||
break
|
||||
result = self._client.list_tables(limit, page_token)
|
||||
for item in result:
|
||||
yield item
|
||||
if len(result) < limit:
|
||||
break
|
||||
else:
|
||||
page_token = result[len(result) - 1]
|
||||
|
||||
@override
|
||||
def open_table(self, name: str) -> Table:
|
||||
@@ -103,9 +97,7 @@ class RemoteDBConnection(DBConnection):
|
||||
|
||||
# check if table exists
|
||||
try:
|
||||
self._loop.run_until_complete(
|
||||
self._client.post(f"/v1/table/{name}/describe/")
|
||||
)
|
||||
self._client.post(f"/v1/table/{name}/describe/")
|
||||
except LanceDBClientError as err:
|
||||
if str(err).startswith("Not found"):
|
||||
logging.error(
|
||||
@@ -248,13 +240,11 @@ class RemoteDBConnection(DBConnection):
|
||||
data = to_ipc_binary(data)
|
||||
request_id = uuid.uuid4().hex
|
||||
|
||||
self._loop.run_until_complete(
|
||||
self._client.post(
|
||||
f"/v1/table/{name}/create/",
|
||||
data=data,
|
||||
request_id=request_id,
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
)
|
||||
self._client.post(
|
||||
f"/v1/table/{name}/create/",
|
||||
data=data,
|
||||
request_id=request_id,
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
)
|
||||
return RemoteTable(self, name)
|
||||
|
||||
@@ -267,13 +257,10 @@ class RemoteDBConnection(DBConnection):
|
||||
name: str
|
||||
The name of the table.
|
||||
"""
|
||||
self._loop.run_until_complete(
|
||||
self._client.post(
|
||||
f"/v1/table/{name}/drop/",
|
||||
)
|
||||
self._client.post(
|
||||
f"/v1/table/{name}/drop/",
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection to the database."""
|
||||
self._loop.close()
|
||||
await self._client.close()
|
||||
self._client.close()
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from functools import cached_property
|
||||
from typing import Dict, Optional, Union
|
||||
@@ -43,18 +42,14 @@ class RemoteTable(Table):
|
||||
of this Table
|
||||
|
||||
"""
|
||||
resp = self._conn._loop.run_until_complete(
|
||||
self._conn._client.post(f"/v1/table/{self._name}/describe/")
|
||||
)
|
||||
resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
|
||||
schema = json_to_schema(resp["schema"])
|
||||
return schema
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
"""Get the current version of the table"""
|
||||
resp = self._conn._loop.run_until_complete(
|
||||
self._conn._client.post(f"/v1/table/{self._name}/describe/")
|
||||
)
|
||||
resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
|
||||
return resp["version"]
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
@@ -116,8 +111,8 @@ class RemoteTable(Table):
|
||||
"metric_type": metric,
|
||||
"index_cache_size": index_cache_size,
|
||||
}
|
||||
resp = self._conn._loop.run_until_complete(
|
||||
self._conn._client.post(f"/v1/table/{self._name}/create_index/", data=data)
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self._name}/create_index/", data=data
|
||||
)
|
||||
return resp
|
||||
|
||||
@@ -161,13 +156,11 @@ class RemoteTable(Table):
|
||||
|
||||
request_id = uuid.uuid4().hex
|
||||
|
||||
self._conn._loop.run_until_complete(
|
||||
self._conn._client.post(
|
||||
f"/v1/table/{self._name}/insert/",
|
||||
data=payload,
|
||||
params={"request_id": request_id, "mode": mode},
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
)
|
||||
self._conn._client.post(
|
||||
f"/v1/table/{self._name}/insert/",
|
||||
data=payload,
|
||||
params={"request_id": request_id, "mode": mode},
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
)
|
||||
|
||||
def search(
|
||||
@@ -233,19 +226,17 @@ class RemoteTable(Table):
|
||||
and len(query.vector) > 0
|
||||
and not isinstance(query.vector[0], float)
|
||||
):
|
||||
futures = []
|
||||
result = []
|
||||
for v in query.vector:
|
||||
v = list(v)
|
||||
q = query.copy()
|
||||
q.vector = v
|
||||
futures.append(self._conn._client.query(self._name, q))
|
||||
result = self._conn._loop.run_until_complete(asyncio.gather(*futures))
|
||||
result.append(self._conn._client.query(self._name, q))
|
||||
return pa.concat_tables(
|
||||
[add_index(r.to_arrow(), i) for i, r in enumerate(result)]
|
||||
)
|
||||
else:
|
||||
result = self._conn._client.query(self._name, query)
|
||||
return self._conn._loop.run_until_complete(result).to_arrow()
|
||||
return self._conn._client.query(self._name, query).to_arrow()
|
||||
|
||||
def delete(self, predicate: str):
|
||||
"""Delete rows from the table.
|
||||
@@ -294,9 +285,7 @@ class RemoteTable(Table):
|
||||
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
|
||||
"""
|
||||
payload = {"predicate": predicate}
|
||||
self._conn._loop.run_until_complete(
|
||||
self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload)
|
||||
)
|
||||
self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload)
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -356,9 +345,7 @@ class RemoteTable(Table):
|
||||
updates = [[k, v] for k, v in values_sql.items()]
|
||||
|
||||
payload = {"predicate": where, "updates": updates}
|
||||
self._conn._loop.run_until_complete(
|
||||
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
|
||||
)
|
||||
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
|
||||
|
||||
|
||||
def add_index(tbl: pa.Table, i: int) -> pa.Table:
|
||||
|
||||
@@ -7,7 +7,7 @@ dependencies = [
|
||||
"ratelimiter~=1.0",
|
||||
"retry>=0.9.2",
|
||||
"tqdm>=4.27.0",
|
||||
"aiohttp",
|
||||
"requests>=2.31,<3",
|
||||
"pydantic>=1.10",
|
||||
"attrs>=21.3.0",
|
||||
"semver>=3.0",
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lancedb import LanceDBConnection
|
||||
|
||||
# TODO: setup integ test mark and script
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Need to set up a local server")
|
||||
def test_against_local_server():
|
||||
conn = LanceDBConnection("lancedb+http://localhost:10024")
|
||||
table = conn.open_table("sift1m_ivf1024_pq16")
|
||||
df = table.search(np.random.rand(128)).to_pandas()
|
||||
assert len(df) == 10
|
||||
@@ -1,95 +0,0 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import attrs
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from lancedb.remote.client import RestfulLanceDBClient, VectorQuery
|
||||
|
||||
|
||||
@attrs.define
|
||||
class MockLanceDBServer:
|
||||
runner: web.AppRunner = attrs.field(init=False)
|
||||
site: web.TCPSite = attrs.field(init=False)
|
||||
|
||||
async def query_handler(self, request: web.Request) -> web.Response:
|
||||
table_name = request.match_info["table_name"]
|
||||
assert table_name == "test_table"
|
||||
|
||||
await request.json()
|
||||
# TODO: do some matching
|
||||
|
||||
vecs = pd.Series([np.random.rand(128) for x in range(10)], name="vector")
|
||||
ids = pd.Series(range(10), name="id")
|
||||
df = pd.DataFrame([vecs, ids]).T
|
||||
|
||||
batch = pa.RecordBatch.from_pandas(
|
||||
df,
|
||||
schema=pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 128)),
|
||||
pa.field("id", pa.int64()),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
sink = pa.BufferOutputStream()
|
||||
with pa.ipc.new_file(sink, batch.schema) as writer:
|
||||
writer.write_batch(batch)
|
||||
|
||||
return web.Response(body=sink.getvalue().to_pybytes())
|
||||
|
||||
async def setup(self):
|
||||
app = web.Application()
|
||||
app.add_routes([web.post("/table/{table_name}", self.query_handler)])
|
||||
self.runner = web.AppRunner(app)
|
||||
await self.runner.setup()
|
||||
self.site = web.TCPSite(self.runner, "localhost", 8111)
|
||||
|
||||
async def start(self):
|
||||
await self.site.start()
|
||||
|
||||
async def stop(self):
|
||||
await self.runner.cleanup()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="flaky somehow, fix later")
|
||||
@pytest.mark.asyncio
|
||||
async def test_e2e_with_mock_server():
|
||||
mock_server = MockLanceDBServer()
|
||||
await mock_server.setup()
|
||||
await mock_server.start()
|
||||
|
||||
try:
|
||||
client = RestfulLanceDBClient("lancedb+http://localhost:8111")
|
||||
df = (
|
||||
await client.query(
|
||||
"test_table",
|
||||
VectorQuery(
|
||||
vector=np.random.rand(128).tolist(),
|
||||
k=10,
|
||||
_metric="L2",
|
||||
columns=["id", "vector"],
|
||||
),
|
||||
)
|
||||
).to_pandas()
|
||||
|
||||
assert "vector" in df.columns
|
||||
assert "id" in df.columns
|
||||
finally:
|
||||
# make sure we don't leak resources
|
||||
await mock_server.stop()
|
||||
Reference in New Issue
Block a user