From e6c6da6104c1e85d25b017fecbfeacf283498a6f Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Fri, 7 Jul 2023 15:41:15 -0700 Subject: [PATCH] [Python] Initial support of cloud API (#260) Support connect with remote database, and implement Search API --- docs/src/python/python.md | 6 +- python/lancedb/__init__.py | 24 +++- python/lancedb/common.py | 10 ++ python/lancedb/db.py | 195 ++++++++++++++++++++++++-------- python/lancedb/query.py | 88 +++++++------- python/lancedb/remote/client.py | 26 +++-- python/lancedb/remote/db.py | 71 ++++++++++++ python/lancedb/remote/table.py | 70 ++++++++++++ python/lancedb/table.py | 178 +++++++++++++++++++++++------ python/tests/test_query.py | 42 ++++--- python/tests/test_remote_db.py | 4 +- 11 files changed, 558 insertions(+), 156 deletions(-) create mode 100644 python/lancedb/remote/db.py create mode 100644 python/lancedb/remote/table.py diff --git a/docs/src/python/python.md b/docs/src/python/python.md index 08e228d1..c0a3aa7d 100644 --- a/docs/src/python/python.md +++ b/docs/src/python/python.md @@ -10,14 +10,16 @@ pip install lancedb ::: lancedb.connect -::: lancedb.LanceDBConnection +::: lancedb.db.DBConnection ## Table -::: lancedb.table.LanceTable +::: lancedb.table.Table ## Querying +::: lancedb.query.Query + ::: lancedb.query.LanceQueryBuilder ::: lancedb.query.LanceFtsQueryBuilder diff --git a/python/lancedb/__init__.py b/python/lancedb/__init__.py index a3b8504d..7c5d3590 100644 --- a/python/lancedb/__init__.py +++ b/python/lancedb/__init__.py @@ -11,16 +11,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .db import URI, LanceDBConnection +from typing import Optional + +from .db import URI, DBConnection, LanceDBConnection +from .remote.db import RemoteDBConnection -def connect(uri: URI) -> LanceDBConnection: - """Connect to a LanceDB instance at the given URI +def connect( + uri: URI, *, api_key: Optional[str] = None, region: str = "us-west-2" +) -> DBConnection: + """Connect to a LanceDB database. Parameters ---------- uri: str or Path The uri of the database. + api_token: str, optional + If presented, connect to LanceDB cloud. + Otherwise, connect to a database on file system or cloud storage. Examples -------- @@ -34,9 +42,17 @@ def connect(uri: URI) -> LanceDBConnection: >>> db = lancedb.connect("s3://my-bucket/lancedb") + Connect to LancdDB cloud: + + >>> db = lancedb.connect("db://my_database", api_key="ldb_...") + Returns ------- - conn : LanceDBConnection + conn : DBConnection A connection to a LanceDB database. """ + if isinstance(uri, str) and uri.startswith("db://"): + if api_key is None: + raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}") + return RemoteDBConnection(uri, api_key, region) return LanceDBConnection(uri) diff --git a/python/lancedb/common.py b/python/lancedb/common.py index b39b5568..47d0bc43 100644 --- a/python/lancedb/common.py +++ b/python/lancedb/common.py @@ -23,3 +23,13 @@ URI = Union[str, Path] # TODO support generator DATA = Union[List[dict], dict, pd.DataFrame] VECTOR_COLUMN_NAME = "vector" + + +class Credential(str): + """Credential field""" + + def __repr__(self) -> str: + return "********" + + def __str__(self) -> str: + return "********" diff --git a/python/lancedb/db.py b/python/lancedb/db.py index fceb2e26..3ae286bb 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -15,17 +15,161 @@ from __future__ import annotations import functools import os +from abc import ABC, abstractmethod from pathlib import Path import pyarrow as pa from pyarrow import fs from .common import DATA, URI -from .table import LanceTable +from .table import LanceTable, Table from .util import get_uri_location, get_uri_scheme -class LanceDBConnection: +class DBConnection(ABC): + """An active LanceDB connection interface.""" + + @abstractmethod + def table_names(self) -> list[str]: + """List all table names in the database.""" + pass + + @abstractmethod + def create_table( + self, + name: str, + data: DATA = None, + schema: pa.Schema = None, + mode: str = "create", + on_bad_vectors: str = "error", + fill_value: float = 0.0, + ) -> Table: + """Create a [Table][lancedb.table.Table] in the database. + + Parameters + ---------- + name: str + The name of the table. + data: list, tuple, dict, pd.DataFrame; optional + The data to insert into the table. + schema: pyarrow.Schema; optional + The schema of the table. + mode: str; default "create" + The mode to use when creating the table. Can be either "create" or "overwrite". + By default, if the table already exists, an exception is raised. + If you want to overwrite the table, use mode="overwrite". + on_bad_vectors: str, default "error" + What to do if any of the vectors are not the same size or contains NaNs. + One of "error", "drop", "fill". + fill_value: float + The value to use when filling vectors. Only used if on_bad_vectors="fill". + + Note + ---- + The vector index won't be created by default. + To create the index, call the `create_index` method on the table. + + Returns + ------- + LanceTable + A reference to the newly created table. + + Examples + -------- + + Can create with list of tuples or dictionaries: + + >>> import lancedb + >>> db = lancedb.connect("./.lancedb") + >>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7}, + ... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}] + >>> db.create_table("my_table", data) + LanceTable(my_table) + >>> db["my_table"].head() + pyarrow.Table + vector: fixed_size_list[2] + child 0, item: float + lat: double + long: double + ---- + vector: [[[1.1,1.2],[0.2,1.8]]] + lat: [[45.5,40.1]] + long: [[-122.7,-74.1]] + + You can also pass a pandas DataFrame: + + >>> import pandas as pd + >>> data = pd.DataFrame({ + ... "vector": [[1.1, 1.2], [0.2, 1.8]], + ... "lat": [45.5, 40.1], + ... "long": [-122.7, -74.1] + ... }) + >>> db.create_table("table2", data) + LanceTable(table2) + >>> db["table2"].head() + pyarrow.Table + vector: fixed_size_list[2] + child 0, item: float + lat: double + long: double + ---- + vector: [[[1.1,1.2],[0.2,1.8]]] + lat: [[45.5,40.1]] + long: [[-122.7,-74.1]] + + Data is converted to Arrow before being written to disk. For maximum + control over how data is saved, either provide the PyArrow schema to + convert to or else provide a PyArrow table directly. + + >>> custom_schema = pa.schema([ + ... pa.field("vector", pa.list_(pa.float32(), 2)), + ... pa.field("lat", pa.float32()), + ... pa.field("long", pa.float32()) + ... ]) + >>> db.create_table("table3", data, schema = custom_schema) + LanceTable(table3) + >>> db["table3"].head() + pyarrow.Table + vector: fixed_size_list[2] + child 0, item: float + lat: float + long: float + ---- + vector: [[[1.1,1.2],[0.2,1.8]]] + lat: [[45.5,40.1]] + long: [[-122.7,-74.1]] + """ + raise NotImplementedError + + def __getitem__(self, name: str) -> LanceTable: + return self.open_table(name) + + def open_table(self, name: str) -> Table: + """Open a Lance Table in the database. + + Parameters + ---------- + name: str + The name of the table. + + Returns + ------- + A LanceTable object representing the table. + """ + raise NotImplementedError + + def drop_table(self, name: str): + """Drop a table from the database. + + Parameters + ---------- + name: str + The name of the table. + """ + raise NotImplementedError + + +class LanceDBConnection(DBConnection): """ A connection to a LanceDB database. @@ -59,13 +203,6 @@ class LanceDBConnection: if not isinstance(uri, Path): scheme = get_uri_scheme(uri) is_local = isinstance(uri, Path) or scheme == "file" - # managed lancedb remote uses schema like lancedb+[http|grpc|...]:// - self._is_managed_remote = not is_local and scheme.startswith("lancedb") - if self._is_managed_remote: - if len(scheme.split("+")) != 2: - raise ValueError( - f"Invalid LanceDB URI: {uri}, expected uri to have scheme like lancedb+://..." - ) if is_local: if isinstance(uri, str): uri = Path(uri) @@ -79,43 +216,6 @@ class LanceDBConnection: def uri(self) -> str: return self._uri - @functools.cached_property - def is_managed_remote(self) -> bool: - return self._is_managed_remote - - @functools.cached_property - def remote_flavor(self) -> str: - if not self.is_managed_remote: - raise ValueError( - "Not a managed remote LanceDB, there should be no server flavor" - ) - return get_uri_scheme(self.uri).split("+")[1] - - @functools.cached_property - def _client(self) -> "lancedb.remote.LanceDBClient": - if not self.is_managed_remote: - raise ValueError("Not a managed remote LanceDB, there should be no client") - - # don't import unless we are really using remote - from lancedb.remote.client import RestfulLanceDBClient - - if self.remote_flavor == "http": - return RestfulLanceDBClient(self._uri) - - raise ValueError("Unsupported remote flavor: " + self.remote_flavor) - - async def close(self): - if self._entered: - raise ValueError("Cannot re-enter the same LanceDBConnection twice") - self._entered = True - await self._client.close() - - async def __aenter__(self) -> LanceDBConnection: - return self - - async def __aexit__(self, exc_type, exc_value, traceback): - await self.close() - def table_names(self) -> list[str]: """Get the names of all tables in the database. @@ -149,9 +249,6 @@ class LanceDBConnection: def __contains__(self, name: str) -> bool: return name in self.table_names() - def __getitem__(self, name: str) -> LanceTable: - return self.open_table(name) - def create_table( self, name: str, diff --git a/python/lancedb/query.py b/python/lancedb/query.py index ae321b21..9262143b 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -10,18 +10,47 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from __future__ import annotations -import asyncio -from typing import Literal, Union +from typing import List, Literal, Optional, Union import numpy as np import pandas as pd import pyarrow as pa +from pydantic import BaseModel from .common import VECTOR_COLUMN_NAME +class Query(BaseModel): + """A Query""" + + vector_column: str = VECTOR_COLUMN_NAME + + # vector to search for + vector: List[float] + + # sql filter to refine the query with + filter: Optional[str] = None + + # top k results to return + k: int + + # # metrics + metric: str = "L2" + + # which columns to return in the results + columns: Optional[List[str]] = None + + # optional query parameters for tuning the results, + # e.g. `{"nprobes": "10", "refine_factor": "10"}` + nprobes: int = 10 + + # Refine factor. + refine_factor: Optional[int] = None + + class LanceQueryBuilder: """ A builder for nearest neighbor queries for LanceDB. @@ -47,9 +76,9 @@ class LanceQueryBuilder: def __init__( self, - table: "lancedb.table.LanceTable", + table: "lancedb.table.Table", query: Union[np.ndarray, str], - vector_column_name: str = VECTOR_COLUMN_NAME, + vector_column: str = VECTOR_COLUMN_NAME, ): self._metric = "L2" self._nprobes = 20 @@ -59,7 +88,7 @@ class LanceQueryBuilder: self._limit = 10 self._columns = None self._where = None - self._vector_column_name = vector_column_name + self._vector_column = vector_column def limit(self, limit: int) -> LanceQueryBuilder: """Set the maximum number of results to return. @@ -181,52 +210,28 @@ class LanceQueryBuilder: def to_arrow(self) -> pa.Table: """ - Execute the query and return the results as a arrow Table. + Execute the query and return the results as an + [Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table). + In addition to the selected columns, LanceDB also returns a vector and also the "score" column which is the distance between the query - vector and the returned vector. + vector and the returned vectors. """ - if self._table._conn.is_managed_remote: - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = asyncio.get_event_loop() - result = self._table._conn._client.query( - self._table.name, self.to_remote_query() - ) - return loop.run_until_complete(result).to_arrow() - - ds = self._table.to_lance() - return ds.to_table( - columns=self._columns, - filter=self._where, - nearest={ - "column": self._vector_column_name, - "q": self._query, - "k": self._limit, - "metric": self._metric, - "nprobes": self._nprobes, - "refine_factor": self._refine_factor, - }, - ) - - def to_remote_query(self) -> "VectorQuery": - # don't import unless we are connecting to remote - from lancedb.remote.client import VectorQuery - - return VectorQuery( - vector=self._query.tolist(), + vector = self._query if isinstance(self._query, list) else self._query.tolist() + query = Query( + vector=vector, filter=self._where, k=self._limit, - _metric=self._metric, + metric=self._metric, columns=self._columns, nprobes=self._nprobes, refine_factor=self._refine_factor, ) + return self._table._execute_query(query) class LanceFtsQueryBuilder(LanceQueryBuilder): - def to_df(self) -> pd.DataFrame: + def to_arrow(self) -> pd.Table: try: import tantivy except ImportError: @@ -243,8 +248,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): # get the scores and doc ids row_ids, scores = search_index(index, self._query, self._limit) if len(row_ids) == 0: - return pd.DataFrame() + empty_schema = pa.schema([pa.field("score", pa.float32())]) + return pa.Table.from_pylist([], schema=empty_schema) scores = pa.array(scores) output_tbl = self._table.to_lance().take(row_ids, columns=self._columns) output_tbl = output_tbl.append_column("score", scores) - return output_tbl.to_pandas() + return output_tbl diff --git a/python/lancedb/remote/client.py b/python/lancedb/remote/client.py index 707603ad..3cccba40 100644 --- a/python/lancedb/remote/client.py +++ b/python/lancedb/remote/client.py @@ -13,12 +13,13 @@ import functools -import urllib.parse +from typing import Dict import aiohttp import attr import pyarrow as pa +from lancedb.common import Credential from lancedb.remote import VectorQuery, VectorQueryResult from lancedb.remote.errors import LanceDBClientError @@ -35,29 +36,32 @@ def _check_not_closed(f): @attr.define(slots=False) class RestfulLanceDBClient: - url: str + db_name: str + region: str + api_key: Credential closed: bool = attr.field(default=False, init=False) @functools.cached_property def session(self) -> aiohttp.ClientSession: - parsed = urllib.parse.urlparse(self.url) - scheme = parsed.scheme - if not scheme.startswith("lancedb"): - raise ValueError( - f"Invalid scheme: {scheme}, must be like lancedb+://" - ) - flavor = scheme.split("+")[1] - url = f"{flavor}://{parsed.hostname}:{parsed.port}" + url = f"https://{self.db_name}.{self.region}.api.lancedb.com" return aiohttp.ClientSession(url) async def close(self): await self.session.close() self.closed = True + @functools.cached_property + def headers(self) -> Dict[str, str]: + return { + "x-api-key": self.api_key, + } + @_check_not_closed async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: async with self.session.post( - f"/table/{table_name}/", json=query.dict(exclude_none=True) + f"/1/table/{table_name}/", + json=query.dict(exclude_none=True), + headers=self.headers, ) as resp: resp: aiohttp.ClientResponse = resp if 400 <= resp.status < 500: diff --git a/python/lancedb/remote/db.py b/python/lancedb/remote/db.py new file mode 100644 index 00000000..7b721662 --- /dev/null +++ b/python/lancedb/remote/db.py @@ -0,0 +1,71 @@ +# Copyright 2023 LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List +from urllib.parse import urlparse + +import pyarrow as pa + +from lancedb.common import DATA +from lancedb.db import DBConnection +from lancedb.table import Table + +from .client import RestfulLanceDBClient + + +class RemoteDBConnection(DBConnection): + """A connection to a remote LanceDB database.""" + + def __init__(self, db_url: str, api_key: str, region: str): + """Connect to a remote LanceDB database.""" + parsed = urlparse(db_url) + if parsed.scheme != "db": + raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://") + self.db_name = parsed.netloc + self.api_key = api_key + self._client = RestfulLanceDBClient(self.db_name, region, api_key) + + def __repr__(self) -> str: + return f"RemoveConnect(name={self.db_name})" + + def table_names(self) -> List[str]: + raise NotImplementedError + + def open_table(self, name: str) -> Table: + """Open a Lance Table in the database. + + Parameters + ---------- + name: str + The name of the table. + + Returns + ------- + A LanceTable object representing the table. + """ + from .table import RemoteTable + + # TODO: check if table exists + + return RemoteTable(self, name) + + def create_table( + self, + name: str, + data: DATA = None, + schema: pa.Schema = None, + mode: str = "create", + on_bad_vectors: str = "error", + fill_value: float = 0.0, + ) -> Table: + raise NotImplementedError diff --git a/python/lancedb/remote/table.py b/python/lancedb/remote/table.py new file mode 100644 index 00000000..e4152e6b --- /dev/null +++ b/python/lancedb/remote/table.py @@ -0,0 +1,70 @@ +# Copyright 2023 LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from typing import Union + +import pyarrow as pa + +from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME + +from ..query import LanceQueryBuilder, Query +from ..table import Query, Table +from .db import RemoteDBConnection + + +class RemoteTable(Table): + def __init__(self, conn: RemoteDBConnection, name: str): + self._conn = conn + self._name = name + + def __repr__(self) -> str: + return f"RemoteTable({self._conn.db_name}.{self.name})" + + def schema(self) -> pa.Schema: + raise NotImplementedError + + def to_arrow(self) -> pa.Table: + raise NotImplementedError + + def create_index( + self, + metric="L2", + num_partitions=256, + num_sub_vectors=96, + vector_column_name: str = VECTOR_COLUMN_NAME, + replace: bool = True, + ): + raise NotImplementedError + + def add( + self, + data: DATA, + mode: str = "append", + on_bad_vectors: str = "error", + fill_value: float = 0.0, + ) -> int: + raise NotImplementedError + + def search( + self, query: Union[VEC, str], vector_column: str = VECTOR_COLUMN_NAME + ) -> LanceQueryBuilder: + return LanceQueryBuilder(self, query, vector_column) + + def _execute_query(self, query: Query) -> pa.Table: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.get_event_loop() + result = self._conn._client.query(self._name, query) + return loop.run_until_complete(result).to_arrow() diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 82f97730..23d4a20e 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -14,6 +14,7 @@ from __future__ import annotations import os +from abc import ABC, abstractmethod from functools import cached_property from typing import List, Union @@ -27,7 +28,7 @@ from lance import LanceDataset from lance.vector import vec_to_table from .common import DATA, VEC, VECTOR_COLUMN_NAME -from .query import LanceFtsQueryBuilder, LanceQueryBuilder +from .query import LanceFtsQueryBuilder, LanceQueryBuilder, Query def _sanitize_data(data, schema, on_bad_vectors, fill_value): @@ -48,14 +49,14 @@ def _sanitize_data(data, schema, on_bad_vectors, fill_value): return data -class LanceTable: +class Table(ABC): """ - A table in a LanceDB database. + A [Table](Table) is a collection of Records in a LanceDB [Database](Database). Examples -------- - Create using [LanceDBConnection.create_table][lancedb.LanceDBConnection.create_table] + Create using [DBConnection.create_table][lancedb.DBConnection.create_table] (more examples in that method's documentation). >>> import lancedb @@ -70,12 +71,12 @@ class LanceTable: vector: [[[1.1,1.2]]] b: [[2]] - Can append new data with [LanceTable.add][lancedb.table.LanceTable.add]. + Can append new data with [Table.add()][lancedb.table.Table.add]. >>> table.add([{"vector": [0.5, 1.3], "b": 4}]) 2 - Can query the table with [LanceTable.search][lancedb.table.LanceTable.search]. + Can query the table with [Table.search][lancedb.table.Table.search]. >>> table.search([0.4, 0.4]).select(["b"]).to_df() b vector score @@ -83,8 +84,128 @@ class LanceTable: 1 2 [1.1, 1.2] 1.13 Search queries are much faster when an index is created. See - [LanceTable.create_index][lancedb.table.LanceTable.create_index]. + [Table.create_index][lancedb.table.Table.create_index]. + """ + @abstractmethod + def schema(self) -> pa.Schema: + """Return the [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#) of + this [Table](Table) + + """ + raise NotImplementedError + + def to_pandas(self) -> pd.DataFrame: + """Return the table as a pandas DataFrame. + + Returns + ------- + pd.DataFrame + """ + return self.to_arrow().to_pandas() + + @abstractmethod + def to_arrow(self) -> pa.Table: + """Return the table as a pyarrow Table. + + Returns + ------- + pa.Table + """ + raise NotImplementedError + + def create_index( + self, + metric="L2", + num_partitions=256, + num_sub_vectors=96, + vector_column_name: str = VECTOR_COLUMN_NAME, + replace: bool = True, + ): + """Create an index on the table. + + Parameters + ---------- + metric: str, default "L2" + The distance metric to use when creating the index. + Valid values are "L2", "cosine", or "dot". + L2 is euclidean distance. + num_partitions: int + The number of IVF partitions to use when creating the index. + Default is 256. + num_sub_vectors: int + The number of PQ sub-vectors to use when creating the index. + Default is 96. + vector_column_name: str, default "vector" + The vector column name to create the index. + replace: bool, default True + If True, replace the existing index if it exists. + If False, raise an error if duplicate index exists. + """ + raise NotImplementedError + + @abstractmethod + def add( + self, + data: DATA, + mode: str = "append", + on_bad_vectors: str = "error", + fill_value: float = 0.0, + ) -> int: + """Add more data to the [Table](Table). + + Parameters + ---------- + data: list-of-dict, dict, pd.DataFrame + The data to insert into the table. + mode: str + The mode to use when writing the data. Valid values are + "append" and "overwrite". + on_bad_vectors: str, default "error" + What to do if any of the vectors are not the same size or contains NaNs. + One of "error", "drop", "fill". + fill_value: float, default 0. + The value to use when filling vectors. Only used if on_bad_vectors="fill". + + Returns + ------- + int + The number of vectors in the table. + """ + raise NotImplementedError + + @abstractmethod + def search( + self, query: Union[VEC, str], vector_column: str = VECTOR_COLUMN_NAME + ) -> LanceQueryBuilder: + """Create a search query to find the nearest neighbors + of the given query vector. + + Parameters + ---------- + query: list, np.ndarray + The query vector. + vector_column: str, default "vector" + The name of the vector column to search. + + Returns + ------- + LanceQueryBuilder + A query builder object representing the query. + Once executed, the query returns selected columns, the vector, + and also the "score" column which is the distance between the query + vector and the returned vector. + """ + raise NotImplementedError + + @abstractmethod + def _execute_query(self, query: Query) -> pa.Table: + pass + + +class LanceTable(Table): + """ + A table in a LanceDB database. """ def __init__( @@ -197,26 +318,7 @@ class LanceTable: vector_column_name=VECTOR_COLUMN_NAME, replace: bool = True, ): - """Create an index on the table. - - Parameters - ---------- - metric: str, default "L2" - The distance metric to use when creating the index. - Valid values are "L2", "cosine", or "dot". - L2 is euclidean distance. - num_partitions: int - The number of IVF partitions to use when creating the index. - Default is 256. - num_sub_vectors: int - The number of PQ sub-vectors to use when creating the index. - Default is 96. - vector_column_name: str, default "vector" - The vector column name to create the index. - replace: bool, default True - If True, replace the existing index if it exists. - If False, raise an error if duplicate index exists. - """ + """Create an index on the table.""" self._dataset.create_index( column=vector_column_name, index_type="IVF_PQ", @@ -387,9 +489,6 @@ class LanceTable: @classmethod def open(cls, db, name): tbl = cls(db, name) - if tbl._conn.is_managed_remote: - # Not completely sure how to check for remote table existence yet. - return tbl if not os.path.exists(tbl._dataset_uri): raise FileNotFoundError( f"Table {name} does not exist. Please first call db.create_table({name}, data)" @@ -424,6 +523,21 @@ class LanceTable: """ self._dataset.delete(where) + def _execute_query(self, query: Query) -> pa.Table: + ds = self.to_lance() + return ds.to_table( + columns=query.columns, + filter=query.filter, + nearest={ + "column": query.vector_column, + "q": query.vector, + "k": query.k, + "metric": query.metric, + "nprobes": query.nprobes, + "refine_factor": query.refine_factor, + }, + ) + def _sanitize_schema( data: pa.Table, @@ -510,7 +624,7 @@ def _sanitize_vector_column( data.column_names.index(vector_column_name), vector_column_name, vec_arr ) - has_nans = pc.any(vec_arr.values.is_nan()).as_py() + has_nans = pc.any(pc.is_nan(vec_arr.values)).as_py() if has_nans: data = _sanitize_nans( data, fill_value, on_bad_vectors, vec_arr, vector_column_name @@ -573,7 +687,7 @@ def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name "`fill_value` must not be None if `on_bad_vectors` is 'fill'" ) fill_value = float(fill_value) - values = pc.if_else(vec_arr.values.is_nan(), fill_value, vec_arr.values) + values = pc.if_else(pc.is_nan(vec_arr.values), fill_value, vec_arr.values) ndims = len(vec_arr[0]) vec_arr = pa.FixedSizeListArray.from_arrays(values, ndims) data = data.set_column( diff --git a/python/tests/test_query.py b/python/tests/test_query.py index 69eda338..f1f15b26 100644 --- a/python/tests/test_query.py +++ b/python/tests/test_query.py @@ -20,18 +20,33 @@ import pyarrow as pa import pytest from lancedb.db import LanceDBConnection -from lancedb.query import LanceQueryBuilder +from lancedb.query import LanceQueryBuilder, Query from lancedb.table import LanceTable class MockTable: def __init__(self, tmp_path): self.uri = tmp_path - self._conn = LanceDBConnection("/tmp/lance/") + self._conn = LanceDBConnection(self.uri) def to_lance(self): return lance.dataset(self.uri) + def _execute_query(self, query): + ds = self.to_lance() + return ds.to_table( + columns=query.columns, + filter=query.filter, + nearest={ + "column": query.vector_column, + "q": query.vector, + "k": query.k, + "metric": query.metric, + "nprobes": query.nprobes, + "refine_factor": query.refine_factor, + }, + ) + @pytest.fixture def table(tmp_path) -> MockTable: @@ -94,20 +109,17 @@ def test_query_builder_with_different_vector_column(): ) ds = mock.Mock() table.to_lance.return_value = ds - table._conn = mock.MagicMock() - table._conn.is_managed_remote = False builder.to_arrow() - ds.to_table.assert_called_once_with( - columns=["b"], - filter="b < 10", - nearest={ - "column": vector_column_name, - "q": query, - "k": 2, - "metric": "cosine", - "nprobes": 20, - "refine_factor": None, - }, + table._execute_query.assert_called_once_with( + Query( + vector=query, + filter="b < 10", + k=2, + metric="cosine", + columns=["b"], + nprobes=20, + refine_factor=None, + ) ) diff --git a/python/tests/test_remote_db.py b/python/tests/test_remote_db.py index 7971340d..aa480870 100644 --- a/python/tests/test_remote_db.py +++ b/python/tests/test_remote_db.py @@ -13,7 +13,7 @@ import pyarrow as pa -from lancedb.db import LanceDBConnection +import lancedb from lancedb.remote.client import VectorQuery, VectorQueryResult @@ -28,7 +28,7 @@ class FakeLanceDBClient: def test_remote_db(): - conn = LanceDBConnection("lancedb+http://client-will-be-injected") + conn = lancedb.connect("db://client-will-be-injected", api_key="fake") setattr(conn, "_client", FakeLanceDBClient()) table = conn["test"]