Compare commits

...

5 Commits

Author SHA1 Message Date
Lei Xu
97364a2514 Bump to v0.1.10-python 2023-07-09 21:52:11 -07:00
Lei Xu
e6c6da6104 [Python] Initial support of cloud API (#260)
Support connect with remote database, and implement Search API
2023-07-07 15:41:15 -07:00
Leon Yee
a5eb665b7d [docs] dynamic docs generation and deployment (#253)
Solves #245 , edited docs.yml to run the generation of docs before
deployment. Tested on a test repository
2023-07-06 21:10:36 -07:00
Chang She
e2325c634b Allow creation of an empty table (#254)
It's inconvenient to always require data at table creation time.
Here we enable you to create an empty table and add data and set schema
later.

---------

Co-authored-by: Chang She <chang@lancedb.com>
2023-07-06 20:44:58 -07:00
Chang She
507eeae9c8 Set default to error instead of drop (#259)
when encountering bad input data, we can default to principle of least
surprise and raise an exception.

Co-authored-by: Chang She <chang@lancedb.com>
2023-07-05 22:44:18 -07:00
18 changed files with 681 additions and 220 deletions

View File

@@ -39,6 +39,28 @@ jobs:
run: |
python -m pip install -e .
python -m pip install -r ../docs/requirements.txt
- name: Set up node
uses: actions/setup-node@v3
with:
node-version: ${{ matrix.node-version }}
cache: 'npm'
cache-dependency-path: node/package-lock.json
- uses: Swatinem/rust-cache@v2
- name: Install node dependencies
working-directory: node
run: |
sudo apt update
sudo apt install -y protobuf-compiler libssl-dev
- name: Build node
working-directory: node
run: |
npm ci
npm run build
npm run tsc
- name: Create markdown files
working-directory: node
run: |
npx typedoc --plugin typedoc-plugin-markdown --out ../docs/src/javascript src/index.ts
- name: Build docs
run: |
PYTHONPATH=. mkdocs build -f docs/mkdocs.yml
@@ -50,4 +72,4 @@ jobs:
path: "docs/site"
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v1
uses: actions/deploy-pages@v1

View File

@@ -61,6 +61,8 @@ jobs:
run: |
pip install -e .
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
pip install pytest pytest-mock
pip install pytest pytest-mock black
- name: Black
run: black --check --diff --no-color --quiet .
- name: Run tests
run: pytest -x -v --durations=30 tests
run: pytest -x -v --durations=30 tests

View File

@@ -10,14 +10,16 @@ pip install lancedb
::: lancedb.connect
::: lancedb.LanceDBConnection
::: lancedb.db.DBConnection
## Table
::: lancedb.table.LanceTable
::: lancedb.table.Table
## Querying
::: lancedb.query.Query
::: lancedb.query.LanceQueryBuilder
::: lancedb.query.LanceFtsQueryBuilder

View File

@@ -11,16 +11,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .db import URI, LanceDBConnection
from typing import Optional
from .db import URI, DBConnection, LanceDBConnection
from .remote.db import RemoteDBConnection
def connect(uri: URI) -> LanceDBConnection:
"""Connect to a LanceDB instance at the given URI
def connect(
uri: URI, *, api_key: Optional[str] = None, region: str = "us-west-2"
) -> DBConnection:
"""Connect to a LanceDB database.
Parameters
----------
uri: str or Path
The uri of the database.
api_token: str, optional
If presented, connect to LanceDB cloud.
Otherwise, connect to a database on file system or cloud storage.
Examples
--------
@@ -34,9 +42,17 @@ def connect(uri: URI) -> LanceDBConnection:
>>> db = lancedb.connect("s3://my-bucket/lancedb")
Connect to LancdDB cloud:
>>> db = lancedb.connect("db://my_database", api_key="ldb_...")
Returns
-------
conn : LanceDBConnection
conn : DBConnection
A connection to a LanceDB database.
"""
if isinstance(uri, str) and uri.startswith("db://"):
if api_key is None:
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
return RemoteDBConnection(uri, api_key, region)
return LanceDBConnection(uri)

View File

@@ -23,3 +23,13 @@ URI = Union[str, Path]
# TODO support generator
DATA = Union[List[dict], dict, pd.DataFrame]
VECTOR_COLUMN_NAME = "vector"
class Credential(str):
"""Credential field"""
def __repr__(self) -> str:
return "********"
def __str__(self) -> str:
return "********"

View File

@@ -1,10 +1,8 @@
import builtins
import os
import pytest
# import lancedb so we don't have to in every example
import lancedb
@pytest.fixture(autouse=True)

View File

@@ -15,17 +15,161 @@ from __future__ import annotations
import functools
import os
from abc import ABC, abstractmethod
from pathlib import Path
import pyarrow as pa
from pyarrow import fs
from .common import DATA, URI
from .table import LanceTable
from .table import LanceTable, Table
from .util import get_uri_location, get_uri_scheme
class LanceDBConnection:
class DBConnection(ABC):
"""An active LanceDB connection interface."""
@abstractmethod
def table_names(self) -> list[str]:
"""List all table names in the database."""
pass
@abstractmethod
def create_table(
self,
name: str,
data: DATA = None,
schema: pa.Schema = None,
mode: str = "create",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
) -> Table:
"""Create a [Table][lancedb.table.Table] in the database.
Parameters
----------
name: str
The name of the table.
data: list, tuple, dict, pd.DataFrame; optional
The data to insert into the table.
schema: pyarrow.Schema; optional
The schema of the table.
mode: str; default "create"
The mode to use when creating the table. Can be either "create" or "overwrite".
By default, if the table already exists, an exception is raised.
If you want to overwrite the table, use mode="overwrite".
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill".
fill_value: float
The value to use when filling vectors. Only used if on_bad_vectors="fill".
Note
----
The vector index won't be created by default.
To create the index, call the `create_index` method on the table.
Returns
-------
LanceTable
A reference to the newly created table.
Examples
--------
Can create with list of tuples or dictionaries:
>>> import lancedb
>>> db = lancedb.connect("./.lancedb")
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
>>> db.create_table("my_table", data)
LanceTable(my_table)
>>> db["my_table"].head()
pyarrow.Table
vector: fixed_size_list<item: float>[2]
child 0, item: float
lat: double
long: double
----
vector: [[[1.1,1.2],[0.2,1.8]]]
lat: [[45.5,40.1]]
long: [[-122.7,-74.1]]
You can also pass a pandas DataFrame:
>>> import pandas as pd
>>> data = pd.DataFrame({
... "vector": [[1.1, 1.2], [0.2, 1.8]],
... "lat": [45.5, 40.1],
... "long": [-122.7, -74.1]
... })
>>> db.create_table("table2", data)
LanceTable(table2)
>>> db["table2"].head()
pyarrow.Table
vector: fixed_size_list<item: float>[2]
child 0, item: float
lat: double
long: double
----
vector: [[[1.1,1.2],[0.2,1.8]]]
lat: [[45.5,40.1]]
long: [[-122.7,-74.1]]
Data is converted to Arrow before being written to disk. For maximum
control over how data is saved, either provide the PyArrow schema to
convert to or else provide a PyArrow table directly.
>>> custom_schema = pa.schema([
... pa.field("vector", pa.list_(pa.float32(), 2)),
... pa.field("lat", pa.float32()),
... pa.field("long", pa.float32())
... ])
>>> db.create_table("table3", data, schema = custom_schema)
LanceTable(table3)
>>> db["table3"].head()
pyarrow.Table
vector: fixed_size_list<item: float>[2]
child 0, item: float
lat: float
long: float
----
vector: [[[1.1,1.2],[0.2,1.8]]]
lat: [[45.5,40.1]]
long: [[-122.7,-74.1]]
"""
raise NotImplementedError
def __getitem__(self, name: str) -> LanceTable:
return self.open_table(name)
def open_table(self, name: str) -> Table:
"""Open a Lance Table in the database.
Parameters
----------
name: str
The name of the table.
Returns
-------
A LanceTable object representing the table.
"""
raise NotImplementedError
def drop_table(self, name: str):
"""Drop a table from the database.
Parameters
----------
name: str
The name of the table.
"""
raise NotImplementedError
class LanceDBConnection(DBConnection):
"""
A connection to a LanceDB database.
@@ -59,13 +203,6 @@ class LanceDBConnection:
if not isinstance(uri, Path):
scheme = get_uri_scheme(uri)
is_local = isinstance(uri, Path) or scheme == "file"
# managed lancedb remote uses schema like lancedb+[http|grpc|...]://
self._is_managed_remote = not is_local and scheme.startswith("lancedb")
if self._is_managed_remote:
if len(scheme.split("+")) != 2:
raise ValueError(
f"Invalid LanceDB URI: {uri}, expected uri to have scheme like lancedb+<flavor>://..."
)
if is_local:
if isinstance(uri, str):
uri = Path(uri)
@@ -79,43 +216,6 @@ class LanceDBConnection:
def uri(self) -> str:
return self._uri
@functools.cached_property
def is_managed_remote(self) -> bool:
return self._is_managed_remote
@functools.cached_property
def remote_flavor(self) -> str:
if not self.is_managed_remote:
raise ValueError(
"Not a managed remote LanceDB, there should be no server flavor"
)
return get_uri_scheme(self.uri).split("+")[1]
@functools.cached_property
def _client(self) -> "lancedb.remote.LanceDBClient":
if not self.is_managed_remote:
raise ValueError("Not a managed remote LanceDB, there should be no client")
# don't import unless we are really using remote
from lancedb.remote.client import RestfulLanceDBClient
if self.remote_flavor == "http":
return RestfulLanceDBClient(self._uri)
raise ValueError("Unsupported remote flavor: " + self.remote_flavor)
async def close(self):
if self._entered:
raise ValueError("Cannot re-enter the same LanceDBConnection twice")
self._entered = True
await self._client.close()
async def __aenter__(self) -> LanceDBConnection:
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.close()
def table_names(self) -> list[str]:
"""Get the names of all tables in the database.
@@ -149,16 +249,13 @@ class LanceDBConnection:
def __contains__(self, name: str) -> bool:
return name in self.table_names()
def __getitem__(self, name: str) -> LanceTable:
return self.open_table(name)
def create_table(
self,
name: str,
data: DATA = None,
schema: pa.Schema = None,
mode: str = "create",
on_bad_vectors: str = "drop",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
) -> LanceTable:
"""Create a table in the database.
@@ -175,9 +272,9 @@ class LanceDBConnection:
The mode to use when creating the table. Can be either "create" or "overwrite".
By default, if the table already exists, an exception is raised.
If you want to overwrite the table, use mode="overwrite".
on_bad_vectors: str
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "raise", "drop", "fill".
One of "error", "drop", "fill".
fill_value: float
The value to use when filling vectors. Only used if on_bad_vectors="fill".

View File

@@ -10,18 +10,47 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
from typing import Awaitable, Literal
from typing import List, Literal, Optional, Union
import numpy as np
import pandas as pd
import pyarrow as pa
from pydantic import BaseModel
from .common import VECTOR_COLUMN_NAME
class Query(BaseModel):
"""A Query"""
vector_column: str = VECTOR_COLUMN_NAME
# vector to search for
vector: List[float]
# sql filter to refine the query with
filter: Optional[str] = None
# top k results to return
k: int
# # metrics
metric: str = "L2"
# which columns to return in the results
columns: Optional[List[str]] = None
# optional query parameters for tuning the results,
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
nprobes: int = 10
# Refine factor.
refine_factor: Optional[int] = None
class LanceQueryBuilder:
"""
A builder for nearest neighbor queries for LanceDB.
@@ -47,9 +76,9 @@ class LanceQueryBuilder:
def __init__(
self,
table: "lancedb.table.LanceTable",
query: np.ndarray,
vector_column_name: str = VECTOR_COLUMN_NAME,
table: "lancedb.table.Table",
query: Union[np.ndarray, str],
vector_column: str = VECTOR_COLUMN_NAME,
):
self._metric = "L2"
self._nprobes = 20
@@ -59,7 +88,7 @@ class LanceQueryBuilder:
self._limit = 10
self._columns = None
self._where = None
self._vector_column_name = vector_column_name
self._vector_column = vector_column
def limit(self, limit: int) -> LanceQueryBuilder:
"""Set the maximum number of results to return.
@@ -181,52 +210,28 @@ class LanceQueryBuilder:
def to_arrow(self) -> pa.Table:
"""
Execute the query and return the results as a arrow Table.
Execute the query and return the results as an
[Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table).
In addition to the selected columns, LanceDB also returns a vector
and also the "score" column which is the distance between the query
vector and the returned vector.
vector and the returned vectors.
"""
if self._table._conn.is_managed_remote:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.get_event_loop()
result = self._table._conn._client.query(
self._table.name, self.to_remote_query()
)
return loop.run_until_complete(result).to_arrow()
ds = self._table.to_lance()
return ds.to_table(
columns=self._columns,
filter=self._where,
nearest={
"column": self._vector_column_name,
"q": self._query,
"k": self._limit,
"metric": self._metric,
"nprobes": self._nprobes,
"refine_factor": self._refine_factor,
},
)
def to_remote_query(self) -> "VectorQuery":
# don't import unless we are connecting to remote
from lancedb.remote.client import VectorQuery
return VectorQuery(
vector=self._query.tolist(),
vector = self._query if isinstance(self._query, list) else self._query.tolist()
query = Query(
vector=vector,
filter=self._where,
k=self._limit,
_metric=self._metric,
metric=self._metric,
columns=self._columns,
nprobes=self._nprobes,
refine_factor=self._refine_factor,
)
return self._table._execute_query(query)
class LanceFtsQueryBuilder(LanceQueryBuilder):
def to_df(self) -> pd.DataFrame:
def to_arrow(self) -> pd.Table:
try:
import tantivy
except ImportError:
@@ -243,8 +248,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
# get the scores and doc ids
row_ids, scores = search_index(index, self._query, self._limit)
if len(row_ids) == 0:
return pd.DataFrame()
empty_schema = pa.schema([pa.field("score", pa.float32())])
return pa.Table.from_pylist([], schema=empty_schema)
scores = pa.array(scores)
output_tbl = self._table.to_lance().take(row_ids, columns=self._columns)
output_tbl = output_tbl.append_column("score", scores)
return output_tbl.to_pandas()
return output_tbl

View File

@@ -15,7 +15,6 @@ import abc
from typing import List, Optional
import attr
import pandas as pd
import pyarrow as pa
from pydantic import BaseModel

View File

@@ -13,12 +13,13 @@
import functools
import urllib.parse
from typing import Dict
import aiohttp
import attr
import pyarrow as pa
from lancedb.common import Credential
from lancedb.remote import VectorQuery, VectorQueryResult
from lancedb.remote.errors import LanceDBClientError
@@ -35,29 +36,32 @@ def _check_not_closed(f):
@attr.define(slots=False)
class RestfulLanceDBClient:
url: str
db_name: str
region: str
api_key: Credential
closed: bool = attr.field(default=False, init=False)
@functools.cached_property
def session(self) -> aiohttp.ClientSession:
parsed = urllib.parse.urlparse(self.url)
scheme = parsed.scheme
if not scheme.startswith("lancedb"):
raise ValueError(
f"Invalid scheme: {scheme}, must be like lancedb+<flavor>://"
)
flavor = scheme.split("+")[1]
url = f"{flavor}://{parsed.hostname}:{parsed.port}"
url = f"https://{self.db_name}.{self.region}.api.lancedb.com"
return aiohttp.ClientSession(url)
async def close(self):
await self.session.close()
self.closed = True
@functools.cached_property
def headers(self) -> Dict[str, str]:
return {
"x-api-key": self.api_key,
}
@_check_not_closed
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
async with self.session.post(
f"/table/{table_name}/", json=query.dict(exclude_none=True)
f"/1/table/{table_name}/",
json=query.dict(exclude_none=True),
headers=self.headers,
) as resp:
resp: aiohttp.ClientResponse = resp
if 400 <= resp.status < 500:

View File

@@ -0,0 +1,71 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from urllib.parse import urlparse
import pyarrow as pa
from lancedb.common import DATA
from lancedb.db import DBConnection
from lancedb.table import Table
from .client import RestfulLanceDBClient
class RemoteDBConnection(DBConnection):
"""A connection to a remote LanceDB database."""
def __init__(self, db_url: str, api_key: str, region: str):
"""Connect to a remote LanceDB database."""
parsed = urlparse(db_url)
if parsed.scheme != "db":
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
self.db_name = parsed.netloc
self.api_key = api_key
self._client = RestfulLanceDBClient(self.db_name, region, api_key)
def __repr__(self) -> str:
return f"RemoveConnect(name={self.db_name})"
def table_names(self) -> List[str]:
raise NotImplementedError
def open_table(self, name: str) -> Table:
"""Open a Lance Table in the database.
Parameters
----------
name: str
The name of the table.
Returns
-------
A LanceTable object representing the table.
"""
from .table import RemoteTable
# TODO: check if table exists
return RemoteTable(self, name)
def create_table(
self,
name: str,
data: DATA = None,
schema: pa.Schema = None,
mode: str = "create",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
) -> Table:
raise NotImplementedError

View File

@@ -0,0 +1,70 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from typing import Union
import pyarrow as pa
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
from ..query import LanceQueryBuilder, Query
from ..table import Query, Table
from .db import RemoteDBConnection
class RemoteTable(Table):
def __init__(self, conn: RemoteDBConnection, name: str):
self._conn = conn
self._name = name
def __repr__(self) -> str:
return f"RemoteTable({self._conn.db_name}.{self.name})"
def schema(self) -> pa.Schema:
raise NotImplementedError
def to_arrow(self) -> pa.Table:
raise NotImplementedError
def create_index(
self,
metric="L2",
num_partitions=256,
num_sub_vectors=96,
vector_column_name: str = VECTOR_COLUMN_NAME,
replace: bool = True,
):
raise NotImplementedError
def add(
self,
data: DATA,
mode: str = "append",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
) -> int:
raise NotImplementedError
def search(
self, query: Union[VEC, str], vector_column: str = VECTOR_COLUMN_NAME
) -> LanceQueryBuilder:
return LanceQueryBuilder(self, query, vector_column)
def _execute_query(self, query: Query) -> pa.Table:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.get_event_loop()
result = self._conn._client.query(self._name, query)
return loop.run_until_complete(result).to_arrow()

View File

@@ -14,19 +14,21 @@
from __future__ import annotations
import os
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Any, List, Union
from typing import List, Union
import lance
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.fs
from lance import LanceDataset
from lance.vector import vec_to_table
from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .query import LanceFtsQueryBuilder, LanceQueryBuilder
from .query import LanceFtsQueryBuilder, LanceQueryBuilder, Query
def _sanitize_data(data, schema, on_bad_vectors, fill_value):
@@ -47,14 +49,14 @@ def _sanitize_data(data, schema, on_bad_vectors, fill_value):
return data
class LanceTable:
class Table(ABC):
"""
A table in a LanceDB database.
A [Table](Table) is a collection of Records in a LanceDB [Database](Database).
Examples
--------
Create using [LanceDBConnection.create_table][lancedb.LanceDBConnection.create_table]
Create using [DBConnection.create_table][lancedb.DBConnection.create_table]
(more examples in that method's documentation).
>>> import lancedb
@@ -69,12 +71,12 @@ class LanceTable:
vector: [[[1.1,1.2]]]
b: [[2]]
Can append new data with [LanceTable.add][lancedb.table.LanceTable.add].
Can append new data with [Table.add()][lancedb.table.Table.add].
>>> table.add([{"vector": [0.5, 1.3], "b": 4}])
2
Can query the table with [LanceTable.search][lancedb.table.LanceTable.search].
Can query the table with [Table.search][lancedb.table.Table.search].
>>> table.search([0.4, 0.4]).select(["b"]).to_df()
b vector score
@@ -82,8 +84,128 @@ class LanceTable:
1 2 [1.1, 1.2] 1.13
Search queries are much faster when an index is created. See
[LanceTable.create_index][lancedb.table.LanceTable.create_index].
[Table.create_index][lancedb.table.Table.create_index].
"""
@abstractmethod
def schema(self) -> pa.Schema:
"""Return the [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#) of
this [Table](Table)
"""
raise NotImplementedError
def to_pandas(self) -> pd.DataFrame:
"""Return the table as a pandas DataFrame.
Returns
-------
pd.DataFrame
"""
return self.to_arrow().to_pandas()
@abstractmethod
def to_arrow(self) -> pa.Table:
"""Return the table as a pyarrow Table.
Returns
-------
pa.Table
"""
raise NotImplementedError
def create_index(
self,
metric="L2",
num_partitions=256,
num_sub_vectors=96,
vector_column_name: str = VECTOR_COLUMN_NAME,
replace: bool = True,
):
"""Create an index on the table.
Parameters
----------
metric: str, default "L2"
The distance metric to use when creating the index.
Valid values are "L2", "cosine", or "dot".
L2 is euclidean distance.
num_partitions: int
The number of IVF partitions to use when creating the index.
Default is 256.
num_sub_vectors: int
The number of PQ sub-vectors to use when creating the index.
Default is 96.
vector_column_name: str, default "vector"
The vector column name to create the index.
replace: bool, default True
If True, replace the existing index if it exists.
If False, raise an error if duplicate index exists.
"""
raise NotImplementedError
@abstractmethod
def add(
self,
data: DATA,
mode: str = "append",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
) -> int:
"""Add more data to the [Table](Table).
Parameters
----------
data: list-of-dict, dict, pd.DataFrame
The data to insert into the table.
mode: str
The mode to use when writing the data. Valid values are
"append" and "overwrite".
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
Returns
-------
int
The number of vectors in the table.
"""
raise NotImplementedError
@abstractmethod
def search(
self, query: Union[VEC, str], vector_column: str = VECTOR_COLUMN_NAME
) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector.
Parameters
----------
query: list, np.ndarray
The query vector.
vector_column: str, default "vector"
The name of the vector column to search.
Returns
-------
LanceQueryBuilder
A query builder object representing the query.
Once executed, the query returns selected columns, the vector,
and also the "score" column which is the distance between the query
vector and the returned vector.
"""
raise NotImplementedError
@abstractmethod
def _execute_query(self, query: Query) -> pa.Table:
pass
class LanceTable(Table):
"""
A table in a LanceDB database.
"""
def __init__(
@@ -95,7 +217,8 @@ class LanceTable:
def _reset_dataset(self):
try:
del self.__dict__["_dataset"]
if "_dataset" in self.__dict__:
del self.__dict__["_dataset"]
except AttributeError:
pass
@@ -195,26 +318,7 @@ class LanceTable:
vector_column_name=VECTOR_COLUMN_NAME,
replace: bool = True,
):
"""Create an index on the table.
Parameters
----------
metric: str, default "L2"
The distance metric to use when creating the index.
Valid values are "L2", "cosine", or "dot".
L2 is euclidean distance.
num_partitions: int
The number of IVF partitions to use when creating the index.
Default is 256.
num_sub_vectors: int
The number of PQ sub-vectors to use when creating the index.
Default is 96.
vector_column_name: str, default "vector"
The vector column name to create the index.
replace: bool, default True
If True, replace the existing index if it exists.
If False, raise an error if duplicate index exists.
"""
"""Create an index on the table."""
self._dataset.create_index(
column=vector_column_name,
index_type="IVF_PQ",
@@ -258,7 +362,7 @@ class LanceTable:
self,
data: DATA,
mode: str = "append",
on_bad_vectors: str = "drop",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
) -> int:
"""Add data to the table.
@@ -270,9 +374,9 @@ class LanceTable:
mode: str
The mode to use when writing the data. Valid values are
"append" and "overwrite".
on_bad_vectors: str
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "raise", "drop", "fill".
One of "error", "drop", "fill".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
@@ -281,6 +385,7 @@ class LanceTable:
int
The number of vectors in the table.
"""
# TODO: manage table listing and metadata separately
data = _sanitize_data(
data, self.schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
@@ -326,10 +431,10 @@ class LanceTable:
cls,
db,
name,
data,
data=None,
schema=None,
mode="create",
on_bad_vectors: str = "drop",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
):
"""
@@ -354,37 +459,40 @@ class LanceTable:
The LanceDB instance to create the table in.
name: str
The name of the table to create.
data: list-of-dict, dict, pd.DataFrame
data: list-of-dict, dict, pd.DataFrame, default None
The data to insert into the table.
At least one of `data` or `schema` must be provided.
schema: dict, optional
The schema of the table. If not provided, the schema is inferred from the data.
At least one of `data` or `schema` must be provided.
mode: str, default "create"
The mode to use when writing the data. Valid values are
"create", "overwrite", and "append".
on_bad_vectors: str
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "raise", "drop", "fill".
One of "error", "drop", "fill".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
tbl = LanceTable(db, name)
data = _sanitize_data(
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
if data is not None:
data = _sanitize_data(
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
else:
if schema is None:
raise ValueError("Either data or schema must be provided")
data = pa.Table.from_pylist([], schema=schema)
lance.write_dataset(data, tbl._dataset_uri, mode=mode)
return tbl
return LanceTable(db, name)
@classmethod
def open(cls, db, name):
tbl = cls(db, name)
if tbl._conn.is_managed_remote:
# Not completely sure how to check for remote table existence yet.
return tbl
if not os.path.exists(tbl._dataset_uri):
raise FileNotFoundError(
f"Table {name} does not exist. Please first call db.create_table({name}, data)"
)
return tbl
def delete(self, where: str):
@@ -415,11 +523,26 @@ class LanceTable:
"""
self._dataset.delete(where)
def _execute_query(self, query: Query) -> pa.Table:
ds = self.to_lance()
return ds.to_table(
columns=query.columns,
filter=query.filter,
nearest={
"column": query.vector_column,
"q": query.vector,
"k": query.k,
"metric": query.metric,
"nprobes": query.nprobes,
"refine_factor": query.refine_factor,
},
)
def _sanitize_schema(
data: pa.Table,
schema: pa.Schema = None,
on_bad_vectors: str = "drop",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
) -> pa.Table:
"""Ensure that the table has the expected schema.
@@ -431,10 +554,10 @@ def _sanitize_schema(
schema: pa.Schema; optional
The expected schema. If not provided, this just converts the
vector column to fixed_size_list(float32) if necessary.
on_bad_vectors: str
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "raise", "drop", "fill".
fill_value: float
One of "error", "drop", "fill".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
if schema is not None:
@@ -463,7 +586,7 @@ def _sanitize_schema(
def _sanitize_vector_column(
data: pa.Table,
vector_column_name: str,
on_bad_vectors: str = "drop",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
) -> pa.Table:
"""
@@ -475,10 +598,10 @@ def _sanitize_vector_column(
The table to sanitize.
vector_column_name: str
The name of the vector column.
on_bad_vectors: str
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "raise", "drop", "fill".
fill_value: float
One of "error", "drop", "fill".
fill_value: float, default 0.0
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
if vector_column_name not in data.column_names:
@@ -501,7 +624,7 @@ def _sanitize_vector_column(
data.column_names.index(vector_column_name), vector_column_name, vec_arr
)
has_nans = pc.any(vec_arr.values.is_nan()).as_py()
has_nans = pc.any(pc.is_nan(vec_arr.values)).as_py()
if has_nans:
data = _sanitize_nans(
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
@@ -524,7 +647,7 @@ def ensure_fixed_size_list_of_f32(vec_arr):
def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
"""Sanitize jagged vectors."""
if on_bad_vectors == "raise":
if on_bad_vectors == "error":
raise ValueError(
f"Vector column {vector_column_name} has variable length vectors "
"Set on_bad_vectors='drop' to remove them, or "
@@ -538,7 +661,7 @@ def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_na
if on_bad_vectors == "fill":
if fill_value is None:
raise ValueError(
f"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
)
fill_arr = pa.scalar([float(fill_value)] * ndims)
vec_arr = pc.if_else(correct_ndims, vec_arr, fill_arr)
@@ -552,7 +675,7 @@ def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_na
def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
"""Sanitize NaNs in vectors"""
if on_bad_vectors == "raise":
if on_bad_vectors == "error":
raise ValueError(
f"Vector column {vector_column_name} has NaNs. "
"Set on_bad_vectors='drop' to remove them, or "
@@ -561,10 +684,10 @@ def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name
elif on_bad_vectors == "fill":
if fill_value is None:
raise ValueError(
f"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
)
fill_value = float(fill_value)
values = pc.if_else(vec_arr.values.is_nan(), fill_value, vec_arr.values)
values = pc.if_else(pc.is_nan(vec_arr.values), fill_value, vec_arr.values)
ndims = len(vec_arr[0])
vec_arr = pa.FixedSizeListArray.from_arrays(values, ndims)
data = data.set_column(

View File

@@ -11,9 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from urllib.parse import ParseResult, urlparse
from pyarrow import fs
from urllib.parse import urlparse
def get_uri_scheme(uri: str) -> str:

View File

@@ -1,6 +1,6 @@
[project]
name = "lancedb"
version = "0.1.9"
version = "0.1.10"
dependencies = ["pylance~=0.5.0", "ratelimiter", "retry", "tqdm", "aiohttp", "pydantic", "attr"]
description = "lancedb"
authors = [

View File

@@ -20,18 +20,33 @@ import pyarrow as pa
import pytest
from lancedb.db import LanceDBConnection
from lancedb.query import LanceQueryBuilder
from lancedb.query import LanceQueryBuilder, Query
from lancedb.table import LanceTable
class MockTable:
def __init__(self, tmp_path):
self.uri = tmp_path
self._conn = LanceDBConnection("/tmp/lance/")
self._conn = LanceDBConnection(self.uri)
def to_lance(self):
return lance.dataset(self.uri)
def _execute_query(self, query):
ds = self.to_lance()
return ds.to_table(
columns=query.columns,
filter=query.filter,
nearest={
"column": query.vector_column,
"q": query.vector,
"k": query.k,
"metric": query.metric,
"nprobes": query.nprobes,
"refine_factor": query.refine_factor,
},
)
@pytest.fixture
def table(tmp_path) -> MockTable:
@@ -94,20 +109,17 @@ def test_query_builder_with_different_vector_column():
)
ds = mock.Mock()
table.to_lance.return_value = ds
table._conn = mock.MagicMock()
table._conn.is_managed_remote = False
builder.to_arrow()
ds.to_table.assert_called_once_with(
columns=["b"],
filter="b < 10",
nearest={
"column": vector_column_name,
"q": query,
"k": 2,
"metric": "cosine",
"nprobes": 20,
"refine_factor": None,
},
table._execute_query.assert_called_once_with(
Query(
vector=query,
filter="b < 10",
k=2,
metric="cosine",
columns=["b"],
nprobes=20,
refine_factor=None,
)
)

View File

@@ -13,7 +13,7 @@
import pyarrow as pa
from lancedb.db import LanceDBConnection
import lancedb
from lancedb.remote.client import VectorQuery, VectorQueryResult
@@ -28,7 +28,7 @@ class FakeLanceDBClient:
def test_remote_db():
conn = LanceDBConnection("lancedb+http://client-will-be-injected")
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
setattr(conn, "_client", FakeLanceDBClient())
table = conn["test"]

View File

@@ -19,6 +19,7 @@ import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from lance.vector import vec_to_table
from lancedb.db import LanceDBConnection
from lancedb.table import LanceTable
@@ -89,7 +90,31 @@ def test_create_table(db):
assert expected == tbl
def test_empty_table(db):
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2)),
pa.field("item", pa.string()),
pa.field("price", pa.float32()),
]
)
tbl = LanceTable.create(db, "test", schema=schema)
data = [
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
]
tbl.add(data=data)
def test_add(db):
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2)),
pa.field("item", pa.string()),
pa.field("price", pa.float64()),
]
)
table = LanceTable.create(
db,
"test",
@@ -98,7 +123,19 @@ def test_add(db):
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
_add(table, schema)
table = LanceTable.create(db, "test2", schema=schema)
table.add(
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
_add(table, schema)
def _add(table, schema):
# table = LanceTable(db, "test")
assert len(table) == 2
@@ -113,13 +150,7 @@ def test_add(db):
pa.array(["foo", "bar", "new"]),
pa.array([10.0, 20.0, 30.0]),
],
schema=pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2)),
pa.field("item", pa.string()),
pa.field("price", pa.float64()),
]
),
schema=schema,
)
assert expected == table.to_arrow()
@@ -181,7 +212,21 @@ def test_create_index_method():
def test_add_with_nans(db):
# By default we drop bad input vectors
# by default we raise an error on bad input vectors
bad_data = [
{"vector": [np.nan], "item": "bar", "price": 20.0},
{"vector": [5], "item": "bar", "price": 20.0},
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
{"vector": [np.nan, 5.0], "item": "bar", "price": 20.0},
]
for row in bad_data:
with pytest.raises(ValueError):
LanceTable.create(
db,
"error_test",
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, row],
)
table = LanceTable.create(
db,
"drop_test",
@@ -191,6 +236,7 @@ def test_add_with_nans(db):
{"vector": [5], "item": "bar", "price": 20.0},
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
],
on_bad_vectors="drop",
)
assert len(table) == 1
@@ -210,18 +256,3 @@ def test_add_with_nans(db):
arrow_tbl = table.to_lance().to_table(filter="item == 'bar'")
v = arrow_tbl["vector"].to_pylist()[0]
assert np.allclose(v, np.array([0.0, 0.0]))
bad_data = [
{"vector": [np.nan], "item": "bar", "price": 20.0},
{"vector": [5], "item": "bar", "price": 20.0},
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
{"vector": [np.nan, 5.0], "item": "bar", "price": 20.0},
]
for row in bad_data:
with pytest.raises(ValueError):
LanceTable.create(
db,
"raise_test",
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, row],
on_bad_vectors="raise",
)