use requests instead of aiohttp for underlying http client (#803)

instead of starting and stopping the current thread's event loop on
every http call, just make an http call.
This commit is contained in:
Sebastian Law
2024-01-09 21:07:50 -08:00
committed by Andrew Miracle
parent 91d64d86e0
commit eda4c587fc
5 changed files with 72 additions and 93 deletions

View File

@@ -13,9 +13,10 @@
import functools
from typing import Any, Callable, Dict, Iterable, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from urllib.parse import urljoin
import aiohttp
import requests
import attrs
import pyarrow as pa
from pydantic import BaseModel
@@ -37,8 +38,8 @@ def _check_not_closed(f):
return wrapped
async def _read_ipc(resp: aiohttp.ClientResponse) -> pa.Table:
resp_body = await resp.read()
def _read_ipc(resp: requests.Response) -> pa.Table:
resp_body = resp.content
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
return reader.read_all()
@@ -53,15 +54,18 @@ class RestfulLanceDBClient:
closed: bool = attrs.field(default=False, init=False)
@functools.cached_property
def session(self) -> aiohttp.ClientSession:
url = (
def session(self) -> requests.Session:
return requests.Session()
@property
def url(self) -> str:
return (
self.host_override
or f"https://{self.db_name}.{self.region}.api.lancedb.com"
)
return aiohttp.ClientSession(url)
async def close(self):
await self.session.close()
def close(self):
self.session.close()
self.closed = True
@functools.cached_property
@@ -76,38 +80,38 @@ class RestfulLanceDBClient:
return headers
@staticmethod
async def _check_status(resp: aiohttp.ClientResponse):
if resp.status == 404:
raise LanceDBClientError(f"Not found: {await resp.text()}")
elif 400 <= resp.status < 500:
def _check_status(resp: requests.Response):
if resp.status_code == 404:
raise LanceDBClientError(f"Not found: {resp.text}")
elif 400 <= resp.status_code < 500:
raise LanceDBClientError(
f"Bad Request: {resp.status}, error: {await resp.text()}"
f"Bad Request: {resp.status_code}, error: {resp.text}"
)
elif 500 <= resp.status < 600:
elif 500 <= resp.status_code < 600:
raise LanceDBClientError(
f"Internal Server Error: {resp.status}, error: {await resp.text()}"
f"Internal Server Error: {resp.status_code}, error: {resp.text}"
)
elif resp.status != 200:
elif resp.status_code != 200:
raise LanceDBClientError(
f"Unknown Error: {resp.status}, error: {await resp.text()}"
f"Unknown Error: {resp.status_code}, error: {resp.text}"
)
@_check_not_closed
async def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None):
def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None):
"""Send a GET request and returns the deserialized response payload."""
if isinstance(params, BaseModel):
params: Dict[str, Any] = params.dict(exclude_none=True)
async with self.session.get(
uri,
with self.session.get(
urljoin(self.url, uri),
params=params,
headers=self.headers,
timeout=aiohttp.ClientTimeout(total=30),
timeout=(5.0, 30.0),
) as resp:
await self._check_status(resp)
return await resp.json()
self._check_status(resp)
return resp.json()
@_check_not_closed
async def post(
def post(
self,
uri: str,
data: Optional[Union[Dict[str, Any], BaseModel, bytes]] = None,
@@ -139,31 +143,26 @@ class RestfulLanceDBClient:
headers["content-type"] = content_type
if request_id is not None:
headers["x-request-id"] = request_id
async with self.session.post(
uri,
with self.session.post(
urljoin(self.url, uri),
headers=headers,
params=params,
timeout=aiohttp.ClientTimeout(total=30),
timeout=(5.0, 30.0),
**req_kwargs,
) as resp:
resp: aiohttp.ClientResponse = resp
await self._check_status(resp)
return await deserialize(resp)
self._check_status(resp)
return deserialize(resp)
@_check_not_closed
async def list_tables(
self, limit: int, page_token: Optional[str] = None
) -> Iterable[str]:
def list_tables(self, limit: int, page_token: Optional[str] = None) -> List[str]:
"""List all tables in the database."""
if page_token is None:
page_token = ""
json = await self.get("/v1/table/", {"limit": limit, "page_token": page_token})
json = self.get("/v1/table/", {"limit": limit, "page_token": page_token})
return json["tables"]
@_check_not_closed
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
"""Query a table."""
tbl = await self.post(
f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc
)
tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc)
return VectorQueryResult(tbl)

View File

@@ -50,10 +50,6 @@ class RemoteDBConnection(DBConnection):
self._client = RestfulLanceDBClient(
self.db_name, region, api_key, host_override
)
try:
self._loop = asyncio.get_running_loop()
except RuntimeError:
self._loop = asyncio.get_event_loop()
def __repr__(self) -> str:
return f"RemoteConnect(name={self.db_name})"
@@ -76,9 +72,8 @@ class RemoteDBConnection(DBConnection):
An iterator of table names.
"""
while True:
result = self._loop.run_until_complete(
self._client.list_tables(limit, page_token)
)
result = self._client.list_tables(limit, page_token)
if len(result) > 0:
page_token = result[len(result) - 1]
else:
@@ -103,9 +98,7 @@ class RemoteDBConnection(DBConnection):
# check if table exists
try:
self._loop.run_until_complete(
self._client.post(f"/v1/table/{name}/describe/")
)
self._client.post(f"/v1/table/{name}/describe/")
except LanceDBClientError as err:
if str(err).startswith("Not found"):
logging.error(
@@ -248,14 +241,13 @@ class RemoteDBConnection(DBConnection):
data = to_ipc_binary(data)
request_id = uuid.uuid4().hex
self._loop.run_until_complete(
self._client.post(
f"/v1/table/{name}/create/",
data=data,
request_id=request_id,
content_type=ARROW_STREAM_CONTENT_TYPE,
)
self._client.post(
f"/v1/table/{name}/create/",
data=data,
request_id=request_id,
content_type=ARROW_STREAM_CONTENT_TYPE,
)
return RemoteTable(self, name)
@override
@@ -267,13 +259,11 @@ class RemoteDBConnection(DBConnection):
name: str
The name of the table.
"""
self._loop.run_until_complete(
self._client.post(
f"/v1/table/{name}/drop/",
)
self._client.post(
f"/v1/table/{name}/drop/",
)
async def close(self):
"""Close the connection to the database."""
self._loop.close()
await self._client.close()
self._client.close()

View File

@@ -43,18 +43,14 @@ class RemoteTable(Table):
of this Table
"""
resp = self._conn._loop.run_until_complete(
self._conn._client.post(f"/v1/table/{self._name}/describe/")
)
resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
schema = json_to_schema(resp["schema"])
return schema
@property
def version(self) -> int:
"""Get the current version of the table"""
resp = self._conn._loop.run_until_complete(
self._conn._client.post(f"/v1/table/{self._name}/describe/")
)
resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
return resp["version"]
def to_arrow(self) -> pa.Table:
@@ -116,9 +112,10 @@ class RemoteTable(Table):
"metric_type": metric,
"index_cache_size": index_cache_size,
}
resp = self._conn._loop.run_until_complete(
self._conn._client.post(f"/v1/table/{self._name}/create_index/", data=data)
resp = self._conn._client.post(
f"/v1/table/{self._name}/create_index/", data=data
)
return resp
def add(
@@ -161,13 +158,11 @@ class RemoteTable(Table):
request_id = uuid.uuid4().hex
self._conn._loop.run_until_complete(
self._conn._client.post(
f"/v1/table/{self._name}/insert/",
data=payload,
params={"request_id": request_id, "mode": mode},
content_type=ARROW_STREAM_CONTENT_TYPE,
)
self._conn._client.post(
f"/v1/table/{self._name}/insert/",
data=payload,
params={"request_id": request_id, "mode": mode},
content_type=ARROW_STREAM_CONTENT_TYPE,
)
def search(
@@ -233,19 +228,19 @@ class RemoteTable(Table):
and len(query.vector) > 0
and not isinstance(query.vector[0], float)
):
futures = []
results = []
for v in query.vector:
v = list(v)
q = query.copy()
q.vector = v
futures.append(self._conn._client.query(self._name, q))
result = self._conn._loop.run_until_complete(asyncio.gather(*futures))
results.append(self._conn._client.query(self._name, q))
return pa.concat_tables(
[add_index(r.to_arrow(), i) for i, r in enumerate(result)]
[add_index(r.to_arrow(), i) for i, r in enumerate(results)]
)
else:
result = self._conn._client.query(self._name, query)
return self._conn._loop.run_until_complete(result).to_arrow()
return result.to_arrow()
def delete(self, predicate: str):
"""Delete rows from the table.
@@ -294,9 +289,7 @@ class RemoteTable(Table):
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
"""
payload = {"predicate": predicate}
self._conn._loop.run_until_complete(
self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload)
)
self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload)
def update(
self,
@@ -356,9 +349,7 @@ class RemoteTable(Table):
updates = [[k, v] for k, v in values_sql.items()]
payload = {"predicate": where, "updates": updates}
self._conn._loop.run_until_complete(
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
)
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
def add_index(tbl: pa.Table, i: int) -> pa.Table:

View File

@@ -7,7 +7,6 @@ dependencies = [
"ratelimiter~=1.0",
"retry>=0.9.2",
"tqdm>=4.27.0",
"aiohttp",
"pydantic>=1.10",
"attrs>=21.3.0",
"semver>=3.0",
@@ -49,7 +48,7 @@ classifiers = [
repository = "https://github.com/lancedb/lancedb"
[project.optional-dependencies]
tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests", "duckdb", "pytz"]
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz"]
dev = ["ruff", "pre-commit", "black"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"]

View File

@@ -18,15 +18,15 @@ from lancedb.remote.client import VectorQuery, VectorQueryResult
class FakeLanceDBClient:
async def close(self):
def close(self):
pass
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
assert table_name == "test"
t = pa.schema([]).empty_table()
return VectorQueryResult(t)
async def post(self, path: str):
def post(self, path: str):
pass