mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-08 12:52:58 +00:00
Compare commits
5 Commits
v0.1.10
...
v0.1.10-py
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
97364a2514 | ||
|
|
e6c6da6104 | ||
|
|
a5eb665b7d | ||
|
|
e2325c634b | ||
|
|
507eeae9c8 |
22
.github/workflows/docs.yml
vendored
22
.github/workflows/docs.yml
vendored
@@ -39,6 +39,28 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
python -m pip install -e .
|
python -m pip install -e .
|
||||||
python -m pip install -r ../docs/requirements.txt
|
python -m pip install -r ../docs/requirements.txt
|
||||||
|
- name: Set up node
|
||||||
|
uses: actions/setup-node@v3
|
||||||
|
with:
|
||||||
|
node-version: ${{ matrix.node-version }}
|
||||||
|
cache: 'npm'
|
||||||
|
cache-dependency-path: node/package-lock.json
|
||||||
|
- uses: Swatinem/rust-cache@v2
|
||||||
|
- name: Install node dependencies
|
||||||
|
working-directory: node
|
||||||
|
run: |
|
||||||
|
sudo apt update
|
||||||
|
sudo apt install -y protobuf-compiler libssl-dev
|
||||||
|
- name: Build node
|
||||||
|
working-directory: node
|
||||||
|
run: |
|
||||||
|
npm ci
|
||||||
|
npm run build
|
||||||
|
npm run tsc
|
||||||
|
- name: Create markdown files
|
||||||
|
working-directory: node
|
||||||
|
run: |
|
||||||
|
npx typedoc --plugin typedoc-plugin-markdown --out ../docs/src/javascript src/index.ts
|
||||||
- name: Build docs
|
- name: Build docs
|
||||||
run: |
|
run: |
|
||||||
PYTHONPATH=. mkdocs build -f docs/mkdocs.yml
|
PYTHONPATH=. mkdocs build -f docs/mkdocs.yml
|
||||||
|
|||||||
4
.github/workflows/python.yml
vendored
4
.github/workflows/python.yml
vendored
@@ -61,6 +61,8 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip install -e .
|
pip install -e .
|
||||||
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
|
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
|
||||||
pip install pytest pytest-mock
|
pip install pytest pytest-mock black
|
||||||
|
- name: Black
|
||||||
|
run: black --check --diff --no-color --quiet .
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: pytest -x -v --durations=30 tests
|
run: pytest -x -v --durations=30 tests
|
||||||
@@ -10,14 +10,16 @@ pip install lancedb
|
|||||||
|
|
||||||
::: lancedb.connect
|
::: lancedb.connect
|
||||||
|
|
||||||
::: lancedb.LanceDBConnection
|
::: lancedb.db.DBConnection
|
||||||
|
|
||||||
## Table
|
## Table
|
||||||
|
|
||||||
::: lancedb.table.LanceTable
|
::: lancedb.table.Table
|
||||||
|
|
||||||
## Querying
|
## Querying
|
||||||
|
|
||||||
|
::: lancedb.query.Query
|
||||||
|
|
||||||
::: lancedb.query.LanceQueryBuilder
|
::: lancedb.query.LanceQueryBuilder
|
||||||
|
|
||||||
::: lancedb.query.LanceFtsQueryBuilder
|
::: lancedb.query.LanceFtsQueryBuilder
|
||||||
|
|||||||
@@ -11,16 +11,24 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .db import URI, LanceDBConnection
|
from typing import Optional
|
||||||
|
|
||||||
|
from .db import URI, DBConnection, LanceDBConnection
|
||||||
|
from .remote.db import RemoteDBConnection
|
||||||
|
|
||||||
|
|
||||||
def connect(uri: URI) -> LanceDBConnection:
|
def connect(
|
||||||
"""Connect to a LanceDB instance at the given URI
|
uri: URI, *, api_key: Optional[str] = None, region: str = "us-west-2"
|
||||||
|
) -> DBConnection:
|
||||||
|
"""Connect to a LanceDB database.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
uri: str or Path
|
uri: str or Path
|
||||||
The uri of the database.
|
The uri of the database.
|
||||||
|
api_token: str, optional
|
||||||
|
If presented, connect to LanceDB cloud.
|
||||||
|
Otherwise, connect to a database on file system or cloud storage.
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
@@ -34,9 +42,17 @@ def connect(uri: URI) -> LanceDBConnection:
|
|||||||
|
|
||||||
>>> db = lancedb.connect("s3://my-bucket/lancedb")
|
>>> db = lancedb.connect("s3://my-bucket/lancedb")
|
||||||
|
|
||||||
|
Connect to LancdDB cloud:
|
||||||
|
|
||||||
|
>>> db = lancedb.connect("db://my_database", api_key="ldb_...")
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
conn : LanceDBConnection
|
conn : DBConnection
|
||||||
A connection to a LanceDB database.
|
A connection to a LanceDB database.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(uri, str) and uri.startswith("db://"):
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
|
||||||
|
return RemoteDBConnection(uri, api_key, region)
|
||||||
return LanceDBConnection(uri)
|
return LanceDBConnection(uri)
|
||||||
|
|||||||
@@ -23,3 +23,13 @@ URI = Union[str, Path]
|
|||||||
# TODO support generator
|
# TODO support generator
|
||||||
DATA = Union[List[dict], dict, pd.DataFrame]
|
DATA = Union[List[dict], dict, pd.DataFrame]
|
||||||
VECTOR_COLUMN_NAME = "vector"
|
VECTOR_COLUMN_NAME = "vector"
|
||||||
|
|
||||||
|
|
||||||
|
class Credential(str):
|
||||||
|
"""Credential field"""
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return "********"
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return "********"
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
import builtins
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
# import lancedb so we don't have to in every example
|
# import lancedb so we don't have to in every example
|
||||||
import lancedb
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
|
|||||||
@@ -15,17 +15,161 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
import os
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from pyarrow import fs
|
from pyarrow import fs
|
||||||
|
|
||||||
from .common import DATA, URI
|
from .common import DATA, URI
|
||||||
from .table import LanceTable
|
from .table import LanceTable, Table
|
||||||
from .util import get_uri_location, get_uri_scheme
|
from .util import get_uri_location, get_uri_scheme
|
||||||
|
|
||||||
|
|
||||||
class LanceDBConnection:
|
class DBConnection(ABC):
|
||||||
|
"""An active LanceDB connection interface."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def table_names(self) -> list[str]:
|
||||||
|
"""List all table names in the database."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_table(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
data: DATA = None,
|
||||||
|
schema: pa.Schema = None,
|
||||||
|
mode: str = "create",
|
||||||
|
on_bad_vectors: str = "error",
|
||||||
|
fill_value: float = 0.0,
|
||||||
|
) -> Table:
|
||||||
|
"""Create a [Table][lancedb.table.Table] in the database.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name: str
|
||||||
|
The name of the table.
|
||||||
|
data: list, tuple, dict, pd.DataFrame; optional
|
||||||
|
The data to insert into the table.
|
||||||
|
schema: pyarrow.Schema; optional
|
||||||
|
The schema of the table.
|
||||||
|
mode: str; default "create"
|
||||||
|
The mode to use when creating the table. Can be either "create" or "overwrite".
|
||||||
|
By default, if the table already exists, an exception is raised.
|
||||||
|
If you want to overwrite the table, use mode="overwrite".
|
||||||
|
on_bad_vectors: str, default "error"
|
||||||
|
What to do if any of the vectors are not the same size or contains NaNs.
|
||||||
|
One of "error", "drop", "fill".
|
||||||
|
fill_value: float
|
||||||
|
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||||
|
|
||||||
|
Note
|
||||||
|
----
|
||||||
|
The vector index won't be created by default.
|
||||||
|
To create the index, call the `create_index` method on the table.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
LanceTable
|
||||||
|
A reference to the newly created table.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
|
||||||
|
Can create with list of tuples or dictionaries:
|
||||||
|
|
||||||
|
>>> import lancedb
|
||||||
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
|
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
||||||
|
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
||||||
|
>>> db.create_table("my_table", data)
|
||||||
|
LanceTable(my_table)
|
||||||
|
>>> db["my_table"].head()
|
||||||
|
pyarrow.Table
|
||||||
|
vector: fixed_size_list<item: float>[2]
|
||||||
|
child 0, item: float
|
||||||
|
lat: double
|
||||||
|
long: double
|
||||||
|
----
|
||||||
|
vector: [[[1.1,1.2],[0.2,1.8]]]
|
||||||
|
lat: [[45.5,40.1]]
|
||||||
|
long: [[-122.7,-74.1]]
|
||||||
|
|
||||||
|
You can also pass a pandas DataFrame:
|
||||||
|
|
||||||
|
>>> import pandas as pd
|
||||||
|
>>> data = pd.DataFrame({
|
||||||
|
... "vector": [[1.1, 1.2], [0.2, 1.8]],
|
||||||
|
... "lat": [45.5, 40.1],
|
||||||
|
... "long": [-122.7, -74.1]
|
||||||
|
... })
|
||||||
|
>>> db.create_table("table2", data)
|
||||||
|
LanceTable(table2)
|
||||||
|
>>> db["table2"].head()
|
||||||
|
pyarrow.Table
|
||||||
|
vector: fixed_size_list<item: float>[2]
|
||||||
|
child 0, item: float
|
||||||
|
lat: double
|
||||||
|
long: double
|
||||||
|
----
|
||||||
|
vector: [[[1.1,1.2],[0.2,1.8]]]
|
||||||
|
lat: [[45.5,40.1]]
|
||||||
|
long: [[-122.7,-74.1]]
|
||||||
|
|
||||||
|
Data is converted to Arrow before being written to disk. For maximum
|
||||||
|
control over how data is saved, either provide the PyArrow schema to
|
||||||
|
convert to or else provide a PyArrow table directly.
|
||||||
|
|
||||||
|
>>> custom_schema = pa.schema([
|
||||||
|
... pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||||
|
... pa.field("lat", pa.float32()),
|
||||||
|
... pa.field("long", pa.float32())
|
||||||
|
... ])
|
||||||
|
>>> db.create_table("table3", data, schema = custom_schema)
|
||||||
|
LanceTable(table3)
|
||||||
|
>>> db["table3"].head()
|
||||||
|
pyarrow.Table
|
||||||
|
vector: fixed_size_list<item: float>[2]
|
||||||
|
child 0, item: float
|
||||||
|
lat: float
|
||||||
|
long: float
|
||||||
|
----
|
||||||
|
vector: [[[1.1,1.2],[0.2,1.8]]]
|
||||||
|
lat: [[45.5,40.1]]
|
||||||
|
long: [[-122.7,-74.1]]
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __getitem__(self, name: str) -> LanceTable:
|
||||||
|
return self.open_table(name)
|
||||||
|
|
||||||
|
def open_table(self, name: str) -> Table:
|
||||||
|
"""Open a Lance Table in the database.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name: str
|
||||||
|
The name of the table.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A LanceTable object representing the table.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def drop_table(self, name: str):
|
||||||
|
"""Drop a table from the database.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name: str
|
||||||
|
The name of the table.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class LanceDBConnection(DBConnection):
|
||||||
"""
|
"""
|
||||||
A connection to a LanceDB database.
|
A connection to a LanceDB database.
|
||||||
|
|
||||||
@@ -59,13 +203,6 @@ class LanceDBConnection:
|
|||||||
if not isinstance(uri, Path):
|
if not isinstance(uri, Path):
|
||||||
scheme = get_uri_scheme(uri)
|
scheme = get_uri_scheme(uri)
|
||||||
is_local = isinstance(uri, Path) or scheme == "file"
|
is_local = isinstance(uri, Path) or scheme == "file"
|
||||||
# managed lancedb remote uses schema like lancedb+[http|grpc|...]://
|
|
||||||
self._is_managed_remote = not is_local and scheme.startswith("lancedb")
|
|
||||||
if self._is_managed_remote:
|
|
||||||
if len(scheme.split("+")) != 2:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid LanceDB URI: {uri}, expected uri to have scheme like lancedb+<flavor>://..."
|
|
||||||
)
|
|
||||||
if is_local:
|
if is_local:
|
||||||
if isinstance(uri, str):
|
if isinstance(uri, str):
|
||||||
uri = Path(uri)
|
uri = Path(uri)
|
||||||
@@ -79,43 +216,6 @@ class LanceDBConnection:
|
|||||||
def uri(self) -> str:
|
def uri(self) -> str:
|
||||||
return self._uri
|
return self._uri
|
||||||
|
|
||||||
@functools.cached_property
|
|
||||||
def is_managed_remote(self) -> bool:
|
|
||||||
return self._is_managed_remote
|
|
||||||
|
|
||||||
@functools.cached_property
|
|
||||||
def remote_flavor(self) -> str:
|
|
||||||
if not self.is_managed_remote:
|
|
||||||
raise ValueError(
|
|
||||||
"Not a managed remote LanceDB, there should be no server flavor"
|
|
||||||
)
|
|
||||||
return get_uri_scheme(self.uri).split("+")[1]
|
|
||||||
|
|
||||||
@functools.cached_property
|
|
||||||
def _client(self) -> "lancedb.remote.LanceDBClient":
|
|
||||||
if not self.is_managed_remote:
|
|
||||||
raise ValueError("Not a managed remote LanceDB, there should be no client")
|
|
||||||
|
|
||||||
# don't import unless we are really using remote
|
|
||||||
from lancedb.remote.client import RestfulLanceDBClient
|
|
||||||
|
|
||||||
if self.remote_flavor == "http":
|
|
||||||
return RestfulLanceDBClient(self._uri)
|
|
||||||
|
|
||||||
raise ValueError("Unsupported remote flavor: " + self.remote_flavor)
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
if self._entered:
|
|
||||||
raise ValueError("Cannot re-enter the same LanceDBConnection twice")
|
|
||||||
self._entered = True
|
|
||||||
await self._client.close()
|
|
||||||
|
|
||||||
async def __aenter__(self) -> LanceDBConnection:
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
||||||
await self.close()
|
|
||||||
|
|
||||||
def table_names(self) -> list[str]:
|
def table_names(self) -> list[str]:
|
||||||
"""Get the names of all tables in the database.
|
"""Get the names of all tables in the database.
|
||||||
|
|
||||||
@@ -149,16 +249,13 @@ class LanceDBConnection:
|
|||||||
def __contains__(self, name: str) -> bool:
|
def __contains__(self, name: str) -> bool:
|
||||||
return name in self.table_names()
|
return name in self.table_names()
|
||||||
|
|
||||||
def __getitem__(self, name: str) -> LanceTable:
|
|
||||||
return self.open_table(name)
|
|
||||||
|
|
||||||
def create_table(
|
def create_table(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
data: DATA = None,
|
data: DATA = None,
|
||||||
schema: pa.Schema = None,
|
schema: pa.Schema = None,
|
||||||
mode: str = "create",
|
mode: str = "create",
|
||||||
on_bad_vectors: str = "drop",
|
on_bad_vectors: str = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
) -> LanceTable:
|
) -> LanceTable:
|
||||||
"""Create a table in the database.
|
"""Create a table in the database.
|
||||||
@@ -175,9 +272,9 @@ class LanceDBConnection:
|
|||||||
The mode to use when creating the table. Can be either "create" or "overwrite".
|
The mode to use when creating the table. Can be either "create" or "overwrite".
|
||||||
By default, if the table already exists, an exception is raised.
|
By default, if the table already exists, an exception is raised.
|
||||||
If you want to overwrite the table, use mode="overwrite".
|
If you want to overwrite the table, use mode="overwrite".
|
||||||
on_bad_vectors: str
|
on_bad_vectors: str, default "error"
|
||||||
What to do if any of the vectors are not the same size or contains NaNs.
|
What to do if any of the vectors are not the same size or contains NaNs.
|
||||||
One of "raise", "drop", "fill".
|
One of "error", "drop", "fill".
|
||||||
fill_value: float
|
fill_value: float
|
||||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||||
|
|
||||||
|
|||||||
@@ -10,18 +10,47 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
from typing import List, Literal, Optional, Union
|
||||||
from typing import Awaitable, Literal
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from .common import VECTOR_COLUMN_NAME
|
from .common import VECTOR_COLUMN_NAME
|
||||||
|
|
||||||
|
|
||||||
|
class Query(BaseModel):
|
||||||
|
"""A Query"""
|
||||||
|
|
||||||
|
vector_column: str = VECTOR_COLUMN_NAME
|
||||||
|
|
||||||
|
# vector to search for
|
||||||
|
vector: List[float]
|
||||||
|
|
||||||
|
# sql filter to refine the query with
|
||||||
|
filter: Optional[str] = None
|
||||||
|
|
||||||
|
# top k results to return
|
||||||
|
k: int
|
||||||
|
|
||||||
|
# # metrics
|
||||||
|
metric: str = "L2"
|
||||||
|
|
||||||
|
# which columns to return in the results
|
||||||
|
columns: Optional[List[str]] = None
|
||||||
|
|
||||||
|
# optional query parameters for tuning the results,
|
||||||
|
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
|
||||||
|
nprobes: int = 10
|
||||||
|
|
||||||
|
# Refine factor.
|
||||||
|
refine_factor: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class LanceQueryBuilder:
|
class LanceQueryBuilder:
|
||||||
"""
|
"""
|
||||||
A builder for nearest neighbor queries for LanceDB.
|
A builder for nearest neighbor queries for LanceDB.
|
||||||
@@ -47,9 +76,9 @@ class LanceQueryBuilder:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
table: "lancedb.table.LanceTable",
|
table: "lancedb.table.Table",
|
||||||
query: np.ndarray,
|
query: Union[np.ndarray, str],
|
||||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
vector_column: str = VECTOR_COLUMN_NAME,
|
||||||
):
|
):
|
||||||
self._metric = "L2"
|
self._metric = "L2"
|
||||||
self._nprobes = 20
|
self._nprobes = 20
|
||||||
@@ -59,7 +88,7 @@ class LanceQueryBuilder:
|
|||||||
self._limit = 10
|
self._limit = 10
|
||||||
self._columns = None
|
self._columns = None
|
||||||
self._where = None
|
self._where = None
|
||||||
self._vector_column_name = vector_column_name
|
self._vector_column = vector_column
|
||||||
|
|
||||||
def limit(self, limit: int) -> LanceQueryBuilder:
|
def limit(self, limit: int) -> LanceQueryBuilder:
|
||||||
"""Set the maximum number of results to return.
|
"""Set the maximum number of results to return.
|
||||||
@@ -181,52 +210,28 @@ class LanceQueryBuilder:
|
|||||||
|
|
||||||
def to_arrow(self) -> pa.Table:
|
def to_arrow(self) -> pa.Table:
|
||||||
"""
|
"""
|
||||||
Execute the query and return the results as a arrow Table.
|
Execute the query and return the results as an
|
||||||
|
[Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table).
|
||||||
|
|
||||||
In addition to the selected columns, LanceDB also returns a vector
|
In addition to the selected columns, LanceDB also returns a vector
|
||||||
and also the "score" column which is the distance between the query
|
and also the "score" column which is the distance between the query
|
||||||
vector and the returned vector.
|
vector and the returned vectors.
|
||||||
"""
|
"""
|
||||||
if self._table._conn.is_managed_remote:
|
vector = self._query if isinstance(self._query, list) else self._query.tolist()
|
||||||
try:
|
query = Query(
|
||||||
loop = asyncio.get_running_loop()
|
vector=vector,
|
||||||
except RuntimeError:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
result = self._table._conn._client.query(
|
|
||||||
self._table.name, self.to_remote_query()
|
|
||||||
)
|
|
||||||
return loop.run_until_complete(result).to_arrow()
|
|
||||||
|
|
||||||
ds = self._table.to_lance()
|
|
||||||
return ds.to_table(
|
|
||||||
columns=self._columns,
|
|
||||||
filter=self._where,
|
|
||||||
nearest={
|
|
||||||
"column": self._vector_column_name,
|
|
||||||
"q": self._query,
|
|
||||||
"k": self._limit,
|
|
||||||
"metric": self._metric,
|
|
||||||
"nprobes": self._nprobes,
|
|
||||||
"refine_factor": self._refine_factor,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_remote_query(self) -> "VectorQuery":
|
|
||||||
# don't import unless we are connecting to remote
|
|
||||||
from lancedb.remote.client import VectorQuery
|
|
||||||
|
|
||||||
return VectorQuery(
|
|
||||||
vector=self._query.tolist(),
|
|
||||||
filter=self._where,
|
filter=self._where,
|
||||||
k=self._limit,
|
k=self._limit,
|
||||||
_metric=self._metric,
|
metric=self._metric,
|
||||||
columns=self._columns,
|
columns=self._columns,
|
||||||
nprobes=self._nprobes,
|
nprobes=self._nprobes,
|
||||||
refine_factor=self._refine_factor,
|
refine_factor=self._refine_factor,
|
||||||
)
|
)
|
||||||
|
return self._table._execute_query(query)
|
||||||
|
|
||||||
|
|
||||||
class LanceFtsQueryBuilder(LanceQueryBuilder):
|
class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||||
def to_df(self) -> pd.DataFrame:
|
def to_arrow(self) -> pd.Table:
|
||||||
try:
|
try:
|
||||||
import tantivy
|
import tantivy
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -243,8 +248,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
|||||||
# get the scores and doc ids
|
# get the scores and doc ids
|
||||||
row_ids, scores = search_index(index, self._query, self._limit)
|
row_ids, scores = search_index(index, self._query, self._limit)
|
||||||
if len(row_ids) == 0:
|
if len(row_ids) == 0:
|
||||||
return pd.DataFrame()
|
empty_schema = pa.schema([pa.field("score", pa.float32())])
|
||||||
|
return pa.Table.from_pylist([], schema=empty_schema)
|
||||||
scores = pa.array(scores)
|
scores = pa.array(scores)
|
||||||
output_tbl = self._table.to_lance().take(row_ids, columns=self._columns)
|
output_tbl = self._table.to_lance().take(row_ids, columns=self._columns)
|
||||||
output_tbl = output_tbl.append_column("score", scores)
|
output_tbl = output_tbl.append_column("score", scores)
|
||||||
return output_tbl.to_pandas()
|
return output_tbl
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ import abc
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import pandas as pd
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|||||||
@@ -13,12 +13,13 @@
|
|||||||
|
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import urllib.parse
|
from typing import Dict
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import attr
|
import attr
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
|
from lancedb.common import Credential
|
||||||
from lancedb.remote import VectorQuery, VectorQueryResult
|
from lancedb.remote import VectorQuery, VectorQueryResult
|
||||||
from lancedb.remote.errors import LanceDBClientError
|
from lancedb.remote.errors import LanceDBClientError
|
||||||
|
|
||||||
@@ -35,29 +36,32 @@ def _check_not_closed(f):
|
|||||||
|
|
||||||
@attr.define(slots=False)
|
@attr.define(slots=False)
|
||||||
class RestfulLanceDBClient:
|
class RestfulLanceDBClient:
|
||||||
url: str
|
db_name: str
|
||||||
|
region: str
|
||||||
|
api_key: Credential
|
||||||
closed: bool = attr.field(default=False, init=False)
|
closed: bool = attr.field(default=False, init=False)
|
||||||
|
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def session(self) -> aiohttp.ClientSession:
|
def session(self) -> aiohttp.ClientSession:
|
||||||
parsed = urllib.parse.urlparse(self.url)
|
url = f"https://{self.db_name}.{self.region}.api.lancedb.com"
|
||||||
scheme = parsed.scheme
|
|
||||||
if not scheme.startswith("lancedb"):
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid scheme: {scheme}, must be like lancedb+<flavor>://"
|
|
||||||
)
|
|
||||||
flavor = scheme.split("+")[1]
|
|
||||||
url = f"{flavor}://{parsed.hostname}:{parsed.port}"
|
|
||||||
return aiohttp.ClientSession(url)
|
return aiohttp.ClientSession(url)
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
await self.session.close()
|
await self.session.close()
|
||||||
self.closed = True
|
self.closed = True
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def headers(self) -> Dict[str, str]:
|
||||||
|
return {
|
||||||
|
"x-api-key": self.api_key,
|
||||||
|
}
|
||||||
|
|
||||||
@_check_not_closed
|
@_check_not_closed
|
||||||
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||||
async with self.session.post(
|
async with self.session.post(
|
||||||
f"/table/{table_name}/", json=query.dict(exclude_none=True)
|
f"/1/table/{table_name}/",
|
||||||
|
json=query.dict(exclude_none=True),
|
||||||
|
headers=self.headers,
|
||||||
) as resp:
|
) as resp:
|
||||||
resp: aiohttp.ClientResponse = resp
|
resp: aiohttp.ClientResponse = resp
|
||||||
if 400 <= resp.status < 500:
|
if 400 <= resp.status < 500:
|
||||||
|
|||||||
71
python/lancedb/remote/db.py
Normal file
71
python/lancedb/remote/db.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
# Copyright 2023 LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
from lancedb.common import DATA
|
||||||
|
from lancedb.db import DBConnection
|
||||||
|
from lancedb.table import Table
|
||||||
|
|
||||||
|
from .client import RestfulLanceDBClient
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteDBConnection(DBConnection):
|
||||||
|
"""A connection to a remote LanceDB database."""
|
||||||
|
|
||||||
|
def __init__(self, db_url: str, api_key: str, region: str):
|
||||||
|
"""Connect to a remote LanceDB database."""
|
||||||
|
parsed = urlparse(db_url)
|
||||||
|
if parsed.scheme != "db":
|
||||||
|
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
|
||||||
|
self.db_name = parsed.netloc
|
||||||
|
self.api_key = api_key
|
||||||
|
self._client = RestfulLanceDBClient(self.db_name, region, api_key)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"RemoveConnect(name={self.db_name})"
|
||||||
|
|
||||||
|
def table_names(self) -> List[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def open_table(self, name: str) -> Table:
|
||||||
|
"""Open a Lance Table in the database.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name: str
|
||||||
|
The name of the table.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A LanceTable object representing the table.
|
||||||
|
"""
|
||||||
|
from .table import RemoteTable
|
||||||
|
|
||||||
|
# TODO: check if table exists
|
||||||
|
|
||||||
|
return RemoteTable(self, name)
|
||||||
|
|
||||||
|
def create_table(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
data: DATA = None,
|
||||||
|
schema: pa.Schema = None,
|
||||||
|
mode: str = "create",
|
||||||
|
on_bad_vectors: str = "error",
|
||||||
|
fill_value: float = 0.0,
|
||||||
|
) -> Table:
|
||||||
|
raise NotImplementedError
|
||||||
70
python/lancedb/remote/table.py
Normal file
70
python/lancedb/remote/table.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
# Copyright 2023 LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||||
|
|
||||||
|
from ..query import LanceQueryBuilder, Query
|
||||||
|
from ..table import Query, Table
|
||||||
|
from .db import RemoteDBConnection
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteTable(Table):
|
||||||
|
def __init__(self, conn: RemoteDBConnection, name: str):
|
||||||
|
self._conn = conn
|
||||||
|
self._name = name
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"RemoteTable({self._conn.db_name}.{self.name})"
|
||||||
|
|
||||||
|
def schema(self) -> pa.Schema:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def to_arrow(self) -> pa.Table:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def create_index(
|
||||||
|
self,
|
||||||
|
metric="L2",
|
||||||
|
num_partitions=256,
|
||||||
|
num_sub_vectors=96,
|
||||||
|
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||||
|
replace: bool = True,
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def add(
|
||||||
|
self,
|
||||||
|
data: DATA,
|
||||||
|
mode: str = "append",
|
||||||
|
on_bad_vectors: str = "error",
|
||||||
|
fill_value: float = 0.0,
|
||||||
|
) -> int:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self, query: Union[VEC, str], vector_column: str = VECTOR_COLUMN_NAME
|
||||||
|
) -> LanceQueryBuilder:
|
||||||
|
return LanceQueryBuilder(self, query, vector_column)
|
||||||
|
|
||||||
|
def _execute_query(self, query: Query) -> pa.Table:
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
result = self._conn._client.query(self._name, query)
|
||||||
|
return loop.run_until_complete(result).to_arrow()
|
||||||
@@ -14,19 +14,21 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any, List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
import lance
|
import lance
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pyarrow.compute as pc
|
import pyarrow.compute as pc
|
||||||
|
import pyarrow.fs
|
||||||
from lance import LanceDataset
|
from lance import LanceDataset
|
||||||
from lance.vector import vec_to_table
|
from lance.vector import vec_to_table
|
||||||
|
|
||||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||||
from .query import LanceFtsQueryBuilder, LanceQueryBuilder
|
from .query import LanceFtsQueryBuilder, LanceQueryBuilder, Query
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_data(data, schema, on_bad_vectors, fill_value):
|
def _sanitize_data(data, schema, on_bad_vectors, fill_value):
|
||||||
@@ -47,14 +49,14 @@ def _sanitize_data(data, schema, on_bad_vectors, fill_value):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
class LanceTable:
|
class Table(ABC):
|
||||||
"""
|
"""
|
||||||
A table in a LanceDB database.
|
A [Table](Table) is a collection of Records in a LanceDB [Database](Database).
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
|
|
||||||
Create using [LanceDBConnection.create_table][lancedb.LanceDBConnection.create_table]
|
Create using [DBConnection.create_table][lancedb.DBConnection.create_table]
|
||||||
(more examples in that method's documentation).
|
(more examples in that method's documentation).
|
||||||
|
|
||||||
>>> import lancedb
|
>>> import lancedb
|
||||||
@@ -69,12 +71,12 @@ class LanceTable:
|
|||||||
vector: [[[1.1,1.2]]]
|
vector: [[[1.1,1.2]]]
|
||||||
b: [[2]]
|
b: [[2]]
|
||||||
|
|
||||||
Can append new data with [LanceTable.add][lancedb.table.LanceTable.add].
|
Can append new data with [Table.add()][lancedb.table.Table.add].
|
||||||
|
|
||||||
>>> table.add([{"vector": [0.5, 1.3], "b": 4}])
|
>>> table.add([{"vector": [0.5, 1.3], "b": 4}])
|
||||||
2
|
2
|
||||||
|
|
||||||
Can query the table with [LanceTable.search][lancedb.table.LanceTable.search].
|
Can query the table with [Table.search][lancedb.table.Table.search].
|
||||||
|
|
||||||
>>> table.search([0.4, 0.4]).select(["b"]).to_df()
|
>>> table.search([0.4, 0.4]).select(["b"]).to_df()
|
||||||
b vector score
|
b vector score
|
||||||
@@ -82,8 +84,128 @@ class LanceTable:
|
|||||||
1 2 [1.1, 1.2] 1.13
|
1 2 [1.1, 1.2] 1.13
|
||||||
|
|
||||||
Search queries are much faster when an index is created. See
|
Search queries are much faster when an index is created. See
|
||||||
[LanceTable.create_index][lancedb.table.LanceTable.create_index].
|
[Table.create_index][lancedb.table.Table.create_index].
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def schema(self) -> pa.Schema:
|
||||||
|
"""Return the [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#) of
|
||||||
|
this [Table](Table)
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def to_pandas(self) -> pd.DataFrame:
|
||||||
|
"""Return the table as a pandas DataFrame.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
pd.DataFrame
|
||||||
|
"""
|
||||||
|
return self.to_arrow().to_pandas()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def to_arrow(self) -> pa.Table:
|
||||||
|
"""Return the table as a pyarrow Table.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
pa.Table
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def create_index(
|
||||||
|
self,
|
||||||
|
metric="L2",
|
||||||
|
num_partitions=256,
|
||||||
|
num_sub_vectors=96,
|
||||||
|
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||||
|
replace: bool = True,
|
||||||
|
):
|
||||||
|
"""Create an index on the table.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
metric: str, default "L2"
|
||||||
|
The distance metric to use when creating the index.
|
||||||
|
Valid values are "L2", "cosine", or "dot".
|
||||||
|
L2 is euclidean distance.
|
||||||
|
num_partitions: int
|
||||||
|
The number of IVF partitions to use when creating the index.
|
||||||
|
Default is 256.
|
||||||
|
num_sub_vectors: int
|
||||||
|
The number of PQ sub-vectors to use when creating the index.
|
||||||
|
Default is 96.
|
||||||
|
vector_column_name: str, default "vector"
|
||||||
|
The vector column name to create the index.
|
||||||
|
replace: bool, default True
|
||||||
|
If True, replace the existing index if it exists.
|
||||||
|
If False, raise an error if duplicate index exists.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add(
|
||||||
|
self,
|
||||||
|
data: DATA,
|
||||||
|
mode: str = "append",
|
||||||
|
on_bad_vectors: str = "error",
|
||||||
|
fill_value: float = 0.0,
|
||||||
|
) -> int:
|
||||||
|
"""Add more data to the [Table](Table).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data: list-of-dict, dict, pd.DataFrame
|
||||||
|
The data to insert into the table.
|
||||||
|
mode: str
|
||||||
|
The mode to use when writing the data. Valid values are
|
||||||
|
"append" and "overwrite".
|
||||||
|
on_bad_vectors: str, default "error"
|
||||||
|
What to do if any of the vectors are not the same size or contains NaNs.
|
||||||
|
One of "error", "drop", "fill".
|
||||||
|
fill_value: float, default 0.
|
||||||
|
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
int
|
||||||
|
The number of vectors in the table.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search(
|
||||||
|
self, query: Union[VEC, str], vector_column: str = VECTOR_COLUMN_NAME
|
||||||
|
) -> LanceQueryBuilder:
|
||||||
|
"""Create a search query to find the nearest neighbors
|
||||||
|
of the given query vector.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
query: list, np.ndarray
|
||||||
|
The query vector.
|
||||||
|
vector_column: str, default "vector"
|
||||||
|
The name of the vector column to search.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
LanceQueryBuilder
|
||||||
|
A query builder object representing the query.
|
||||||
|
Once executed, the query returns selected columns, the vector,
|
||||||
|
and also the "score" column which is the distance between the query
|
||||||
|
vector and the returned vector.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _execute_query(self, query: Query) -> pa.Table:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LanceTable(Table):
|
||||||
|
"""
|
||||||
|
A table in a LanceDB database.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -95,6 +217,7 @@ class LanceTable:
|
|||||||
|
|
||||||
def _reset_dataset(self):
|
def _reset_dataset(self):
|
||||||
try:
|
try:
|
||||||
|
if "_dataset" in self.__dict__:
|
||||||
del self.__dict__["_dataset"]
|
del self.__dict__["_dataset"]
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
@@ -195,26 +318,7 @@ class LanceTable:
|
|||||||
vector_column_name=VECTOR_COLUMN_NAME,
|
vector_column_name=VECTOR_COLUMN_NAME,
|
||||||
replace: bool = True,
|
replace: bool = True,
|
||||||
):
|
):
|
||||||
"""Create an index on the table.
|
"""Create an index on the table."""
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
metric: str, default "L2"
|
|
||||||
The distance metric to use when creating the index.
|
|
||||||
Valid values are "L2", "cosine", or "dot".
|
|
||||||
L2 is euclidean distance.
|
|
||||||
num_partitions: int
|
|
||||||
The number of IVF partitions to use when creating the index.
|
|
||||||
Default is 256.
|
|
||||||
num_sub_vectors: int
|
|
||||||
The number of PQ sub-vectors to use when creating the index.
|
|
||||||
Default is 96.
|
|
||||||
vector_column_name: str, default "vector"
|
|
||||||
The vector column name to create the index.
|
|
||||||
replace: bool, default True
|
|
||||||
If True, replace the existing index if it exists.
|
|
||||||
If False, raise an error if duplicate index exists.
|
|
||||||
"""
|
|
||||||
self._dataset.create_index(
|
self._dataset.create_index(
|
||||||
column=vector_column_name,
|
column=vector_column_name,
|
||||||
index_type="IVF_PQ",
|
index_type="IVF_PQ",
|
||||||
@@ -258,7 +362,7 @@ class LanceTable:
|
|||||||
self,
|
self,
|
||||||
data: DATA,
|
data: DATA,
|
||||||
mode: str = "append",
|
mode: str = "append",
|
||||||
on_bad_vectors: str = "drop",
|
on_bad_vectors: str = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Add data to the table.
|
"""Add data to the table.
|
||||||
@@ -270,9 +374,9 @@ class LanceTable:
|
|||||||
mode: str
|
mode: str
|
||||||
The mode to use when writing the data. Valid values are
|
The mode to use when writing the data. Valid values are
|
||||||
"append" and "overwrite".
|
"append" and "overwrite".
|
||||||
on_bad_vectors: str
|
on_bad_vectors: str, default "error"
|
||||||
What to do if any of the vectors are not the same size or contains NaNs.
|
What to do if any of the vectors are not the same size or contains NaNs.
|
||||||
One of "raise", "drop", "fill".
|
One of "error", "drop", "fill".
|
||||||
fill_value: float, default 0.
|
fill_value: float, default 0.
|
||||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||||
|
|
||||||
@@ -281,6 +385,7 @@ class LanceTable:
|
|||||||
int
|
int
|
||||||
The number of vectors in the table.
|
The number of vectors in the table.
|
||||||
"""
|
"""
|
||||||
|
# TODO: manage table listing and metadata separately
|
||||||
data = _sanitize_data(
|
data = _sanitize_data(
|
||||||
data, self.schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
data, self.schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||||
)
|
)
|
||||||
@@ -326,10 +431,10 @@ class LanceTable:
|
|||||||
cls,
|
cls,
|
||||||
db,
|
db,
|
||||||
name,
|
name,
|
||||||
data,
|
data=None,
|
||||||
schema=None,
|
schema=None,
|
||||||
mode="create",
|
mode="create",
|
||||||
on_bad_vectors: str = "drop",
|
on_bad_vectors: str = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -354,37 +459,40 @@ class LanceTable:
|
|||||||
The LanceDB instance to create the table in.
|
The LanceDB instance to create the table in.
|
||||||
name: str
|
name: str
|
||||||
The name of the table to create.
|
The name of the table to create.
|
||||||
data: list-of-dict, dict, pd.DataFrame
|
data: list-of-dict, dict, pd.DataFrame, default None
|
||||||
The data to insert into the table.
|
The data to insert into the table.
|
||||||
|
At least one of `data` or `schema` must be provided.
|
||||||
schema: dict, optional
|
schema: dict, optional
|
||||||
The schema of the table. If not provided, the schema is inferred from the data.
|
The schema of the table. If not provided, the schema is inferred from the data.
|
||||||
|
At least one of `data` or `schema` must be provided.
|
||||||
mode: str, default "create"
|
mode: str, default "create"
|
||||||
The mode to use when writing the data. Valid values are
|
The mode to use when writing the data. Valid values are
|
||||||
"create", "overwrite", and "append".
|
"create", "overwrite", and "append".
|
||||||
on_bad_vectors: str
|
on_bad_vectors: str, default "error"
|
||||||
What to do if any of the vectors are not the same size or contains NaNs.
|
What to do if any of the vectors are not the same size or contains NaNs.
|
||||||
One of "raise", "drop", "fill".
|
One of "error", "drop", "fill".
|
||||||
fill_value: float, default 0.
|
fill_value: float, default 0.
|
||||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||||
"""
|
"""
|
||||||
tbl = LanceTable(db, name)
|
tbl = LanceTable(db, name)
|
||||||
|
if data is not None:
|
||||||
data = _sanitize_data(
|
data = _sanitize_data(
|
||||||
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if schema is None:
|
||||||
|
raise ValueError("Either data or schema must be provided")
|
||||||
|
data = pa.Table.from_pylist([], schema=schema)
|
||||||
lance.write_dataset(data, tbl._dataset_uri, mode=mode)
|
lance.write_dataset(data, tbl._dataset_uri, mode=mode)
|
||||||
return tbl
|
return LanceTable(db, name)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def open(cls, db, name):
|
def open(cls, db, name):
|
||||||
tbl = cls(db, name)
|
tbl = cls(db, name)
|
||||||
if tbl._conn.is_managed_remote:
|
|
||||||
# Not completely sure how to check for remote table existence yet.
|
|
||||||
return tbl
|
|
||||||
if not os.path.exists(tbl._dataset_uri):
|
if not os.path.exists(tbl._dataset_uri):
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"Table {name} does not exist. Please first call db.create_table({name}, data)"
|
f"Table {name} does not exist. Please first call db.create_table({name}, data)"
|
||||||
)
|
)
|
||||||
|
|
||||||
return tbl
|
return tbl
|
||||||
|
|
||||||
def delete(self, where: str):
|
def delete(self, where: str):
|
||||||
@@ -415,11 +523,26 @@ class LanceTable:
|
|||||||
"""
|
"""
|
||||||
self._dataset.delete(where)
|
self._dataset.delete(where)
|
||||||
|
|
||||||
|
def _execute_query(self, query: Query) -> pa.Table:
|
||||||
|
ds = self.to_lance()
|
||||||
|
return ds.to_table(
|
||||||
|
columns=query.columns,
|
||||||
|
filter=query.filter,
|
||||||
|
nearest={
|
||||||
|
"column": query.vector_column,
|
||||||
|
"q": query.vector,
|
||||||
|
"k": query.k,
|
||||||
|
"metric": query.metric,
|
||||||
|
"nprobes": query.nprobes,
|
||||||
|
"refine_factor": query.refine_factor,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_schema(
|
def _sanitize_schema(
|
||||||
data: pa.Table,
|
data: pa.Table,
|
||||||
schema: pa.Schema = None,
|
schema: pa.Schema = None,
|
||||||
on_bad_vectors: str = "drop",
|
on_bad_vectors: str = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
) -> pa.Table:
|
) -> pa.Table:
|
||||||
"""Ensure that the table has the expected schema.
|
"""Ensure that the table has the expected schema.
|
||||||
@@ -431,10 +554,10 @@ def _sanitize_schema(
|
|||||||
schema: pa.Schema; optional
|
schema: pa.Schema; optional
|
||||||
The expected schema. If not provided, this just converts the
|
The expected schema. If not provided, this just converts the
|
||||||
vector column to fixed_size_list(float32) if necessary.
|
vector column to fixed_size_list(float32) if necessary.
|
||||||
on_bad_vectors: str
|
on_bad_vectors: str, default "error"
|
||||||
What to do if any of the vectors are not the same size or contains NaNs.
|
What to do if any of the vectors are not the same size or contains NaNs.
|
||||||
One of "raise", "drop", "fill".
|
One of "error", "drop", "fill".
|
||||||
fill_value: float
|
fill_value: float, default 0.
|
||||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||||
"""
|
"""
|
||||||
if schema is not None:
|
if schema is not None:
|
||||||
@@ -463,7 +586,7 @@ def _sanitize_schema(
|
|||||||
def _sanitize_vector_column(
|
def _sanitize_vector_column(
|
||||||
data: pa.Table,
|
data: pa.Table,
|
||||||
vector_column_name: str,
|
vector_column_name: str,
|
||||||
on_bad_vectors: str = "drop",
|
on_bad_vectors: str = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
) -> pa.Table:
|
) -> pa.Table:
|
||||||
"""
|
"""
|
||||||
@@ -475,10 +598,10 @@ def _sanitize_vector_column(
|
|||||||
The table to sanitize.
|
The table to sanitize.
|
||||||
vector_column_name: str
|
vector_column_name: str
|
||||||
The name of the vector column.
|
The name of the vector column.
|
||||||
on_bad_vectors: str
|
on_bad_vectors: str, default "error"
|
||||||
What to do if any of the vectors are not the same size or contains NaNs.
|
What to do if any of the vectors are not the same size or contains NaNs.
|
||||||
One of "raise", "drop", "fill".
|
One of "error", "drop", "fill".
|
||||||
fill_value: float
|
fill_value: float, default 0.0
|
||||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||||
"""
|
"""
|
||||||
if vector_column_name not in data.column_names:
|
if vector_column_name not in data.column_names:
|
||||||
@@ -501,7 +624,7 @@ def _sanitize_vector_column(
|
|||||||
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
||||||
)
|
)
|
||||||
|
|
||||||
has_nans = pc.any(vec_arr.values.is_nan()).as_py()
|
has_nans = pc.any(pc.is_nan(vec_arr.values)).as_py()
|
||||||
if has_nans:
|
if has_nans:
|
||||||
data = _sanitize_nans(
|
data = _sanitize_nans(
|
||||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||||
@@ -524,7 +647,7 @@ def ensure_fixed_size_list_of_f32(vec_arr):
|
|||||||
|
|
||||||
def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
|
def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
|
||||||
"""Sanitize jagged vectors."""
|
"""Sanitize jagged vectors."""
|
||||||
if on_bad_vectors == "raise":
|
if on_bad_vectors == "error":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Vector column {vector_column_name} has variable length vectors "
|
f"Vector column {vector_column_name} has variable length vectors "
|
||||||
"Set on_bad_vectors='drop' to remove them, or "
|
"Set on_bad_vectors='drop' to remove them, or "
|
||||||
@@ -538,7 +661,7 @@ def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_na
|
|||||||
if on_bad_vectors == "fill":
|
if on_bad_vectors == "fill":
|
||||||
if fill_value is None:
|
if fill_value is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
|
"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
|
||||||
)
|
)
|
||||||
fill_arr = pa.scalar([float(fill_value)] * ndims)
|
fill_arr = pa.scalar([float(fill_value)] * ndims)
|
||||||
vec_arr = pc.if_else(correct_ndims, vec_arr, fill_arr)
|
vec_arr = pc.if_else(correct_ndims, vec_arr, fill_arr)
|
||||||
@@ -552,7 +675,7 @@ def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_na
|
|||||||
|
|
||||||
def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
|
def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
|
||||||
"""Sanitize NaNs in vectors"""
|
"""Sanitize NaNs in vectors"""
|
||||||
if on_bad_vectors == "raise":
|
if on_bad_vectors == "error":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Vector column {vector_column_name} has NaNs. "
|
f"Vector column {vector_column_name} has NaNs. "
|
||||||
"Set on_bad_vectors='drop' to remove them, or "
|
"Set on_bad_vectors='drop' to remove them, or "
|
||||||
@@ -561,10 +684,10 @@ def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
|||||||
elif on_bad_vectors == "fill":
|
elif on_bad_vectors == "fill":
|
||||||
if fill_value is None:
|
if fill_value is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
|
"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
|
||||||
)
|
)
|
||||||
fill_value = float(fill_value)
|
fill_value = float(fill_value)
|
||||||
values = pc.if_else(vec_arr.values.is_nan(), fill_value, vec_arr.values)
|
values = pc.if_else(pc.is_nan(vec_arr.values), fill_value, vec_arr.values)
|
||||||
ndims = len(vec_arr[0])
|
ndims = len(vec_arr[0])
|
||||||
vec_arr = pa.FixedSizeListArray.from_arrays(values, ndims)
|
vec_arr = pa.FixedSizeListArray.from_arrays(values, ndims)
|
||||||
data = data.set_column(
|
data = data.set_column(
|
||||||
|
|||||||
@@ -11,9 +11,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from urllib.parse import ParseResult, urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from pyarrow import fs
|
|
||||||
|
|
||||||
|
|
||||||
def get_uri_scheme(uri: str) -> str:
|
def get_uri_scheme(uri: str) -> str:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.1.9"
|
version = "0.1.10"
|
||||||
dependencies = ["pylance~=0.5.0", "ratelimiter", "retry", "tqdm", "aiohttp", "pydantic", "attr"]
|
dependencies = ["pylance~=0.5.0", "ratelimiter", "retry", "tqdm", "aiohttp", "pydantic", "attr"]
|
||||||
description = "lancedb"
|
description = "lancedb"
|
||||||
authors = [
|
authors = [
|
||||||
|
|||||||
@@ -20,18 +20,33 @@ import pyarrow as pa
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from lancedb.db import LanceDBConnection
|
from lancedb.db import LanceDBConnection
|
||||||
from lancedb.query import LanceQueryBuilder
|
from lancedb.query import LanceQueryBuilder, Query
|
||||||
from lancedb.table import LanceTable
|
from lancedb.table import LanceTable
|
||||||
|
|
||||||
|
|
||||||
class MockTable:
|
class MockTable:
|
||||||
def __init__(self, tmp_path):
|
def __init__(self, tmp_path):
|
||||||
self.uri = tmp_path
|
self.uri = tmp_path
|
||||||
self._conn = LanceDBConnection("/tmp/lance/")
|
self._conn = LanceDBConnection(self.uri)
|
||||||
|
|
||||||
def to_lance(self):
|
def to_lance(self):
|
||||||
return lance.dataset(self.uri)
|
return lance.dataset(self.uri)
|
||||||
|
|
||||||
|
def _execute_query(self, query):
|
||||||
|
ds = self.to_lance()
|
||||||
|
return ds.to_table(
|
||||||
|
columns=query.columns,
|
||||||
|
filter=query.filter,
|
||||||
|
nearest={
|
||||||
|
"column": query.vector_column,
|
||||||
|
"q": query.vector,
|
||||||
|
"k": query.k,
|
||||||
|
"metric": query.metric,
|
||||||
|
"nprobes": query.nprobes,
|
||||||
|
"refine_factor": query.refine_factor,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def table(tmp_path) -> MockTable:
|
def table(tmp_path) -> MockTable:
|
||||||
@@ -94,20 +109,17 @@ def test_query_builder_with_different_vector_column():
|
|||||||
)
|
)
|
||||||
ds = mock.Mock()
|
ds = mock.Mock()
|
||||||
table.to_lance.return_value = ds
|
table.to_lance.return_value = ds
|
||||||
table._conn = mock.MagicMock()
|
|
||||||
table._conn.is_managed_remote = False
|
|
||||||
builder.to_arrow()
|
builder.to_arrow()
|
||||||
ds.to_table.assert_called_once_with(
|
table._execute_query.assert_called_once_with(
|
||||||
columns=["b"],
|
Query(
|
||||||
|
vector=query,
|
||||||
filter="b < 10",
|
filter="b < 10",
|
||||||
nearest={
|
k=2,
|
||||||
"column": vector_column_name,
|
metric="cosine",
|
||||||
"q": query,
|
columns=["b"],
|
||||||
"k": 2,
|
nprobes=20,
|
||||||
"metric": "cosine",
|
refine_factor=None,
|
||||||
"nprobes": 20,
|
)
|
||||||
"refine_factor": None,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
from lancedb.db import LanceDBConnection
|
import lancedb
|
||||||
from lancedb.remote.client import VectorQuery, VectorQueryResult
|
from lancedb.remote.client import VectorQuery, VectorQueryResult
|
||||||
|
|
||||||
|
|
||||||
@@ -28,7 +28,7 @@ class FakeLanceDBClient:
|
|||||||
|
|
||||||
|
|
||||||
def test_remote_db():
|
def test_remote_db():
|
||||||
conn = LanceDBConnection("lancedb+http://client-will-be-injected")
|
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
|
||||||
setattr(conn, "_client", FakeLanceDBClient())
|
setattr(conn, "_client", FakeLanceDBClient())
|
||||||
|
|
||||||
table = conn["test"]
|
table = conn["test"]
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import numpy as np
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pytest
|
import pytest
|
||||||
|
from lance.vector import vec_to_table
|
||||||
|
|
||||||
from lancedb.db import LanceDBConnection
|
from lancedb.db import LanceDBConnection
|
||||||
from lancedb.table import LanceTable
|
from lancedb.table import LanceTable
|
||||||
@@ -89,7 +90,31 @@ def test_create_table(db):
|
|||||||
assert expected == tbl
|
assert expected == tbl
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_table(db):
|
||||||
|
schema = pa.schema(
|
||||||
|
[
|
||||||
|
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||||
|
pa.field("item", pa.string()),
|
||||||
|
pa.field("price", pa.float32()),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
tbl = LanceTable.create(db, "test", schema=schema)
|
||||||
|
data = [
|
||||||
|
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||||
|
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||||
|
]
|
||||||
|
tbl.add(data=data)
|
||||||
|
|
||||||
|
|
||||||
def test_add(db):
|
def test_add(db):
|
||||||
|
schema = pa.schema(
|
||||||
|
[
|
||||||
|
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||||
|
pa.field("item", pa.string()),
|
||||||
|
pa.field("price", pa.float64()),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
table = LanceTable.create(
|
table = LanceTable.create(
|
||||||
db,
|
db,
|
||||||
"test",
|
"test",
|
||||||
@@ -98,7 +123,19 @@ def test_add(db):
|
|||||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
_add(table, schema)
|
||||||
|
|
||||||
|
table = LanceTable.create(db, "test2", schema=schema)
|
||||||
|
table.add(
|
||||||
|
data=[
|
||||||
|
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||||
|
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
_add(table, schema)
|
||||||
|
|
||||||
|
|
||||||
|
def _add(table, schema):
|
||||||
# table = LanceTable(db, "test")
|
# table = LanceTable(db, "test")
|
||||||
assert len(table) == 2
|
assert len(table) == 2
|
||||||
|
|
||||||
@@ -113,13 +150,7 @@ def test_add(db):
|
|||||||
pa.array(["foo", "bar", "new"]),
|
pa.array(["foo", "bar", "new"]),
|
||||||
pa.array([10.0, 20.0, 30.0]),
|
pa.array([10.0, 20.0, 30.0]),
|
||||||
],
|
],
|
||||||
schema=pa.schema(
|
schema=schema,
|
||||||
[
|
|
||||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
|
||||||
pa.field("item", pa.string()),
|
|
||||||
pa.field("price", pa.float64()),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
assert expected == table.to_arrow()
|
assert expected == table.to_arrow()
|
||||||
|
|
||||||
@@ -181,7 +212,21 @@ def test_create_index_method():
|
|||||||
|
|
||||||
|
|
||||||
def test_add_with_nans(db):
|
def test_add_with_nans(db):
|
||||||
# By default we drop bad input vectors
|
# by default we raise an error on bad input vectors
|
||||||
|
bad_data = [
|
||||||
|
{"vector": [np.nan], "item": "bar", "price": 20.0},
|
||||||
|
{"vector": [5], "item": "bar", "price": 20.0},
|
||||||
|
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
|
||||||
|
{"vector": [np.nan, 5.0], "item": "bar", "price": 20.0},
|
||||||
|
]
|
||||||
|
for row in bad_data:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
LanceTable.create(
|
||||||
|
db,
|
||||||
|
"error_test",
|
||||||
|
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, row],
|
||||||
|
)
|
||||||
|
|
||||||
table = LanceTable.create(
|
table = LanceTable.create(
|
||||||
db,
|
db,
|
||||||
"drop_test",
|
"drop_test",
|
||||||
@@ -191,6 +236,7 @@ def test_add_with_nans(db):
|
|||||||
{"vector": [5], "item": "bar", "price": 20.0},
|
{"vector": [5], "item": "bar", "price": 20.0},
|
||||||
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
|
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
|
||||||
],
|
],
|
||||||
|
on_bad_vectors="drop",
|
||||||
)
|
)
|
||||||
assert len(table) == 1
|
assert len(table) == 1
|
||||||
|
|
||||||
@@ -210,18 +256,3 @@ def test_add_with_nans(db):
|
|||||||
arrow_tbl = table.to_lance().to_table(filter="item == 'bar'")
|
arrow_tbl = table.to_lance().to_table(filter="item == 'bar'")
|
||||||
v = arrow_tbl["vector"].to_pylist()[0]
|
v = arrow_tbl["vector"].to_pylist()[0]
|
||||||
assert np.allclose(v, np.array([0.0, 0.0]))
|
assert np.allclose(v, np.array([0.0, 0.0]))
|
||||||
|
|
||||||
bad_data = [
|
|
||||||
{"vector": [np.nan], "item": "bar", "price": 20.0},
|
|
||||||
{"vector": [5], "item": "bar", "price": 20.0},
|
|
||||||
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
|
|
||||||
{"vector": [np.nan, 5.0], "item": "bar", "price": 20.0},
|
|
||||||
]
|
|
||||||
for row in bad_data:
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
LanceTable.create(
|
|
||||||
db,
|
|
||||||
"raise_test",
|
|
||||||
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, row],
|
|
||||||
on_bad_vectors="raise",
|
|
||||||
)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user