Compare commits

...

1 Commits

Author SHA1 Message Date
rmeng
60f6dc6a64 chore: switch over to requtes for remote client 2024-01-09 22:35:15 -05:00
6 changed files with 72 additions and 224 deletions

View File

@@ -15,7 +15,8 @@
import functools import functools
from typing import Any, Callable, Dict, Iterable, Optional, Union from typing import Any, Callable, Dict, Iterable, Optional, Union
import aiohttp import requests
import urllib.parse
import attrs import attrs
import pyarrow as pa import pyarrow as pa
from pydantic import BaseModel from pydantic import BaseModel
@@ -37,8 +38,8 @@ def _check_not_closed(f):
return wrapped return wrapped
async def _read_ipc(resp: aiohttp.ClientResponse) -> pa.Table: def _read_ipc(resp: requests.Response) -> pa.Table:
resp_body = await resp.read() resp_body = resp.raw.read()
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader: with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
return reader.read_all() return reader.read_all()
@@ -53,15 +54,24 @@ class RestfulLanceDBClient:
closed: bool = attrs.field(default=False, init=False) closed: bool = attrs.field(default=False, init=False)
@functools.cached_property @functools.cached_property
def session(self) -> aiohttp.ClientSession: def session(self) -> requests.Session:
url = ( session = requests.session()
session.stream = True
return session
@functools.cached_property
def url(self) -> str:
return (
self.host_override self.host_override
or f"https://{self.db_name}.{self.region}.api.lancedb.com" or f"https://{self.db_name}.{self.region}.api.lancedb.com"
) )
return aiohttp.ClientSession(url)
async def close(self): def _get_request_url(self, uri: str) -> str:
await self.session.close() return urllib.parse.urljoin(self.url, uri)
def close(self):
self.session.close()
self.closed = True self.closed = True
@functools.cached_property @functools.cached_property
@@ -75,39 +85,25 @@ class RestfulLanceDBClient:
headers["x-lancedb-database"] = self.db_name headers["x-lancedb-database"] = self.db_name
return headers return headers
@staticmethod
async def _check_status(resp: aiohttp.ClientResponse):
if resp.status == 404:
raise LanceDBClientError(f"Not found: {await resp.text()}")
elif 400 <= resp.status < 500:
raise LanceDBClientError(
f"Bad Request: {resp.status}, error: {await resp.text()}"
)
elif 500 <= resp.status < 600:
raise LanceDBClientError(
f"Internal Server Error: {resp.status}, error: {await resp.text()}"
)
elif resp.status != 200:
raise LanceDBClientError(
f"Unknown Error: {resp.status}, error: {await resp.text()}"
)
@_check_not_closed @_check_not_closed
async def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None): def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None):
"""Send a GET request and returns the deserialized response payload.""" """Send a GET request and returns the deserialized response payload."""
if isinstance(params, BaseModel): if isinstance(params, BaseModel):
params: Dict[str, Any] = params.dict(exclude_none=True) params: Dict[str, Any] = params.dict(exclude_none=True)
async with self.session.get(
uri, resp = self.session.get(
self._get_request_url(uri),
params=params, params=params,
headers=self.headers, headers=self.headers,
timeout=aiohttp.ClientTimeout(total=30), # 5s connect timeout, 30s read timeout
) as resp: timeout=(5.0, 30.0),
await self._check_status(resp) )
return await resp.json()
resp.raise_for_status()
return resp.json()
@_check_not_closed @_check_not_closed
async def post( def post(
self, self,
uri: str, uri: str,
data: Optional[Union[Dict[str, Any], BaseModel, bytes]] = None, data: Optional[Union[Dict[str, Any], BaseModel, bytes]] = None,
@@ -139,31 +135,31 @@ class RestfulLanceDBClient:
headers["content-type"] = content_type headers["content-type"] = content_type
if request_id is not None: if request_id is not None:
headers["x-request-id"] = request_id headers["x-request-id"] = request_id
async with self.session.post(
uri, resp = self.session.post(
headers=headers, self._get_request_url(uri),
params=params, params=params,
timeout=aiohttp.ClientTimeout(total=30), headers=self.headers,
# 5s connect timeout, 30s read timeout
timeout=(5.0, 30.0),
**req_kwargs, **req_kwargs,
) as resp: )
resp: aiohttp.ClientResponse = resp resp.raise_for_status()
await self._check_status(resp)
return await deserialize(resp) return deserialize(resp)
@_check_not_closed @_check_not_closed
async def list_tables( def list_tables(
self, limit: int, page_token: Optional[str] = None self, limit: int, page_token: Optional[str] = None
) -> Iterable[str]: ) -> Iterable[str]:
"""List all tables in the database.""" """List all tables in the database."""
if page_token is None: if page_token is None:
page_token = "" page_token = ""
json = await self.get("/v1/table/", {"limit": limit, "page_token": page_token}) json = self.get("/v1/table/", {"limit": limit, "page_token": page_token})
return json["tables"] return json["tables"]
@_check_not_closed @_check_not_closed
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
"""Query a table.""" """Query a table."""
tbl = await self.post( tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc)
f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc
)
return VectorQueryResult(tbl) return VectorQueryResult(tbl)

View File

@@ -50,10 +50,6 @@ class RemoteDBConnection(DBConnection):
self._client = RestfulLanceDBClient( self._client = RestfulLanceDBClient(
self.db_name, region, api_key, host_override self.db_name, region, api_key, host_override
) )
try:
self._loop = asyncio.get_running_loop()
except RuntimeError:
self._loop = asyncio.get_event_loop()
def __repr__(self) -> str: def __repr__(self) -> str:
return f"RemoteConnect(name={self.db_name})" return f"RemoteConnect(name={self.db_name})"
@@ -76,15 +72,13 @@ class RemoteDBConnection(DBConnection):
An iterator of table names. An iterator of table names.
""" """
while True: while True:
result = self._loop.run_until_complete( result = self._client.list_tables(limit, page_token)
self._client.list_tables(limit, page_token)
)
if len(result) > 0:
page_token = result[len(result) - 1]
else:
break
for item in result: for item in result:
yield item yield item
if len(result) < limit:
break
else:
page_token = result[len(result) - 1]
@override @override
def open_table(self, name: str) -> Table: def open_table(self, name: str) -> Table:
@@ -103,9 +97,7 @@ class RemoteDBConnection(DBConnection):
# check if table exists # check if table exists
try: try:
self._loop.run_until_complete(
self._client.post(f"/v1/table/{name}/describe/") self._client.post(f"/v1/table/{name}/describe/")
)
except LanceDBClientError as err: except LanceDBClientError as err:
if str(err).startswith("Not found"): if str(err).startswith("Not found"):
logging.error( logging.error(
@@ -248,14 +240,12 @@ class RemoteDBConnection(DBConnection):
data = to_ipc_binary(data) data = to_ipc_binary(data)
request_id = uuid.uuid4().hex request_id = uuid.uuid4().hex
self._loop.run_until_complete(
self._client.post( self._client.post(
f"/v1/table/{name}/create/", f"/v1/table/{name}/create/",
data=data, data=data,
request_id=request_id, request_id=request_id,
content_type=ARROW_STREAM_CONTENT_TYPE, content_type=ARROW_STREAM_CONTENT_TYPE,
) )
)
return RemoteTable(self, name) return RemoteTable(self, name)
@override @override
@@ -267,13 +257,10 @@ class RemoteDBConnection(DBConnection):
name: str name: str
The name of the table. The name of the table.
""" """
self._loop.run_until_complete(
self._client.post( self._client.post(
f"/v1/table/{name}/drop/", f"/v1/table/{name}/drop/",
) )
)
async def close(self): async def close(self):
"""Close the connection to the database.""" """Close the connection to the database."""
self._loop.close() self._client.close()
await self._client.close()

View File

@@ -11,7 +11,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import uuid import uuid
from functools import cached_property from functools import cached_property
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
@@ -43,18 +42,14 @@ class RemoteTable(Table):
of this Table of this Table
""" """
resp = self._conn._loop.run_until_complete( resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
self._conn._client.post(f"/v1/table/{self._name}/describe/")
)
schema = json_to_schema(resp["schema"]) schema = json_to_schema(resp["schema"])
return schema return schema
@property @property
def version(self) -> int: def version(self) -> int:
"""Get the current version of the table""" """Get the current version of the table"""
resp = self._conn._loop.run_until_complete( resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
self._conn._client.post(f"/v1/table/{self._name}/describe/")
)
return resp["version"] return resp["version"]
def to_arrow(self) -> pa.Table: def to_arrow(self) -> pa.Table:
@@ -116,8 +111,8 @@ class RemoteTable(Table):
"metric_type": metric, "metric_type": metric,
"index_cache_size": index_cache_size, "index_cache_size": index_cache_size,
} }
resp = self._conn._loop.run_until_complete( resp = self._conn._client.post(
self._conn._client.post(f"/v1/table/{self._name}/create_index/", data=data) f"/v1/table/{self._name}/create_index/", data=data
) )
return resp return resp
@@ -161,14 +156,12 @@ class RemoteTable(Table):
request_id = uuid.uuid4().hex request_id = uuid.uuid4().hex
self._conn._loop.run_until_complete(
self._conn._client.post( self._conn._client.post(
f"/v1/table/{self._name}/insert/", f"/v1/table/{self._name}/insert/",
data=payload, data=payload,
params={"request_id": request_id, "mode": mode}, params={"request_id": request_id, "mode": mode},
content_type=ARROW_STREAM_CONTENT_TYPE, content_type=ARROW_STREAM_CONTENT_TYPE,
) )
)
def search( def search(
self, query: Union[VEC, str], vector_column_name: str = VECTOR_COLUMN_NAME self, query: Union[VEC, str], vector_column_name: str = VECTOR_COLUMN_NAME
@@ -233,19 +226,17 @@ class RemoteTable(Table):
and len(query.vector) > 0 and len(query.vector) > 0
and not isinstance(query.vector[0], float) and not isinstance(query.vector[0], float)
): ):
futures = [] result = []
for v in query.vector: for v in query.vector:
v = list(v) v = list(v)
q = query.copy() q = query.copy()
q.vector = v q.vector = v
futures.append(self._conn._client.query(self._name, q)) result.append(self._conn._client.query(self._name, q))
result = self._conn._loop.run_until_complete(asyncio.gather(*futures))
return pa.concat_tables( return pa.concat_tables(
[add_index(r.to_arrow(), i) for i, r in enumerate(result)] [add_index(r.to_arrow(), i) for i, r in enumerate(result)]
) )
else: else:
result = self._conn._client.query(self._name, query) return self._conn._client.query(self._name, query).to_arrow()
return self._conn._loop.run_until_complete(result).to_arrow()
def delete(self, predicate: str): def delete(self, predicate: str):
"""Delete rows from the table. """Delete rows from the table.
@@ -294,9 +285,7 @@ class RemoteTable(Table):
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP 0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
""" """
payload = {"predicate": predicate} payload = {"predicate": predicate}
self._conn._loop.run_until_complete(
self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload) self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload)
)
def update( def update(
self, self,
@@ -356,9 +345,7 @@ class RemoteTable(Table):
updates = [[k, v] for k, v in values_sql.items()] updates = [[k, v] for k, v in values_sql.items()]
payload = {"predicate": where, "updates": updates} payload = {"predicate": where, "updates": updates}
self._conn._loop.run_until_complete(
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload) self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
)
def add_index(tbl: pa.Table, i: int) -> pa.Table: def add_index(tbl: pa.Table, i: int) -> pa.Table:

View File

@@ -7,7 +7,7 @@ dependencies = [
"ratelimiter~=1.0", "ratelimiter~=1.0",
"retry>=0.9.2", "retry>=0.9.2",
"tqdm>=4.27.0", "tqdm>=4.27.0",
"aiohttp", "requests>=2.31,<3",
"pydantic>=1.10", "pydantic>=1.10",
"attrs>=21.3.0", "attrs>=21.3.0",
"semver>=3.0", "semver>=3.0",

View File

@@ -1,27 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
from lancedb import LanceDBConnection
# TODO: setup integ test mark and script
@pytest.mark.skip(reason="Need to set up a local server")
def test_against_local_server():
conn = LanceDBConnection("lancedb+http://localhost:10024")
table = conn.open_table("sift1m_ivf1024_pq16")
df = table.search(np.random.rand(128)).to_pandas()
assert len(df) == 10

View File

@@ -1,95 +0,0 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import attrs
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from aiohttp import web
from lancedb.remote.client import RestfulLanceDBClient, VectorQuery
@attrs.define
class MockLanceDBServer:
runner: web.AppRunner = attrs.field(init=False)
site: web.TCPSite = attrs.field(init=False)
async def query_handler(self, request: web.Request) -> web.Response:
table_name = request.match_info["table_name"]
assert table_name == "test_table"
await request.json()
# TODO: do some matching
vecs = pd.Series([np.random.rand(128) for x in range(10)], name="vector")
ids = pd.Series(range(10), name="id")
df = pd.DataFrame([vecs, ids]).T
batch = pa.RecordBatch.from_pandas(
df,
schema=pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 128)),
pa.field("id", pa.int64()),
]
),
)
sink = pa.BufferOutputStream()
with pa.ipc.new_file(sink, batch.schema) as writer:
writer.write_batch(batch)
return web.Response(body=sink.getvalue().to_pybytes())
async def setup(self):
app = web.Application()
app.add_routes([web.post("/table/{table_name}", self.query_handler)])
self.runner = web.AppRunner(app)
await self.runner.setup()
self.site = web.TCPSite(self.runner, "localhost", 8111)
async def start(self):
await self.site.start()
async def stop(self):
await self.runner.cleanup()
@pytest.mark.skip(reason="flaky somehow, fix later")
@pytest.mark.asyncio
async def test_e2e_with_mock_server():
mock_server = MockLanceDBServer()
await mock_server.setup()
await mock_server.start()
try:
client = RestfulLanceDBClient("lancedb+http://localhost:8111")
df = (
await client.query(
"test_table",
VectorQuery(
vector=np.random.rand(128).tolist(),
k=10,
_metric="L2",
columns=["id", "vector"],
),
)
).to_pandas()
assert "vector" in df.columns
assert "id" in df.columns
finally:
# make sure we don't leak resources
await mock_server.stop()