[Python] Initial support of cloud API (#260)

Support connect with remote database, and implement Search API
This commit is contained in:
Lei Xu
2023-07-07 15:41:15 -07:00
committed by GitHub
parent a5eb665b7d
commit e6c6da6104
11 changed files with 558 additions and 156 deletions

View File

@@ -10,14 +10,16 @@ pip install lancedb
::: lancedb.connect ::: lancedb.connect
::: lancedb.LanceDBConnection ::: lancedb.db.DBConnection
## Table ## Table
::: lancedb.table.LanceTable ::: lancedb.table.Table
## Querying ## Querying
::: lancedb.query.Query
::: lancedb.query.LanceQueryBuilder ::: lancedb.query.LanceQueryBuilder
::: lancedb.query.LanceFtsQueryBuilder ::: lancedb.query.LanceFtsQueryBuilder

View File

@@ -11,16 +11,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .db import URI, LanceDBConnection from typing import Optional
from .db import URI, DBConnection, LanceDBConnection
from .remote.db import RemoteDBConnection
def connect(uri: URI) -> LanceDBConnection: def connect(
"""Connect to a LanceDB instance at the given URI uri: URI, *, api_key: Optional[str] = None, region: str = "us-west-2"
) -> DBConnection:
"""Connect to a LanceDB database.
Parameters Parameters
---------- ----------
uri: str or Path uri: str or Path
The uri of the database. The uri of the database.
api_token: str, optional
If presented, connect to LanceDB cloud.
Otherwise, connect to a database on file system or cloud storage.
Examples Examples
-------- --------
@@ -34,9 +42,17 @@ def connect(uri: URI) -> LanceDBConnection:
>>> db = lancedb.connect("s3://my-bucket/lancedb") >>> db = lancedb.connect("s3://my-bucket/lancedb")
Connect to LancdDB cloud:
>>> db = lancedb.connect("db://my_database", api_key="ldb_...")
Returns Returns
------- -------
conn : LanceDBConnection conn : DBConnection
A connection to a LanceDB database. A connection to a LanceDB database.
""" """
if isinstance(uri, str) and uri.startswith("db://"):
if api_key is None:
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
return RemoteDBConnection(uri, api_key, region)
return LanceDBConnection(uri) return LanceDBConnection(uri)

View File

@@ -23,3 +23,13 @@ URI = Union[str, Path]
# TODO support generator # TODO support generator
DATA = Union[List[dict], dict, pd.DataFrame] DATA = Union[List[dict], dict, pd.DataFrame]
VECTOR_COLUMN_NAME = "vector" VECTOR_COLUMN_NAME = "vector"
class Credential(str):
"""Credential field"""
def __repr__(self) -> str:
return "********"
def __str__(self) -> str:
return "********"

View File

@@ -15,17 +15,161 @@ from __future__ import annotations
import functools import functools
import os import os
from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
import pyarrow as pa import pyarrow as pa
from pyarrow import fs from pyarrow import fs
from .common import DATA, URI from .common import DATA, URI
from .table import LanceTable from .table import LanceTable, Table
from .util import get_uri_location, get_uri_scheme from .util import get_uri_location, get_uri_scheme
class LanceDBConnection: class DBConnection(ABC):
"""An active LanceDB connection interface."""
@abstractmethod
def table_names(self) -> list[str]:
"""List all table names in the database."""
pass
@abstractmethod
def create_table(
self,
name: str,
data: DATA = None,
schema: pa.Schema = None,
mode: str = "create",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
) -> Table:
"""Create a [Table][lancedb.table.Table] in the database.
Parameters
----------
name: str
The name of the table.
data: list, tuple, dict, pd.DataFrame; optional
The data to insert into the table.
schema: pyarrow.Schema; optional
The schema of the table.
mode: str; default "create"
The mode to use when creating the table. Can be either "create" or "overwrite".
By default, if the table already exists, an exception is raised.
If you want to overwrite the table, use mode="overwrite".
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill".
fill_value: float
The value to use when filling vectors. Only used if on_bad_vectors="fill".
Note
----
The vector index won't be created by default.
To create the index, call the `create_index` method on the table.
Returns
-------
LanceTable
A reference to the newly created table.
Examples
--------
Can create with list of tuples or dictionaries:
>>> import lancedb
>>> db = lancedb.connect("./.lancedb")
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
>>> db.create_table("my_table", data)
LanceTable(my_table)
>>> db["my_table"].head()
pyarrow.Table
vector: fixed_size_list<item: float>[2]
child 0, item: float
lat: double
long: double
----
vector: [[[1.1,1.2],[0.2,1.8]]]
lat: [[45.5,40.1]]
long: [[-122.7,-74.1]]
You can also pass a pandas DataFrame:
>>> import pandas as pd
>>> data = pd.DataFrame({
... "vector": [[1.1, 1.2], [0.2, 1.8]],
... "lat": [45.5, 40.1],
... "long": [-122.7, -74.1]
... })
>>> db.create_table("table2", data)
LanceTable(table2)
>>> db["table2"].head()
pyarrow.Table
vector: fixed_size_list<item: float>[2]
child 0, item: float
lat: double
long: double
----
vector: [[[1.1,1.2],[0.2,1.8]]]
lat: [[45.5,40.1]]
long: [[-122.7,-74.1]]
Data is converted to Arrow before being written to disk. For maximum
control over how data is saved, either provide the PyArrow schema to
convert to or else provide a PyArrow table directly.
>>> custom_schema = pa.schema([
... pa.field("vector", pa.list_(pa.float32(), 2)),
... pa.field("lat", pa.float32()),
... pa.field("long", pa.float32())
... ])
>>> db.create_table("table3", data, schema = custom_schema)
LanceTable(table3)
>>> db["table3"].head()
pyarrow.Table
vector: fixed_size_list<item: float>[2]
child 0, item: float
lat: float
long: float
----
vector: [[[1.1,1.2],[0.2,1.8]]]
lat: [[45.5,40.1]]
long: [[-122.7,-74.1]]
"""
raise NotImplementedError
def __getitem__(self, name: str) -> LanceTable:
return self.open_table(name)
def open_table(self, name: str) -> Table:
"""Open a Lance Table in the database.
Parameters
----------
name: str
The name of the table.
Returns
-------
A LanceTable object representing the table.
"""
raise NotImplementedError
def drop_table(self, name: str):
"""Drop a table from the database.
Parameters
----------
name: str
The name of the table.
"""
raise NotImplementedError
class LanceDBConnection(DBConnection):
""" """
A connection to a LanceDB database. A connection to a LanceDB database.
@@ -59,13 +203,6 @@ class LanceDBConnection:
if not isinstance(uri, Path): if not isinstance(uri, Path):
scheme = get_uri_scheme(uri) scheme = get_uri_scheme(uri)
is_local = isinstance(uri, Path) or scheme == "file" is_local = isinstance(uri, Path) or scheme == "file"
# managed lancedb remote uses schema like lancedb+[http|grpc|...]://
self._is_managed_remote = not is_local and scheme.startswith("lancedb")
if self._is_managed_remote:
if len(scheme.split("+")) != 2:
raise ValueError(
f"Invalid LanceDB URI: {uri}, expected uri to have scheme like lancedb+<flavor>://..."
)
if is_local: if is_local:
if isinstance(uri, str): if isinstance(uri, str):
uri = Path(uri) uri = Path(uri)
@@ -79,43 +216,6 @@ class LanceDBConnection:
def uri(self) -> str: def uri(self) -> str:
return self._uri return self._uri
@functools.cached_property
def is_managed_remote(self) -> bool:
return self._is_managed_remote
@functools.cached_property
def remote_flavor(self) -> str:
if not self.is_managed_remote:
raise ValueError(
"Not a managed remote LanceDB, there should be no server flavor"
)
return get_uri_scheme(self.uri).split("+")[1]
@functools.cached_property
def _client(self) -> "lancedb.remote.LanceDBClient":
if not self.is_managed_remote:
raise ValueError("Not a managed remote LanceDB, there should be no client")
# don't import unless we are really using remote
from lancedb.remote.client import RestfulLanceDBClient
if self.remote_flavor == "http":
return RestfulLanceDBClient(self._uri)
raise ValueError("Unsupported remote flavor: " + self.remote_flavor)
async def close(self):
if self._entered:
raise ValueError("Cannot re-enter the same LanceDBConnection twice")
self._entered = True
await self._client.close()
async def __aenter__(self) -> LanceDBConnection:
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.close()
def table_names(self) -> list[str]: def table_names(self) -> list[str]:
"""Get the names of all tables in the database. """Get the names of all tables in the database.
@@ -149,9 +249,6 @@ class LanceDBConnection:
def __contains__(self, name: str) -> bool: def __contains__(self, name: str) -> bool:
return name in self.table_names() return name in self.table_names()
def __getitem__(self, name: str) -> LanceTable:
return self.open_table(name)
def create_table( def create_table(
self, self,
name: str, name: str,

View File

@@ -10,18 +10,47 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import annotations from __future__ import annotations
import asyncio from typing import List, Literal, Optional, Union
from typing import Literal, Union
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import pyarrow as pa import pyarrow as pa
from pydantic import BaseModel
from .common import VECTOR_COLUMN_NAME from .common import VECTOR_COLUMN_NAME
class Query(BaseModel):
"""A Query"""
vector_column: str = VECTOR_COLUMN_NAME
# vector to search for
vector: List[float]
# sql filter to refine the query with
filter: Optional[str] = None
# top k results to return
k: int
# # metrics
metric: str = "L2"
# which columns to return in the results
columns: Optional[List[str]] = None
# optional query parameters for tuning the results,
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
nprobes: int = 10
# Refine factor.
refine_factor: Optional[int] = None
class LanceQueryBuilder: class LanceQueryBuilder:
""" """
A builder for nearest neighbor queries for LanceDB. A builder for nearest neighbor queries for LanceDB.
@@ -47,9 +76,9 @@ class LanceQueryBuilder:
def __init__( def __init__(
self, self,
table: "lancedb.table.LanceTable", table: "lancedb.table.Table",
query: Union[np.ndarray, str], query: Union[np.ndarray, str],
vector_column_name: str = VECTOR_COLUMN_NAME, vector_column: str = VECTOR_COLUMN_NAME,
): ):
self._metric = "L2" self._metric = "L2"
self._nprobes = 20 self._nprobes = 20
@@ -59,7 +88,7 @@ class LanceQueryBuilder:
self._limit = 10 self._limit = 10
self._columns = None self._columns = None
self._where = None self._where = None
self._vector_column_name = vector_column_name self._vector_column = vector_column
def limit(self, limit: int) -> LanceQueryBuilder: def limit(self, limit: int) -> LanceQueryBuilder:
"""Set the maximum number of results to return. """Set the maximum number of results to return.
@@ -181,52 +210,28 @@ class LanceQueryBuilder:
def to_arrow(self) -> pa.Table: def to_arrow(self) -> pa.Table:
""" """
Execute the query and return the results as a arrow Table. Execute the query and return the results as an
[Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table).
In addition to the selected columns, LanceDB also returns a vector In addition to the selected columns, LanceDB also returns a vector
and also the "score" column which is the distance between the query and also the "score" column which is the distance between the query
vector and the returned vector. vector and the returned vectors.
""" """
if self._table._conn.is_managed_remote: vector = self._query if isinstance(self._query, list) else self._query.tolist()
try: query = Query(
loop = asyncio.get_running_loop() vector=vector,
except RuntimeError:
loop = asyncio.get_event_loop()
result = self._table._conn._client.query(
self._table.name, self.to_remote_query()
)
return loop.run_until_complete(result).to_arrow()
ds = self._table.to_lance()
return ds.to_table(
columns=self._columns,
filter=self._where,
nearest={
"column": self._vector_column_name,
"q": self._query,
"k": self._limit,
"metric": self._metric,
"nprobes": self._nprobes,
"refine_factor": self._refine_factor,
},
)
def to_remote_query(self) -> "VectorQuery":
# don't import unless we are connecting to remote
from lancedb.remote.client import VectorQuery
return VectorQuery(
vector=self._query.tolist(),
filter=self._where, filter=self._where,
k=self._limit, k=self._limit,
_metric=self._metric, metric=self._metric,
columns=self._columns, columns=self._columns,
nprobes=self._nprobes, nprobes=self._nprobes,
refine_factor=self._refine_factor, refine_factor=self._refine_factor,
) )
return self._table._execute_query(query)
class LanceFtsQueryBuilder(LanceQueryBuilder): class LanceFtsQueryBuilder(LanceQueryBuilder):
def to_df(self) -> pd.DataFrame: def to_arrow(self) -> pd.Table:
try: try:
import tantivy import tantivy
except ImportError: except ImportError:
@@ -243,8 +248,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
# get the scores and doc ids # get the scores and doc ids
row_ids, scores = search_index(index, self._query, self._limit) row_ids, scores = search_index(index, self._query, self._limit)
if len(row_ids) == 0: if len(row_ids) == 0:
return pd.DataFrame() empty_schema = pa.schema([pa.field("score", pa.float32())])
return pa.Table.from_pylist([], schema=empty_schema)
scores = pa.array(scores) scores = pa.array(scores)
output_tbl = self._table.to_lance().take(row_ids, columns=self._columns) output_tbl = self._table.to_lance().take(row_ids, columns=self._columns)
output_tbl = output_tbl.append_column("score", scores) output_tbl = output_tbl.append_column("score", scores)
return output_tbl.to_pandas() return output_tbl

View File

@@ -13,12 +13,13 @@
import functools import functools
import urllib.parse from typing import Dict
import aiohttp import aiohttp
import attr import attr
import pyarrow as pa import pyarrow as pa
from lancedb.common import Credential
from lancedb.remote import VectorQuery, VectorQueryResult from lancedb.remote import VectorQuery, VectorQueryResult
from lancedb.remote.errors import LanceDBClientError from lancedb.remote.errors import LanceDBClientError
@@ -35,29 +36,32 @@ def _check_not_closed(f):
@attr.define(slots=False) @attr.define(slots=False)
class RestfulLanceDBClient: class RestfulLanceDBClient:
url: str db_name: str
region: str
api_key: Credential
closed: bool = attr.field(default=False, init=False) closed: bool = attr.field(default=False, init=False)
@functools.cached_property @functools.cached_property
def session(self) -> aiohttp.ClientSession: def session(self) -> aiohttp.ClientSession:
parsed = urllib.parse.urlparse(self.url) url = f"https://{self.db_name}.{self.region}.api.lancedb.com"
scheme = parsed.scheme
if not scheme.startswith("lancedb"):
raise ValueError(
f"Invalid scheme: {scheme}, must be like lancedb+<flavor>://"
)
flavor = scheme.split("+")[1]
url = f"{flavor}://{parsed.hostname}:{parsed.port}"
return aiohttp.ClientSession(url) return aiohttp.ClientSession(url)
async def close(self): async def close(self):
await self.session.close() await self.session.close()
self.closed = True self.closed = True
@functools.cached_property
def headers(self) -> Dict[str, str]:
return {
"x-api-key": self.api_key,
}
@_check_not_closed @_check_not_closed
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
async with self.session.post( async with self.session.post(
f"/table/{table_name}/", json=query.dict(exclude_none=True) f"/1/table/{table_name}/",
json=query.dict(exclude_none=True),
headers=self.headers,
) as resp: ) as resp:
resp: aiohttp.ClientResponse = resp resp: aiohttp.ClientResponse = resp
if 400 <= resp.status < 500: if 400 <= resp.status < 500:

View File

@@ -0,0 +1,71 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from urllib.parse import urlparse
import pyarrow as pa
from lancedb.common import DATA
from lancedb.db import DBConnection
from lancedb.table import Table
from .client import RestfulLanceDBClient
class RemoteDBConnection(DBConnection):
"""A connection to a remote LanceDB database."""
def __init__(self, db_url: str, api_key: str, region: str):
"""Connect to a remote LanceDB database."""
parsed = urlparse(db_url)
if parsed.scheme != "db":
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
self.db_name = parsed.netloc
self.api_key = api_key
self._client = RestfulLanceDBClient(self.db_name, region, api_key)
def __repr__(self) -> str:
return f"RemoveConnect(name={self.db_name})"
def table_names(self) -> List[str]:
raise NotImplementedError
def open_table(self, name: str) -> Table:
"""Open a Lance Table in the database.
Parameters
----------
name: str
The name of the table.
Returns
-------
A LanceTable object representing the table.
"""
from .table import RemoteTable
# TODO: check if table exists
return RemoteTable(self, name)
def create_table(
self,
name: str,
data: DATA = None,
schema: pa.Schema = None,
mode: str = "create",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
) -> Table:
raise NotImplementedError

View File

@@ -0,0 +1,70 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from typing import Union
import pyarrow as pa
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
from ..query import LanceQueryBuilder, Query
from ..table import Query, Table
from .db import RemoteDBConnection
class RemoteTable(Table):
def __init__(self, conn: RemoteDBConnection, name: str):
self._conn = conn
self._name = name
def __repr__(self) -> str:
return f"RemoteTable({self._conn.db_name}.{self.name})"
def schema(self) -> pa.Schema:
raise NotImplementedError
def to_arrow(self) -> pa.Table:
raise NotImplementedError
def create_index(
self,
metric="L2",
num_partitions=256,
num_sub_vectors=96,
vector_column_name: str = VECTOR_COLUMN_NAME,
replace: bool = True,
):
raise NotImplementedError
def add(
self,
data: DATA,
mode: str = "append",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
) -> int:
raise NotImplementedError
def search(
self, query: Union[VEC, str], vector_column: str = VECTOR_COLUMN_NAME
) -> LanceQueryBuilder:
return LanceQueryBuilder(self, query, vector_column)
def _execute_query(self, query: Query) -> pa.Table:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.get_event_loop()
result = self._conn._client.query(self._name, query)
return loop.run_until_complete(result).to_arrow()

View File

@@ -14,6 +14,7 @@
from __future__ import annotations from __future__ import annotations
import os import os
from abc import ABC, abstractmethod
from functools import cached_property from functools import cached_property
from typing import List, Union from typing import List, Union
@@ -27,7 +28,7 @@ from lance import LanceDataset
from lance.vector import vec_to_table from lance.vector import vec_to_table
from .common import DATA, VEC, VECTOR_COLUMN_NAME from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .query import LanceFtsQueryBuilder, LanceQueryBuilder from .query import LanceFtsQueryBuilder, LanceQueryBuilder, Query
def _sanitize_data(data, schema, on_bad_vectors, fill_value): def _sanitize_data(data, schema, on_bad_vectors, fill_value):
@@ -48,14 +49,14 @@ def _sanitize_data(data, schema, on_bad_vectors, fill_value):
return data return data
class LanceTable: class Table(ABC):
""" """
A table in a LanceDB database. A [Table](Table) is a collection of Records in a LanceDB [Database](Database).
Examples Examples
-------- --------
Create using [LanceDBConnection.create_table][lancedb.LanceDBConnection.create_table] Create using [DBConnection.create_table][lancedb.DBConnection.create_table]
(more examples in that method's documentation). (more examples in that method's documentation).
>>> import lancedb >>> import lancedb
@@ -70,12 +71,12 @@ class LanceTable:
vector: [[[1.1,1.2]]] vector: [[[1.1,1.2]]]
b: [[2]] b: [[2]]
Can append new data with [LanceTable.add][lancedb.table.LanceTable.add]. Can append new data with [Table.add()][lancedb.table.Table.add].
>>> table.add([{"vector": [0.5, 1.3], "b": 4}]) >>> table.add([{"vector": [0.5, 1.3], "b": 4}])
2 2
Can query the table with [LanceTable.search][lancedb.table.LanceTable.search]. Can query the table with [Table.search][lancedb.table.Table.search].
>>> table.search([0.4, 0.4]).select(["b"]).to_df() >>> table.search([0.4, 0.4]).select(["b"]).to_df()
b vector score b vector score
@@ -83,8 +84,128 @@ class LanceTable:
1 2 [1.1, 1.2] 1.13 1 2 [1.1, 1.2] 1.13
Search queries are much faster when an index is created. See Search queries are much faster when an index is created. See
[LanceTable.create_index][lancedb.table.LanceTable.create_index]. [Table.create_index][lancedb.table.Table.create_index].
"""
@abstractmethod
def schema(self) -> pa.Schema:
"""Return the [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#) of
this [Table](Table)
"""
raise NotImplementedError
def to_pandas(self) -> pd.DataFrame:
"""Return the table as a pandas DataFrame.
Returns
-------
pd.DataFrame
"""
return self.to_arrow().to_pandas()
@abstractmethod
def to_arrow(self) -> pa.Table:
"""Return the table as a pyarrow Table.
Returns
-------
pa.Table
"""
raise NotImplementedError
def create_index(
self,
metric="L2",
num_partitions=256,
num_sub_vectors=96,
vector_column_name: str = VECTOR_COLUMN_NAME,
replace: bool = True,
):
"""Create an index on the table.
Parameters
----------
metric: str, default "L2"
The distance metric to use when creating the index.
Valid values are "L2", "cosine", or "dot".
L2 is euclidean distance.
num_partitions: int
The number of IVF partitions to use when creating the index.
Default is 256.
num_sub_vectors: int
The number of PQ sub-vectors to use when creating the index.
Default is 96.
vector_column_name: str, default "vector"
The vector column name to create the index.
replace: bool, default True
If True, replace the existing index if it exists.
If False, raise an error if duplicate index exists.
"""
raise NotImplementedError
@abstractmethod
def add(
self,
data: DATA,
mode: str = "append",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
) -> int:
"""Add more data to the [Table](Table).
Parameters
----------
data: list-of-dict, dict, pd.DataFrame
The data to insert into the table.
mode: str
The mode to use when writing the data. Valid values are
"append" and "overwrite".
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
Returns
-------
int
The number of vectors in the table.
"""
raise NotImplementedError
@abstractmethod
def search(
self, query: Union[VEC, str], vector_column: str = VECTOR_COLUMN_NAME
) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector.
Parameters
----------
query: list, np.ndarray
The query vector.
vector_column: str, default "vector"
The name of the vector column to search.
Returns
-------
LanceQueryBuilder
A query builder object representing the query.
Once executed, the query returns selected columns, the vector,
and also the "score" column which is the distance between the query
vector and the returned vector.
"""
raise NotImplementedError
@abstractmethod
def _execute_query(self, query: Query) -> pa.Table:
pass
class LanceTable(Table):
"""
A table in a LanceDB database.
""" """
def __init__( def __init__(
@@ -197,26 +318,7 @@ class LanceTable:
vector_column_name=VECTOR_COLUMN_NAME, vector_column_name=VECTOR_COLUMN_NAME,
replace: bool = True, replace: bool = True,
): ):
"""Create an index on the table. """Create an index on the table."""
Parameters
----------
metric: str, default "L2"
The distance metric to use when creating the index.
Valid values are "L2", "cosine", or "dot".
L2 is euclidean distance.
num_partitions: int
The number of IVF partitions to use when creating the index.
Default is 256.
num_sub_vectors: int
The number of PQ sub-vectors to use when creating the index.
Default is 96.
vector_column_name: str, default "vector"
The vector column name to create the index.
replace: bool, default True
If True, replace the existing index if it exists.
If False, raise an error if duplicate index exists.
"""
self._dataset.create_index( self._dataset.create_index(
column=vector_column_name, column=vector_column_name,
index_type="IVF_PQ", index_type="IVF_PQ",
@@ -387,9 +489,6 @@ class LanceTable:
@classmethod @classmethod
def open(cls, db, name): def open(cls, db, name):
tbl = cls(db, name) tbl = cls(db, name)
if tbl._conn.is_managed_remote:
# Not completely sure how to check for remote table existence yet.
return tbl
if not os.path.exists(tbl._dataset_uri): if not os.path.exists(tbl._dataset_uri):
raise FileNotFoundError( raise FileNotFoundError(
f"Table {name} does not exist. Please first call db.create_table({name}, data)" f"Table {name} does not exist. Please first call db.create_table({name}, data)"
@@ -424,6 +523,21 @@ class LanceTable:
""" """
self._dataset.delete(where) self._dataset.delete(where)
def _execute_query(self, query: Query) -> pa.Table:
ds = self.to_lance()
return ds.to_table(
columns=query.columns,
filter=query.filter,
nearest={
"column": query.vector_column,
"q": query.vector,
"k": query.k,
"metric": query.metric,
"nprobes": query.nprobes,
"refine_factor": query.refine_factor,
},
)
def _sanitize_schema( def _sanitize_schema(
data: pa.Table, data: pa.Table,
@@ -510,7 +624,7 @@ def _sanitize_vector_column(
data.column_names.index(vector_column_name), vector_column_name, vec_arr data.column_names.index(vector_column_name), vector_column_name, vec_arr
) )
has_nans = pc.any(vec_arr.values.is_nan()).as_py() has_nans = pc.any(pc.is_nan(vec_arr.values)).as_py()
if has_nans: if has_nans:
data = _sanitize_nans( data = _sanitize_nans(
data, fill_value, on_bad_vectors, vec_arr, vector_column_name data, fill_value, on_bad_vectors, vec_arr, vector_column_name
@@ -573,7 +687,7 @@ def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name
"`fill_value` must not be None if `on_bad_vectors` is 'fill'" "`fill_value` must not be None if `on_bad_vectors` is 'fill'"
) )
fill_value = float(fill_value) fill_value = float(fill_value)
values = pc.if_else(vec_arr.values.is_nan(), fill_value, vec_arr.values) values = pc.if_else(pc.is_nan(vec_arr.values), fill_value, vec_arr.values)
ndims = len(vec_arr[0]) ndims = len(vec_arr[0])
vec_arr = pa.FixedSizeListArray.from_arrays(values, ndims) vec_arr = pa.FixedSizeListArray.from_arrays(values, ndims)
data = data.set_column( data = data.set_column(

View File

@@ -20,18 +20,33 @@ import pyarrow as pa
import pytest import pytest
from lancedb.db import LanceDBConnection from lancedb.db import LanceDBConnection
from lancedb.query import LanceQueryBuilder from lancedb.query import LanceQueryBuilder, Query
from lancedb.table import LanceTable from lancedb.table import LanceTable
class MockTable: class MockTable:
def __init__(self, tmp_path): def __init__(self, tmp_path):
self.uri = tmp_path self.uri = tmp_path
self._conn = LanceDBConnection("/tmp/lance/") self._conn = LanceDBConnection(self.uri)
def to_lance(self): def to_lance(self):
return lance.dataset(self.uri) return lance.dataset(self.uri)
def _execute_query(self, query):
ds = self.to_lance()
return ds.to_table(
columns=query.columns,
filter=query.filter,
nearest={
"column": query.vector_column,
"q": query.vector,
"k": query.k,
"metric": query.metric,
"nprobes": query.nprobes,
"refine_factor": query.refine_factor,
},
)
@pytest.fixture @pytest.fixture
def table(tmp_path) -> MockTable: def table(tmp_path) -> MockTable:
@@ -94,20 +109,17 @@ def test_query_builder_with_different_vector_column():
) )
ds = mock.Mock() ds = mock.Mock()
table.to_lance.return_value = ds table.to_lance.return_value = ds
table._conn = mock.MagicMock()
table._conn.is_managed_remote = False
builder.to_arrow() builder.to_arrow()
ds.to_table.assert_called_once_with( table._execute_query.assert_called_once_with(
columns=["b"], Query(
filter="b < 10", vector=query,
nearest={ filter="b < 10",
"column": vector_column_name, k=2,
"q": query, metric="cosine",
"k": 2, columns=["b"],
"metric": "cosine", nprobes=20,
"nprobes": 20, refine_factor=None,
"refine_factor": None, )
},
) )

View File

@@ -13,7 +13,7 @@
import pyarrow as pa import pyarrow as pa
from lancedb.db import LanceDBConnection import lancedb
from lancedb.remote.client import VectorQuery, VectorQueryResult from lancedb.remote.client import VectorQuery, VectorQueryResult
@@ -28,7 +28,7 @@ class FakeLanceDBClient:
def test_remote_db(): def test_remote_db():
conn = LanceDBConnection("lancedb+http://client-will-be-injected") conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
setattr(conn, "_client", FakeLanceDBClient()) setattr(conn, "_client", FakeLanceDBClient())
table = conn["test"] table = conn["test"]