mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-11 06:12:58 +00:00
Merge branch 'main' of https://github.com/lancedb/lancedb into hybrid_query
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.5.5
|
||||
current_version = 0.6.1
|
||||
commit = True
|
||||
message = [python] Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
24
python/ASYNC_MIGRATION.md
Normal file
24
python/ASYNC_MIGRATION.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# Migration from Sync to Async API
|
||||
|
||||
A new asynchronous API has been added to LanceDb. This API is built
|
||||
on top of the rust lancedb crate (instead of being built on top of
|
||||
pylance). This will help keep the various language bindings in sync.
|
||||
There are some slight changes between the synchronous and the asynchronous
|
||||
APIs. This document will help you migrate. These changes relate mostly
|
||||
to the Connection and Table classes.
|
||||
|
||||
## Almost all functions are async
|
||||
|
||||
The most important change is that almost all functions are now async.
|
||||
This means the functions now return `asyncio` coroutines. You will
|
||||
need to use `await` to call these functions.
|
||||
|
||||
## Connection
|
||||
|
||||
No changes yet.
|
||||
|
||||
## Table
|
||||
|
||||
* Previously `Table.schema` was a property. Now it is an async method.
|
||||
* The method `Table.__len__` was removed and `len(table)` will no longer
|
||||
work. Use `Table.count_rows` instead.
|
||||
30
python/Cargo.toml
Normal file
30
python/Cargo.toml
Normal file
@@ -0,0 +1,30 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.4.10"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
|
||||
|
||||
[lib]
|
||||
name = "_lancedb"
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
arrow = { version = "50.0.0", features = ["pyarrow"] }
|
||||
lancedb = { path = "../rust/lancedb" }
|
||||
env_logger = "0.10"
|
||||
pyo3 = { version = "0.20", features = ["extension-module", "abi3-py38"] }
|
||||
pyo3-asyncio = { version = "0.20", features = ["attributes", "tokio-runtime"] }
|
||||
|
||||
# Prevent dynamic linking of lzma, which comes from datafusion
|
||||
lzma-sys = { version = "*", features = ["static"] }
|
||||
|
||||
[build-dependencies]
|
||||
pyo3-build-config = { version = "0.20.3", features = [
|
||||
"extension-module",
|
||||
"abi3-py38",
|
||||
] }
|
||||
@@ -20,10 +20,10 @@ results = table.search([0.1, 0.3]).limit(20).to_list()
|
||||
print(results)
|
||||
```
|
||||
|
||||
|
||||
## Development
|
||||
|
||||
Create a virtual environment and activate it:
|
||||
LanceDb is based on the rust crate `lancedb` and is built with maturin. In order to build with maturin
|
||||
you will either need a conda environment or a virtual environment (venv).
|
||||
|
||||
```bash
|
||||
python -m venv venv
|
||||
@@ -33,7 +33,15 @@ python -m venv venv
|
||||
Install the necessary packages:
|
||||
|
||||
```bash
|
||||
python -m pip install .
|
||||
python -m pip install .[tests,dev]
|
||||
```
|
||||
|
||||
To build the python package you can use maturin:
|
||||
|
||||
```bash
|
||||
# This will build the rust bindings and place them in the appropriate place
|
||||
# in your venv or conda environment
|
||||
matruin develop
|
||||
```
|
||||
|
||||
To run the unit tests:
|
||||
@@ -45,7 +53,7 @@ pytest
|
||||
To run the doc tests:
|
||||
|
||||
```bash
|
||||
pytest --doctest-modules lancedb
|
||||
pytest --doctest-modules python/lancedb
|
||||
```
|
||||
|
||||
To run linter and automatically fix all errors:
|
||||
@@ -61,31 +69,27 @@ If any packages are missing, install them with:
|
||||
pip install <PACKAGE_NAME>
|
||||
```
|
||||
|
||||
|
||||
___
|
||||
For **Windows** users, there may be errors when installing packages, so these commands may be helpful:
|
||||
|
||||
Activate the virtual environment:
|
||||
|
||||
```bash
|
||||
. .\venv\Scripts\activate
|
||||
```
|
||||
|
||||
You may need to run the installs separately:
|
||||
|
||||
```bash
|
||||
pip install -e .[tests]
|
||||
pip install -e .[dev]
|
||||
```
|
||||
|
||||
|
||||
`tantivy` requires `rust` to be installed, so install it with `conda`, as it doesn't support windows installation:
|
||||
|
||||
```bash
|
||||
pip install wheel
|
||||
pip install cargo
|
||||
conda install rust
|
||||
pip install tantivy
|
||||
```
|
||||
|
||||
To run the unit tests:
|
||||
```bash
|
||||
pytest
|
||||
```
|
||||
|
||||
3
python/build.rs
Normal file
3
python/build.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
fn main() {
|
||||
pyo3_build_config::add_extension_module_link_args();
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
from .util import safe_import_pandas
|
||||
|
||||
pd = safe_import_pandas()
|
||||
|
||||
DATA = Union[List[dict], dict, "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]]
|
||||
VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray]
|
||||
URI = Union[str, Path]
|
||||
VECTOR_COLUMN_NAME = "vector"
|
||||
|
||||
|
||||
class Credential(str):
|
||||
"""Credential field"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "********"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "********"
|
||||
@@ -1,9 +1,9 @@
|
||||
[project]
|
||||
name = "lancedb"
|
||||
version = "0.5.5"
|
||||
version = "0.6.1"
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.9.16",
|
||||
"pylance==0.10.1",
|
||||
"ratelimiter~=1.0",
|
||||
"retry>=0.9.2",
|
||||
"tqdm>=4.27.0",
|
||||
@@ -14,7 +14,7 @@ dependencies = [
|
||||
"pyyaml>=6.0",
|
||||
"click>=8.1.7",
|
||||
"requests>=2.31.0",
|
||||
"overrides>=0.7"
|
||||
"overrides>=0.7",
|
||||
]
|
||||
description = "lancedb"
|
||||
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
|
||||
@@ -26,7 +26,7 @@ keywords = [
|
||||
"data-science",
|
||||
"machine-learning",
|
||||
"arrow",
|
||||
"data-analytics"
|
||||
"data-analytics",
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
@@ -48,21 +48,53 @@ classifiers = [
|
||||
repository = "https://github.com/lancedb/lancedb"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars>=0.19"]
|
||||
tests = [
|
||||
"aiohttp",
|
||||
"pandas>=1.4",
|
||||
"pytest",
|
||||
"pytest-mock",
|
||||
"pytest-asyncio",
|
||||
"duckdb",
|
||||
"pytz",
|
||||
"polars>=0.19",
|
||||
]
|
||||
dev = ["ruff", "pre-commit"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||
docs = [
|
||||
"mkdocs",
|
||||
"mkdocs-jupyter",
|
||||
"mkdocs-material",
|
||||
"mkdocstrings[python]",
|
||||
"mkdocs-ultralytics-plugin==0.0.44",
|
||||
]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "huggingface_hub",
|
||||
"InstructorEmbedding", "google.generativeai", "boto3>=1.28.57", "awscli>=1.29.57", "botocore>=1.31.57"]
|
||||
embeddings = [
|
||||
"openai>=1.6.1",
|
||||
"sentence-transformers",
|
||||
"torch",
|
||||
"pillow",
|
||||
"open-clip-torch",
|
||||
"cohere",
|
||||
"huggingface_hub",
|
||||
"InstructorEmbedding",
|
||||
"google.generativeai",
|
||||
"boto3>=1.28.57",
|
||||
"awscli>=1.29.57",
|
||||
"botocore>=1.31.57",
|
||||
]
|
||||
|
||||
[tool.maturin]
|
||||
python-source = "python"
|
||||
module-name = "lancedb._lancedb"
|
||||
|
||||
[project.scripts]
|
||||
lancedb = "lancedb.cli.cli:cli"
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
requires = ["maturin>=1.4"]
|
||||
build-backend = "maturin"
|
||||
|
||||
[tool.ruff]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["F", "E", "W", "I", "G", "TCH", "PERF"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
@@ -70,5 +102,5 @@ addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py"
|
||||
|
||||
markers = [
|
||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||
"asyncio"
|
||||
"asyncio",
|
||||
]
|
||||
|
||||
@@ -19,8 +19,9 @@ from typing import Optional, Union
|
||||
|
||||
__version__ = importlib.metadata.version("lancedb")
|
||||
|
||||
from .common import URI
|
||||
from .db import DBConnection, LanceDBConnection
|
||||
from ._lancedb import connect as lancedb_connect
|
||||
from .common import URI, sanitize_uri
|
||||
from .db import AsyncConnection, AsyncLanceDBConnection, DBConnection, LanceDBConnection
|
||||
from .remote.db import RemoteDBConnection
|
||||
from .schema import vector # noqa: F401
|
||||
from .utils import sentry_log # noqa: F401
|
||||
@@ -101,3 +102,74 @@ def connect(
|
||||
uri, api_key, region, host_override, request_thread_pool=request_thread_pool
|
||||
)
|
||||
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)
|
||||
|
||||
|
||||
async def connect_async(
|
||||
uri: URI,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
region: str = "us-east-1",
|
||||
host_override: Optional[str] = None,
|
||||
read_consistency_interval: Optional[timedelta] = None,
|
||||
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
|
||||
) -> AsyncConnection:
|
||||
"""Connect to a LanceDB database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uri: str or Path
|
||||
The uri of the database.
|
||||
api_key: str, optional
|
||||
If present, connect to LanceDB cloud.
|
||||
Otherwise, connect to a database on file system or cloud storage.
|
||||
Can be set via environment variable `LANCEDB_API_KEY`.
|
||||
region: str, default "us-east-1"
|
||||
The region to use for LanceDB Cloud.
|
||||
host_override: str, optional
|
||||
The override url for LanceDB Cloud.
|
||||
read_consistency_interval: timedelta, default None
|
||||
(For LanceDB OSS only)
|
||||
The interval at which to check for updates to the table from other
|
||||
processes. If None, then consistency is not checked. For performance
|
||||
reasons, this is the default. For strong consistency, set this to
|
||||
zero seconds. Then every read will check for updates from other
|
||||
processes. As a compromise, you can set this to a non-zero timedelta
|
||||
for eventual consistency. If more than that interval has passed since
|
||||
the last check, then the table will be checked for updates. Note: this
|
||||
consistency only applies to read operations. Write operations are
|
||||
always consistent.
|
||||
request_thread_pool: int or ThreadPoolExecutor, optional
|
||||
The thread pool to use for making batch requests to the LanceDB Cloud API.
|
||||
If an integer, then a ThreadPoolExecutor will be created with that
|
||||
number of threads. If None, then a ThreadPoolExecutor will be created
|
||||
with the default number of threads. If a ThreadPoolExecutor, then that
|
||||
executor will be used for making requests. This is for LanceDB Cloud
|
||||
only and is only used when making batch requests (i.e., passing in
|
||||
multiple queries to the search method at once).
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
For a local directory, provide a path for the database:
|
||||
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("~/.lancedb")
|
||||
|
||||
For object storage, use a URI prefix:
|
||||
|
||||
>>> db = lancedb.connect("s3://my-bucket/lancedb")
|
||||
|
||||
Connect to LancdDB cloud:
|
||||
|
||||
>>> db = lancedb.connect("db://my_database", api_key="ldb_...")
|
||||
|
||||
Returns
|
||||
-------
|
||||
conn : DBConnection
|
||||
A connection to a LanceDB database.
|
||||
"""
|
||||
return AsyncLanceDBConnection(
|
||||
await lancedb_connect(
|
||||
sanitize_uri(uri), api_key, region, host_override, read_consistency_interval
|
||||
)
|
||||
)
|
||||
24
python/python/lancedb/_lancedb.pyi
Normal file
24
python/python/lancedb/_lancedb.pyi
Normal file
@@ -0,0 +1,24 @@
|
||||
from typing import Optional
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
class Connection(object):
|
||||
async def table_names(self) -> list[str]: ...
|
||||
async def create_table(
|
||||
self, name: str, mode: str, data: pa.RecordBatchReader
|
||||
) -> Table: ...
|
||||
async def create_empty_table(
|
||||
self, name: str, mode: str, schema: pa.Schema
|
||||
) -> Table: ...
|
||||
|
||||
class Table(object):
|
||||
def name(self) -> str: ...
|
||||
async def schema(self) -> pa.Schema: ...
|
||||
|
||||
async def connect(
|
||||
uri: str,
|
||||
api_key: Optional[str],
|
||||
region: Optional[str],
|
||||
host_override: Optional[str],
|
||||
read_consistency_interval: Optional[float],
|
||||
) -> Connection: ...
|
||||
136
python/python/lancedb/common.py
Normal file
136
python/python/lancedb/common.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
from .util import safe_import_pandas
|
||||
|
||||
pd = safe_import_pandas()
|
||||
|
||||
DATA = Union[List[dict], dict, "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]]
|
||||
VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray]
|
||||
URI = Union[str, Path]
|
||||
VECTOR_COLUMN_NAME = "vector"
|
||||
|
||||
|
||||
class Credential(str):
|
||||
"""Credential field"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "********"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "********"
|
||||
|
||||
|
||||
def sanitize_uri(uri: URI) -> str:
|
||||
return str(uri)
|
||||
|
||||
|
||||
def _casting_recordbatch_iter(
|
||||
input_iter: Iterable[pa.RecordBatch], schema: pa.Schema
|
||||
) -> Iterable[pa.RecordBatch]:
|
||||
"""
|
||||
Wrapper around an iterator of record batches. If the batches don't match the
|
||||
schema, try to cast them to the schema. If that fails, raise an error.
|
||||
|
||||
This is helpful for users who might have written the iterator with default
|
||||
data types in PyArrow, but specified more specific types in the schema. For
|
||||
example, PyArrow defaults to float64 for floating point types, but Lance
|
||||
uses float32 for vectors.
|
||||
"""
|
||||
for batch in input_iter:
|
||||
if not isinstance(batch, pa.RecordBatch):
|
||||
raise TypeError(f"Expected RecordBatch, got {type(batch)}")
|
||||
if batch.schema != schema:
|
||||
try:
|
||||
# RecordBatch doesn't have a cast method, but table does.
|
||||
batch = pa.Table.from_batches([batch]).cast(schema).to_batches()[0]
|
||||
except pa.lib.ArrowInvalid:
|
||||
raise ValueError(
|
||||
f"Input RecordBatch iterator yielded a batch with schema that "
|
||||
f"does not match the expected schema.\nExpected:\n{schema}\n"
|
||||
f"Got:\n{batch.schema}"
|
||||
)
|
||||
yield batch
|
||||
|
||||
|
||||
def data_to_reader(
|
||||
data: DATA, schema: Optional[pa.Schema] = None
|
||||
) -> pa.RecordBatchReader:
|
||||
"""Convert various types of input into a RecordBatchReader"""
|
||||
if pd is not None and isinstance(data, pd.DataFrame):
|
||||
return pa.Table.from_pandas(data, schema=schema).to_reader()
|
||||
elif isinstance(data, pa.Table):
|
||||
return data.to_reader()
|
||||
elif isinstance(data, pa.RecordBatch):
|
||||
return pa.Table.from_batches([data]).to_reader()
|
||||
# elif isinstance(data, LanceDataset):
|
||||
# return data_obj.scanner().to_reader()
|
||||
elif isinstance(data, pa.dataset.Dataset):
|
||||
return pa.dataset.Scanner.from_dataset(data).to_reader()
|
||||
elif isinstance(data, pa.dataset.Scanner):
|
||||
return data.to_reader()
|
||||
elif isinstance(data, pa.RecordBatchReader):
|
||||
return data
|
||||
elif (
|
||||
type(data).__module__.startswith("polars")
|
||||
and data.__class__.__name__ == "DataFrame"
|
||||
):
|
||||
return data.to_arrow().to_reader()
|
||||
# for other iterables, assume they are of type Iterable[RecordBatch]
|
||||
elif isinstance(data, Iterable):
|
||||
if schema is not None:
|
||||
data = _casting_recordbatch_iter(data, schema)
|
||||
return pa.RecordBatchReader.from_batches(schema, data)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Must provide schema to write dataset from RecordBatch iterable"
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unknown data type {type(data)}. "
|
||||
"Please check "
|
||||
"https://lancedb.github.io/lance/read_and_write.html "
|
||||
"to see supported types."
|
||||
)
|
||||
|
||||
|
||||
def validate_schema(schema: pa.Schema):
|
||||
"""
|
||||
Make sure the metadata is valid utf8
|
||||
"""
|
||||
if schema.metadata is not None:
|
||||
_validate_metadata(schema.metadata)
|
||||
|
||||
|
||||
def _validate_metadata(metadata: dict):
|
||||
"""
|
||||
Make sure the metadata values are valid utf8 (can be nested)
|
||||
|
||||
Raises ValueError if not valid utf8
|
||||
"""
|
||||
for k, v in metadata.items():
|
||||
if isinstance(v, bytes):
|
||||
try:
|
||||
v.decode("utf8")
|
||||
except UnicodeDecodeError:
|
||||
raise ValueError(
|
||||
f"Metadata key {k} is not valid utf8. "
|
||||
"Consider base64 encode for generic binary metadata."
|
||||
)
|
||||
elif isinstance(v, dict):
|
||||
_validate_metadata(v)
|
||||
@@ -13,6 +13,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
@@ -22,15 +23,20 @@ import pyarrow as pa
|
||||
from overrides import EnforceOverrides, override
|
||||
from pyarrow import fs
|
||||
|
||||
from .table import LanceTable, Table
|
||||
from lancedb.common import data_to_reader, validate_schema
|
||||
from lancedb.embeddings.registry import EmbeddingFunctionRegistry
|
||||
from lancedb.utils.events import register_event
|
||||
|
||||
from .pydantic import LanceModel
|
||||
from .table import AsyncLanceTable, LanceTable, Table, _sanitize_data
|
||||
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
|
||||
from ._lancedb import Connection as LanceDbConnection
|
||||
from .common import DATA, URI
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
from .pydantic import LanceModel
|
||||
|
||||
|
||||
class DBConnection(EnforceOverrides):
|
||||
@@ -40,14 +46,21 @@ class DBConnection(EnforceOverrides):
|
||||
def table_names(
|
||||
self, page_token: Optional[str] = None, limit: int = 10
|
||||
) -> Iterable[str]:
|
||||
"""List all table in this database
|
||||
"""List all tables in this database, in sorted order
|
||||
|
||||
Parameters
|
||||
----------
|
||||
page_token: str, optional
|
||||
The token to use for pagination. If not present, start from the beginning.
|
||||
Typically, this token is last table name from the previous page.
|
||||
Only supported by LanceDb Cloud.
|
||||
limit: int, default 10
|
||||
The size of the page to return.
|
||||
Only supported by LanceDb Cloud.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Iterable of str
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -412,3 +425,313 @@ class LanceDBConnection(DBConnection):
|
||||
def drop_database(self):
|
||||
filesystem, path = fs_from_uri(self.uri)
|
||||
filesystem.delete_dir(path)
|
||||
|
||||
|
||||
class AsyncConnection(EnforceOverrides):
|
||||
"""An active LanceDB connection interface."""
|
||||
|
||||
@abstractmethod
|
||||
async def table_names(
|
||||
self, *, page_token: Optional[str] = None, limit: int = 10
|
||||
) -> Iterable[str]:
|
||||
"""List all tables in this database, in sorted order
|
||||
|
||||
Parameters
|
||||
----------
|
||||
page_token: str, optional
|
||||
The token to use for pagination. If not present, start from the beginning.
|
||||
Typically, this token is last table name from the previous page.
|
||||
Only supported by LanceDb Cloud.
|
||||
limit: int, default 10
|
||||
The size of the page to return.
|
||||
Only supported by LanceDb Cloud.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Iterable of str
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_table(
|
||||
self,
|
||||
name: str,
|
||||
data: Optional[DATA] = None,
|
||||
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||
mode: str = "create",
|
||||
exist_ok: bool = False,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
) -> Table:
|
||||
"""Create a [Table][lancedb.table.Table] in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
data: The data to initialize the table, *optional*
|
||||
User must provide at least one of `data` or `schema`.
|
||||
Acceptable types are:
|
||||
|
||||
- dict or list-of-dict
|
||||
|
||||
- pandas.DataFrame
|
||||
|
||||
- pyarrow.Table or pyarrow.RecordBatch
|
||||
schema: The schema of the table, *optional*
|
||||
Acceptable types are:
|
||||
|
||||
- pyarrow.Schema
|
||||
|
||||
- [LanceModel][lancedb.pydantic.LanceModel]
|
||||
mode: str; default "create"
|
||||
The mode to use when creating the table.
|
||||
Can be either "create" or "overwrite".
|
||||
By default, if the table already exists, an exception is raised.
|
||||
If you want to overwrite the table, use mode="overwrite".
|
||||
exist_ok: bool, default False
|
||||
If a table by the same name already exists, then raise an exception
|
||||
if exist_ok=False. If exist_ok=True, then open the existing table;
|
||||
it will not add the provided data but will validate against any
|
||||
schema that's specified.
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceTable
|
||||
A reference to the newly created table.
|
||||
|
||||
!!! note
|
||||
|
||||
The vector index won't be created by default.
|
||||
To create the index, call the `create_index` method on the table.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
Can create with list of tuples or dictionaries:
|
||||
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
||||
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
||||
>>> db.create_table("my_table", data)
|
||||
LanceTable(connection=..., name="my_table")
|
||||
>>> db["my_table"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
child 0, item: float
|
||||
lat: double
|
||||
long: double
|
||||
----
|
||||
vector: [[[1.1,1.2],[0.2,1.8]]]
|
||||
lat: [[45.5,40.1]]
|
||||
long: [[-122.7,-74.1]]
|
||||
|
||||
You can also pass a pandas DataFrame:
|
||||
|
||||
>>> import pandas as pd
|
||||
>>> data = pd.DataFrame({
|
||||
... "vector": [[1.1, 1.2], [0.2, 1.8]],
|
||||
... "lat": [45.5, 40.1],
|
||||
... "long": [-122.7, -74.1]
|
||||
... })
|
||||
>>> db.create_table("table2", data)
|
||||
LanceTable(connection=..., name="table2")
|
||||
>>> db["table2"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
child 0, item: float
|
||||
lat: double
|
||||
long: double
|
||||
----
|
||||
vector: [[[1.1,1.2],[0.2,1.8]]]
|
||||
lat: [[45.5,40.1]]
|
||||
long: [[-122.7,-74.1]]
|
||||
|
||||
Data is converted to Arrow before being written to disk. For maximum
|
||||
control over how data is saved, either provide the PyArrow schema to
|
||||
convert to or else provide a [PyArrow Table](pyarrow.Table) directly.
|
||||
|
||||
>>> custom_schema = pa.schema([
|
||||
... pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
... pa.field("lat", pa.float32()),
|
||||
... pa.field("long", pa.float32())
|
||||
... ])
|
||||
>>> db.create_table("table3", data, schema = custom_schema)
|
||||
LanceTable(connection=..., name="table3")
|
||||
>>> db["table3"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
child 0, item: float
|
||||
lat: float
|
||||
long: float
|
||||
----
|
||||
vector: [[[1.1,1.2],[0.2,1.8]]]
|
||||
lat: [[45.5,40.1]]
|
||||
long: [[-122.7,-74.1]]
|
||||
|
||||
|
||||
It is also possible to create an table from `[Iterable[pa.RecordBatch]]`:
|
||||
|
||||
|
||||
>>> import pyarrow as pa
|
||||
>>> def make_batches():
|
||||
... for i in range(5):
|
||||
... yield pa.RecordBatch.from_arrays(
|
||||
... [
|
||||
... pa.array([[3.1, 4.1], [5.9, 26.5]],
|
||||
... pa.list_(pa.float32(), 2)),
|
||||
... pa.array(["foo", "bar"]),
|
||||
... pa.array([10.0, 20.0]),
|
||||
... ],
|
||||
... ["vector", "item", "price"],
|
||||
... )
|
||||
>>> schema=pa.schema([
|
||||
... pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
... pa.field("item", pa.utf8()),
|
||||
... pa.field("price", pa.float32()),
|
||||
... ])
|
||||
>>> db.create_table("table4", make_batches(), schema=schema)
|
||||
LanceTable(connection=..., name="table4")
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def open_table(self, name: str) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A LanceTable object representing the table.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def drop_table(self, name: str):
|
||||
"""Drop a table from the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def drop_database(self):
|
||||
"""
|
||||
Drop database
|
||||
This is the same thing as dropping all the tables
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AsyncLanceDBConnection(AsyncConnection):
|
||||
def __init__(self, connection: LanceDbConnection):
|
||||
self._inner = connection
|
||||
|
||||
async def __repr__(self) -> str:
|
||||
pass
|
||||
|
||||
@override
|
||||
async def table_names(
|
||||
self,
|
||||
*,
|
||||
page_token=None,
|
||||
limit=None,
|
||||
) -> Iterable[str]:
|
||||
# TODO: hook in page_token and limit
|
||||
return await self._inner.table_names()
|
||||
|
||||
@override
|
||||
async def create_table(
|
||||
self,
|
||||
name: str,
|
||||
data: Optional[DATA] = None,
|
||||
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||
mode: str = "create",
|
||||
exist_ok: bool = False,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
) -> Table:
|
||||
if mode.lower() not in ["create", "overwrite"]:
|
||||
raise ValueError("mode must be either 'create' or 'overwrite'")
|
||||
|
||||
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
||||
# convert LanceModel to pyarrow schema
|
||||
# note that it's possible this contains
|
||||
# embedding function metadata already
|
||||
schema = schema.to_arrow_schema()
|
||||
|
||||
metadata = None
|
||||
if embedding_functions is not None:
|
||||
# If we passed in embedding functions explicitly
|
||||
# then we'll override any schema metadata that
|
||||
# may was implicitly specified by the LanceModel schema
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
metadata = registry.get_table_metadata(embedding_functions)
|
||||
|
||||
if data is not None:
|
||||
data = _sanitize_data(
|
||||
data,
|
||||
schema,
|
||||
metadata=metadata,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
|
||||
if schema is None:
|
||||
if data is None:
|
||||
raise ValueError("Either data or schema must be provided")
|
||||
elif hasattr(data, "schema"):
|
||||
schema = data.schema
|
||||
elif isinstance(data, Iterable):
|
||||
if metadata:
|
||||
raise TypeError(
|
||||
(
|
||||
"Persistent embedding functions not yet "
|
||||
"supported for generator data input"
|
||||
)
|
||||
)
|
||||
|
||||
if metadata:
|
||||
schema = schema.with_metadata(metadata)
|
||||
validate_schema(schema)
|
||||
|
||||
if mode == "create" and exist_ok:
|
||||
mode = "exist_ok"
|
||||
|
||||
if data is None:
|
||||
new_table = await self._inner.create_empty_table(name, mode, schema)
|
||||
else:
|
||||
data = data_to_reader(data, schema)
|
||||
new_table = await self._inner.create_table(
|
||||
name,
|
||||
mode,
|
||||
data,
|
||||
)
|
||||
|
||||
register_event("create_table")
|
||||
return AsyncLanceTable(new_table)
|
||||
|
||||
@override
|
||||
async def open_table(self, name: str) -> LanceTable:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def drop_table(self, name: str, ignore_missing: bool = False):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def drop_database(self):
|
||||
raise NotImplementedError
|
||||
172
python/python/lancedb/embeddings/imagebind.py
Normal file
172
python/python/lancedb/embeddings/imagebind.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import cached_property
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import EmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import AUDIO, IMAGES, TEXT
|
||||
|
||||
|
||||
@register("imagebind")
|
||||
class ImageBindEmbeddings(EmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the ImageBind API
|
||||
For generating multi-modal embeddings across
|
||||
six different modalities: images, text, audio, depth, thermal, and IMU data
|
||||
|
||||
to download package, run :
|
||||
`pip install imagebind@git+https://github.com/raghavdixit99/ImageBind`
|
||||
"""
|
||||
|
||||
name: str = "imagebind_huge"
|
||||
device: str = "cpu"
|
||||
normalize: bool = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._ndims = 1024
|
||||
self._audio_extensions = (".mp3", ".wav", ".flac", ".ogg", ".aac")
|
||||
self._image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp")
|
||||
|
||||
@cached_property
|
||||
def embedding_model(self):
|
||||
"""
|
||||
Get the embedding model. This is cached so that the model is only loaded
|
||||
once per process.
|
||||
"""
|
||||
return self.get_embedding_model()
|
||||
|
||||
@cached_property
|
||||
def _data(self):
|
||||
"""
|
||||
Get the data module from imagebind
|
||||
"""
|
||||
data = attempt_import_or_raise("imagebind.data", "imagebind")
|
||||
return data
|
||||
|
||||
@cached_property
|
||||
def _ModalityType(self):
|
||||
"""
|
||||
Get the ModalityType from imagebind
|
||||
"""
|
||||
imagebind = attempt_import_or_raise("imagebind", "imagebind")
|
||||
return imagebind.imagebind_model.ModalityType
|
||||
|
||||
def ndims(self):
|
||||
return self._ndims
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: Union[str], *args, **kwargs
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : Union[str]
|
||||
The query to embed. A query can be either text, image paths or audio paths.
|
||||
"""
|
||||
query = self.sanitize_input(query)
|
||||
if query[0].endswith(self._audio_extensions):
|
||||
return [self.generate_audio_embeddings(query)]
|
||||
elif query[0].endswith(self._image_extensions):
|
||||
return [self.generate_image_embeddings(query)]
|
||||
else:
|
||||
return [self.generate_text_embeddings(query)]
|
||||
|
||||
def generate_image_embeddings(self, image: IMAGES) -> np.ndarray:
|
||||
torch = attempt_import_or_raise("torch")
|
||||
inputs = {
|
||||
self._ModalityType.VISION: self._data.load_and_transform_vision_data(
|
||||
image, self.device
|
||||
)
|
||||
}
|
||||
with torch.no_grad():
|
||||
image_features = self.embedding_model(inputs)[self._ModalityType.VISION]
|
||||
if self.normalize:
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
return image_features.cpu().numpy().squeeze()
|
||||
|
||||
def generate_audio_embeddings(self, audio: AUDIO) -> np.ndarray:
|
||||
torch = attempt_import_or_raise("torch")
|
||||
inputs = {
|
||||
self._ModalityType.AUDIO: self._data.load_and_transform_audio_data(
|
||||
audio, self.device
|
||||
)
|
||||
}
|
||||
with torch.no_grad():
|
||||
audio_features = self.embedding_model(inputs)[self._ModalityType.AUDIO]
|
||||
if self.normalize:
|
||||
audio_features /= audio_features.norm(dim=-1, keepdim=True)
|
||||
return audio_features.cpu().numpy().squeeze()
|
||||
|
||||
def generate_text_embeddings(self, text: TEXT) -> np.ndarray:
|
||||
torch = attempt_import_or_raise("torch")
|
||||
inputs = {
|
||||
self._ModalityType.TEXT: self._data.load_and_transform_text(
|
||||
text, self.device
|
||||
)
|
||||
}
|
||||
with torch.no_grad():
|
||||
text_features = self.embedding_model(inputs)[self._ModalityType.TEXT]
|
||||
if self.normalize:
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
return text_features.cpu().numpy().squeeze()
|
||||
|
||||
def compute_source_embeddings(
|
||||
self, source: Union[IMAGES, AUDIO], *args, **kwargs
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given sourcefield column in the pydantic model.
|
||||
"""
|
||||
source = self.sanitize_input(source)
|
||||
embeddings = []
|
||||
if source[0].endswith(self._audio_extensions):
|
||||
embeddings.extend(self.generate_audio_embeddings(source))
|
||||
return embeddings
|
||||
elif source[0].endswith(self._image_extensions):
|
||||
embeddings.extend(self.generate_image_embeddings(source))
|
||||
return embeddings
|
||||
else:
|
||||
embeddings.extend(self.generate_text_embeddings(source))
|
||||
return embeddings
|
||||
|
||||
def sanitize_input(
|
||||
self, input: Union[IMAGES, AUDIO]
|
||||
) -> Union[List[bytes], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(input, (str, bytes)):
|
||||
input = [input]
|
||||
elif isinstance(input, pa.Array):
|
||||
input = input.to_pylist()
|
||||
elif isinstance(input, pa.ChunkedArray):
|
||||
input = input.combine_chunks().to_pylist()
|
||||
return input
|
||||
|
||||
def get_embedding_model(self):
|
||||
"""
|
||||
fetches the imagebind embedding model
|
||||
"""
|
||||
imagebind = attempt_import_or_raise("imagebind", "imagebind")
|
||||
model = imagebind.imagebind_model.imagebind_huge(pretrained=True)
|
||||
model.eval()
|
||||
model.to(self.device)
|
||||
return model
|
||||
@@ -103,9 +103,9 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
||||
# convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly
|
||||
|
||||
source_instruction: str = "represent the document for retrieval"
|
||||
query_instruction: (
|
||||
str
|
||||
) = "represent the document for retrieving the most similar documents"
|
||||
query_instruction: str = (
|
||||
"represent the document for retrieving the most similar documents"
|
||||
)
|
||||
|
||||
@weak_lru(maxsize=1)
|
||||
def ndims(self):
|
||||
@@ -36,6 +36,7 @@ TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
IMAGES = Union[
|
||||
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
|
||||
]
|
||||
AUDIO = Union[str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
|
||||
|
||||
@deprecated
|
||||
@@ -12,6 +12,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""Full text search index using tantivy-py"""
|
||||
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
@@ -16,7 +16,7 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Type, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
|
||||
import deprecation
|
||||
import numpy as np
|
||||
@@ -93,7 +93,7 @@ class Query(pydantic.BaseModel):
|
||||
metric: str = "L2"
|
||||
|
||||
# which columns to return in the results
|
||||
columns: Optional[List[str]] = None
|
||||
columns: Optional[Union[List[str], Dict[str, str]]] = None
|
||||
|
||||
# optional query parameters for tuning the results,
|
||||
# e.g. `{"nprobes": "10", "refine_factor": "10"}`
|
||||
@@ -332,20 +332,25 @@ class LanceQueryBuilder(ABC):
|
||||
self._limit = limit
|
||||
return self
|
||||
|
||||
def select(self, columns: list) -> LanceQueryBuilder:
|
||||
def select(self, columns: Union[list[str], dict[str, str]]) -> LanceQueryBuilder:
|
||||
"""Set the columns to return.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
columns: list
|
||||
The columns to return.
|
||||
columns: list of str, or dict of str to str default None
|
||||
List of column names to be fetched.
|
||||
Or a dictionary of column names to SQL expressions.
|
||||
All columns are fetched if None or unspecified.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._columns = columns
|
||||
if isinstance(columns, list) or isinstance(columns, dict):
|
||||
self._columns = columns
|
||||
else:
|
||||
raise ValueError("columns must be a list or a dictionary")
|
||||
return self
|
||||
|
||||
def where(self, where: str, prefilter: bool = False) -> LanceQueryBuilder:
|
||||
@@ -403,7 +408,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
>>> (table.search([0.4, 0.4])
|
||||
... .metric("cosine")
|
||||
... .where("b < 10")
|
||||
... .select(["b"])
|
||||
... .select(["b", "vector"])
|
||||
... .limit(2)
|
||||
... .to_pandas())
|
||||
b vector _distance
|
||||
@@ -15,7 +15,7 @@ import logging
|
||||
import uuid
|
||||
from concurrent.futures import Future
|
||||
from functools import cached_property
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Dict, Iterable, Optional, Union
|
||||
|
||||
import pyarrow as pa
|
||||
from lance import json_to_schema
|
||||
@@ -66,12 +66,36 @@ class RemoteTable(Table):
|
||||
"""to_pandas() is not yet supported on LanceDB cloud."""
|
||||
return NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.")
|
||||
|
||||
def create_scalar_index(self, *args, **kwargs):
|
||||
"""Creates a scalar index"""
|
||||
return NotImplementedError(
|
||||
"create_scalar_index() is not yet supported on LanceDB cloud."
|
||||
def list_indices(self):
|
||||
"""List all the indices on the table"""
|
||||
print(self._name)
|
||||
resp = self._conn._client.post(f"/v1/table/{self._name}/index/list/")
|
||||
return resp
|
||||
|
||||
def create_scalar_index(
|
||||
self,
|
||||
column: str,
|
||||
):
|
||||
"""Creates a scalar index
|
||||
Parameters
|
||||
----------
|
||||
column : str
|
||||
The column to be indexed. Must be a boolean, integer, float,
|
||||
or string column.
|
||||
"""
|
||||
index_type = "scalar"
|
||||
|
||||
data = {
|
||||
"column": column,
|
||||
"index_type": index_type,
|
||||
"replace": True,
|
||||
}
|
||||
resp = self._conn._client.post(
|
||||
f"/v1/table/{self._name}/create_scalar_index/", data=data
|
||||
)
|
||||
|
||||
return resp
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
metric="L2",
|
||||
@@ -277,6 +301,7 @@ class RemoteTable(Table):
|
||||
f = Future()
|
||||
f.set_result(self._conn._client.query(name, q))
|
||||
return f
|
||||
|
||||
else:
|
||||
|
||||
def submit(name, q):
|
||||
@@ -473,6 +498,21 @@ class RemoteTable(Table):
|
||||
"count_rows() is not yet supported on the LanceDB cloud"
|
||||
)
|
||||
|
||||
def add_columns(self, transforms: Dict[str, str]):
|
||||
raise NotImplementedError(
|
||||
"add_columns() is not yet supported on the LanceDB cloud"
|
||||
)
|
||||
|
||||
def alter_columns(self, alterations: Iterable[Dict[str, str]]):
|
||||
raise NotImplementedError(
|
||||
"alter_columns() is not yet supported on the LanceDB cloud"
|
||||
)
|
||||
|
||||
def drop_columns(self, columns: Iterable[str]):
|
||||
raise NotImplementedError(
|
||||
"drop_columns() is not yet supported on the LanceDB cloud"
|
||||
)
|
||||
|
||||
|
||||
def add_index(tbl: pa.Table, i: int) -> pa.Table:
|
||||
return tbl.add_column(
|
||||
@@ -12,6 +12,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""Schema related utilities."""
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ import pyarrow.compute as pc
|
||||
import pyarrow.fs as pa_fs
|
||||
from lance import LanceDataset
|
||||
from lance.vector import vec_to_table
|
||||
from overrides import override
|
||||
|
||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
@@ -48,6 +49,7 @@ if TYPE_CHECKING:
|
||||
import PIL
|
||||
from lance.dataset import CleanupStats, ReaderLike
|
||||
|
||||
from ._lancedb import Table as LanceDBTable
|
||||
from .db import LanceDBConnection
|
||||
|
||||
|
||||
@@ -160,7 +162,7 @@ class Table(ABC):
|
||||
|
||||
Can query the table with [Table.search][lancedb.table.Table.search].
|
||||
|
||||
>>> table.search([0.4, 0.4]).select(["b"]).to_pandas()
|
||||
>>> table.search([0.4, 0.4]).select(["b", "vector"]).to_pandas()
|
||||
b vector _distance
|
||||
0 4 [0.5, 1.3] 0.82
|
||||
1 2 [1.1, 1.2] 1.13
|
||||
@@ -438,7 +440,7 @@ class Table(ABC):
|
||||
>>> query = [0.4, 1.4, 2.4]
|
||||
>>> (table.search(query)
|
||||
... .where("original_width > 1000", prefilter=True)
|
||||
... .select(["caption", "original_width"])
|
||||
... .select(["caption", "original_width", "vector"])
|
||||
... .limit(2)
|
||||
... .to_pandas())
|
||||
caption original_width vector _distance
|
||||
@@ -662,6 +664,56 @@ class Table(ABC):
|
||||
For most cases, the default should be fine.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_columns(self, transforms: Dict[str, str]):
|
||||
"""
|
||||
Add new columns with defined values.
|
||||
|
||||
This is not yet available in LanceDB Cloud.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
transforms: Dict[str, str]
|
||||
A map of column name to a SQL expression to use to calculate the
|
||||
value of the new column. These expressions will be evaluated for
|
||||
each row in the table, and can reference existing columns.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def alter_columns(self, alterations: Iterable[Dict[str, str]]):
|
||||
"""
|
||||
Alter column names and nullability.
|
||||
|
||||
This is not yet available in LanceDB Cloud.
|
||||
|
||||
alterations : Iterable[Dict[str, Any]]
|
||||
A sequence of dictionaries, each with the following keys:
|
||||
- "path": str
|
||||
The column path to alter. For a top-level column, this is the name.
|
||||
For a nested column, this is the dot-separated path, e.g. "a.b.c".
|
||||
- "name": str, optional
|
||||
The new name of the column. If not specified, the column name is
|
||||
not changed.
|
||||
- "nullable": bool, optional
|
||||
Whether the column should be nullable. If not specified, the column
|
||||
nullability is not changed. Only non-nullable columns can be changed
|
||||
to nullable. Currently, you cannot change a nullable column to
|
||||
non-nullable.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def drop_columns(self, columns: Iterable[str]):
|
||||
"""
|
||||
Drop columns from the table.
|
||||
|
||||
This is not yet available in LanceDB Cloud.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
columns : Iterable[str]
|
||||
The names of the columns to drop.
|
||||
"""
|
||||
|
||||
|
||||
class _LanceDatasetRef(ABC):
|
||||
@property
|
||||
@@ -1223,7 +1275,7 @@ class LanceTable(Table):
|
||||
>>> query = [0.4, 1.4, 2.4]
|
||||
>>> (table.search(query)
|
||||
... .where("original_width > 1000", prefilter=True)
|
||||
... .select(["caption", "original_width"])
|
||||
... .select(["caption", "original_width", "vector"])
|
||||
... .limit(2)
|
||||
... .to_pandas())
|
||||
caption original_width vector _distance
|
||||
@@ -1550,6 +1602,22 @@ class LanceTable(Table):
|
||||
"""
|
||||
return self.to_lance().optimize.compact_files(*args, **kwargs)
|
||||
|
||||
def add_columns(self, transforms: Dict[str, str]):
|
||||
self._dataset_mut.add_columns(transforms)
|
||||
|
||||
def alter_columns(self, *alterations: Iterable[Dict[str, str]]):
|
||||
modified = []
|
||||
# I called this name in pylance, but I think I regret that now. So we
|
||||
# allow both name and rename.
|
||||
for alter in alterations:
|
||||
if "rename" in alter:
|
||||
alter["name"] = alter.pop("rename")
|
||||
modified.append(alter)
|
||||
self._dataset_mut.alter_columns(*modified)
|
||||
|
||||
def drop_columns(self, columns: Iterable[str]):
|
||||
self._dataset_mut.drop_columns(columns)
|
||||
|
||||
|
||||
def _sanitize_schema(
|
||||
data: pa.Table,
|
||||
@@ -1728,3 +1796,715 @@ def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
is_full = np.any(~is_value_nan.reshape(-1, vec_arr.type.list_size), axis=1)
|
||||
data = data.filter(is_full)
|
||||
return data
|
||||
|
||||
|
||||
class AsyncTable(ABC):
|
||||
"""
|
||||
A Table is a collection of Records in a LanceDB Database.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
Create using [DBConnection.create_table][lancedb.DBConnection.create_table]
|
||||
(more examples in that method's documentation).
|
||||
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2}])
|
||||
>>> table.head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
child 0, item: float
|
||||
b: int64
|
||||
----
|
||||
vector: [[[1.1,1.2]]]
|
||||
b: [[2]]
|
||||
|
||||
Can append new data with [Table.add()][lancedb.table.Table.add].
|
||||
|
||||
>>> table.add([{"vector": [0.5, 1.3], "b": 4}])
|
||||
|
||||
Can query the table with [Table.search][lancedb.table.Table.search].
|
||||
|
||||
>>> table.search([0.4, 0.4]).select(["b", "vector"]).to_pandas()
|
||||
b vector _distance
|
||||
0 4 [0.5, 1.3] 0.82
|
||||
1 2 [1.1, 1.2] 1.13
|
||||
|
||||
Search queries are much faster when an index is created. See
|
||||
[Table.create_index][lancedb.table.Table.create_index].
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""The name of the table."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def schema(self) -> pa.Schema:
|
||||
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
|
||||
of this Table
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
"""
|
||||
Count the number of rows in the table.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filter: str, optional
|
||||
A SQL where clause to filter the rows to count.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def to_pandas(self) -> "pd.DataFrame":
|
||||
"""Return the table as a pandas DataFrame.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame
|
||||
"""
|
||||
return self.to_arrow().to_pandas()
|
||||
|
||||
@abstractmethod
|
||||
async def to_arrow(self) -> pa.Table:
|
||||
"""Return the table as a pyarrow Table.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pa.Table
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_index(
|
||||
self,
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
replace: bool = True,
|
||||
accelerator: Optional[str] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
):
|
||||
"""Create an index on the table.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric: str, default "L2"
|
||||
The distance metric to use when creating the index.
|
||||
Valid values are "L2", "cosine", or "dot".
|
||||
L2 is euclidean distance.
|
||||
num_partitions: int, default 256
|
||||
The number of IVF partitions to use when creating the index.
|
||||
Default is 256.
|
||||
num_sub_vectors: int, default 96
|
||||
The number of PQ sub-vectors to use when creating the index.
|
||||
Default is 96.
|
||||
vector_column_name: str, default "vector"
|
||||
The vector column name to create the index.
|
||||
replace: bool, default True
|
||||
- If True, replace the existing index if it exists.
|
||||
|
||||
- If False, raise an error if duplicate index exists.
|
||||
accelerator: str, default None
|
||||
If set, use the given accelerator to create the index.
|
||||
Only support "cuda" for now.
|
||||
index_cache_size : int, optional
|
||||
The size of the index cache in number of entries. Default value is 256.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def create_scalar_index(
|
||||
self,
|
||||
column: str,
|
||||
*,
|
||||
replace: bool = True,
|
||||
):
|
||||
"""Create a scalar index on a column.
|
||||
|
||||
Scalar indices, like vector indices, can be used to speed up scans. A scalar
|
||||
index can speed up scans that contain filter expressions on the indexed column.
|
||||
For example, the following scan will be faster if the column ``my_col`` has
|
||||
a scalar index:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import lancedb
|
||||
|
||||
db = lancedb.connect("/data/lance")
|
||||
img_table = db.open_table("images")
|
||||
my_df = img_table.search().where("my_col = 7", prefilter=True).to_pandas()
|
||||
|
||||
Scalar indices can also speed up scans containing a vector search and a
|
||||
prefilter:
|
||||
|
||||
.. code-block::python
|
||||
|
||||
import lancedb
|
||||
|
||||
db = lancedb.connect("/data/lance")
|
||||
img_table = db.open_table("images")
|
||||
img_table.search([1, 2, 3, 4], vector_column_name="vector")
|
||||
.where("my_col != 7", prefilter=True)
|
||||
.to_pandas()
|
||||
|
||||
Scalar indices can only speed up scans for basic filters using
|
||||
equality, comparison, range (e.g. ``my_col BETWEEN 0 AND 100``), and set
|
||||
membership (e.g. `my_col IN (0, 1, 2)`)
|
||||
|
||||
Scalar indices can be used if the filter contains multiple indexed columns and
|
||||
the filter criteria are AND'd or OR'd together
|
||||
(e.g. ``my_col < 0 AND other_col> 100``)
|
||||
|
||||
Scalar indices may be used if the filter contains non-indexed columns but,
|
||||
depending on the structure of the filter, they may not be usable. For example,
|
||||
if the column ``not_indexed`` does not have a scalar index then the filter
|
||||
``my_col = 0 OR not_indexed = 1`` will not be able to use any scalar index on
|
||||
``my_col``.
|
||||
|
||||
**Experimental API**
|
||||
|
||||
Parameters
|
||||
----------
|
||||
column : str
|
||||
The column to be indexed. Must be a boolean, integer, float,
|
||||
or string column.
|
||||
replace : bool, default True
|
||||
Replace the existing index if it exists.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import lance
|
||||
|
||||
dataset = lance.dataset("./images.lance")
|
||||
dataset.create_scalar_index("category")
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def add(
|
||||
self,
|
||||
data: DATA,
|
||||
mode: str = "append",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
):
|
||||
"""Add more data to the [Table](Table).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: DATA
|
||||
The data to insert into the table. Acceptable types are:
|
||||
|
||||
- dict or list-of-dict
|
||||
|
||||
- pandas.DataFrame
|
||||
|
||||
- pyarrow.Table or pyarrow.RecordBatch
|
||||
mode: str
|
||||
The mode to use when writing the data. Valid values are
|
||||
"append" and "overwrite".
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||
"""
|
||||
Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder]
|
||||
that can be used to create a "merge insert" operation
|
||||
|
||||
This operation can add rows, update rows, and remove rows all in a single
|
||||
transaction. It is a very generic tool that can be used to create
|
||||
behaviors like "insert if not exists", "update or insert (i.e. upsert)",
|
||||
or even replace a portion of existing data with new data (e.g. replace
|
||||
all data where month="january")
|
||||
|
||||
The merge insert operation works by combining new data from a
|
||||
**source table** with existing data in a **target table** by using a
|
||||
join. There are three categories of records.
|
||||
|
||||
"Matched" records are records that exist in both the source table and
|
||||
the target table. "Not matched" records exist only in the source table
|
||||
(e.g. these are new data) "Not matched by source" records exist only
|
||||
in the target table (this is old data)
|
||||
|
||||
The builder returned by this method can be used to customize what
|
||||
should happen for each category of data.
|
||||
|
||||
Please note that the data may appear to be reordered as part of this
|
||||
operation. This is because updated rows will be deleted from the
|
||||
dataset and then reinserted at the end with the new values.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
on: Union[str, Iterable[str]]
|
||||
A column (or columns) to join on. This is how records from the
|
||||
source table and target table are matched. Typically this is some
|
||||
kind of key or id column.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> data = pa.table({"a": [2, 1, 3], "b": ["a", "b", "c"]})
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data)
|
||||
>>> new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]})
|
||||
>>> # Perform a "upsert" operation
|
||||
>>> table.merge_insert("a") \\
|
||||
... .when_matched_update_all() \\
|
||||
... .when_not_matched_insert_all() \\
|
||||
... .execute(new_data)
|
||||
>>> # The order of new rows is non-deterministic since we use
|
||||
>>> # a hash-join as part of this operation and so we sort here
|
||||
>>> table.to_arrow().sort_by("a").to_pandas()
|
||||
a b
|
||||
0 1 b
|
||||
1 2 x
|
||||
2 3 y
|
||||
3 4 z
|
||||
"""
|
||||
on = [on] if isinstance(on, str) else list(on.iter())
|
||||
|
||||
return LanceMergeInsertBuilder(self, on)
|
||||
|
||||
@abstractmethod
|
||||
async def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: str = "auto",
|
||||
) -> LanceQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector. We currently support [vector search][search]
|
||||
and [full-text search][experimental-full-text-search].
|
||||
|
||||
All query options are defined in [Query][lancedb.query.Query].
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> data = [
|
||||
... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]},
|
||||
... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]},
|
||||
... {"original_width": 3000, "caption": "test", "vector": [0.3, 6.2, 2.6]}
|
||||
... ]
|
||||
>>> table = db.create_table("my_table", data)
|
||||
>>> query = [0.4, 1.4, 2.4]
|
||||
>>> (table.search(query)
|
||||
... .where("original_width > 1000", prefilter=True)
|
||||
... .select(["caption", "original_width", "vector"])
|
||||
... .limit(2)
|
||||
... .to_pandas())
|
||||
caption original_width vector _distance
|
||||
0 foo 2000 [0.5, 3.4, 1.3] 5.220000
|
||||
1 test 3000 [0.3, 6.2, 2.6] 23.089996
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: list/np.ndarray/str/PIL.Image.Image, default None
|
||||
The targetted vector to search for.
|
||||
|
||||
- *default None*.
|
||||
Acceptable types are: list, np.ndarray, PIL.Image.Image
|
||||
|
||||
- If None then the select/where/limit clauses are applied to filter
|
||||
the table
|
||||
vector_column_name: str, optional
|
||||
The name of the vector column to search.
|
||||
|
||||
The vector column needs to be a pyarrow fixed size list type
|
||||
|
||||
- If not specified then the vector column is inferred from
|
||||
the table schema
|
||||
|
||||
- If the table has multiple vector columns then the *vector_column_name*
|
||||
needs to be specified. Otherwise, an error is raised.
|
||||
query_type: str
|
||||
*default "auto"*.
|
||||
Acceptable types are: "vector", "fts", "hybrid", or "auto"
|
||||
|
||||
- If "auto" then the query type is inferred from the query;
|
||||
|
||||
- If `query` is a list/np.ndarray then the query type is
|
||||
"vector";
|
||||
|
||||
- If `query` is a PIL.Image.Image then either do vector search,
|
||||
or raise an error if no corresponding embedding function is found.
|
||||
|
||||
- If `query` is a string, then the query type is "vector" if the
|
||||
table has embedding functions else the query type is "fts"
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
A query builder object representing the query.
|
||||
Once executed, the query returns
|
||||
|
||||
- selected columns
|
||||
|
||||
- the vector
|
||||
|
||||
- and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def _execute_query(self, query: Query) -> pa.Table:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def _do_merge(
|
||||
self,
|
||||
merge: LanceMergeInsertBuilder,
|
||||
new_data: DATA,
|
||||
on_bad_vectors: str,
|
||||
fill_value: float,
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, where: str):
|
||||
"""Delete rows from the table.
|
||||
|
||||
This can be used to delete a single row, many rows, all rows, or
|
||||
sometimes no rows (if your predicate matches nothing).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
where: str
|
||||
The SQL where clause to use when deleting rows.
|
||||
|
||||
- For example, 'x = 2' or 'x IN (1, 2, 3)'.
|
||||
|
||||
The filter must not be empty, or it will error.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> data = [
|
||||
... {"x": 1, "vector": [1, 2]},
|
||||
... {"x": 2, "vector": [3, 4]},
|
||||
... {"x": 3, "vector": [5, 6]}
|
||||
... ]
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data)
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 1 [1.0, 2.0]
|
||||
1 2 [3.0, 4.0]
|
||||
2 3 [5.0, 6.0]
|
||||
>>> table.delete("x = 2")
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 1 [1.0, 2.0]
|
||||
1 3 [5.0, 6.0]
|
||||
|
||||
If you have a list of values to delete, you can combine them into a
|
||||
stringified list and use the `IN` operator:
|
||||
|
||||
>>> to_remove = [1, 5]
|
||||
>>> to_remove = ", ".join([str(v) for v in to_remove])
|
||||
>>> to_remove
|
||||
'1, 5'
|
||||
>>> table.delete(f"x IN ({to_remove})")
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 3 [5.0, 6.0]
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def update(
|
||||
self,
|
||||
where: Optional[str] = None,
|
||||
values: Optional[dict] = None,
|
||||
*,
|
||||
values_sql: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""
|
||||
This can be used to update zero to all rows depending on how many
|
||||
rows match the where clause. If no where clause is provided, then
|
||||
all rows will be updated.
|
||||
|
||||
Either `values` or `values_sql` must be provided. You cannot provide
|
||||
both.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
where: str, optional
|
||||
The SQL where clause to use when updating rows. For example, 'x = 2'
|
||||
or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
|
||||
values: dict, optional
|
||||
The values to update. The keys are the column names and the values
|
||||
are the values to set.
|
||||
values_sql: dict, optional
|
||||
The values to update, expressed as SQL expression strings. These can
|
||||
reference existing columns. For example, {"x": "x + 1"} will increment
|
||||
the x column by 1.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> import pandas as pd
|
||||
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data)
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 1 [1.0, 2.0]
|
||||
1 2 [3.0, 4.0]
|
||||
2 3 [5.0, 6.0]
|
||||
>>> table.update(where="x = 2", values={"vector": [10, 10]})
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 1 [1.0, 2.0]
|
||||
1 3 [5.0, 6.0]
|
||||
2 2 [10.0, 10.0]
|
||||
>>> table.update(values_sql={"x": "x + 1"})
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 2 [1.0, 2.0]
|
||||
1 4 [5.0, 6.0]
|
||||
2 3 [10.0, 10.0]
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def cleanup_old_versions(
|
||||
self,
|
||||
older_than: Optional[timedelta] = None,
|
||||
*,
|
||||
delete_unverified: bool = False,
|
||||
) -> CleanupStats:
|
||||
"""
|
||||
Clean up old versions of the table, freeing disk space.
|
||||
|
||||
Note: This function is not available in LanceDb Cloud (since LanceDb
|
||||
Cloud manages cleanup for you automatically)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
older_than: timedelta, default None
|
||||
The minimum age of the version to delete. If None, then this defaults
|
||||
to two weeks.
|
||||
delete_unverified: bool, default False
|
||||
Because they may be part of an in-progress transaction, files newer
|
||||
than 7 days old are not deleted by default. If you are sure that
|
||||
there are no in-progress transactions, then you can set this to True
|
||||
to delete all files older than `older_than`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
CleanupStats
|
||||
The stats of the cleanup operation, including how many bytes were
|
||||
freed.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def compact_files(self, *args, **kwargs):
|
||||
"""
|
||||
Run the compaction process on the table.
|
||||
|
||||
Note: This function is not available in LanceDb Cloud (since LanceDb
|
||||
Cloud manages compaction for you automatically)
|
||||
|
||||
This can be run after making several small appends to optimize the table
|
||||
for faster reads.
|
||||
|
||||
Arguments are passed onto :meth:`lance.dataset.DatasetOptimizer.compact_files`.
|
||||
For most cases, the default should be fine.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def add_columns(self, transforms: Dict[str, str]):
|
||||
"""
|
||||
Add new columns with defined values.
|
||||
|
||||
This is not yet available in LanceDB Cloud.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
transforms: Dict[str, str]
|
||||
A map of column name to a SQL expression to use to calculate the
|
||||
value of the new column. These expressions will be evaluated for
|
||||
each row in the table, and can reference existing columns.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def alter_columns(self, alterations: Iterable[Dict[str, str]]):
|
||||
"""
|
||||
Alter column names and nullability.
|
||||
|
||||
This is not yet available in LanceDB Cloud.
|
||||
|
||||
alterations : Iterable[Dict[str, Any]]
|
||||
A sequence of dictionaries, each with the following keys:
|
||||
- "path": str
|
||||
The column path to alter. For a top-level column, this is the name.
|
||||
For a nested column, this is the dot-separated path, e.g. "a.b.c".
|
||||
- "name": str, optional
|
||||
The new name of the column. If not specified, the column name is
|
||||
not changed.
|
||||
- "nullable": bool, optional
|
||||
Whether the column should be nullable. If not specified, the column
|
||||
nullability is not changed. Only non-nullable columns can be changed
|
||||
to nullable. Currently, you cannot change a nullable column to
|
||||
non-nullable.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def drop_columns(self, columns: Iterable[str]):
|
||||
"""
|
||||
Drop columns from the table.
|
||||
|
||||
This is not yet available in LanceDB Cloud.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
columns : Iterable[str]
|
||||
The names of the columns to drop.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AsyncLanceTable(AsyncTable):
|
||||
def __init__(self, table: LanceDBTable):
|
||||
self._inner = table
|
||||
|
||||
@property
|
||||
@override
|
||||
def name(self) -> str:
|
||||
return self._inner.name()
|
||||
|
||||
@override
|
||||
async def schema(self) -> pa.Schema:
|
||||
return await self._inner.schema()
|
||||
|
||||
@override
|
||||
async def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
async def to_pandas(self) -> "pd.DataFrame":
|
||||
return self.to_arrow().to_pandas()
|
||||
|
||||
@override
|
||||
async def to_arrow(self) -> pa.Table:
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_index(
|
||||
self,
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
replace: bool = True,
|
||||
accelerator: Optional[str] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def create_scalar_index(
|
||||
self,
|
||||
column: str,
|
||||
*,
|
||||
replace: bool = True,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def add(
|
||||
self,
|
||||
data: DATA,
|
||||
mode: str = "append",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||
on = [on] if isinstance(on, str) else list(on.iter())
|
||||
|
||||
return LanceMergeInsertBuilder(self, on)
|
||||
|
||||
@override
|
||||
async def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: str = "auto",
|
||||
) -> LanceQueryBuilder:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def _execute_query(self, query: Query) -> pa.Table:
|
||||
pass
|
||||
|
||||
@override
|
||||
async def _do_merge(
|
||||
self,
|
||||
merge: LanceMergeInsertBuilder,
|
||||
new_data: DATA,
|
||||
on_bad_vectors: str,
|
||||
fill_value: float,
|
||||
):
|
||||
pass
|
||||
|
||||
@override
|
||||
async def delete(self, where: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def update(
|
||||
self,
|
||||
where: Optional[str] = None,
|
||||
values: Optional[dict] = None,
|
||||
*,
|
||||
values_sql: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def cleanup_old_versions(
|
||||
self,
|
||||
older_than: Optional[timedelta] = None,
|
||||
*,
|
||||
delete_unverified: bool = False,
|
||||
) -> CleanupStats:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def compact_files(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def add_columns(self, transforms: Dict[str, str]):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def alter_columns(self, alterations: Iterable[Dict[str, str]]):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def drop_columns(self, columns: Iterable[str]):
|
||||
raise NotImplementedError
|
||||
@@ -1,5 +1,4 @@
|
||||
from click.testing import CliRunner
|
||||
|
||||
from lancedb.cli.cli import cli
|
||||
from lancedb.utils import CONFIG
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from lancedb.context import contextualize
|
||||
|
||||
|
||||
@@ -11,12 +11,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
|
||||
@@ -166,6 +165,24 @@ def test_table_names(tmp_path):
|
||||
assert db.table_names() == ["test1", "test2", "test3"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_names_async(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
db.create_table("test2", data=data)
|
||||
db.create_table("test1", data=data)
|
||||
db.create_table("test3", data=data)
|
||||
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
assert await db.table_names() == ["test1", "test2", "test3"]
|
||||
|
||||
|
||||
def test_create_mode(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pd.DataFrame(
|
||||
@@ -233,6 +250,78 @@ def test_create_exist_ok(tmp_path):
|
||||
db.create_table("test", schema=bad_schema, exist_ok=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_mode_async(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
await db.create_table("test", data=data)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await db.create_table("test", data=data)
|
||||
|
||||
new_data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["fizz", "buzz"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
_tbl = await db.create_table("test", data=new_data, mode="overwrite")
|
||||
|
||||
# MIGRATION: to_pandas() is not available in async
|
||||
# assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_exist_ok_async(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
tbl = await db.create_table("test", data=data)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await db.create_table("test", data=data)
|
||||
|
||||
# open the table but don't add more rows
|
||||
tbl2 = await db.create_table("test", data=data, exist_ok=True)
|
||||
assert tbl.name == tbl2.name
|
||||
assert await tbl.schema() == await tbl2.schema()
|
||||
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), list_size=2)),
|
||||
pa.field("item", pa.utf8()),
|
||||
pa.field("price", pa.float64()),
|
||||
]
|
||||
)
|
||||
tbl3 = await db.create_table("test", schema=schema, exist_ok=True)
|
||||
assert await tbl3.schema() == schema
|
||||
|
||||
# Migration: When creating a table, but the table already exists, but
|
||||
# the schema is different, it should raise an error.
|
||||
# bad_schema = pa.schema(
|
||||
# [
|
||||
# pa.field("vector", pa.list_(pa.float32(), list_size=2)),
|
||||
# pa.field("item", pa.utf8()),
|
||||
# pa.field("price", pa.float64()),
|
||||
# pa.field("extra", pa.float32()),
|
||||
# ]
|
||||
# )
|
||||
# with pytest.raises(ValueError):
|
||||
# await db.create_table("test", schema=bad_schema, exist_ok=True)
|
||||
|
||||
|
||||
def test_delete_table(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pd.DataFrame(
|
||||
@@ -13,7 +13,6 @@
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lancedb import LanceDBConnection
|
||||
|
||||
# TODO: setup integ test mark and script
|
||||
@@ -13,11 +13,10 @@
|
||||
import sys
|
||||
|
||||
import lance
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
import lancedb
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.embeddings import (
|
||||
EmbeddingFunctionConfig,
|
||||
@@ -14,12 +14,11 @@ import importlib
|
||||
import io
|
||||
import os
|
||||
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
import lancedb
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
@@ -28,6 +27,23 @@ from lancedb.pydantic import LanceModel, Vector
|
||||
# or connection to external api
|
||||
|
||||
|
||||
try:
|
||||
if importlib.util.find_spec("mlx.core") is not None:
|
||||
_mlx = True
|
||||
else:
|
||||
_mlx = None
|
||||
except Exception:
|
||||
_mlx = None
|
||||
|
||||
try:
|
||||
if importlib.util.find_spec("imagebind") is not None:
|
||||
_imagebind = True
|
||||
else:
|
||||
_imagebind = None
|
||||
except Exception:
|
||||
_imagebind = None
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
|
||||
def test_basic_text_embeddings(alias, tmp_path):
|
||||
@@ -158,6 +174,88 @@ def test_openclip(tmp_path):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
_imagebind is None,
|
||||
reason="skip if imagebind not installed.",
|
||||
)
|
||||
@pytest.mark.slow
|
||||
def test_imagebind(tmp_path):
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import lancedb.embeddings.imagebind
|
||||
import pandas as pd
|
||||
import requests
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
print(f"Created temporary directory {temp_dir}")
|
||||
|
||||
def download_images(image_uris):
|
||||
downloaded_image_paths = []
|
||||
for uri in image_uris:
|
||||
try:
|
||||
response = requests.get(uri, stream=True)
|
||||
if response.status_code == 200:
|
||||
# Extract image name from URI
|
||||
image_name = os.path.basename(uri)
|
||||
image_path = os.path.join(temp_dir, image_name)
|
||||
with open(image_path, "wb") as out_file:
|
||||
shutil.copyfileobj(response.raw, out_file)
|
||||
downloaded_image_paths.append(image_path)
|
||||
except Exception as e: # noqa: PERF203
|
||||
print(f"Failed to download {uri}. Error: {e}")
|
||||
return temp_dir, downloaded_image_paths
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = get_registry()
|
||||
func = registry.get("imagebind").create(max_retries=0)
|
||||
|
||||
class Images(LanceModel):
|
||||
label: str
|
||||
image_uri: str = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = db.create_table("images", schema=Images)
|
||||
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||
]
|
||||
temp_dir, downloaded_images = download_images(uris)
|
||||
table.add(pd.DataFrame({"label": labels, "image_uri": downloaded_images}))
|
||||
# text search
|
||||
actual = (
|
||||
table.search("man's best friend", vector_column_name="vector")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == "dog"
|
||||
|
||||
# image search
|
||||
query_image_uri = [
|
||||
"https://live.staticflickr.com/65535/33336453970_491665f66e_h.jpg"
|
||||
]
|
||||
temp_dir, downloaded_images = download_images(query_image_uri)
|
||||
query_image_uri = downloaded_images[0]
|
||||
actual = (
|
||||
table.search(query_image_uri, vector_column_name="vector")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == "dog"
|
||||
|
||||
if os.path.isdir(temp_dir):
|
||||
shutil.rmtree(temp_dir)
|
||||
print(f"Deleted temporary directory {temp_dir}")
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||
@@ -217,13 +315,6 @@ def test_gemini_embedding(tmp_path):
|
||||
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||
|
||||
|
||||
try:
|
||||
if importlib.util.find_spec("mlx.core") is not None:
|
||||
_mlx = True
|
||||
except ImportError:
|
||||
_mlx = None
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
_mlx is None,
|
||||
reason="mlx tests only required for apple users.",
|
||||
@@ -14,13 +14,13 @@ import os
|
||||
import random
|
||||
from unittest import mock
|
||||
|
||||
import lancedb as ldb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import tantivy
|
||||
|
||||
import lancedb as ldb
|
||||
import lancedb.fts
|
||||
pytest.importorskip("lancedb.fts")
|
||||
tantivy = pytest.importorskip("tantivy")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -13,9 +13,8 @@
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
import lancedb
|
||||
import pytest
|
||||
|
||||
# You need to setup AWS credentials an a base path to run this test. Example
|
||||
# AWS_PROFILE=default TEST_S3_BASE_URL=s3://my_bucket/dataset pytest tests/test_io.py
|
||||
@@ -20,9 +20,8 @@ from typing import List, Optional, Tuple
|
||||
import pyarrow as pa
|
||||
import pydantic
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -18,7 +18,6 @@ import numpy as np
|
||||
import pandas.testing as tm
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.query import LanceVectorQueryBuilder, Query
|
||||
@@ -88,13 +87,24 @@ def test_query_builder(table):
|
||||
rs = (
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||
.limit(1)
|
||||
.select(["id"])
|
||||
.select(["id", "vector"])
|
||||
.to_list()
|
||||
)
|
||||
assert rs[0]["id"] == 1
|
||||
assert all(np.array(rs[0]["vector"]) == [1, 2])
|
||||
|
||||
|
||||
def test_dynamic_projection(table):
|
||||
rs = (
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||
.limit(1)
|
||||
.select({"id": "id", "id2": "id * 2"})
|
||||
.to_list()
|
||||
)
|
||||
assert rs[0]["id"] == 1
|
||||
assert rs[0]["id2"] == 2
|
||||
|
||||
|
||||
def test_query_builder_with_filter(table):
|
||||
rs = LanceVectorQueryBuilder(table, [0, 0], "vector").where("id = 2").to_list()
|
||||
assert rs[0]["id"] == 2
|
||||
@@ -17,7 +17,6 @@ import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from lancedb.remote.client import RestfulLanceDBClient, VectorQuery
|
||||
|
||||
|
||||
@@ -11,9 +11,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
import lancedb
|
||||
import pyarrow as pa
|
||||
from lancedb.remote.client import VectorQuery, VectorQueryResult
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import os
|
||||
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import lancedb
|
||||
from lancedb.conftest import MockTextEmbeddingFunction # noqa
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
@@ -15,6 +14,9 @@ from lancedb.rerankers import (
|
||||
)
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
# Tests rely on FTS index
|
||||
pytest.importorskip("lancedb.fts")
|
||||
|
||||
|
||||
def get_test_table(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
@@ -20,19 +20,18 @@ from typing import List
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
import lance
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
import lancedb
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.table import LanceTable
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class MockDB:
|
||||
@@ -804,6 +803,9 @@ def test_count_rows(db):
|
||||
|
||||
|
||||
def test_hybrid_search(db, tmp_path):
|
||||
# This test uses an FTS index
|
||||
pytest.importorskip("lancedb.fts")
|
||||
|
||||
db = MockDB(str(tmp_path))
|
||||
# Create a LanceDB table schema with a vector and a text column
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
@@ -898,3 +900,29 @@ def test_restore_consistency(tmp_path):
|
||||
table.add([{"id": 2}])
|
||||
assert table_fixed.version == table.version - 1
|
||||
assert table_ref_latest.version == table.version
|
||||
|
||||
|
||||
# Schema evolution
|
||||
def test_add_columns(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pa.table({"id": [0, 1]})
|
||||
table = LanceTable.create(db, "my_table", data=data)
|
||||
table.add_columns({"new_col": "id + 2"})
|
||||
assert table.to_arrow().column_names == ["id", "new_col"]
|
||||
assert table.to_arrow()["new_col"].to_pylist() == [2, 3]
|
||||
|
||||
|
||||
def test_alter_columns(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pa.table({"id": [0, 1]})
|
||||
table = LanceTable.create(db, "my_table", data=data)
|
||||
table.alter_columns({"path": "id", "rename": "new_id"})
|
||||
assert table.to_arrow().column_names == ["new_id"]
|
||||
|
||||
|
||||
def test_drop_columns(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pa.table({"id": [0, 1], "category": ["a", "b"]})
|
||||
table = LanceTable.create(db, "my_table", data=data)
|
||||
table.drop_columns(["category"])
|
||||
assert table.to_arrow().column_names == ["id"]
|
||||
@@ -1,8 +1,7 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
import lancedb
|
||||
import pytest
|
||||
from lancedb.utils.events import _Events
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ import os
|
||||
import pathlib
|
||||
|
||||
import pytest
|
||||
|
||||
from lancedb.util import get_uri_scheme, join_uri
|
||||
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import setuptools
|
||||
|
||||
if __name__ == "__main__":
|
||||
setuptools.setup()
|
||||
125
python/src/connection.rs
Normal file
125
python/src/connection.rs
Normal file
@@ -0,0 +1,125 @@
|
||||
// Copyright 2024 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
|
||||
use lancedb::connection::{Connection as LanceConnection, CreateTableMode};
|
||||
use pyo3::{
|
||||
exceptions::PyValueError, pyclass, pyfunction, pymethods, PyAny, PyRef, PyResult, Python,
|
||||
};
|
||||
use pyo3_asyncio::tokio::future_into_py;
|
||||
|
||||
use crate::{error::PythonErrorExt, table::Table};
|
||||
|
||||
#[pyclass]
|
||||
pub struct Connection {
|
||||
inner: LanceConnection,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
fn parse_create_mode_str(mode: &str) -> PyResult<CreateTableMode> {
|
||||
match mode {
|
||||
"create" => Ok(CreateTableMode::Create),
|
||||
"overwrite" => Ok(CreateTableMode::Overwrite),
|
||||
"exist_ok" => Ok(CreateTableMode::exist_ok(|builder| builder)),
|
||||
_ => Err(PyValueError::new_err(format!("Invalid mode {}", mode))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl Connection {
|
||||
pub fn table_names(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.table_names().await.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn create_table<'a>(
|
||||
self_: PyRef<'a, Self>,
|
||||
name: String,
|
||||
mode: &str,
|
||||
data: &PyAny,
|
||||
) -> PyResult<&'a PyAny> {
|
||||
let inner = self_.inner.clone();
|
||||
|
||||
let mode = Self::parse_create_mode_str(mode)?;
|
||||
|
||||
let batches = Box::new(ArrowArrayStreamReader::from_pyarrow(data)?);
|
||||
future_into_py(self_.py(), async move {
|
||||
let table = inner
|
||||
.create_table(name, batches)
|
||||
.mode(mode)
|
||||
.execute()
|
||||
.await
|
||||
.infer_error()?;
|
||||
Ok(Table::new(table))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn create_empty_table<'a>(
|
||||
self_: PyRef<'a, Self>,
|
||||
name: String,
|
||||
mode: &str,
|
||||
schema: &PyAny,
|
||||
) -> PyResult<&'a PyAny> {
|
||||
let inner = self_.inner.clone();
|
||||
|
||||
let mode = Self::parse_create_mode_str(mode)?;
|
||||
|
||||
let schema = Schema::from_pyarrow(schema)?;
|
||||
|
||||
future_into_py(self_.py(), async move {
|
||||
let table = inner
|
||||
.create_empty_table(name, Arc::new(schema))
|
||||
.mode(mode)
|
||||
.execute()
|
||||
.await
|
||||
.infer_error()?;
|
||||
Ok(Table::new(table))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
pub fn connect(
|
||||
py: Python,
|
||||
uri: String,
|
||||
api_key: Option<String>,
|
||||
region: Option<String>,
|
||||
host_override: Option<String>,
|
||||
read_consistency_interval: Option<f64>,
|
||||
) -> PyResult<&PyAny> {
|
||||
future_into_py(py, async move {
|
||||
let mut builder = lancedb::connect(&uri);
|
||||
if let Some(api_key) = api_key {
|
||||
builder = builder.api_key(&api_key);
|
||||
}
|
||||
if let Some(region) = region {
|
||||
builder = builder.region(®ion);
|
||||
}
|
||||
if let Some(host_override) = host_override {
|
||||
builder = builder.host_override(&host_override);
|
||||
}
|
||||
if let Some(read_consistency_interval) = read_consistency_interval {
|
||||
let read_consistency_interval = Duration::from_secs_f64(read_consistency_interval);
|
||||
builder = builder.read_consistency_interval(read_consistency_interval);
|
||||
}
|
||||
Ok(Connection {
|
||||
inner: builder.execute().await.infer_error()?,
|
||||
})
|
||||
})
|
||||
}
|
||||
64
python/src/error.rs
Normal file
64
python/src/error.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
// Copyright 2024 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use pyo3::{
|
||||
exceptions::{PyOSError, PyRuntimeError, PyValueError},
|
||||
PyResult,
|
||||
};
|
||||
|
||||
use lancedb::error::Error as LanceError;
|
||||
|
||||
pub trait PythonErrorExt<T> {
|
||||
/// Convert to a python error based on the Lance error type
|
||||
fn infer_error(self) -> PyResult<T>;
|
||||
/// Convert to OSError
|
||||
fn os_error(self) -> PyResult<T>;
|
||||
/// Convert to RuntimeError
|
||||
fn runtime_error(self) -> PyResult<T>;
|
||||
/// Convert to ValueError
|
||||
fn value_error(self) -> PyResult<T>;
|
||||
}
|
||||
|
||||
impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
|
||||
fn infer_error(self) -> PyResult<T> {
|
||||
match &self {
|
||||
Ok(_) => Ok(self.unwrap()),
|
||||
Err(err) => match err {
|
||||
LanceError::InvalidInput { .. } => self.value_error(),
|
||||
LanceError::InvalidTableName { .. } => self.value_error(),
|
||||
LanceError::TableNotFound { .. } => self.value_error(),
|
||||
LanceError::Schema { .. } => self.value_error(),
|
||||
LanceError::CreateDir { .. } => self.os_error(),
|
||||
LanceError::TableAlreadyExists { .. } => self.runtime_error(),
|
||||
LanceError::Store { .. } => self.runtime_error(),
|
||||
LanceError::Lance { .. } => self.runtime_error(),
|
||||
LanceError::Runtime { .. } => self.runtime_error(),
|
||||
LanceError::Http { .. } => self.runtime_error(),
|
||||
LanceError::Arrow { .. } => self.runtime_error(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn os_error(self) -> PyResult<T> {
|
||||
self.map_err(|err| PyOSError::new_err(err.to_string()))
|
||||
}
|
||||
|
||||
fn runtime_error(self) -> PyResult<T> {
|
||||
self.map_err(|err| PyRuntimeError::new_err(err.to_string()))
|
||||
}
|
||||
|
||||
fn value_error(self) -> PyResult<T> {
|
||||
self.map_err(|err| PyValueError::new_err(err.to_string()))
|
||||
}
|
||||
}
|
||||
33
python/src/lib.rs
Normal file
33
python/src/lib.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
// Copyright 2024 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use connection::{connect, Connection};
|
||||
use env_logger::Env;
|
||||
use pyo3::{pymodule, types::PyModule, wrap_pyfunction, PyResult, Python};
|
||||
|
||||
pub mod connection;
|
||||
pub mod error;
|
||||
pub mod table;
|
||||
|
||||
#[pymodule]
|
||||
pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
let env = Env::new()
|
||||
.filter_or("LANCEDB_LOG", "warn")
|
||||
.write_style("LANCEDB_LOG_STYLE");
|
||||
env_logger::init_from_env(env);
|
||||
m.add_class::<Connection>()?;
|
||||
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||
Ok(())
|
||||
}
|
||||
34
python/src/table.rs
Normal file
34
python/src/table.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::pyarrow::ToPyArrow;
|
||||
use lancedb::table::Table as LanceTable;
|
||||
use pyo3::{pyclass, pymethods, PyAny, PyRef, PyResult, Python};
|
||||
use pyo3_asyncio::tokio::future_into_py;
|
||||
|
||||
use crate::error::PythonErrorExt;
|
||||
|
||||
#[pyclass]
|
||||
pub struct Table {
|
||||
inner: Arc<dyn LanceTable>,
|
||||
}
|
||||
|
||||
impl Table {
|
||||
pub(crate) fn new(inner: Arc<dyn LanceTable>) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl Table {
|
||||
pub fn name(&self) -> String {
|
||||
self.inner.name().to_string()
|
||||
}
|
||||
|
||||
pub fn schema(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let schema = inner.schema().await.infer_error()?;
|
||||
Python::with_gil(|py| schema.to_pyarrow(py))
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user