mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 14:49:57 +00:00
use requests instead of aiohttp for underlying http client (#803)
instead of starting and stopping the current thread's event loop on every http call, just make an http call.
This commit is contained in:
committed by
Andrew Miracle
parent
91d64d86e0
commit
eda4c587fc
@@ -13,9 +13,10 @@
|
||||
|
||||
|
||||
import functools
|
||||
from typing import Any, Callable, Dict, Iterable, Optional, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
import attrs
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel
|
||||
@@ -37,8 +38,8 @@ def _check_not_closed(f):
|
||||
return wrapped
|
||||
|
||||
|
||||
async def _read_ipc(resp: aiohttp.ClientResponse) -> pa.Table:
|
||||
resp_body = await resp.read()
|
||||
def _read_ipc(resp: requests.Response) -> pa.Table:
|
||||
resp_body = resp.content
|
||||
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
|
||||
return reader.read_all()
|
||||
|
||||
@@ -53,15 +54,18 @@ class RestfulLanceDBClient:
|
||||
closed: bool = attrs.field(default=False, init=False)
|
||||
|
||||
@functools.cached_property
|
||||
def session(self) -> aiohttp.ClientSession:
|
||||
url = (
|
||||
def session(self) -> requests.Session:
|
||||
return requests.Session()
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return (
|
||||
self.host_override
|
||||
or f"https://{self.db_name}.{self.region}.api.lancedb.com"
|
||||
)
|
||||
return aiohttp.ClientSession(url)
|
||||
|
||||
async def close(self):
|
||||
await self.session.close()
|
||||
def close(self):
|
||||
self.session.close()
|
||||
self.closed = True
|
||||
|
||||
@functools.cached_property
|
||||
@@ -76,38 +80,38 @@ class RestfulLanceDBClient:
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
async def _check_status(resp: aiohttp.ClientResponse):
|
||||
if resp.status == 404:
|
||||
raise LanceDBClientError(f"Not found: {await resp.text()}")
|
||||
elif 400 <= resp.status < 500:
|
||||
def _check_status(resp: requests.Response):
|
||||
if resp.status_code == 404:
|
||||
raise LanceDBClientError(f"Not found: {resp.text}")
|
||||
elif 400 <= resp.status_code < 500:
|
||||
raise LanceDBClientError(
|
||||
f"Bad Request: {resp.status}, error: {await resp.text()}"
|
||||
f"Bad Request: {resp.status_code}, error: {resp.text}"
|
||||
)
|
||||
elif 500 <= resp.status < 600:
|
||||
elif 500 <= resp.status_code < 600:
|
||||
raise LanceDBClientError(
|
||||
f"Internal Server Error: {resp.status}, error: {await resp.text()}"
|
||||
f"Internal Server Error: {resp.status_code}, error: {resp.text}"
|
||||
)
|
||||
elif resp.status != 200:
|
||||
elif resp.status_code != 200:
|
||||
raise LanceDBClientError(
|
||||
f"Unknown Error: {resp.status}, error: {await resp.text()}"
|
||||
f"Unknown Error: {resp.status_code}, error: {resp.text}"
|
||||
)
|
||||
|
||||
@_check_not_closed
|
||||
async def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None):
|
||||
def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None):
|
||||
"""Send a GET request and returns the deserialized response payload."""
|
||||
if isinstance(params, BaseModel):
|
||||
params: Dict[str, Any] = params.dict(exclude_none=True)
|
||||
async with self.session.get(
|
||||
uri,
|
||||
with self.session.get(
|
||||
urljoin(self.url, uri),
|
||||
params=params,
|
||||
headers=self.headers,
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
timeout=(5.0, 30.0),
|
||||
) as resp:
|
||||
await self._check_status(resp)
|
||||
return await resp.json()
|
||||
self._check_status(resp)
|
||||
return resp.json()
|
||||
|
||||
@_check_not_closed
|
||||
async def post(
|
||||
def post(
|
||||
self,
|
||||
uri: str,
|
||||
data: Optional[Union[Dict[str, Any], BaseModel, bytes]] = None,
|
||||
@@ -139,31 +143,26 @@ class RestfulLanceDBClient:
|
||||
headers["content-type"] = content_type
|
||||
if request_id is not None:
|
||||
headers["x-request-id"] = request_id
|
||||
async with self.session.post(
|
||||
uri,
|
||||
with self.session.post(
|
||||
urljoin(self.url, uri),
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
timeout=(5.0, 30.0),
|
||||
**req_kwargs,
|
||||
) as resp:
|
||||
resp: aiohttp.ClientResponse = resp
|
||||
await self._check_status(resp)
|
||||
return await deserialize(resp)
|
||||
self._check_status(resp)
|
||||
return deserialize(resp)
|
||||
|
||||
@_check_not_closed
|
||||
async def list_tables(
|
||||
self, limit: int, page_token: Optional[str] = None
|
||||
) -> Iterable[str]:
|
||||
def list_tables(self, limit: int, page_token: Optional[str] = None) -> List[str]:
|
||||
"""List all tables in the database."""
|
||||
if page_token is None:
|
||||
page_token = ""
|
||||
json = await self.get("/v1/table/", {"limit": limit, "page_token": page_token})
|
||||
json = self.get("/v1/table/", {"limit": limit, "page_token": page_token})
|
||||
return json["tables"]
|
||||
|
||||
@_check_not_closed
|
||||
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
"""Query a table."""
|
||||
tbl = await self.post(
|
||||
f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc
|
||||
)
|
||||
tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc)
|
||||
return VectorQueryResult(tbl)
|
||||
|
||||
@@ -50,10 +50,6 @@ class RemoteDBConnection(DBConnection):
|
||||
self._client = RestfulLanceDBClient(
|
||||
self.db_name, region, api_key, host_override
|
||||
)
|
||||
try:
|
||||
self._loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoteConnect(name={self.db_name})"
|
||||
@@ -76,9 +72,8 @@ class RemoteDBConnection(DBConnection):
|
||||
An iterator of table names.
|
||||
"""
|
||||
while True:
|
||||
result = self._loop.run_until_complete(
|
||||
self._client.list_tables(limit, page_token)
|
||||
)
|
||||
result = self._client.list_tables(limit, page_token)
|
||||
|
||||
if len(result) > 0:
|
||||
page_token = result[len(result) - 1]
|
||||
else:
|
||||
@@ -103,9 +98,7 @@ class RemoteDBConnection(DBConnection):
|
||||
|
||||
# check if table exists
|
||||
try:
|
||||
self._loop.run_until_complete(
|
||||
self._client.post(f"/v1/table/{name}/describe/")
|
||||
)
|
||||
self._client.post(f"/v1/table/{name}/describe/")
|
||||
except LanceDBClientError as err:
|
||||
if str(err).startswith("Not found"):
|
||||
logging.error(
|
||||
@@ -248,14 +241,13 @@ class RemoteDBConnection(DBConnection):
|
||||
data = to_ipc_binary(data)
|
||||
request_id = uuid.uuid4().hex
|
||||
|
||||
self._loop.run_until_complete(
|
||||
self._client.post(
|
||||
f"/v1/table/{name}/create/",
|
||||
data=data,
|
||||
request_id=request_id,
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
)
|
||||
self._client.post(
|
||||
f"/v1/table/{name}/create/",
|
||||
data=data,
|
||||
request_id=request_id,
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
)
|
||||
|
||||
return RemoteTable(self, name)
|
||||
|
||||
@override
|
||||
@@ -267,13 +259,11 @@ class RemoteDBConnection(DBConnection):
|
||||
name: str
|
||||
The name of the table.
|
||||
"""
|
||||
self._loop.run_until_complete(
|
||||
self._client.post(
|
||||
f"/v1/table/{name}/drop/",
|
||||
)
|
||||
|
||||
self._client.post(
|
||||
f"/v1/table/{name}/drop/",
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection to the database."""
|
||||
self._loop.close()
|
||||
await self._client.close()
|
||||
self._client.close()
|
||||
|
||||
@@ -43,18 +43,14 @@ class RemoteTable(Table):
|
||||
of this Table
|
||||
|
||||
"""
|
||||
resp = self._conn._loop.run_until_complete(
|
||||
self._conn._client.post(f"/v1/table/{self._name}/describe/")
|
||||
)
|
||||
resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
|
||||
schema = json_to_schema(resp["schema"])
|
||||
return schema
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
"""Get the current version of the table"""
|
||||
resp = self._conn._loop.run_until_complete(
|
||||
self._conn._client.post(f"/v1/table/{self._name}/describe/")
|
||||
)
|
||||
resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
|
||||
return resp["version"]
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
@@ -116,9 +112,10 @@ class RemoteTable(Table):
|
||||
"metric_type": metric,
|
||||
"index_cache_size": index_cache_size,
|
||||
}
|
||||
resp = self._conn._loop.run_until_complete(
|
||||
self._conn._client.post(f"/v1/table/{self._name}/create_index/", data=data)
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self._name}/create_index/", data=data
|
||||
)
|
||||
|
||||
return resp
|
||||
|
||||
def add(
|
||||
@@ -161,13 +158,11 @@ class RemoteTable(Table):
|
||||
|
||||
request_id = uuid.uuid4().hex
|
||||
|
||||
self._conn._loop.run_until_complete(
|
||||
self._conn._client.post(
|
||||
f"/v1/table/{self._name}/insert/",
|
||||
data=payload,
|
||||
params={"request_id": request_id, "mode": mode},
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
)
|
||||
self._conn._client.post(
|
||||
f"/v1/table/{self._name}/insert/",
|
||||
data=payload,
|
||||
params={"request_id": request_id, "mode": mode},
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
)
|
||||
|
||||
def search(
|
||||
@@ -233,19 +228,19 @@ class RemoteTable(Table):
|
||||
and len(query.vector) > 0
|
||||
and not isinstance(query.vector[0], float)
|
||||
):
|
||||
futures = []
|
||||
results = []
|
||||
for v in query.vector:
|
||||
v = list(v)
|
||||
q = query.copy()
|
||||
q.vector = v
|
||||
futures.append(self._conn._client.query(self._name, q))
|
||||
result = self._conn._loop.run_until_complete(asyncio.gather(*futures))
|
||||
results.append(self._conn._client.query(self._name, q))
|
||||
|
||||
return pa.concat_tables(
|
||||
[add_index(r.to_arrow(), i) for i, r in enumerate(result)]
|
||||
[add_index(r.to_arrow(), i) for i, r in enumerate(results)]
|
||||
)
|
||||
else:
|
||||
result = self._conn._client.query(self._name, query)
|
||||
return self._conn._loop.run_until_complete(result).to_arrow()
|
||||
return result.to_arrow()
|
||||
|
||||
def delete(self, predicate: str):
|
||||
"""Delete rows from the table.
|
||||
@@ -294,9 +289,7 @@ class RemoteTable(Table):
|
||||
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
|
||||
"""
|
||||
payload = {"predicate": predicate}
|
||||
self._conn._loop.run_until_complete(
|
||||
self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload)
|
||||
)
|
||||
self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload)
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -356,9 +349,7 @@ class RemoteTable(Table):
|
||||
updates = [[k, v] for k, v in values_sql.items()]
|
||||
|
||||
payload = {"predicate": where, "updates": updates}
|
||||
self._conn._loop.run_until_complete(
|
||||
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
|
||||
)
|
||||
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
|
||||
|
||||
|
||||
def add_index(tbl: pa.Table, i: int) -> pa.Table:
|
||||
|
||||
@@ -7,7 +7,6 @@ dependencies = [
|
||||
"ratelimiter~=1.0",
|
||||
"retry>=0.9.2",
|
||||
"tqdm>=4.27.0",
|
||||
"aiohttp",
|
||||
"pydantic>=1.10",
|
||||
"attrs>=21.3.0",
|
||||
"semver>=3.0",
|
||||
@@ -49,7 +48,7 @@ classifiers = [
|
||||
repository = "https://github.com/lancedb/lancedb"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests", "duckdb", "pytz"]
|
||||
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz"]
|
||||
dev = ["ruff", "pre-commit", "black"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
|
||||
@@ -18,15 +18,15 @@ from lancedb.remote.client import VectorQuery, VectorQueryResult
|
||||
|
||||
|
||||
class FakeLanceDBClient:
|
||||
async def close(self):
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
assert table_name == "test"
|
||||
t = pa.schema([]).empty_table()
|
||||
return VectorQueryResult(t)
|
||||
|
||||
async def post(self, path: str):
|
||||
def post(self, path: str):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user