[Python] Initial support of cloud API (#260)

Support connect with remote database, and implement Search API
This commit is contained in:
Lei Xu
2023-07-07 15:41:15 -07:00
committed by GitHub
parent a5eb665b7d
commit e6c6da6104
11 changed files with 558 additions and 156 deletions

View File

@@ -13,12 +13,13 @@
import functools
import urllib.parse
from typing import Dict
import aiohttp
import attr
import pyarrow as pa
from lancedb.common import Credential
from lancedb.remote import VectorQuery, VectorQueryResult
from lancedb.remote.errors import LanceDBClientError
@@ -35,29 +36,32 @@ def _check_not_closed(f):
@attr.define(slots=False)
class RestfulLanceDBClient:
url: str
db_name: str
region: str
api_key: Credential
closed: bool = attr.field(default=False, init=False)
@functools.cached_property
def session(self) -> aiohttp.ClientSession:
parsed = urllib.parse.urlparse(self.url)
scheme = parsed.scheme
if not scheme.startswith("lancedb"):
raise ValueError(
f"Invalid scheme: {scheme}, must be like lancedb+<flavor>://"
)
flavor = scheme.split("+")[1]
url = f"{flavor}://{parsed.hostname}:{parsed.port}"
url = f"https://{self.db_name}.{self.region}.api.lancedb.com"
return aiohttp.ClientSession(url)
async def close(self):
await self.session.close()
self.closed = True
@functools.cached_property
def headers(self) -> Dict[str, str]:
return {
"x-api-key": self.api_key,
}
@_check_not_closed
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
async with self.session.post(
f"/table/{table_name}/", json=query.dict(exclude_none=True)
f"/1/table/{table_name}/",
json=query.dict(exclude_none=True),
headers=self.headers,
) as resp:
resp: aiohttp.ClientResponse = resp
if 400 <= resp.status < 500:

View File

@@ -0,0 +1,71 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from urllib.parse import urlparse
import pyarrow as pa
from lancedb.common import DATA
from lancedb.db import DBConnection
from lancedb.table import Table
from .client import RestfulLanceDBClient
class RemoteDBConnection(DBConnection):
"""A connection to a remote LanceDB database."""
def __init__(self, db_url: str, api_key: str, region: str):
"""Connect to a remote LanceDB database."""
parsed = urlparse(db_url)
if parsed.scheme != "db":
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
self.db_name = parsed.netloc
self.api_key = api_key
self._client = RestfulLanceDBClient(self.db_name, region, api_key)
def __repr__(self) -> str:
return f"RemoveConnect(name={self.db_name})"
def table_names(self) -> List[str]:
raise NotImplementedError
def open_table(self, name: str) -> Table:
"""Open a Lance Table in the database.
Parameters
----------
name: str
The name of the table.
Returns
-------
A LanceTable object representing the table.
"""
from .table import RemoteTable
# TODO: check if table exists
return RemoteTable(self, name)
def create_table(
self,
name: str,
data: DATA = None,
schema: pa.Schema = None,
mode: str = "create",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
) -> Table:
raise NotImplementedError

View File

@@ -0,0 +1,70 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from typing import Union
import pyarrow as pa
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
from ..query import LanceQueryBuilder, Query
from ..table import Query, Table
from .db import RemoteDBConnection
class RemoteTable(Table):
def __init__(self, conn: RemoteDBConnection, name: str):
self._conn = conn
self._name = name
def __repr__(self) -> str:
return f"RemoteTable({self._conn.db_name}.{self.name})"
def schema(self) -> pa.Schema:
raise NotImplementedError
def to_arrow(self) -> pa.Table:
raise NotImplementedError
def create_index(
self,
metric="L2",
num_partitions=256,
num_sub_vectors=96,
vector_column_name: str = VECTOR_COLUMN_NAME,
replace: bool = True,
):
raise NotImplementedError
def add(
self,
data: DATA,
mode: str = "append",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
) -> int:
raise NotImplementedError
def search(
self, query: Union[VEC, str], vector_column: str = VECTOR_COLUMN_NAME
) -> LanceQueryBuilder:
return LanceQueryBuilder(self, query, vector_column)
def _execute_query(self, query: Query) -> pa.Table:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.get_event_loop()
result = self._conn._client.query(self._name, query)
return loop.run_until_complete(result).to_arrow()