use requests instead of aiohttp for underlying http client (#803)

instead of starting and stopping the current thread's event loop on
every http call, just make an http call.
This commit is contained in:
Sebastian Law
2024-01-09 21:07:50 -08:00
committed by Weston Pace
parent 7581cbb38f
commit 4aa7f58a07
5 changed files with 72 additions and 93 deletions

View File

@@ -13,9 +13,10 @@
import functools import functools
from typing import Any, Callable, Dict, Iterable, Optional, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from urllib.parse import urljoin
import aiohttp import requests
import attrs import attrs
import pyarrow as pa import pyarrow as pa
from pydantic import BaseModel from pydantic import BaseModel
@@ -37,8 +38,8 @@ def _check_not_closed(f):
return wrapped return wrapped
async def _read_ipc(resp: aiohttp.ClientResponse) -> pa.Table: def _read_ipc(resp: requests.Response) -> pa.Table:
resp_body = await resp.read() resp_body = resp.content
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader: with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
return reader.read_all() return reader.read_all()
@@ -53,15 +54,18 @@ class RestfulLanceDBClient:
closed: bool = attrs.field(default=False, init=False) closed: bool = attrs.field(default=False, init=False)
@functools.cached_property @functools.cached_property
def session(self) -> aiohttp.ClientSession: def session(self) -> requests.Session:
url = ( return requests.Session()
@property
def url(self) -> str:
return (
self.host_override self.host_override
or f"https://{self.db_name}.{self.region}.api.lancedb.com" or f"https://{self.db_name}.{self.region}.api.lancedb.com"
) )
return aiohttp.ClientSession(url)
async def close(self): def close(self):
await self.session.close() self.session.close()
self.closed = True self.closed = True
@functools.cached_property @functools.cached_property
@@ -76,38 +80,38 @@ class RestfulLanceDBClient:
return headers return headers
@staticmethod @staticmethod
async def _check_status(resp: aiohttp.ClientResponse): def _check_status(resp: requests.Response):
if resp.status == 404: if resp.status_code == 404:
raise LanceDBClientError(f"Not found: {await resp.text()}") raise LanceDBClientError(f"Not found: {resp.text}")
elif 400 <= resp.status < 500: elif 400 <= resp.status_code < 500:
raise LanceDBClientError( raise LanceDBClientError(
f"Bad Request: {resp.status}, error: {await resp.text()}" f"Bad Request: {resp.status_code}, error: {resp.text}"
) )
elif 500 <= resp.status < 600: elif 500 <= resp.status_code < 600:
raise LanceDBClientError( raise LanceDBClientError(
f"Internal Server Error: {resp.status}, error: {await resp.text()}" f"Internal Server Error: {resp.status_code}, error: {resp.text}"
) )
elif resp.status != 200: elif resp.status_code != 200:
raise LanceDBClientError( raise LanceDBClientError(
f"Unknown Error: {resp.status}, error: {await resp.text()}" f"Unknown Error: {resp.status_code}, error: {resp.text}"
) )
@_check_not_closed @_check_not_closed
async def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None): def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None):
"""Send a GET request and returns the deserialized response payload.""" """Send a GET request and returns the deserialized response payload."""
if isinstance(params, BaseModel): if isinstance(params, BaseModel):
params: Dict[str, Any] = params.dict(exclude_none=True) params: Dict[str, Any] = params.dict(exclude_none=True)
async with self.session.get( with self.session.get(
uri, urljoin(self.url, uri),
params=params, params=params,
headers=self.headers, headers=self.headers,
timeout=aiohttp.ClientTimeout(total=30), timeout=(5.0, 30.0),
) as resp: ) as resp:
await self._check_status(resp) self._check_status(resp)
return await resp.json() return resp.json()
@_check_not_closed @_check_not_closed
async def post( def post(
self, self,
uri: str, uri: str,
data: Optional[Union[Dict[str, Any], BaseModel, bytes]] = None, data: Optional[Union[Dict[str, Any], BaseModel, bytes]] = None,
@@ -139,31 +143,26 @@ class RestfulLanceDBClient:
headers["content-type"] = content_type headers["content-type"] = content_type
if request_id is not None: if request_id is not None:
headers["x-request-id"] = request_id headers["x-request-id"] = request_id
async with self.session.post( with self.session.post(
uri, urljoin(self.url, uri),
headers=headers, headers=headers,
params=params, params=params,
timeout=aiohttp.ClientTimeout(total=30), timeout=(5.0, 30.0),
**req_kwargs, **req_kwargs,
) as resp: ) as resp:
resp: aiohttp.ClientResponse = resp self._check_status(resp)
await self._check_status(resp) return deserialize(resp)
return await deserialize(resp)
@_check_not_closed @_check_not_closed
async def list_tables( def list_tables(self, limit: int, page_token: Optional[str] = None) -> List[str]:
self, limit: int, page_token: Optional[str] = None
) -> Iterable[str]:
"""List all tables in the database.""" """List all tables in the database."""
if page_token is None: if page_token is None:
page_token = "" page_token = ""
json = await self.get("/v1/table/", {"limit": limit, "page_token": page_token}) json = self.get("/v1/table/", {"limit": limit, "page_token": page_token})
return json["tables"] return json["tables"]
@_check_not_closed @_check_not_closed
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
"""Query a table.""" """Query a table."""
tbl = await self.post( tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc)
f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc
)
return VectorQueryResult(tbl) return VectorQueryResult(tbl)

View File

@@ -50,10 +50,6 @@ class RemoteDBConnection(DBConnection):
self._client = RestfulLanceDBClient( self._client = RestfulLanceDBClient(
self.db_name, region, api_key, host_override self.db_name, region, api_key, host_override
) )
try:
self._loop = asyncio.get_running_loop()
except RuntimeError:
self._loop = asyncio.get_event_loop()
def __repr__(self) -> str: def __repr__(self) -> str:
return f"RemoteConnect(name={self.db_name})" return f"RemoteConnect(name={self.db_name})"
@@ -76,9 +72,8 @@ class RemoteDBConnection(DBConnection):
An iterator of table names. An iterator of table names.
""" """
while True: while True:
result = self._loop.run_until_complete( result = self._client.list_tables(limit, page_token)
self._client.list_tables(limit, page_token)
)
if len(result) > 0: if len(result) > 0:
page_token = result[len(result) - 1] page_token = result[len(result) - 1]
else: else:
@@ -103,9 +98,7 @@ class RemoteDBConnection(DBConnection):
# check if table exists # check if table exists
try: try:
self._loop.run_until_complete( self._client.post(f"/v1/table/{name}/describe/")
self._client.post(f"/v1/table/{name}/describe/")
)
except LanceDBClientError as err: except LanceDBClientError as err:
if str(err).startswith("Not found"): if str(err).startswith("Not found"):
logging.error( logging.error(
@@ -248,14 +241,13 @@ class RemoteDBConnection(DBConnection):
data = to_ipc_binary(data) data = to_ipc_binary(data)
request_id = uuid.uuid4().hex request_id = uuid.uuid4().hex
self._loop.run_until_complete( self._client.post(
self._client.post( f"/v1/table/{name}/create/",
f"/v1/table/{name}/create/", data=data,
data=data, request_id=request_id,
request_id=request_id, content_type=ARROW_STREAM_CONTENT_TYPE,
content_type=ARROW_STREAM_CONTENT_TYPE,
)
) )
return RemoteTable(self, name) return RemoteTable(self, name)
@override @override
@@ -267,13 +259,11 @@ class RemoteDBConnection(DBConnection):
name: str name: str
The name of the table. The name of the table.
""" """
self._loop.run_until_complete(
self._client.post( self._client.post(
f"/v1/table/{name}/drop/", f"/v1/table/{name}/drop/",
)
) )
async def close(self): async def close(self):
"""Close the connection to the database.""" """Close the connection to the database."""
self._loop.close() self._client.close()
await self._client.close()

View File

@@ -43,18 +43,14 @@ class RemoteTable(Table):
of this Table of this Table
""" """
resp = self._conn._loop.run_until_complete( resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
self._conn._client.post(f"/v1/table/{self._name}/describe/")
)
schema = json_to_schema(resp["schema"]) schema = json_to_schema(resp["schema"])
return schema return schema
@property @property
def version(self) -> int: def version(self) -> int:
"""Get the current version of the table""" """Get the current version of the table"""
resp = self._conn._loop.run_until_complete( resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
self._conn._client.post(f"/v1/table/{self._name}/describe/")
)
return resp["version"] return resp["version"]
def to_arrow(self) -> pa.Table: def to_arrow(self) -> pa.Table:
@@ -116,9 +112,10 @@ class RemoteTable(Table):
"metric_type": metric, "metric_type": metric,
"index_cache_size": index_cache_size, "index_cache_size": index_cache_size,
} }
resp = self._conn._loop.run_until_complete( resp = self._conn._client.post(
self._conn._client.post(f"/v1/table/{self._name}/create_index/", data=data) f"/v1/table/{self._name}/create_index/", data=data
) )
return resp return resp
def add( def add(
@@ -161,13 +158,11 @@ class RemoteTable(Table):
request_id = uuid.uuid4().hex request_id = uuid.uuid4().hex
self._conn._loop.run_until_complete( self._conn._client.post(
self._conn._client.post( f"/v1/table/{self._name}/insert/",
f"/v1/table/{self._name}/insert/", data=payload,
data=payload, params={"request_id": request_id, "mode": mode},
params={"request_id": request_id, "mode": mode}, content_type=ARROW_STREAM_CONTENT_TYPE,
content_type=ARROW_STREAM_CONTENT_TYPE,
)
) )
def search( def search(
@@ -233,19 +228,19 @@ class RemoteTable(Table):
and len(query.vector) > 0 and len(query.vector) > 0
and not isinstance(query.vector[0], float) and not isinstance(query.vector[0], float)
): ):
futures = [] results = []
for v in query.vector: for v in query.vector:
v = list(v) v = list(v)
q = query.copy() q = query.copy()
q.vector = v q.vector = v
futures.append(self._conn._client.query(self._name, q)) results.append(self._conn._client.query(self._name, q))
result = self._conn._loop.run_until_complete(asyncio.gather(*futures))
return pa.concat_tables( return pa.concat_tables(
[add_index(r.to_arrow(), i) for i, r in enumerate(result)] [add_index(r.to_arrow(), i) for i, r in enumerate(results)]
) )
else: else:
result = self._conn._client.query(self._name, query) result = self._conn._client.query(self._name, query)
return self._conn._loop.run_until_complete(result).to_arrow() return result.to_arrow()
def delete(self, predicate: str): def delete(self, predicate: str):
"""Delete rows from the table. """Delete rows from the table.
@@ -294,9 +289,7 @@ class RemoteTable(Table):
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP 0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
""" """
payload = {"predicate": predicate} payload = {"predicate": predicate}
self._conn._loop.run_until_complete( self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload)
self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload)
)
def update( def update(
self, self,
@@ -356,9 +349,7 @@ class RemoteTable(Table):
updates = [[k, v] for k, v in values_sql.items()] updates = [[k, v] for k, v in values_sql.items()]
payload = {"predicate": where, "updates": updates} payload = {"predicate": where, "updates": updates}
self._conn._loop.run_until_complete( self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
)
def add_index(tbl: pa.Table, i: int) -> pa.Table: def add_index(tbl: pa.Table, i: int) -> pa.Table:

View File

@@ -7,7 +7,6 @@ dependencies = [
"ratelimiter~=1.0", "ratelimiter~=1.0",
"retry>=0.9.2", "retry>=0.9.2",
"tqdm>=4.27.0", "tqdm>=4.27.0",
"aiohttp",
"pydantic>=1.10", "pydantic>=1.10",
"attrs>=21.3.0", "attrs>=21.3.0",
"semver>=3.0", "semver>=3.0",
@@ -46,7 +45,7 @@ classifiers = [
repository = "https://github.com/lancedb/lancedb" repository = "https://github.com/lancedb/lancedb"
[project.optional-dependencies] [project.optional-dependencies]
tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests", "duckdb", "pytz"] tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz"]
dev = ["ruff", "pre-commit", "black"] dev = ["ruff", "pre-commit", "black"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"] clip = ["torch", "pillow", "open-clip"]

View File

@@ -18,15 +18,15 @@ from lancedb.remote.client import VectorQuery, VectorQueryResult
class FakeLanceDBClient: class FakeLanceDBClient:
async def close(self): def close(self):
pass pass
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
assert table_name == "test" assert table_name == "test"
t = pa.schema([]).empty_table() t = pa.schema([]).empty_table()
return VectorQueryResult(t) return VectorQueryResult(t)
async def post(self, path: str): def post(self, path: str):
pass pass