mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-16 16:52:57 +00:00
port remote connection client into lancedb (#194)
* to_df() is now async, added `to_df_blocking` to convenience * add remote lancedb client to public lancedb * make lancedb connection class understand url scheme `lancedb+<connection_type>://<host>:<port>`.
This commit is contained in:
61
python/lancedb/remote/__init__.py
Normal file
61
python/lancedb/remote/__init__.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from typing import List, Optional
|
||||
import attr
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
__all__ = ["LanceDBClient", "VectorQuery", "VectorQueryResult"]
|
||||
|
||||
|
||||
class VectorQuery(BaseModel):
|
||||
# vector to search for
|
||||
vector: List[float]
|
||||
|
||||
# sql filter to refine the query with
|
||||
filter: Optional[str] = None
|
||||
|
||||
# top k results to return
|
||||
k: int
|
||||
|
||||
# # metrics
|
||||
_metric: str = "L2"
|
||||
|
||||
# which columns to return in the results
|
||||
columns: Optional[List[str]] = None
|
||||
|
||||
# optional query parameters for tuning the results,
|
||||
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
|
||||
nprobes: int = 10
|
||||
|
||||
refine_factor: Optional[int] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class VectorQueryResult:
|
||||
# for now the response is directly seralized into a pandas dataframe
|
||||
tbl: pa.Table
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
return self.tbl
|
||||
|
||||
|
||||
class LanceDBClient(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
"""Query the LanceDB server for the given table and query."""
|
||||
pass
|
||||
79
python/lancedb/remote/client.py
Normal file
79
python/lancedb/remote/client.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import functools
|
||||
|
||||
import aiohttp
|
||||
import attr
|
||||
import pyarrow as pa
|
||||
import urllib.parse
|
||||
|
||||
from lancedb.remote import VectorQuery, VectorQueryResult
|
||||
from lancedb.remote.errors import LanceDBClientError
|
||||
|
||||
|
||||
def _check_not_closed(f):
|
||||
@functools.wraps(f)
|
||||
def wrapped(self, *args, **kwargs):
|
||||
if self.closed:
|
||||
raise ValueError("Connection is closed")
|
||||
return f(self, *args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
@attr.define(slots=False)
|
||||
class RestfulLanceDBClient:
|
||||
url: str
|
||||
closed: bool = attr.field(default=False, init=False)
|
||||
|
||||
@functools.cached_property
|
||||
def session(self) -> aiohttp.ClientSession:
|
||||
parsed = urllib.parse.urlparse(self.url)
|
||||
scheme = parsed.scheme
|
||||
if not scheme.startswith("lancedb"):
|
||||
raise ValueError(
|
||||
f"Invalid scheme: {scheme}, must be like lancedb+<flavor>://"
|
||||
)
|
||||
flavor = scheme.split("+")[1]
|
||||
url = f"{flavor}://{parsed.hostname}:{parsed.port}"
|
||||
return aiohttp.ClientSession(url)
|
||||
|
||||
async def close(self):
|
||||
await self.session.close()
|
||||
self.closed = True
|
||||
|
||||
@_check_not_closed
|
||||
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
async with self.session.post(
|
||||
f"/table/{table_name}/", json=query.dict(exclude_none=True)
|
||||
) as resp:
|
||||
resp: aiohttp.ClientResponse = resp
|
||||
if 400 <= resp.status < 500:
|
||||
raise LanceDBClientError(
|
||||
f"Bad Request: {resp.status}, error: {await resp.text()}"
|
||||
)
|
||||
if 500 <= resp.status < 600:
|
||||
raise LanceDBClientError(
|
||||
f"Internal Server Error: {resp.status}, error: {await resp.text()}"
|
||||
)
|
||||
if resp.status != 200:
|
||||
raise LanceDBClientError(
|
||||
f"Unknown Error: {resp.status}, error: {await resp.text()}"
|
||||
)
|
||||
|
||||
resp_body = await resp.read()
|
||||
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
|
||||
tbl = reader.read_all()
|
||||
return VectorQueryResult(tbl)
|
||||
16
python/lancedb/remote/errors.py
Normal file
16
python/lancedb/remote/errors.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class LanceDBClientError(RuntimeError):
|
||||
pass
|
||||
Reference in New Issue
Block a user