diff --git a/python/lancedb/remote/client.py b/python/lancedb/remote/client.py index 30b68e93..4f767d5c 100644 --- a/python/lancedb/remote/client.py +++ b/python/lancedb/remote/client.py @@ -15,7 +15,8 @@ import functools from typing import Any, Callable, Dict, Iterable, Optional, Union -import aiohttp +import requests +import urllib.parse import attrs import pyarrow as pa from pydantic import BaseModel @@ -37,8 +38,8 @@ def _check_not_closed(f): return wrapped -async def _read_ipc(resp: aiohttp.ClientResponse) -> pa.Table: - resp_body = await resp.read() +def _read_ipc(resp: requests.Response) -> pa.Table: + resp_body = resp.raw.read() with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader: return reader.read_all() @@ -53,15 +54,24 @@ class RestfulLanceDBClient: closed: bool = attrs.field(default=False, init=False) @functools.cached_property - def session(self) -> aiohttp.ClientSession: - url = ( + def session(self) -> requests.Session: + session = requests.session() + session.stream = True + + return session + + @functools.cached_property + def url(self) -> str: + return ( self.host_override or f"https://{self.db_name}.{self.region}.api.lancedb.com" ) - return aiohttp.ClientSession(url) - async def close(self): - await self.session.close() + def _get_request_url(self, uri: str) -> str: + return urllib.parse.urljoin(self.url, uri) + + def close(self): + self.session.close() self.closed = True @functools.cached_property @@ -75,39 +85,25 @@ class RestfulLanceDBClient: headers["x-lancedb-database"] = self.db_name return headers - @staticmethod - async def _check_status(resp: aiohttp.ClientResponse): - if resp.status == 404: - raise LanceDBClientError(f"Not found: {await resp.text()}") - elif 400 <= resp.status < 500: - raise LanceDBClientError( - f"Bad Request: {resp.status}, error: {await resp.text()}" - ) - elif 500 <= resp.status < 600: - raise LanceDBClientError( - f"Internal Server Error: {resp.status}, error: {await resp.text()}" - ) - elif resp.status != 200: - raise LanceDBClientError( - f"Unknown Error: {resp.status}, error: {await resp.text()}" - ) - @_check_not_closed - async def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None): + def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None): """Send a GET request and returns the deserialized response payload.""" if isinstance(params, BaseModel): params: Dict[str, Any] = params.dict(exclude_none=True) - async with self.session.get( - uri, + + resp = self.session.get( + self._get_request_url(uri), params=params, headers=self.headers, - timeout=aiohttp.ClientTimeout(total=30), - ) as resp: - await self._check_status(resp) - return await resp.json() + # 5s connect timeout, 30s read timeout + timeout=(5.0, 30.0), + ) + + resp.raise_for_status() + return resp.json() @_check_not_closed - async def post( + def post( self, uri: str, data: Optional[Union[Dict[str, Any], BaseModel, bytes]] = None, @@ -139,31 +135,31 @@ class RestfulLanceDBClient: headers["content-type"] = content_type if request_id is not None: headers["x-request-id"] = request_id - async with self.session.post( - uri, - headers=headers, + + resp = self.session.post( + self._get_request_url(uri), params=params, - timeout=aiohttp.ClientTimeout(total=30), + headers=self.headers, + # 5s connect timeout, 30s read timeout + timeout=(5.0, 30.0), **req_kwargs, - ) as resp: - resp: aiohttp.ClientResponse = resp - await self._check_status(resp) - return await deserialize(resp) + ) + resp.raise_for_status() + + return deserialize(resp) @_check_not_closed - async def list_tables( + def list_tables( self, limit: int, page_token: Optional[str] = None ) -> Iterable[str]: """List all tables in the database.""" if page_token is None: page_token = "" - json = await self.get("/v1/table/", {"limit": limit, "page_token": page_token}) + json = self.get("/v1/table/", {"limit": limit, "page_token": page_token}) return json["tables"] @_check_not_closed - async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: + def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: """Query a table.""" - tbl = await self.post( - f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc - ) + tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc) return VectorQueryResult(tbl) diff --git a/python/lancedb/remote/db.py b/python/lancedb/remote/db.py index 0c8be4ca..7b596943 100644 --- a/python/lancedb/remote/db.py +++ b/python/lancedb/remote/db.py @@ -50,10 +50,6 @@ class RemoteDBConnection(DBConnection): self._client = RestfulLanceDBClient( self.db_name, region, api_key, host_override ) - try: - self._loop = asyncio.get_running_loop() - except RuntimeError: - self._loop = asyncio.get_event_loop() def __repr__(self) -> str: return f"RemoteConnect(name={self.db_name})" @@ -76,15 +72,13 @@ class RemoteDBConnection(DBConnection): An iterator of table names. """ while True: - result = self._loop.run_until_complete( - self._client.list_tables(limit, page_token) - ) - if len(result) > 0: - page_token = result[len(result) - 1] - else: - break + result = self._client.list_tables(limit, page_token) for item in result: yield item + if len(result) < limit: + break + else: + page_token = result[len(result) - 1] @override def open_table(self, name: str) -> Table: @@ -103,9 +97,7 @@ class RemoteDBConnection(DBConnection): # check if table exists try: - self._loop.run_until_complete( - self._client.post(f"/v1/table/{name}/describe/") - ) + self._client.post(f"/v1/table/{name}/describe/") except LanceDBClientError as err: if str(err).startswith("Not found"): logging.error( @@ -248,13 +240,11 @@ class RemoteDBConnection(DBConnection): data = to_ipc_binary(data) request_id = uuid.uuid4().hex - self._loop.run_until_complete( - self._client.post( - f"/v1/table/{name}/create/", - data=data, - request_id=request_id, - content_type=ARROW_STREAM_CONTENT_TYPE, - ) + self._client.post( + f"/v1/table/{name}/create/", + data=data, + request_id=request_id, + content_type=ARROW_STREAM_CONTENT_TYPE, ) return RemoteTable(self, name) @@ -267,13 +257,10 @@ class RemoteDBConnection(DBConnection): name: str The name of the table. """ - self._loop.run_until_complete( - self._client.post( - f"/v1/table/{name}/drop/", - ) + self._client.post( + f"/v1/table/{name}/drop/", ) async def close(self): """Close the connection to the database.""" - self._loop.close() - await self._client.close() + self._client.close() diff --git a/python/lancedb/remote/table.py b/python/lancedb/remote/table.py index e09011a7..26c0bdbe 100644 --- a/python/lancedb/remote/table.py +++ b/python/lancedb/remote/table.py @@ -11,7 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import uuid from functools import cached_property from typing import Dict, Optional, Union @@ -43,18 +42,14 @@ class RemoteTable(Table): of this Table """ - resp = self._conn._loop.run_until_complete( - self._conn._client.post(f"/v1/table/{self._name}/describe/") - ) + resp = self._conn._client.post(f"/v1/table/{self._name}/describe/") schema = json_to_schema(resp["schema"]) return schema @property def version(self) -> int: """Get the current version of the table""" - resp = self._conn._loop.run_until_complete( - self._conn._client.post(f"/v1/table/{self._name}/describe/") - ) + resp = self._conn._client.post(f"/v1/table/{self._name}/describe/") return resp["version"] def to_arrow(self) -> pa.Table: @@ -116,8 +111,8 @@ class RemoteTable(Table): "metric_type": metric, "index_cache_size": index_cache_size, } - resp = self._conn._loop.run_until_complete( - self._conn._client.post(f"/v1/table/{self._name}/create_index/", data=data) + resp = self._conn._client.post( + f"/v1/table/{self._name}/create_index/", data=data ) return resp @@ -161,13 +156,11 @@ class RemoteTable(Table): request_id = uuid.uuid4().hex - self._conn._loop.run_until_complete( - self._conn._client.post( - f"/v1/table/{self._name}/insert/", - data=payload, - params={"request_id": request_id, "mode": mode}, - content_type=ARROW_STREAM_CONTENT_TYPE, - ) + self._conn._client.post( + f"/v1/table/{self._name}/insert/", + data=payload, + params={"request_id": request_id, "mode": mode}, + content_type=ARROW_STREAM_CONTENT_TYPE, ) def search( @@ -233,19 +226,17 @@ class RemoteTable(Table): and len(query.vector) > 0 and not isinstance(query.vector[0], float) ): - futures = [] + result = [] for v in query.vector: v = list(v) q = query.copy() q.vector = v - futures.append(self._conn._client.query(self._name, q)) - result = self._conn._loop.run_until_complete(asyncio.gather(*futures)) + result.append(self._conn._client.query(self._name, q)) return pa.concat_tables( [add_index(r.to_arrow(), i) for i, r in enumerate(result)] ) else: - result = self._conn._client.query(self._name, query) - return self._conn._loop.run_until_complete(result).to_arrow() + return self._conn._client.query(self._name, query).to_arrow() def delete(self, predicate: str): """Delete rows from the table. @@ -294,9 +285,7 @@ class RemoteTable(Table): 0 2 [3.0, 4.0] 85.0 # doctest: +SKIP """ payload = {"predicate": predicate} - self._conn._loop.run_until_complete( - self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload) - ) + self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload) def update( self, @@ -356,9 +345,7 @@ class RemoteTable(Table): updates = [[k, v] for k, v in values_sql.items()] payload = {"predicate": where, "updates": updates} - self._conn._loop.run_until_complete( - self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload) - ) + self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload) def add_index(tbl: pa.Table, i: int) -> pa.Table: diff --git a/python/pyproject.toml b/python/pyproject.toml index 179bfe76..9028633a 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -7,7 +7,7 @@ dependencies = [ "ratelimiter~=1.0", "retry>=0.9.2", "tqdm>=4.27.0", - "aiohttp", + "requests>=2.31,<3", "pydantic>=1.10", "attrs>=21.3.0", "semver>=3.0", diff --git a/python/tests/test_e2e_remote_db.py b/python/tests/test_e2e_remote_db.py deleted file mode 100644 index e9e69c48..00000000 --- a/python/tests/test_e2e_remote_db.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2023 LanceDB Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import pytest - -from lancedb import LanceDBConnection - -# TODO: setup integ test mark and script - - -@pytest.mark.skip(reason="Need to set up a local server") -def test_against_local_server(): - conn = LanceDBConnection("lancedb+http://localhost:10024") - table = conn.open_table("sift1m_ivf1024_pq16") - df = table.search(np.random.rand(128)).to_pandas() - assert len(df) == 10 diff --git a/python/tests/test_remote_client.py b/python/tests/test_remote_client.py deleted file mode 100644 index 73ebf153..00000000 --- a/python/tests/test_remote_client.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2023 LanceDB Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import attrs -import numpy as np -import pandas as pd -import pyarrow as pa -import pytest -from aiohttp import web - -from lancedb.remote.client import RestfulLanceDBClient, VectorQuery - - -@attrs.define -class MockLanceDBServer: - runner: web.AppRunner = attrs.field(init=False) - site: web.TCPSite = attrs.field(init=False) - - async def query_handler(self, request: web.Request) -> web.Response: - table_name = request.match_info["table_name"] - assert table_name == "test_table" - - await request.json() - # TODO: do some matching - - vecs = pd.Series([np.random.rand(128) for x in range(10)], name="vector") - ids = pd.Series(range(10), name="id") - df = pd.DataFrame([vecs, ids]).T - - batch = pa.RecordBatch.from_pandas( - df, - schema=pa.schema( - [ - pa.field("vector", pa.list_(pa.float32(), 128)), - pa.field("id", pa.int64()), - ] - ), - ) - - sink = pa.BufferOutputStream() - with pa.ipc.new_file(sink, batch.schema) as writer: - writer.write_batch(batch) - - return web.Response(body=sink.getvalue().to_pybytes()) - - async def setup(self): - app = web.Application() - app.add_routes([web.post("/table/{table_name}", self.query_handler)]) - self.runner = web.AppRunner(app) - await self.runner.setup() - self.site = web.TCPSite(self.runner, "localhost", 8111) - - async def start(self): - await self.site.start() - - async def stop(self): - await self.runner.cleanup() - - -@pytest.mark.skip(reason="flaky somehow, fix later") -@pytest.mark.asyncio -async def test_e2e_with_mock_server(): - mock_server = MockLanceDBServer() - await mock_server.setup() - await mock_server.start() - - try: - client = RestfulLanceDBClient("lancedb+http://localhost:8111") - df = ( - await client.query( - "test_table", - VectorQuery( - vector=np.random.rand(128).tolist(), - k=10, - _metric="L2", - columns=["id", "vector"], - ), - ) - ).to_pandas() - - assert "vector" in df.columns - assert "id" in df.columns - finally: - # make sure we don't leak resources - await mock_server.stop()