mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-09 13:22:58 +00:00
port remote connection client into lancedb (#194)
* to_df() is now async, added `to_df_blocking` to convenience * add remote lancedb client to public lancedb * make lancedb connection class understand url scheme `lancedb+<connection_type>://<host>:<port>`.
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
@@ -56,7 +57,16 @@ class LanceDBConnection:
|
||||
"""
|
||||
|
||||
def __init__(self, uri: URI):
|
||||
is_local = isinstance(uri, Path) or get_uri_scheme(uri) == "file"
|
||||
if not isinstance(uri, Path):
|
||||
scheme = get_uri_scheme(uri)
|
||||
is_local = isinstance(uri, Path) or scheme == "file"
|
||||
# managed lancedb remote uses schema like lancedb+[http|grpc|...]://
|
||||
self._is_managed_remote = not is_local and scheme.startswith("lancedb")
|
||||
if self._is_managed_remote:
|
||||
if len(scheme.split("+")) != 2:
|
||||
raise ValueError(
|
||||
f"Invalid LanceDB URI: {uri}, expected uri to have scheme like lancedb+<flavor>://..."
|
||||
)
|
||||
if is_local:
|
||||
if isinstance(uri, str):
|
||||
uri = Path(uri)
|
||||
@@ -64,10 +74,49 @@ class LanceDBConnection:
|
||||
Path(uri).mkdir(parents=True, exist_ok=True)
|
||||
self._uri = str(uri)
|
||||
|
||||
self._entered = False
|
||||
|
||||
@property
|
||||
def uri(self) -> str:
|
||||
return self._uri
|
||||
|
||||
@functools.cached_property
|
||||
def is_managed_remote(self) -> bool:
|
||||
return self._is_managed_remote
|
||||
|
||||
@functools.cached_property
|
||||
def remote_flavor(self) -> str:
|
||||
if not self.is_managed_remote:
|
||||
raise ValueError(
|
||||
"Not a managed remote LanceDB, there should be no server flavor"
|
||||
)
|
||||
return get_uri_scheme(self.uri).split("+")[1]
|
||||
|
||||
@functools.cached_property
|
||||
def _client(self) -> "lancedb.remote.LanceDBClient":
|
||||
if not self.is_managed_remote:
|
||||
raise ValueError("Not a managed remote LanceDB, there should be no client")
|
||||
|
||||
# don't import unless we are really using remote
|
||||
from lancedb.remote.client import RestfulLanceDBClient
|
||||
|
||||
if self.remote_flavor == "http":
|
||||
return RestfulLanceDBClient(self._uri)
|
||||
|
||||
raise ValueError("Unsupported remote flavor: " + self.remote_flavor)
|
||||
|
||||
async def close(self):
|
||||
if self._entered:
|
||||
raise ValueError("Cannot re-enter the same LanceDBConnection twice")
|
||||
self._entered = True
|
||||
await self._client.close()
|
||||
|
||||
async def __aenter__(self) -> LanceDBConnection:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
await self.close()
|
||||
|
||||
def table_names(self) -> list[str]:
|
||||
"""Get the names of all tables in the database.
|
||||
|
||||
|
||||
@@ -11,11 +11,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
from typing import Literal
|
||||
from typing import Awaitable, Literal
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import asyncio
|
||||
|
||||
from .common import VECTOR_COLUMN_NAME
|
||||
|
||||
@@ -168,8 +169,28 @@ class LanceQueryBuilder:
|
||||
and also the "score" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
|
||||
return self.to_arrow().to_pandas()
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
"""
|
||||
Execute the query and return the results as a arrow Table.
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the "score" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
if self._table._conn.is_managed_remote:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.get_event_loop()
|
||||
result = self._table._conn._client.query(
|
||||
self._table.name, self.to_remote_query()
|
||||
)
|
||||
return loop.run_until_complete(result).to_arrow()
|
||||
|
||||
ds = self._table.to_lance()
|
||||
tbl = ds.to_table(
|
||||
return ds.to_table(
|
||||
columns=self._columns,
|
||||
filter=self._where,
|
||||
nearest={
|
||||
@@ -181,7 +202,20 @@ class LanceQueryBuilder:
|
||||
"refine_factor": self._refine_factor,
|
||||
},
|
||||
)
|
||||
return tbl.to_pandas()
|
||||
|
||||
def to_remote_query(self) -> "VectorQuery":
|
||||
# don't import unless we are connecting to remote
|
||||
from lancedb.remote.client import VectorQuery
|
||||
|
||||
return VectorQuery(
|
||||
vector=self._query.tolist(),
|
||||
filter=self._where,
|
||||
k=self._limit,
|
||||
_metric=self._metric,
|
||||
columns=self._columns,
|
||||
nprobes=self._nprobes,
|
||||
refine_factor=self._refine_factor,
|
||||
)
|
||||
|
||||
|
||||
class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
61
python/lancedb/remote/__init__.py
Normal file
61
python/lancedb/remote/__init__.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from typing import List, Optional
|
||||
import attr
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
__all__ = ["LanceDBClient", "VectorQuery", "VectorQueryResult"]
|
||||
|
||||
|
||||
class VectorQuery(BaseModel):
|
||||
# vector to search for
|
||||
vector: List[float]
|
||||
|
||||
# sql filter to refine the query with
|
||||
filter: Optional[str] = None
|
||||
|
||||
# top k results to return
|
||||
k: int
|
||||
|
||||
# # metrics
|
||||
_metric: str = "L2"
|
||||
|
||||
# which columns to return in the results
|
||||
columns: Optional[List[str]] = None
|
||||
|
||||
# optional query parameters for tuning the results,
|
||||
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
|
||||
nprobes: int = 10
|
||||
|
||||
refine_factor: Optional[int] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class VectorQueryResult:
|
||||
# for now the response is directly seralized into a pandas dataframe
|
||||
tbl: pa.Table
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
return self.tbl
|
||||
|
||||
|
||||
class LanceDBClient(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
"""Query the LanceDB server for the given table and query."""
|
||||
pass
|
||||
79
python/lancedb/remote/client.py
Normal file
79
python/lancedb/remote/client.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import functools
|
||||
|
||||
import aiohttp
|
||||
import attr
|
||||
import pyarrow as pa
|
||||
import urllib.parse
|
||||
|
||||
from lancedb.remote import VectorQuery, VectorQueryResult
|
||||
from lancedb.remote.errors import LanceDBClientError
|
||||
|
||||
|
||||
def _check_not_closed(f):
|
||||
@functools.wraps(f)
|
||||
def wrapped(self, *args, **kwargs):
|
||||
if self.closed:
|
||||
raise ValueError("Connection is closed")
|
||||
return f(self, *args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
@attr.define(slots=False)
|
||||
class RestfulLanceDBClient:
|
||||
url: str
|
||||
closed: bool = attr.field(default=False, init=False)
|
||||
|
||||
@functools.cached_property
|
||||
def session(self) -> aiohttp.ClientSession:
|
||||
parsed = urllib.parse.urlparse(self.url)
|
||||
scheme = parsed.scheme
|
||||
if not scheme.startswith("lancedb"):
|
||||
raise ValueError(
|
||||
f"Invalid scheme: {scheme}, must be like lancedb+<flavor>://"
|
||||
)
|
||||
flavor = scheme.split("+")[1]
|
||||
url = f"{flavor}://{parsed.hostname}:{parsed.port}"
|
||||
return aiohttp.ClientSession(url)
|
||||
|
||||
async def close(self):
|
||||
await self.session.close()
|
||||
self.closed = True
|
||||
|
||||
@_check_not_closed
|
||||
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
async with self.session.post(
|
||||
f"/table/{table_name}/", json=query.dict(exclude_none=True)
|
||||
) as resp:
|
||||
resp: aiohttp.ClientResponse = resp
|
||||
if 400 <= resp.status < 500:
|
||||
raise LanceDBClientError(
|
||||
f"Bad Request: {resp.status}, error: {await resp.text()}"
|
||||
)
|
||||
if 500 <= resp.status < 600:
|
||||
raise LanceDBClientError(
|
||||
f"Internal Server Error: {resp.status}, error: {await resp.text()}"
|
||||
)
|
||||
if resp.status != 200:
|
||||
raise LanceDBClientError(
|
||||
f"Unknown Error: {resp.status}, error: {await resp.text()}"
|
||||
)
|
||||
|
||||
resp_body = await resp.read()
|
||||
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
|
||||
tbl = reader.read_all()
|
||||
return VectorQueryResult(tbl)
|
||||
16
python/lancedb/remote/errors.py
Normal file
16
python/lancedb/remote/errors.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class LanceDBClientError(RuntimeError):
|
||||
pass
|
||||
@@ -14,7 +14,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from functools import cached_property
|
||||
from typing import List, Union
|
||||
|
||||
@@ -27,7 +26,6 @@ from lance.vector import vec_to_table
|
||||
|
||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from .query import LanceFtsQueryBuilder, LanceQueryBuilder
|
||||
from .util import get_uri_scheme
|
||||
|
||||
|
||||
def _sanitize_data(data, schema):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[project]
|
||||
name = "lancedb"
|
||||
version = "0.1.8"
|
||||
dependencies = ["pylance>=0.4.20", "ratelimiter", "retry", "tqdm"]
|
||||
dependencies = ["pylance>=0.4.20", "ratelimiter", "retry", "tqdm", "aiohttp", "pydantic", "attr"]
|
||||
description = "lancedb"
|
||||
authors = [
|
||||
{ name = "LanceDB Devs", email = "dev@lancedb.com" },
|
||||
@@ -37,7 +37,7 @@ repository = "https://github.com/lancedb/lancedb"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tests = [
|
||||
"pytest", "pytest-mock", "doctest"
|
||||
"pytest", "pytest-mock", "doctest", "pytest-asyncio"
|
||||
]
|
||||
dev = [
|
||||
"ruff", "pre-commit", "black"
|
||||
|
||||
28
python/tests/test_e2e_remote_db.py
Normal file
28
python/tests/test_e2e_remote_db.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lancedb import LanceDBConnection
|
||||
|
||||
import numpy as np
|
||||
|
||||
import pytest
|
||||
|
||||
# TODO: setup integ test mark and script
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Need to set up a local server")
|
||||
def test_against_local_server():
|
||||
conn = LanceDBConnection("lancedb+http://localhost:10024")
|
||||
table = conn.open_table("sift1m_ivf1024_pq16")
|
||||
df = table.search(np.random.rand(128)).to_df()
|
||||
assert len(df) == 10
|
||||
@@ -19,10 +19,13 @@ import pyarrow as pa
|
||||
import pytest
|
||||
from lancedb.query import LanceQueryBuilder
|
||||
|
||||
from lancedb.db import LanceDBConnection
|
||||
|
||||
|
||||
class MockTable:
|
||||
def __init__(self, tmp_path):
|
||||
self.uri = tmp_path
|
||||
self._conn = LanceDBConnection("/tmp/lance/")
|
||||
|
||||
def to_lance(self):
|
||||
return lance.dataset(self.uri)
|
||||
|
||||
95
python/tests/test_remote_client.py
Normal file
95
python/tests/test_remote_client.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import attr
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from lancedb.remote.client import RestfulLanceDBClient, VectorQuery
|
||||
|
||||
|
||||
@attr.define
|
||||
class MockLanceDBServer:
|
||||
runner: web.AppRunner = attr.field(init=False)
|
||||
site: web.TCPSite = attr.field(init=False)
|
||||
|
||||
async def query_handler(self, request: web.Request) -> web.Response:
|
||||
table_name = request.match_info["table_name"]
|
||||
assert table_name == "test_table"
|
||||
|
||||
request_json = await request.json()
|
||||
# TODO: do some matching
|
||||
|
||||
vecs = pd.Series([np.random.rand(128) for x in range(10)], name="vector")
|
||||
ids = pd.Series(range(10), name="id")
|
||||
df = pd.DataFrame([vecs, ids]).T
|
||||
|
||||
batch = pa.RecordBatch.from_pandas(
|
||||
df,
|
||||
schema=pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 128)),
|
||||
pa.field("id", pa.int64()),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
sink = pa.BufferOutputStream()
|
||||
with pa.ipc.new_file(sink, batch.schema) as writer:
|
||||
writer.write_batch(batch)
|
||||
|
||||
return web.Response(body=sink.getvalue().to_pybytes())
|
||||
|
||||
async def setup(self):
|
||||
app = web.Application()
|
||||
app.add_routes([web.post("/table/{table_name}", self.query_handler)])
|
||||
self.runner = web.AppRunner(app)
|
||||
await self.runner.setup()
|
||||
self.site = web.TCPSite(self.runner, "localhost", 8111)
|
||||
|
||||
async def start(self):
|
||||
await self.site.start()
|
||||
|
||||
async def stop(self):
|
||||
await self.runner.cleanup()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="flaky somehow, fix later")
|
||||
@pytest.mark.asyncio
|
||||
async def test_e2e_with_mock_server():
|
||||
mock_server = MockLanceDBServer()
|
||||
await mock_server.setup()
|
||||
await mock_server.start()
|
||||
|
||||
try:
|
||||
client = RestfulLanceDBClient("lancedb+http://localhost:8111")
|
||||
df = (
|
||||
await client.query(
|
||||
"test_table",
|
||||
VectorQuery(
|
||||
vector=np.random.rand(128).tolist(),
|
||||
k=10,
|
||||
_metric="L2",
|
||||
columns=["id", "vector"],
|
||||
),
|
||||
)
|
||||
).to_df()
|
||||
|
||||
assert "vector" in df.columns
|
||||
assert "id" in df.columns
|
||||
finally:
|
||||
# make sure we don't leak resources
|
||||
await mock_server.stop()
|
||||
36
python/tests/test_remote_db.py
Normal file
36
python/tests/test_remote_db.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
from lancedb.remote.client import VectorQuery, VectorQueryResult
|
||||
from lancedb.db import LanceDBConnection
|
||||
|
||||
|
||||
class FakeLanceDBClient:
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
assert table_name == "test"
|
||||
t = pa.schema([]).empty_table()
|
||||
return VectorQueryResult(t)
|
||||
|
||||
|
||||
def test_remote_db():
|
||||
conn = LanceDBConnection("lancedb+http://client-will-be-injected")
|
||||
setattr(conn, "_client", FakeLanceDBClient())
|
||||
|
||||
table = conn["test"]
|
||||
table.search([1.0, 2.0]).to_df()
|
||||
@@ -11,6 +11,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
@@ -23,6 +24,10 @@ class MockDB:
|
||||
def __init__(self, uri: Path):
|
||||
self.uri = uri
|
||||
|
||||
@functools.cached_property
|
||||
def is_managed_remote(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db(tmp_path) -> MockDB:
|
||||
|
||||
Reference in New Issue
Block a user