Compare commits

...

1 Commits

Author SHA1 Message Date
rmeng
60f6dc6a64 chore: switch over to requtes for remote client 2024-01-09 22:35:15 -05:00
6 changed files with 72 additions and 224 deletions

View File

@@ -15,7 +15,8 @@
import functools
from typing import Any, Callable, Dict, Iterable, Optional, Union
import aiohttp
import requests
import urllib.parse
import attrs
import pyarrow as pa
from pydantic import BaseModel
@@ -37,8 +38,8 @@ def _check_not_closed(f):
return wrapped
async def _read_ipc(resp: aiohttp.ClientResponse) -> pa.Table:
resp_body = await resp.read()
def _read_ipc(resp: requests.Response) -> pa.Table:
resp_body = resp.raw.read()
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
return reader.read_all()
@@ -53,15 +54,24 @@ class RestfulLanceDBClient:
closed: bool = attrs.field(default=False, init=False)
@functools.cached_property
def session(self) -> aiohttp.ClientSession:
url = (
def session(self) -> requests.Session:
session = requests.session()
session.stream = True
return session
@functools.cached_property
def url(self) -> str:
return (
self.host_override
or f"https://{self.db_name}.{self.region}.api.lancedb.com"
)
return aiohttp.ClientSession(url)
async def close(self):
await self.session.close()
def _get_request_url(self, uri: str) -> str:
return urllib.parse.urljoin(self.url, uri)
def close(self):
self.session.close()
self.closed = True
@functools.cached_property
@@ -75,39 +85,25 @@ class RestfulLanceDBClient:
headers["x-lancedb-database"] = self.db_name
return headers
@staticmethod
async def _check_status(resp: aiohttp.ClientResponse):
if resp.status == 404:
raise LanceDBClientError(f"Not found: {await resp.text()}")
elif 400 <= resp.status < 500:
raise LanceDBClientError(
f"Bad Request: {resp.status}, error: {await resp.text()}"
)
elif 500 <= resp.status < 600:
raise LanceDBClientError(
f"Internal Server Error: {resp.status}, error: {await resp.text()}"
)
elif resp.status != 200:
raise LanceDBClientError(
f"Unknown Error: {resp.status}, error: {await resp.text()}"
)
@_check_not_closed
async def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None):
def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None):
"""Send a GET request and returns the deserialized response payload."""
if isinstance(params, BaseModel):
params: Dict[str, Any] = params.dict(exclude_none=True)
async with self.session.get(
uri,
resp = self.session.get(
self._get_request_url(uri),
params=params,
headers=self.headers,
timeout=aiohttp.ClientTimeout(total=30),
) as resp:
await self._check_status(resp)
return await resp.json()
# 5s connect timeout, 30s read timeout
timeout=(5.0, 30.0),
)
resp.raise_for_status()
return resp.json()
@_check_not_closed
async def post(
def post(
self,
uri: str,
data: Optional[Union[Dict[str, Any], BaseModel, bytes]] = None,
@@ -139,31 +135,31 @@ class RestfulLanceDBClient:
headers["content-type"] = content_type
if request_id is not None:
headers["x-request-id"] = request_id
async with self.session.post(
uri,
headers=headers,
resp = self.session.post(
self._get_request_url(uri),
params=params,
timeout=aiohttp.ClientTimeout(total=30),
headers=self.headers,
# 5s connect timeout, 30s read timeout
timeout=(5.0, 30.0),
**req_kwargs,
) as resp:
resp: aiohttp.ClientResponse = resp
await self._check_status(resp)
return await deserialize(resp)
)
resp.raise_for_status()
return deserialize(resp)
@_check_not_closed
async def list_tables(
def list_tables(
self, limit: int, page_token: Optional[str] = None
) -> Iterable[str]:
"""List all tables in the database."""
if page_token is None:
page_token = ""
json = await self.get("/v1/table/", {"limit": limit, "page_token": page_token})
json = self.get("/v1/table/", {"limit": limit, "page_token": page_token})
return json["tables"]
@_check_not_closed
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
"""Query a table."""
tbl = await self.post(
f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc
)
tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc)
return VectorQueryResult(tbl)

View File

@@ -50,10 +50,6 @@ class RemoteDBConnection(DBConnection):
self._client = RestfulLanceDBClient(
self.db_name, region, api_key, host_override
)
try:
self._loop = asyncio.get_running_loop()
except RuntimeError:
self._loop = asyncio.get_event_loop()
def __repr__(self) -> str:
return f"RemoteConnect(name={self.db_name})"
@@ -76,15 +72,13 @@ class RemoteDBConnection(DBConnection):
An iterator of table names.
"""
while True:
result = self._loop.run_until_complete(
self._client.list_tables(limit, page_token)
)
if len(result) > 0:
page_token = result[len(result) - 1]
else:
break
result = self._client.list_tables(limit, page_token)
for item in result:
yield item
if len(result) < limit:
break
else:
page_token = result[len(result) - 1]
@override
def open_table(self, name: str) -> Table:
@@ -103,9 +97,7 @@ class RemoteDBConnection(DBConnection):
# check if table exists
try:
self._loop.run_until_complete(
self._client.post(f"/v1/table/{name}/describe/")
)
except LanceDBClientError as err:
if str(err).startswith("Not found"):
logging.error(
@@ -248,14 +240,12 @@ class RemoteDBConnection(DBConnection):
data = to_ipc_binary(data)
request_id = uuid.uuid4().hex
self._loop.run_until_complete(
self._client.post(
f"/v1/table/{name}/create/",
data=data,
request_id=request_id,
content_type=ARROW_STREAM_CONTENT_TYPE,
)
)
return RemoteTable(self, name)
@override
@@ -267,13 +257,10 @@ class RemoteDBConnection(DBConnection):
name: str
The name of the table.
"""
self._loop.run_until_complete(
self._client.post(
f"/v1/table/{name}/drop/",
)
)
async def close(self):
"""Close the connection to the database."""
self._loop.close()
await self._client.close()
self._client.close()

View File

@@ -11,7 +11,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import uuid
from functools import cached_property
from typing import Dict, Optional, Union
@@ -43,18 +42,14 @@ class RemoteTable(Table):
of this Table
"""
resp = self._conn._loop.run_until_complete(
self._conn._client.post(f"/v1/table/{self._name}/describe/")
)
resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
schema = json_to_schema(resp["schema"])
return schema
@property
def version(self) -> int:
"""Get the current version of the table"""
resp = self._conn._loop.run_until_complete(
self._conn._client.post(f"/v1/table/{self._name}/describe/")
)
resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
return resp["version"]
def to_arrow(self) -> pa.Table:
@@ -116,8 +111,8 @@ class RemoteTable(Table):
"metric_type": metric,
"index_cache_size": index_cache_size,
}
resp = self._conn._loop.run_until_complete(
self._conn._client.post(f"/v1/table/{self._name}/create_index/", data=data)
resp = self._conn._client.post(
f"/v1/table/{self._name}/create_index/", data=data
)
return resp
@@ -161,14 +156,12 @@ class RemoteTable(Table):
request_id = uuid.uuid4().hex
self._conn._loop.run_until_complete(
self._conn._client.post(
f"/v1/table/{self._name}/insert/",
data=payload,
params={"request_id": request_id, "mode": mode},
content_type=ARROW_STREAM_CONTENT_TYPE,
)
)
def search(
self, query: Union[VEC, str], vector_column_name: str = VECTOR_COLUMN_NAME
@@ -233,19 +226,17 @@ class RemoteTable(Table):
and len(query.vector) > 0
and not isinstance(query.vector[0], float)
):
futures = []
result = []
for v in query.vector:
v = list(v)
q = query.copy()
q.vector = v
futures.append(self._conn._client.query(self._name, q))
result = self._conn._loop.run_until_complete(asyncio.gather(*futures))
result.append(self._conn._client.query(self._name, q))
return pa.concat_tables(
[add_index(r.to_arrow(), i) for i, r in enumerate(result)]
)
else:
result = self._conn._client.query(self._name, query)
return self._conn._loop.run_until_complete(result).to_arrow()
return self._conn._client.query(self._name, query).to_arrow()
def delete(self, predicate: str):
"""Delete rows from the table.
@@ -294,9 +285,7 @@ class RemoteTable(Table):
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
"""
payload = {"predicate": predicate}
self._conn._loop.run_until_complete(
self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload)
)
def update(
self,
@@ -356,9 +345,7 @@ class RemoteTable(Table):
updates = [[k, v] for k, v in values_sql.items()]
payload = {"predicate": where, "updates": updates}
self._conn._loop.run_until_complete(
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
)
def add_index(tbl: pa.Table, i: int) -> pa.Table:

View File

@@ -7,7 +7,7 @@ dependencies = [
"ratelimiter~=1.0",
"retry>=0.9.2",
"tqdm>=4.27.0",
"aiohttp",
"requests>=2.31,<3",
"pydantic>=1.10",
"attrs>=21.3.0",
"semver>=3.0",

View File

@@ -1,27 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
from lancedb import LanceDBConnection
# TODO: setup integ test mark and script
@pytest.mark.skip(reason="Need to set up a local server")
def test_against_local_server():
conn = LanceDBConnection("lancedb+http://localhost:10024")
table = conn.open_table("sift1m_ivf1024_pq16")
df = table.search(np.random.rand(128)).to_pandas()
assert len(df) == 10

View File

@@ -1,95 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import attrs
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from aiohttp import web
from lancedb.remote.client import RestfulLanceDBClient, VectorQuery
@attrs.define
class MockLanceDBServer:
runner: web.AppRunner = attrs.field(init=False)
site: web.TCPSite = attrs.field(init=False)
async def query_handler(self, request: web.Request) -> web.Response:
table_name = request.match_info["table_name"]
assert table_name == "test_table"
await request.json()
# TODO: do some matching
vecs = pd.Series([np.random.rand(128) for x in range(10)], name="vector")
ids = pd.Series(range(10), name="id")
df = pd.DataFrame([vecs, ids]).T
batch = pa.RecordBatch.from_pandas(
df,
schema=pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 128)),
pa.field("id", pa.int64()),
]
),
)
sink = pa.BufferOutputStream()
with pa.ipc.new_file(sink, batch.schema) as writer:
writer.write_batch(batch)
return web.Response(body=sink.getvalue().to_pybytes())
async def setup(self):
app = web.Application()
app.add_routes([web.post("/table/{table_name}", self.query_handler)])
self.runner = web.AppRunner(app)
await self.runner.setup()
self.site = web.TCPSite(self.runner, "localhost", 8111)
async def start(self):
await self.site.start()
async def stop(self):
await self.runner.cleanup()
@pytest.mark.skip(reason="flaky somehow, fix later")
@pytest.mark.asyncio
async def test_e2e_with_mock_server():
mock_server = MockLanceDBServer()
await mock_server.setup()
await mock_server.start()
try:
client = RestfulLanceDBClient("lancedb+http://localhost:8111")
df = (
await client.query(
"test_table",
VectorQuery(
vector=np.random.rand(128).tolist(),
k=10,
_metric="L2",
columns=["id", "vector"],
),
)
).to_pandas()
assert "vector" in df.columns
assert "id" in df.columns
finally:
# make sure we don't leak resources
await mock_server.stop()