From cbb56e25ab7c221f44847aff6842f50c80530ebd Mon Sep 17 00:00:00 2001 From: Rob Meng Date: Thu, 15 Jun 2023 18:57:52 -0400 Subject: [PATCH] port remote connection client into lancedb (#194) * to_df() is now async, added `to_df_blocking` to convenience * add remote lancedb client to public lancedb * make lancedb connection class understand url scheme `lancedb+://:`. --- python/lancedb/db.py | 51 +++++++++++++++- python/lancedb/query.py | 40 ++++++++++++- python/lancedb/remote/__init__.py | 61 +++++++++++++++++++ python/lancedb/remote/client.py | 79 +++++++++++++++++++++++++ python/lancedb/remote/errors.py | 16 +++++ python/lancedb/table.py | 2 - python/pyproject.toml | 4 +- python/tests/test_e2e_remote_db.py | 28 +++++++++ python/tests/test_query.py | 3 + python/tests/test_remote_client.py | 95 ++++++++++++++++++++++++++++++ python/tests/test_remote_db.py | 36 +++++++++++ python/tests/test_table.py | 5 ++ 12 files changed, 412 insertions(+), 8 deletions(-) create mode 100644 python/lancedb/remote/__init__.py create mode 100644 python/lancedb/remote/client.py create mode 100644 python/lancedb/remote/errors.py create mode 100644 python/tests/test_e2e_remote_db.py create mode 100644 python/tests/test_remote_client.py create mode 100644 python/tests/test_remote_db.py diff --git a/python/lancedb/db.py b/python/lancedb/db.py index f48f0130..ea1e9abf 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -12,6 +12,7 @@ # limitations under the License. from __future__ import annotations +import functools import os from pathlib import Path @@ -56,7 +57,16 @@ class LanceDBConnection: """ def __init__(self, uri: URI): - is_local = isinstance(uri, Path) or get_uri_scheme(uri) == "file" + if not isinstance(uri, Path): + scheme = get_uri_scheme(uri) + is_local = isinstance(uri, Path) or scheme == "file" + # managed lancedb remote uses schema like lancedb+[http|grpc|...]:// + self._is_managed_remote = not is_local and scheme.startswith("lancedb") + if self._is_managed_remote: + if len(scheme.split("+")) != 2: + raise ValueError( + f"Invalid LanceDB URI: {uri}, expected uri to have scheme like lancedb+://..." + ) if is_local: if isinstance(uri, str): uri = Path(uri) @@ -64,10 +74,49 @@ class LanceDBConnection: Path(uri).mkdir(parents=True, exist_ok=True) self._uri = str(uri) + self._entered = False + @property def uri(self) -> str: return self._uri + @functools.cached_property + def is_managed_remote(self) -> bool: + return self._is_managed_remote + + @functools.cached_property + def remote_flavor(self) -> str: + if not self.is_managed_remote: + raise ValueError( + "Not a managed remote LanceDB, there should be no server flavor" + ) + return get_uri_scheme(self.uri).split("+")[1] + + @functools.cached_property + def _client(self) -> "lancedb.remote.LanceDBClient": + if not self.is_managed_remote: + raise ValueError("Not a managed remote LanceDB, there should be no client") + + # don't import unless we are really using remote + from lancedb.remote.client import RestfulLanceDBClient + + if self.remote_flavor == "http": + return RestfulLanceDBClient(self._uri) + + raise ValueError("Unsupported remote flavor: " + self.remote_flavor) + + async def close(self): + if self._entered: + raise ValueError("Cannot re-enter the same LanceDBConnection twice") + self._entered = True + await self._client.close() + + async def __aenter__(self) -> LanceDBConnection: + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() + def table_names(self) -> list[str]: """Get the names of all tables in the database. diff --git a/python/lancedb/query.py b/python/lancedb/query.py index 2ad4b85b..10dc6bc1 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -11,11 +11,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations -from typing import Literal +from typing import Awaitable, Literal import numpy as np import pandas as pd import pyarrow as pa +import asyncio from .common import VECTOR_COLUMN_NAME @@ -168,8 +169,28 @@ class LanceQueryBuilder: and also the "score" column which is the distance between the query vector and the returned vector. """ + + return self.to_arrow().to_pandas() + + def to_arrow(self) -> pa.Table: + """ + Execute the query and return the results as a arrow Table. + In addition to the selected columns, LanceDB also returns a vector + and also the "score" column which is the distance between the query + vector and the returned vector. + """ + if self._table._conn.is_managed_remote: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.get_event_loop() + result = self._table._conn._client.query( + self._table.name, self.to_remote_query() + ) + return loop.run_until_complete(result).to_arrow() + ds = self._table.to_lance() - tbl = ds.to_table( + return ds.to_table( columns=self._columns, filter=self._where, nearest={ @@ -181,7 +202,20 @@ class LanceQueryBuilder: "refine_factor": self._refine_factor, }, ) - return tbl.to_pandas() + + def to_remote_query(self) -> "VectorQuery": + # don't import unless we are connecting to remote + from lancedb.remote.client import VectorQuery + + return VectorQuery( + vector=self._query.tolist(), + filter=self._where, + k=self._limit, + _metric=self._metric, + columns=self._columns, + nprobes=self._nprobes, + refine_factor=self._refine_factor, + ) class LanceFtsQueryBuilder(LanceQueryBuilder): diff --git a/python/lancedb/remote/__init__.py b/python/lancedb/remote/__init__.py new file mode 100644 index 00000000..8932c666 --- /dev/null +++ b/python/lancedb/remote/__init__.py @@ -0,0 +1,61 @@ +# Copyright 2023 LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import List, Optional +import attr +import pandas as pd +import pyarrow as pa + +from pydantic import BaseModel + +__all__ = ["LanceDBClient", "VectorQuery", "VectorQueryResult"] + + +class VectorQuery(BaseModel): + # vector to search for + vector: List[float] + + # sql filter to refine the query with + filter: Optional[str] = None + + # top k results to return + k: int + + # # metrics + _metric: str = "L2" + + # which columns to return in the results + columns: Optional[List[str]] = None + + # optional query parameters for tuning the results, + # e.g. `{"nprobes": "10", "refine_factor": "10"}` + nprobes: int = 10 + + refine_factor: Optional[int] = None + + +@attr.define +class VectorQueryResult: + # for now the response is directly seralized into a pandas dataframe + tbl: pa.Table + + def to_arrow(self) -> pa.Table: + return self.tbl + + +class LanceDBClient(abc.ABC): + @abc.abstractmethod + def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: + """Query the LanceDB server for the given table and query.""" + pass diff --git a/python/lancedb/remote/client.py b/python/lancedb/remote/client.py new file mode 100644 index 00000000..3970712d --- /dev/null +++ b/python/lancedb/remote/client.py @@ -0,0 +1,79 @@ +# Copyright 2023 LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import functools + +import aiohttp +import attr +import pyarrow as pa +import urllib.parse + +from lancedb.remote import VectorQuery, VectorQueryResult +from lancedb.remote.errors import LanceDBClientError + + +def _check_not_closed(f): + @functools.wraps(f) + def wrapped(self, *args, **kwargs): + if self.closed: + raise ValueError("Connection is closed") + return f(self, *args, **kwargs) + + return wrapped + + +@attr.define(slots=False) +class RestfulLanceDBClient: + url: str + closed: bool = attr.field(default=False, init=False) + + @functools.cached_property + def session(self) -> aiohttp.ClientSession: + parsed = urllib.parse.urlparse(self.url) + scheme = parsed.scheme + if not scheme.startswith("lancedb"): + raise ValueError( + f"Invalid scheme: {scheme}, must be like lancedb+://" + ) + flavor = scheme.split("+")[1] + url = f"{flavor}://{parsed.hostname}:{parsed.port}" + return aiohttp.ClientSession(url) + + async def close(self): + await self.session.close() + self.closed = True + + @_check_not_closed + async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: + async with self.session.post( + f"/table/{table_name}/", json=query.dict(exclude_none=True) + ) as resp: + resp: aiohttp.ClientResponse = resp + if 400 <= resp.status < 500: + raise LanceDBClientError( + f"Bad Request: {resp.status}, error: {await resp.text()}" + ) + if 500 <= resp.status < 600: + raise LanceDBClientError( + f"Internal Server Error: {resp.status}, error: {await resp.text()}" + ) + if resp.status != 200: + raise LanceDBClientError( + f"Unknown Error: {resp.status}, error: {await resp.text()}" + ) + + resp_body = await resp.read() + with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader: + tbl = reader.read_all() + return VectorQueryResult(tbl) diff --git a/python/lancedb/remote/errors.py b/python/lancedb/remote/errors.py new file mode 100644 index 00000000..a4d290dc --- /dev/null +++ b/python/lancedb/remote/errors.py @@ -0,0 +1,16 @@ +# Copyright 2023 LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class LanceDBClientError(RuntimeError): + pass diff --git a/python/lancedb/table.py b/python/lancedb/table.py index af794470..fb87eb44 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -14,7 +14,6 @@ from __future__ import annotations import os -import shutil from functools import cached_property from typing import List, Union @@ -27,7 +26,6 @@ from lance.vector import vec_to_table from .common import DATA, VEC, VECTOR_COLUMN_NAME from .query import LanceFtsQueryBuilder, LanceQueryBuilder -from .util import get_uri_scheme def _sanitize_data(data, schema): diff --git a/python/pyproject.toml b/python/pyproject.toml index a0a09cde..ba3a0e95 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "lancedb" version = "0.1.8" -dependencies = ["pylance>=0.4.20", "ratelimiter", "retry", "tqdm"] +dependencies = ["pylance>=0.4.20", "ratelimiter", "retry", "tqdm", "aiohttp", "pydantic", "attr"] description = "lancedb" authors = [ { name = "LanceDB Devs", email = "dev@lancedb.com" }, @@ -37,7 +37,7 @@ repository = "https://github.com/lancedb/lancedb" [project.optional-dependencies] tests = [ - "pytest", "pytest-mock", "doctest" + "pytest", "pytest-mock", "doctest", "pytest-asyncio" ] dev = [ "ruff", "pre-commit", "black" diff --git a/python/tests/test_e2e_remote_db.py b/python/tests/test_e2e_remote_db.py new file mode 100644 index 00000000..3f5344c7 --- /dev/null +++ b/python/tests/test_e2e_remote_db.py @@ -0,0 +1,28 @@ +# Copyright 2023 LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lancedb import LanceDBConnection + +import numpy as np + +import pytest + +# TODO: setup integ test mark and script + + +@pytest.mark.skip(reason="Need to set up a local server") +def test_against_local_server(): + conn = LanceDBConnection("lancedb+http://localhost:10024") + table = conn.open_table("sift1m_ivf1024_pq16") + df = table.search(np.random.rand(128)).to_df() + assert len(df) == 10 diff --git a/python/tests/test_query.py b/python/tests/test_query.py index 675c3772..1c50ca63 100644 --- a/python/tests/test_query.py +++ b/python/tests/test_query.py @@ -19,10 +19,13 @@ import pyarrow as pa import pytest from lancedb.query import LanceQueryBuilder +from lancedb.db import LanceDBConnection + class MockTable: def __init__(self, tmp_path): self.uri = tmp_path + self._conn = LanceDBConnection("/tmp/lance/") def to_lance(self): return lance.dataset(self.uri) diff --git a/python/tests/test_remote_client.py b/python/tests/test_remote_client.py new file mode 100644 index 00000000..e9fd309a --- /dev/null +++ b/python/tests/test_remote_client.py @@ -0,0 +1,95 @@ +# Copyright 2023 LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import attr +import numpy as np +import pandas as pd +import pyarrow as pa +import pytest +from aiohttp import web + +from lancedb.remote.client import RestfulLanceDBClient, VectorQuery + + +@attr.define +class MockLanceDBServer: + runner: web.AppRunner = attr.field(init=False) + site: web.TCPSite = attr.field(init=False) + + async def query_handler(self, request: web.Request) -> web.Response: + table_name = request.match_info["table_name"] + assert table_name == "test_table" + + request_json = await request.json() + # TODO: do some matching + + vecs = pd.Series([np.random.rand(128) for x in range(10)], name="vector") + ids = pd.Series(range(10), name="id") + df = pd.DataFrame([vecs, ids]).T + + batch = pa.RecordBatch.from_pandas( + df, + schema=pa.schema( + [ + pa.field("vector", pa.list_(pa.float32(), 128)), + pa.field("id", pa.int64()), + ] + ), + ) + + sink = pa.BufferOutputStream() + with pa.ipc.new_file(sink, batch.schema) as writer: + writer.write_batch(batch) + + return web.Response(body=sink.getvalue().to_pybytes()) + + async def setup(self): + app = web.Application() + app.add_routes([web.post("/table/{table_name}", self.query_handler)]) + self.runner = web.AppRunner(app) + await self.runner.setup() + self.site = web.TCPSite(self.runner, "localhost", 8111) + + async def start(self): + await self.site.start() + + async def stop(self): + await self.runner.cleanup() + + +@pytest.mark.skip(reason="flaky somehow, fix later") +@pytest.mark.asyncio +async def test_e2e_with_mock_server(): + mock_server = MockLanceDBServer() + await mock_server.setup() + await mock_server.start() + + try: + client = RestfulLanceDBClient("lancedb+http://localhost:8111") + df = ( + await client.query( + "test_table", + VectorQuery( + vector=np.random.rand(128).tolist(), + k=10, + _metric="L2", + columns=["id", "vector"], + ), + ) + ).to_df() + + assert "vector" in df.columns + assert "id" in df.columns + finally: + # make sure we don't leak resources + await mock_server.stop() diff --git a/python/tests/test_remote_db.py b/python/tests/test_remote_db.py new file mode 100644 index 00000000..b0e3d0c9 --- /dev/null +++ b/python/tests/test_remote_db.py @@ -0,0 +1,36 @@ +# Copyright 2023 LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pyarrow as pa + + +from lancedb.remote.client import VectorQuery, VectorQueryResult +from lancedb.db import LanceDBConnection + + +class FakeLanceDBClient: + async def close(self): + pass + + async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: + assert table_name == "test" + t = pa.schema([]).empty_table() + return VectorQueryResult(t) + + +def test_remote_db(): + conn = LanceDBConnection("lancedb+http://client-will-be-injected") + setattr(conn, "_client", FakeLanceDBClient()) + + table = conn["test"] + table.search([1.0, 2.0]).to_df() diff --git a/python/tests/test_table.py b/python/tests/test_table.py index 00eb8cc7..311e9b2c 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools from pathlib import Path import pandas as pd @@ -23,6 +24,10 @@ class MockDB: def __init__(self, uri: Path): self.uri = uri + @functools.cached_property + def is_managed_remote(self) -> bool: + return False + @pytest.fixture def db(tmp_path) -> MockDB: