diff --git a/python/lancedb/remote/client.py b/python/lancedb/remote/client.py index 30b68e93..6f9bf292 100644 --- a/python/lancedb/remote/client.py +++ b/python/lancedb/remote/client.py @@ -13,9 +13,10 @@ import functools -from typing import Any, Callable, Dict, Iterable, Optional, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Union +from urllib.parse import urljoin -import aiohttp +import requests import attrs import pyarrow as pa from pydantic import BaseModel @@ -37,8 +38,8 @@ def _check_not_closed(f): return wrapped -async def _read_ipc(resp: aiohttp.ClientResponse) -> pa.Table: - resp_body = await resp.read() +def _read_ipc(resp: requests.Response) -> pa.Table: + resp_body = resp.content with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader: return reader.read_all() @@ -53,15 +54,18 @@ class RestfulLanceDBClient: closed: bool = attrs.field(default=False, init=False) @functools.cached_property - def session(self) -> aiohttp.ClientSession: - url = ( + def session(self) -> requests.Session: + return requests.Session() + + @property + def url(self) -> str: + return ( self.host_override or f"https://{self.db_name}.{self.region}.api.lancedb.com" ) - return aiohttp.ClientSession(url) - async def close(self): - await self.session.close() + def close(self): + self.session.close() self.closed = True @functools.cached_property @@ -76,38 +80,38 @@ class RestfulLanceDBClient: return headers @staticmethod - async def _check_status(resp: aiohttp.ClientResponse): - if resp.status == 404: - raise LanceDBClientError(f"Not found: {await resp.text()}") - elif 400 <= resp.status < 500: + def _check_status(resp: requests.Response): + if resp.status_code == 404: + raise LanceDBClientError(f"Not found: {resp.text}") + elif 400 <= resp.status_code < 500: raise LanceDBClientError( - f"Bad Request: {resp.status}, error: {await resp.text()}" + f"Bad Request: {resp.status_code}, error: {resp.text}" ) - elif 500 <= resp.status < 600: + elif 500 <= resp.status_code < 600: raise LanceDBClientError( - f"Internal Server Error: {resp.status}, error: {await resp.text()}" + f"Internal Server Error: {resp.status_code}, error: {resp.text}" ) - elif resp.status != 200: + elif resp.status_code != 200: raise LanceDBClientError( - f"Unknown Error: {resp.status}, error: {await resp.text()}" + f"Unknown Error: {resp.status_code}, error: {resp.text}" ) @_check_not_closed - async def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None): + def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None): """Send a GET request and returns the deserialized response payload.""" if isinstance(params, BaseModel): params: Dict[str, Any] = params.dict(exclude_none=True) - async with self.session.get( - uri, + with self.session.get( + urljoin(self.url, uri), params=params, headers=self.headers, - timeout=aiohttp.ClientTimeout(total=30), + timeout=(5.0, 30.0), ) as resp: - await self._check_status(resp) - return await resp.json() + self._check_status(resp) + return resp.json() @_check_not_closed - async def post( + def post( self, uri: str, data: Optional[Union[Dict[str, Any], BaseModel, bytes]] = None, @@ -139,31 +143,26 @@ class RestfulLanceDBClient: headers["content-type"] = content_type if request_id is not None: headers["x-request-id"] = request_id - async with self.session.post( - uri, + with self.session.post( + urljoin(self.url, uri), headers=headers, params=params, - timeout=aiohttp.ClientTimeout(total=30), + timeout=(5.0, 30.0), **req_kwargs, ) as resp: - resp: aiohttp.ClientResponse = resp - await self._check_status(resp) - return await deserialize(resp) + self._check_status(resp) + return deserialize(resp) @_check_not_closed - async def list_tables( - self, limit: int, page_token: Optional[str] = None - ) -> Iterable[str]: + def list_tables(self, limit: int, page_token: Optional[str] = None) -> List[str]: """List all tables in the database.""" if page_token is None: page_token = "" - json = await self.get("/v1/table/", {"limit": limit, "page_token": page_token}) + json = self.get("/v1/table/", {"limit": limit, "page_token": page_token}) return json["tables"] @_check_not_closed - async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: + def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: """Query a table.""" - tbl = await self.post( - f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc - ) + tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc) return VectorQueryResult(tbl) diff --git a/python/lancedb/remote/db.py b/python/lancedb/remote/db.py index 0c8be4ca..337406db 100644 --- a/python/lancedb/remote/db.py +++ b/python/lancedb/remote/db.py @@ -50,10 +50,6 @@ class RemoteDBConnection(DBConnection): self._client = RestfulLanceDBClient( self.db_name, region, api_key, host_override ) - try: - self._loop = asyncio.get_running_loop() - except RuntimeError: - self._loop = asyncio.get_event_loop() def __repr__(self) -> str: return f"RemoteConnect(name={self.db_name})" @@ -76,9 +72,8 @@ class RemoteDBConnection(DBConnection): An iterator of table names. """ while True: - result = self._loop.run_until_complete( - self._client.list_tables(limit, page_token) - ) + result = self._client.list_tables(limit, page_token) + if len(result) > 0: page_token = result[len(result) - 1] else: @@ -103,9 +98,7 @@ class RemoteDBConnection(DBConnection): # check if table exists try: - self._loop.run_until_complete( - self._client.post(f"/v1/table/{name}/describe/") - ) + self._client.post(f"/v1/table/{name}/describe/") except LanceDBClientError as err: if str(err).startswith("Not found"): logging.error( @@ -248,14 +241,13 @@ class RemoteDBConnection(DBConnection): data = to_ipc_binary(data) request_id = uuid.uuid4().hex - self._loop.run_until_complete( - self._client.post( - f"/v1/table/{name}/create/", - data=data, - request_id=request_id, - content_type=ARROW_STREAM_CONTENT_TYPE, - ) + self._client.post( + f"/v1/table/{name}/create/", + data=data, + request_id=request_id, + content_type=ARROW_STREAM_CONTENT_TYPE, ) + return RemoteTable(self, name) @override @@ -267,13 +259,11 @@ class RemoteDBConnection(DBConnection): name: str The name of the table. """ - self._loop.run_until_complete( - self._client.post( - f"/v1/table/{name}/drop/", - ) + + self._client.post( + f"/v1/table/{name}/drop/", ) async def close(self): """Close the connection to the database.""" - self._loop.close() - await self._client.close() + self._client.close() diff --git a/python/lancedb/remote/table.py b/python/lancedb/remote/table.py index e09011a7..63572ebb 100644 --- a/python/lancedb/remote/table.py +++ b/python/lancedb/remote/table.py @@ -43,18 +43,14 @@ class RemoteTable(Table): of this Table """ - resp = self._conn._loop.run_until_complete( - self._conn._client.post(f"/v1/table/{self._name}/describe/") - ) + resp = self._conn._client.post(f"/v1/table/{self._name}/describe/") schema = json_to_schema(resp["schema"]) return schema @property def version(self) -> int: """Get the current version of the table""" - resp = self._conn._loop.run_until_complete( - self._conn._client.post(f"/v1/table/{self._name}/describe/") - ) + resp = self._conn._client.post(f"/v1/table/{self._name}/describe/") return resp["version"] def to_arrow(self) -> pa.Table: @@ -116,9 +112,10 @@ class RemoteTable(Table): "metric_type": metric, "index_cache_size": index_cache_size, } - resp = self._conn._loop.run_until_complete( - self._conn._client.post(f"/v1/table/{self._name}/create_index/", data=data) + resp = self._conn._client.post( + f"/v1/table/{self._name}/create_index/", data=data ) + return resp def add( @@ -161,13 +158,11 @@ class RemoteTable(Table): request_id = uuid.uuid4().hex - self._conn._loop.run_until_complete( - self._conn._client.post( - f"/v1/table/{self._name}/insert/", - data=payload, - params={"request_id": request_id, "mode": mode}, - content_type=ARROW_STREAM_CONTENT_TYPE, - ) + self._conn._client.post( + f"/v1/table/{self._name}/insert/", + data=payload, + params={"request_id": request_id, "mode": mode}, + content_type=ARROW_STREAM_CONTENT_TYPE, ) def search( @@ -233,19 +228,19 @@ class RemoteTable(Table): and len(query.vector) > 0 and not isinstance(query.vector[0], float) ): - futures = [] + results = [] for v in query.vector: v = list(v) q = query.copy() q.vector = v - futures.append(self._conn._client.query(self._name, q)) - result = self._conn._loop.run_until_complete(asyncio.gather(*futures)) + results.append(self._conn._client.query(self._name, q)) + return pa.concat_tables( - [add_index(r.to_arrow(), i) for i, r in enumerate(result)] + [add_index(r.to_arrow(), i) for i, r in enumerate(results)] ) else: result = self._conn._client.query(self._name, query) - return self._conn._loop.run_until_complete(result).to_arrow() + return result.to_arrow() def delete(self, predicate: str): """Delete rows from the table. @@ -294,9 +289,7 @@ class RemoteTable(Table): 0 2 [3.0, 4.0] 85.0 # doctest: +SKIP """ payload = {"predicate": predicate} - self._conn._loop.run_until_complete( - self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload) - ) + self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload) def update( self, @@ -356,9 +349,7 @@ class RemoteTable(Table): updates = [[k, v] for k, v in values_sql.items()] payload = {"predicate": where, "updates": updates} - self._conn._loop.run_until_complete( - self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload) - ) + self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload) def add_index(tbl: pa.Table, i: int) -> pa.Table: diff --git a/python/pyproject.toml b/python/pyproject.toml index 179bfe76..9b90f698 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -7,7 +7,6 @@ dependencies = [ "ratelimiter~=1.0", "retry>=0.9.2", "tqdm>=4.27.0", - "aiohttp", "pydantic>=1.10", "attrs>=21.3.0", "semver>=3.0", @@ -49,7 +48,7 @@ classifiers = [ repository = "https://github.com/lancedb/lancedb" [project.optional-dependencies] -tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests", "duckdb", "pytz"] +tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz"] dev = ["ruff", "pre-commit", "black"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] clip = ["torch", "pillow", "open-clip"] diff --git a/python/tests/test_remote_db.py b/python/tests/test_remote_db.py index 00ee8c43..d4928c6a 100644 --- a/python/tests/test_remote_db.py +++ b/python/tests/test_remote_db.py @@ -18,15 +18,15 @@ from lancedb.remote.client import VectorQuery, VectorQueryResult class FakeLanceDBClient: - async def close(self): + def close(self): pass - async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: + def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: assert table_name == "test" t = pa.schema([]).empty_table() return VectorQueryResult(t) - async def post(self, path: str): + def post(self, path: str): pass