mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-07 04:12:59 +00:00
use requests instead of aiohttp for underlying http client (#803)
instead of starting and stopping the current thread's event loop on every http call, just make an http call.
This commit is contained in:
committed by
Weston Pace
parent
7581cbb38f
commit
4aa7f58a07
@@ -13,9 +13,10 @@
|
|||||||
|
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
from typing import Any, Callable, Dict, Iterable, Optional, Union
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
import aiohttp
|
import requests
|
||||||
import attrs
|
import attrs
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -37,8 +38,8 @@ def _check_not_closed(f):
|
|||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
async def _read_ipc(resp: aiohttp.ClientResponse) -> pa.Table:
|
def _read_ipc(resp: requests.Response) -> pa.Table:
|
||||||
resp_body = await resp.read()
|
resp_body = resp.content
|
||||||
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
|
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
|
||||||
return reader.read_all()
|
return reader.read_all()
|
||||||
|
|
||||||
@@ -53,15 +54,18 @@ class RestfulLanceDBClient:
|
|||||||
closed: bool = attrs.field(default=False, init=False)
|
closed: bool = attrs.field(default=False, init=False)
|
||||||
|
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def session(self) -> aiohttp.ClientSession:
|
def session(self) -> requests.Session:
|
||||||
url = (
|
return requests.Session()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def url(self) -> str:
|
||||||
|
return (
|
||||||
self.host_override
|
self.host_override
|
||||||
or f"https://{self.db_name}.{self.region}.api.lancedb.com"
|
or f"https://{self.db_name}.{self.region}.api.lancedb.com"
|
||||||
)
|
)
|
||||||
return aiohttp.ClientSession(url)
|
|
||||||
|
|
||||||
async def close(self):
|
def close(self):
|
||||||
await self.session.close()
|
self.session.close()
|
||||||
self.closed = True
|
self.closed = True
|
||||||
|
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
@@ -76,38 +80,38 @@ class RestfulLanceDBClient:
|
|||||||
return headers
|
return headers
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _check_status(resp: aiohttp.ClientResponse):
|
def _check_status(resp: requests.Response):
|
||||||
if resp.status == 404:
|
if resp.status_code == 404:
|
||||||
raise LanceDBClientError(f"Not found: {await resp.text()}")
|
raise LanceDBClientError(f"Not found: {resp.text}")
|
||||||
elif 400 <= resp.status < 500:
|
elif 400 <= resp.status_code < 500:
|
||||||
raise LanceDBClientError(
|
raise LanceDBClientError(
|
||||||
f"Bad Request: {resp.status}, error: {await resp.text()}"
|
f"Bad Request: {resp.status_code}, error: {resp.text}"
|
||||||
)
|
)
|
||||||
elif 500 <= resp.status < 600:
|
elif 500 <= resp.status_code < 600:
|
||||||
raise LanceDBClientError(
|
raise LanceDBClientError(
|
||||||
f"Internal Server Error: {resp.status}, error: {await resp.text()}"
|
f"Internal Server Error: {resp.status_code}, error: {resp.text}"
|
||||||
)
|
)
|
||||||
elif resp.status != 200:
|
elif resp.status_code != 200:
|
||||||
raise LanceDBClientError(
|
raise LanceDBClientError(
|
||||||
f"Unknown Error: {resp.status}, error: {await resp.text()}"
|
f"Unknown Error: {resp.status_code}, error: {resp.text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@_check_not_closed
|
@_check_not_closed
|
||||||
async def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None):
|
def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None):
|
||||||
"""Send a GET request and returns the deserialized response payload."""
|
"""Send a GET request and returns the deserialized response payload."""
|
||||||
if isinstance(params, BaseModel):
|
if isinstance(params, BaseModel):
|
||||||
params: Dict[str, Any] = params.dict(exclude_none=True)
|
params: Dict[str, Any] = params.dict(exclude_none=True)
|
||||||
async with self.session.get(
|
with self.session.get(
|
||||||
uri,
|
urljoin(self.url, uri),
|
||||||
params=params,
|
params=params,
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
timeout=aiohttp.ClientTimeout(total=30),
|
timeout=(5.0, 30.0),
|
||||||
) as resp:
|
) as resp:
|
||||||
await self._check_status(resp)
|
self._check_status(resp)
|
||||||
return await resp.json()
|
return resp.json()
|
||||||
|
|
||||||
@_check_not_closed
|
@_check_not_closed
|
||||||
async def post(
|
def post(
|
||||||
self,
|
self,
|
||||||
uri: str,
|
uri: str,
|
||||||
data: Optional[Union[Dict[str, Any], BaseModel, bytes]] = None,
|
data: Optional[Union[Dict[str, Any], BaseModel, bytes]] = None,
|
||||||
@@ -139,31 +143,26 @@ class RestfulLanceDBClient:
|
|||||||
headers["content-type"] = content_type
|
headers["content-type"] = content_type
|
||||||
if request_id is not None:
|
if request_id is not None:
|
||||||
headers["x-request-id"] = request_id
|
headers["x-request-id"] = request_id
|
||||||
async with self.session.post(
|
with self.session.post(
|
||||||
uri,
|
urljoin(self.url, uri),
|
||||||
headers=headers,
|
headers=headers,
|
||||||
params=params,
|
params=params,
|
||||||
timeout=aiohttp.ClientTimeout(total=30),
|
timeout=(5.0, 30.0),
|
||||||
**req_kwargs,
|
**req_kwargs,
|
||||||
) as resp:
|
) as resp:
|
||||||
resp: aiohttp.ClientResponse = resp
|
self._check_status(resp)
|
||||||
await self._check_status(resp)
|
return deserialize(resp)
|
||||||
return await deserialize(resp)
|
|
||||||
|
|
||||||
@_check_not_closed
|
@_check_not_closed
|
||||||
async def list_tables(
|
def list_tables(self, limit: int, page_token: Optional[str] = None) -> List[str]:
|
||||||
self, limit: int, page_token: Optional[str] = None
|
|
||||||
) -> Iterable[str]:
|
|
||||||
"""List all tables in the database."""
|
"""List all tables in the database."""
|
||||||
if page_token is None:
|
if page_token is None:
|
||||||
page_token = ""
|
page_token = ""
|
||||||
json = await self.get("/v1/table/", {"limit": limit, "page_token": page_token})
|
json = self.get("/v1/table/", {"limit": limit, "page_token": page_token})
|
||||||
return json["tables"]
|
return json["tables"]
|
||||||
|
|
||||||
@_check_not_closed
|
@_check_not_closed
|
||||||
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||||
"""Query a table."""
|
"""Query a table."""
|
||||||
tbl = await self.post(
|
tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc)
|
||||||
f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc
|
|
||||||
)
|
|
||||||
return VectorQueryResult(tbl)
|
return VectorQueryResult(tbl)
|
||||||
|
|||||||
@@ -50,10 +50,6 @@ class RemoteDBConnection(DBConnection):
|
|||||||
self._client = RestfulLanceDBClient(
|
self._client = RestfulLanceDBClient(
|
||||||
self.db_name, region, api_key, host_override
|
self.db_name, region, api_key, host_override
|
||||||
)
|
)
|
||||||
try:
|
|
||||||
self._loop = asyncio.get_running_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
self._loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"RemoteConnect(name={self.db_name})"
|
return f"RemoteConnect(name={self.db_name})"
|
||||||
@@ -76,9 +72,8 @@ class RemoteDBConnection(DBConnection):
|
|||||||
An iterator of table names.
|
An iterator of table names.
|
||||||
"""
|
"""
|
||||||
while True:
|
while True:
|
||||||
result = self._loop.run_until_complete(
|
result = self._client.list_tables(limit, page_token)
|
||||||
self._client.list_tables(limit, page_token)
|
|
||||||
)
|
|
||||||
if len(result) > 0:
|
if len(result) > 0:
|
||||||
page_token = result[len(result) - 1]
|
page_token = result[len(result) - 1]
|
||||||
else:
|
else:
|
||||||
@@ -103,9 +98,7 @@ class RemoteDBConnection(DBConnection):
|
|||||||
|
|
||||||
# check if table exists
|
# check if table exists
|
||||||
try:
|
try:
|
||||||
self._loop.run_until_complete(
|
self._client.post(f"/v1/table/{name}/describe/")
|
||||||
self._client.post(f"/v1/table/{name}/describe/")
|
|
||||||
)
|
|
||||||
except LanceDBClientError as err:
|
except LanceDBClientError as err:
|
||||||
if str(err).startswith("Not found"):
|
if str(err).startswith("Not found"):
|
||||||
logging.error(
|
logging.error(
|
||||||
@@ -248,14 +241,13 @@ class RemoteDBConnection(DBConnection):
|
|||||||
data = to_ipc_binary(data)
|
data = to_ipc_binary(data)
|
||||||
request_id = uuid.uuid4().hex
|
request_id = uuid.uuid4().hex
|
||||||
|
|
||||||
self._loop.run_until_complete(
|
self._client.post(
|
||||||
self._client.post(
|
f"/v1/table/{name}/create/",
|
||||||
f"/v1/table/{name}/create/",
|
data=data,
|
||||||
data=data,
|
request_id=request_id,
|
||||||
request_id=request_id,
|
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return RemoteTable(self, name)
|
return RemoteTable(self, name)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@@ -267,13 +259,11 @@ class RemoteDBConnection(DBConnection):
|
|||||||
name: str
|
name: str
|
||||||
The name of the table.
|
The name of the table.
|
||||||
"""
|
"""
|
||||||
self._loop.run_until_complete(
|
|
||||||
self._client.post(
|
self._client.post(
|
||||||
f"/v1/table/{name}/drop/",
|
f"/v1/table/{name}/drop/",
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""Close the connection to the database."""
|
"""Close the connection to the database."""
|
||||||
self._loop.close()
|
self._client.close()
|
||||||
await self._client.close()
|
|
||||||
|
|||||||
@@ -43,18 +43,14 @@ class RemoteTable(Table):
|
|||||||
of this Table
|
of this Table
|
||||||
|
|
||||||
"""
|
"""
|
||||||
resp = self._conn._loop.run_until_complete(
|
resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
|
||||||
self._conn._client.post(f"/v1/table/{self._name}/describe/")
|
|
||||||
)
|
|
||||||
schema = json_to_schema(resp["schema"])
|
schema = json_to_schema(resp["schema"])
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def version(self) -> int:
|
def version(self) -> int:
|
||||||
"""Get the current version of the table"""
|
"""Get the current version of the table"""
|
||||||
resp = self._conn._loop.run_until_complete(
|
resp = self._conn._client.post(f"/v1/table/{self._name}/describe/")
|
||||||
self._conn._client.post(f"/v1/table/{self._name}/describe/")
|
|
||||||
)
|
|
||||||
return resp["version"]
|
return resp["version"]
|
||||||
|
|
||||||
def to_arrow(self) -> pa.Table:
|
def to_arrow(self) -> pa.Table:
|
||||||
@@ -116,9 +112,10 @@ class RemoteTable(Table):
|
|||||||
"metric_type": metric,
|
"metric_type": metric,
|
||||||
"index_cache_size": index_cache_size,
|
"index_cache_size": index_cache_size,
|
||||||
}
|
}
|
||||||
resp = self._conn._loop.run_until_complete(
|
resp = self._conn._client.post(
|
||||||
self._conn._client.post(f"/v1/table/{self._name}/create_index/", data=data)
|
f"/v1/table/{self._name}/create_index/", data=data
|
||||||
)
|
)
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
@@ -161,13 +158,11 @@ class RemoteTable(Table):
|
|||||||
|
|
||||||
request_id = uuid.uuid4().hex
|
request_id = uuid.uuid4().hex
|
||||||
|
|
||||||
self._conn._loop.run_until_complete(
|
self._conn._client.post(
|
||||||
self._conn._client.post(
|
f"/v1/table/{self._name}/insert/",
|
||||||
f"/v1/table/{self._name}/insert/",
|
data=payload,
|
||||||
data=payload,
|
params={"request_id": request_id, "mode": mode},
|
||||||
params={"request_id": request_id, "mode": mode},
|
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
@@ -233,19 +228,19 @@ class RemoteTable(Table):
|
|||||||
and len(query.vector) > 0
|
and len(query.vector) > 0
|
||||||
and not isinstance(query.vector[0], float)
|
and not isinstance(query.vector[0], float)
|
||||||
):
|
):
|
||||||
futures = []
|
results = []
|
||||||
for v in query.vector:
|
for v in query.vector:
|
||||||
v = list(v)
|
v = list(v)
|
||||||
q = query.copy()
|
q = query.copy()
|
||||||
q.vector = v
|
q.vector = v
|
||||||
futures.append(self._conn._client.query(self._name, q))
|
results.append(self._conn._client.query(self._name, q))
|
||||||
result = self._conn._loop.run_until_complete(asyncio.gather(*futures))
|
|
||||||
return pa.concat_tables(
|
return pa.concat_tables(
|
||||||
[add_index(r.to_arrow(), i) for i, r in enumerate(result)]
|
[add_index(r.to_arrow(), i) for i, r in enumerate(results)]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = self._conn._client.query(self._name, query)
|
result = self._conn._client.query(self._name, query)
|
||||||
return self._conn._loop.run_until_complete(result).to_arrow()
|
return result.to_arrow()
|
||||||
|
|
||||||
def delete(self, predicate: str):
|
def delete(self, predicate: str):
|
||||||
"""Delete rows from the table.
|
"""Delete rows from the table.
|
||||||
@@ -294,9 +289,7 @@ class RemoteTable(Table):
|
|||||||
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
|
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
|
||||||
"""
|
"""
|
||||||
payload = {"predicate": predicate}
|
payload = {"predicate": predicate}
|
||||||
self._conn._loop.run_until_complete(
|
self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload)
|
||||||
self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload)
|
|
||||||
)
|
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
@@ -356,9 +349,7 @@ class RemoteTable(Table):
|
|||||||
updates = [[k, v] for k, v in values_sql.items()]
|
updates = [[k, v] for k, v in values_sql.items()]
|
||||||
|
|
||||||
payload = {"predicate": where, "updates": updates}
|
payload = {"predicate": where, "updates": updates}
|
||||||
self._conn._loop.run_until_complete(
|
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
|
||||||
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def add_index(tbl: pa.Table, i: int) -> pa.Table:
|
def add_index(tbl: pa.Table, i: int) -> pa.Table:
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ dependencies = [
|
|||||||
"ratelimiter~=1.0",
|
"ratelimiter~=1.0",
|
||||||
"retry>=0.9.2",
|
"retry>=0.9.2",
|
||||||
"tqdm>=4.27.0",
|
"tqdm>=4.27.0",
|
||||||
"aiohttp",
|
|
||||||
"pydantic>=1.10",
|
"pydantic>=1.10",
|
||||||
"attrs>=21.3.0",
|
"attrs>=21.3.0",
|
||||||
"semver>=3.0",
|
"semver>=3.0",
|
||||||
@@ -46,7 +45,7 @@ classifiers = [
|
|||||||
repository = "https://github.com/lancedb/lancedb"
|
repository = "https://github.com/lancedb/lancedb"
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests", "duckdb", "pytz"]
|
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz"]
|
||||||
dev = ["ruff", "pre-commit", "black"]
|
dev = ["ruff", "pre-commit", "black"]
|
||||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||||
clip = ["torch", "pillow", "open-clip"]
|
clip = ["torch", "pillow", "open-clip"]
|
||||||
|
|||||||
@@ -18,15 +18,15 @@ from lancedb.remote.client import VectorQuery, VectorQueryResult
|
|||||||
|
|
||||||
|
|
||||||
class FakeLanceDBClient:
|
class FakeLanceDBClient:
|
||||||
async def close(self):
|
def close(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||||
assert table_name == "test"
|
assert table_name == "test"
|
||||||
t = pa.schema([]).empty_table()
|
t = pa.schema([]).empty_table()
|
||||||
return VectorQueryResult(t)
|
return VectorQueryResult(t)
|
||||||
|
|
||||||
async def post(self, path: str):
|
def post(self, path: str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user