mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 02:42:57 +00:00
feat: add a basic async python client starting point (#1014)
This changes `lancedb` from a "pure python" setuptools project to a maturin project and adds a rust lancedb dependency. The async python client is extremely minimal (only `connect` and `Connection.table_names` are supported). The purpose of this PR is to get the infrastructure in place for building out the rest of the async client. Although this is not technically a breaking change (no APIs are changing) it is still a considerable change in the way the wheels are built because they now include the native shared library.
This commit is contained in:
26
python/Cargo.toml
Normal file
26
python/Cargo.toml
Normal file
@@ -0,0 +1,26 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.4.10"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
|
||||
|
||||
[lib]
|
||||
name = "_lancedb"
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
lancedb = { path = "../rust/lancedb" }
|
||||
env_logger = "0.10"
|
||||
pyo3 = { version = "0.20", features = ["extension-module", "abi3-py38"] }
|
||||
pyo3-asyncio = { version = "0.20", features = ["attributes", "tokio-runtime"] }
|
||||
|
||||
# Prevent dynamic linking of lzma, which comes from datafusion
|
||||
lzma-sys = { version = "*", features = ["static"] }
|
||||
|
||||
[build-dependencies]
|
||||
pyo3-build-config = { version = "0.20.3", features = ["extension-module", "abi3-py38"] }
|
||||
@@ -20,10 +20,10 @@ results = table.search([0.1, 0.3]).limit(20).to_list()
|
||||
print(results)
|
||||
```
|
||||
|
||||
|
||||
## Development
|
||||
|
||||
Create a virtual environment and activate it:
|
||||
LanceDb is based on the rust crate `lancedb` and is built with maturin. In order to build with maturin
|
||||
you will either need a conda environment or a virtual environment (venv).
|
||||
|
||||
```bash
|
||||
python -m venv venv
|
||||
@@ -33,7 +33,15 @@ python -m venv venv
|
||||
Install the necessary packages:
|
||||
|
||||
```bash
|
||||
python -m pip install .
|
||||
python -m pip install .[tests,dev]
|
||||
```
|
||||
|
||||
To build the python package you can use maturin:
|
||||
|
||||
```bash
|
||||
# This will build the rust bindings and place them in the appropriate place
|
||||
# in your venv or conda environment
|
||||
matruin develop
|
||||
```
|
||||
|
||||
To run the unit tests:
|
||||
@@ -45,7 +53,7 @@ pytest
|
||||
To run the doc tests:
|
||||
|
||||
```bash
|
||||
pytest --doctest-modules lancedb
|
||||
pytest --doctest-modules python/lancedb
|
||||
```
|
||||
|
||||
To run linter and automatically fix all errors:
|
||||
@@ -61,31 +69,27 @@ If any packages are missing, install them with:
|
||||
pip install <PACKAGE_NAME>
|
||||
```
|
||||
|
||||
|
||||
___
|
||||
For **Windows** users, there may be errors when installing packages, so these commands may be helpful:
|
||||
|
||||
Activate the virtual environment:
|
||||
|
||||
```bash
|
||||
. .\venv\Scripts\activate
|
||||
```
|
||||
|
||||
You may need to run the installs separately:
|
||||
|
||||
```bash
|
||||
pip install -e .[tests]
|
||||
pip install -e .[dev]
|
||||
```
|
||||
|
||||
|
||||
`tantivy` requires `rust` to be installed, so install it with `conda`, as it doesn't support windows installation:
|
||||
|
||||
```bash
|
||||
pip install wheel
|
||||
pip install cargo
|
||||
conda install rust
|
||||
pip install tantivy
|
||||
```
|
||||
|
||||
To run the unit tests:
|
||||
```bash
|
||||
pytest
|
||||
```
|
||||
|
||||
3
python/build.rs
Normal file
3
python/build.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
fn main() {
|
||||
pyo3_build_config::add_extension_module_link_args();
|
||||
}
|
||||
@@ -14,7 +14,7 @@ dependencies = [
|
||||
"pyyaml>=6.0",
|
||||
"click>=8.1.7",
|
||||
"requests>=2.31.0",
|
||||
"overrides>=0.7"
|
||||
"overrides>=0.7",
|
||||
]
|
||||
description = "lancedb"
|
||||
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
|
||||
@@ -26,7 +26,7 @@ keywords = [
|
||||
"data-science",
|
||||
"machine-learning",
|
||||
"arrow",
|
||||
"data-analytics"
|
||||
"data-analytics",
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
@@ -48,19 +48,51 @@ classifiers = [
|
||||
repository = "https://github.com/lancedb/lancedb"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars>=0.19"]
|
||||
tests = [
|
||||
"aiohttp",
|
||||
"pandas>=1.4",
|
||||
"pytest",
|
||||
"pytest-mock",
|
||||
"pytest-asyncio",
|
||||
"duckdb",
|
||||
"pytz",
|
||||
"polars>=0.19",
|
||||
]
|
||||
dev = ["ruff", "pre-commit"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]", "mkdocs-ultralytics-plugin==0.0.44"]
|
||||
docs = [
|
||||
"mkdocs",
|
||||
"mkdocs-jupyter",
|
||||
"mkdocs-material",
|
||||
"mkdocstrings[python]",
|
||||
"mkdocs-ultralytics-plugin==0.0.44",
|
||||
]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "huggingface_hub",
|
||||
"InstructorEmbedding", "google.generativeai", "boto3>=1.28.57", "awscli>=1.29.57", "botocore>=1.31.57"]
|
||||
embeddings = [
|
||||
"openai>=1.6.1",
|
||||
"sentence-transformers",
|
||||
"torch",
|
||||
"pillow",
|
||||
"open-clip-torch",
|
||||
"cohere",
|
||||
"huggingface_hub",
|
||||
"InstructorEmbedding",
|
||||
"google.generativeai",
|
||||
"boto3>=1.28.57",
|
||||
"awscli>=1.29.57",
|
||||
"botocore>=1.31.57",
|
||||
]
|
||||
|
||||
[tool.maturin]
|
||||
python-source = "python"
|
||||
module-name = "lancedb._lancedb"
|
||||
|
||||
[project.scripts]
|
||||
lancedb = "lancedb.cli.cli:cli"
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
requires = ["maturin>=1.4"]
|
||||
build-backend = "maturin"
|
||||
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["F", "E", "W", "I", "G", "TCH", "PERF"]
|
||||
@@ -70,5 +102,5 @@ addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py"
|
||||
|
||||
markers = [
|
||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||
"asyncio"
|
||||
"asyncio",
|
||||
]
|
||||
|
||||
@@ -19,8 +19,9 @@ from typing import Optional, Union
|
||||
|
||||
__version__ = importlib.metadata.version("lancedb")
|
||||
|
||||
from .common import URI
|
||||
from .db import DBConnection, LanceDBConnection
|
||||
from ._lancedb import connect as lancedb_connect
|
||||
from .common import URI, sanitize_uri
|
||||
from .db import AsyncConnection, AsyncLanceDBConnection, DBConnection, LanceDBConnection
|
||||
from .remote.db import RemoteDBConnection
|
||||
from .schema import vector # noqa: F401
|
||||
from .utils import sentry_log # noqa: F401
|
||||
@@ -101,3 +102,74 @@ def connect(
|
||||
uri, api_key, region, host_override, request_thread_pool=request_thread_pool
|
||||
)
|
||||
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)
|
||||
|
||||
|
||||
async def connect_async(
|
||||
uri: URI,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
region: str = "us-east-1",
|
||||
host_override: Optional[str] = None,
|
||||
read_consistency_interval: Optional[timedelta] = None,
|
||||
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
|
||||
) -> AsyncConnection:
|
||||
"""Connect to a LanceDB database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uri: str or Path
|
||||
The uri of the database.
|
||||
api_key: str, optional
|
||||
If present, connect to LanceDB cloud.
|
||||
Otherwise, connect to a database on file system or cloud storage.
|
||||
Can be set via environment variable `LANCEDB_API_KEY`.
|
||||
region: str, default "us-east-1"
|
||||
The region to use for LanceDB Cloud.
|
||||
host_override: str, optional
|
||||
The override url for LanceDB Cloud.
|
||||
read_consistency_interval: timedelta, default None
|
||||
(For LanceDB OSS only)
|
||||
The interval at which to check for updates to the table from other
|
||||
processes. If None, then consistency is not checked. For performance
|
||||
reasons, this is the default. For strong consistency, set this to
|
||||
zero seconds. Then every read will check for updates from other
|
||||
processes. As a compromise, you can set this to a non-zero timedelta
|
||||
for eventual consistency. If more than that interval has passed since
|
||||
the last check, then the table will be checked for updates. Note: this
|
||||
consistency only applies to read operations. Write operations are
|
||||
always consistent.
|
||||
request_thread_pool: int or ThreadPoolExecutor, optional
|
||||
The thread pool to use for making batch requests to the LanceDB Cloud API.
|
||||
If an integer, then a ThreadPoolExecutor will be created with that
|
||||
number of threads. If None, then a ThreadPoolExecutor will be created
|
||||
with the default number of threads. If a ThreadPoolExecutor, then that
|
||||
executor will be used for making requests. This is for LanceDB Cloud
|
||||
only and is only used when making batch requests (i.e., passing in
|
||||
multiple queries to the search method at once).
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
For a local directory, provide a path for the database:
|
||||
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("~/.lancedb")
|
||||
|
||||
For object storage, use a URI prefix:
|
||||
|
||||
>>> db = lancedb.connect("s3://my-bucket/lancedb")
|
||||
|
||||
Connect to LancdDB cloud:
|
||||
|
||||
>>> db = lancedb.connect("db://my_database", api_key="ldb_...")
|
||||
|
||||
Returns
|
||||
-------
|
||||
conn : DBConnection
|
||||
A connection to a LanceDB database.
|
||||
"""
|
||||
return AsyncLanceDBConnection(
|
||||
await lancedb_connect(
|
||||
sanitize_uri(uri), api_key, region, host_override, read_consistency_interval
|
||||
)
|
||||
)
|
||||
12
python/python/lancedb/_lancedb.pyi
Normal file
12
python/python/lancedb/_lancedb.pyi
Normal file
@@ -0,0 +1,12 @@
|
||||
from typing import Optional
|
||||
|
||||
class Connection(object):
|
||||
async def table_names(self) -> list[str]: ...
|
||||
|
||||
async def connect(
|
||||
uri: str,
|
||||
api_key: Optional[str],
|
||||
region: Optional[str],
|
||||
host_override: Optional[str],
|
||||
read_consistency_interval: Optional[float],
|
||||
) -> Connection: ...
|
||||
@@ -34,3 +34,7 @@ class Credential(str):
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "********"
|
||||
|
||||
|
||||
def sanitize_uri(uri: URI) -> str:
|
||||
return str(uri)
|
||||
@@ -28,6 +28,7 @@ from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
|
||||
from ._lancedb import Connection as LanceDbConnection
|
||||
from .common import DATA, URI
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
from .pydantic import LanceModel
|
||||
@@ -40,14 +41,21 @@ class DBConnection(EnforceOverrides):
|
||||
def table_names(
|
||||
self, page_token: Optional[str] = None, limit: int = 10
|
||||
) -> Iterable[str]:
|
||||
"""List all table in this database
|
||||
"""List all tables in this database, in sorted order
|
||||
|
||||
Parameters
|
||||
----------
|
||||
page_token: str, optional
|
||||
The token to use for pagination. If not present, start from the beginning.
|
||||
Typically, this token is last table name from the previous page.
|
||||
Only supported by LanceDb Cloud.
|
||||
limit: int, default 10
|
||||
The size of the page to return.
|
||||
Only supported by LanceDb Cloud.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Iterable of str
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -412,3 +420,254 @@ class LanceDBConnection(DBConnection):
|
||||
def drop_database(self):
|
||||
filesystem, path = fs_from_uri(self.uri)
|
||||
filesystem.delete_dir(path)
|
||||
|
||||
|
||||
class AsyncConnection(EnforceOverrides):
|
||||
"""An active LanceDB connection interface."""
|
||||
|
||||
@abstractmethod
|
||||
async def table_names(
|
||||
self, *, page_token: Optional[str] = None, limit: int = 10
|
||||
) -> Iterable[str]:
|
||||
"""List all tables in this database, in sorted order
|
||||
|
||||
Parameters
|
||||
----------
|
||||
page_token: str, optional
|
||||
The token to use for pagination. If not present, start from the beginning.
|
||||
Typically, this token is last table name from the previous page.
|
||||
Only supported by LanceDb Cloud.
|
||||
limit: int, default 10
|
||||
The size of the page to return.
|
||||
Only supported by LanceDb Cloud.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Iterable of str
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_table(
|
||||
self,
|
||||
name: str,
|
||||
data: Optional[DATA] = None,
|
||||
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||
mode: str = "create",
|
||||
exist_ok: bool = False,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
) -> Table:
|
||||
"""Create a [Table][lancedb.table.Table] in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
data: The data to initialize the table, *optional*
|
||||
User must provide at least one of `data` or `schema`.
|
||||
Acceptable types are:
|
||||
|
||||
- dict or list-of-dict
|
||||
|
||||
- pandas.DataFrame
|
||||
|
||||
- pyarrow.Table or pyarrow.RecordBatch
|
||||
schema: The schema of the table, *optional*
|
||||
Acceptable types are:
|
||||
|
||||
- pyarrow.Schema
|
||||
|
||||
- [LanceModel][lancedb.pydantic.LanceModel]
|
||||
mode: str; default "create"
|
||||
The mode to use when creating the table.
|
||||
Can be either "create" or "overwrite".
|
||||
By default, if the table already exists, an exception is raised.
|
||||
If you want to overwrite the table, use mode="overwrite".
|
||||
exist_ok: bool, default False
|
||||
If a table by the same name already exists, then raise an exception
|
||||
if exist_ok=False. If exist_ok=True, then open the existing table;
|
||||
it will not add the provided data but will validate against any
|
||||
schema that's specified.
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceTable
|
||||
A reference to the newly created table.
|
||||
|
||||
!!! note
|
||||
|
||||
The vector index won't be created by default.
|
||||
To create the index, call the `create_index` method on the table.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
Can create with list of tuples or dictionaries:
|
||||
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
||||
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
||||
>>> db.create_table("my_table", data)
|
||||
LanceTable(connection=..., name="my_table")
|
||||
>>> db["my_table"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
child 0, item: float
|
||||
lat: double
|
||||
long: double
|
||||
----
|
||||
vector: [[[1.1,1.2],[0.2,1.8]]]
|
||||
lat: [[45.5,40.1]]
|
||||
long: [[-122.7,-74.1]]
|
||||
|
||||
You can also pass a pandas DataFrame:
|
||||
|
||||
>>> import pandas as pd
|
||||
>>> data = pd.DataFrame({
|
||||
... "vector": [[1.1, 1.2], [0.2, 1.8]],
|
||||
... "lat": [45.5, 40.1],
|
||||
... "long": [-122.7, -74.1]
|
||||
... })
|
||||
>>> db.create_table("table2", data)
|
||||
LanceTable(connection=..., name="table2")
|
||||
>>> db["table2"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
child 0, item: float
|
||||
lat: double
|
||||
long: double
|
||||
----
|
||||
vector: [[[1.1,1.2],[0.2,1.8]]]
|
||||
lat: [[45.5,40.1]]
|
||||
long: [[-122.7,-74.1]]
|
||||
|
||||
Data is converted to Arrow before being written to disk. For maximum
|
||||
control over how data is saved, either provide the PyArrow schema to
|
||||
convert to or else provide a [PyArrow Table](pyarrow.Table) directly.
|
||||
|
||||
>>> custom_schema = pa.schema([
|
||||
... pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
... pa.field("lat", pa.float32()),
|
||||
... pa.field("long", pa.float32())
|
||||
... ])
|
||||
>>> db.create_table("table3", data, schema = custom_schema)
|
||||
LanceTable(connection=..., name="table3")
|
||||
>>> db["table3"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
child 0, item: float
|
||||
lat: float
|
||||
long: float
|
||||
----
|
||||
vector: [[[1.1,1.2],[0.2,1.8]]]
|
||||
lat: [[45.5,40.1]]
|
||||
long: [[-122.7,-74.1]]
|
||||
|
||||
|
||||
It is also possible to create an table from `[Iterable[pa.RecordBatch]]`:
|
||||
|
||||
|
||||
>>> import pyarrow as pa
|
||||
>>> def make_batches():
|
||||
... for i in range(5):
|
||||
... yield pa.RecordBatch.from_arrays(
|
||||
... [
|
||||
... pa.array([[3.1, 4.1], [5.9, 26.5]],
|
||||
... pa.list_(pa.float32(), 2)),
|
||||
... pa.array(["foo", "bar"]),
|
||||
... pa.array([10.0, 20.0]),
|
||||
... ],
|
||||
... ["vector", "item", "price"],
|
||||
... )
|
||||
>>> schema=pa.schema([
|
||||
... pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
... pa.field("item", pa.utf8()),
|
||||
... pa.field("price", pa.float32()),
|
||||
... ])
|
||||
>>> db.create_table("table4", make_batches(), schema=schema)
|
||||
LanceTable(connection=..., name="table4")
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def open_table(self, name: str) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A LanceTable object representing the table.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def drop_table(self, name: str):
|
||||
"""Drop a table from the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def drop_database(self):
|
||||
"""
|
||||
Drop database
|
||||
This is the same thing as dropping all the tables
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AsyncLanceDBConnection(AsyncConnection):
|
||||
def __init__(self, connection: LanceDbConnection):
|
||||
self._inner = connection
|
||||
|
||||
async def __repr__(self) -> str:
|
||||
pass
|
||||
|
||||
@override
|
||||
async def table_names(
|
||||
self,
|
||||
*,
|
||||
page_token=None,
|
||||
limit=None,
|
||||
) -> Iterable[str]:
|
||||
return await self._inner.table_names()
|
||||
|
||||
@override
|
||||
async def create_table(
|
||||
self,
|
||||
name: str,
|
||||
data: Optional[DATA] = None,
|
||||
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||
mode: str = "create",
|
||||
exist_ok: bool = False,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
) -> LanceTable:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def open_table(self, name: str) -> LanceTable:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def drop_table(self, name: str, ignore_missing: bool = False):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def drop_database(self):
|
||||
raise NotImplementedError
|
||||
@@ -1,5 +1,4 @@
|
||||
from click.testing import CliRunner
|
||||
|
||||
from lancedb.cli.cli import cli
|
||||
from lancedb.utils import CONFIG
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from lancedb.context import contextualize
|
||||
|
||||
|
||||
@@ -11,12 +11,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
|
||||
@@ -166,6 +165,24 @@ def test_table_names(tmp_path):
|
||||
assert db.table_names() == ["test1", "test2", "test3"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_names_async(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
db.create_table("test2", data=data)
|
||||
db.create_table("test1", data=data)
|
||||
db.create_table("test3", data=data)
|
||||
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
assert await db.table_names() == ["test1", "test2", "test3"]
|
||||
|
||||
|
||||
def test_create_mode(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pd.DataFrame(
|
||||
@@ -13,7 +13,6 @@
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lancedb import LanceDBConnection
|
||||
|
||||
# TODO: setup integ test mark and script
|
||||
@@ -13,11 +13,10 @@
|
||||
import sys
|
||||
|
||||
import lance
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
import lancedb
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.embeddings import (
|
||||
EmbeddingFunctionConfig,
|
||||
@@ -14,12 +14,11 @@ import importlib
|
||||
import io
|
||||
import os
|
||||
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
import lancedb
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
@@ -185,10 +184,9 @@ def test_imagebind(tmp_path):
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import lancedb.embeddings.imagebind
|
||||
import pandas as pd
|
||||
import requests
|
||||
|
||||
import lancedb.embeddings.imagebind
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
@@ -14,13 +14,13 @@ import os
|
||||
import random
|
||||
from unittest import mock
|
||||
|
||||
import lancedb as ldb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import tantivy
|
||||
|
||||
import lancedb as ldb
|
||||
import lancedb.fts
|
||||
pytest.importorskip("lancedb.fts")
|
||||
tantivy = pytest.importorskip("tantivy")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -13,9 +13,8 @@
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
import lancedb
|
||||
import pytest
|
||||
|
||||
# You need to setup AWS credentials an a base path to run this test. Example
|
||||
# AWS_PROFILE=default TEST_S3_BASE_URL=s3://my_bucket/dataset pytest tests/test_io.py
|
||||
@@ -20,9 +20,8 @@ from typing import List, Optional, Tuple
|
||||
import pyarrow as pa
|
||||
import pydantic
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from lancedb.pydantic import PYDANTIC_VERSION, LanceModel, Vector, pydantic_to_schema
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -18,7 +18,6 @@ import numpy as np
|
||||
import pandas.testing as tm
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.query import LanceVectorQueryBuilder, Query
|
||||
@@ -17,7 +17,6 @@ import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from lancedb.remote.client import RestfulLanceDBClient, VectorQuery
|
||||
|
||||
|
||||
@@ -11,9 +11,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
import lancedb
|
||||
import pyarrow as pa
|
||||
from lancedb.remote.client import VectorQuery, VectorQueryResult
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import os
|
||||
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import lancedb
|
||||
from lancedb.conftest import MockTextEmbeddingFunction # noqa
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
@@ -15,6 +14,9 @@ from lancedb.rerankers import (
|
||||
)
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
# Tests rely on FTS index
|
||||
pytest.importorskip("lancedb.fts")
|
||||
|
||||
|
||||
def get_test_table(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
@@ -20,19 +20,18 @@ from typing import List
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
import lance
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
import lancedb
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.table import LanceTable
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class MockDB:
|
||||
@@ -804,6 +803,9 @@ def test_count_rows(db):
|
||||
|
||||
|
||||
def test_hybrid_search(db, tmp_path):
|
||||
# This test uses an FTS index
|
||||
pytest.importorskip("lancedb.fts")
|
||||
|
||||
db = MockDB(str(tmp_path))
|
||||
# Create a LanceDB table schema with a vector and a text column
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
@@ -1,8 +1,7 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
import lancedb
|
||||
import pytest
|
||||
from lancedb.utils.events import _Events
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ import os
|
||||
import pathlib
|
||||
|
||||
import pytest
|
||||
|
||||
from lancedb.util import get_uri_scheme, join_uri
|
||||
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import setuptools
|
||||
|
||||
if __name__ == "__main__":
|
||||
setuptools.setup()
|
||||
66
python/src/connection.rs
Normal file
66
python/src/connection.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
// Copyright 2024 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use lancedb::connection::Connection as LanceConnection;
|
||||
use pyo3::{pyclass, pyfunction, pymethods, PyAny, PyRef, PyResult, Python};
|
||||
use pyo3_asyncio::tokio::future_into_py;
|
||||
|
||||
use crate::error::PythonErrorExt;
|
||||
|
||||
#[pyclass]
|
||||
pub struct Connection {
|
||||
inner: LanceConnection,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl Connection {
|
||||
pub fn table_names(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.table_names().await.infer_error()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
pub fn connect(
|
||||
py: Python,
|
||||
uri: String,
|
||||
api_key: Option<String>,
|
||||
region: Option<String>,
|
||||
host_override: Option<String>,
|
||||
read_consistency_interval: Option<f64>,
|
||||
) -> PyResult<&PyAny> {
|
||||
future_into_py(py, async move {
|
||||
let mut builder = lancedb::connect(&uri);
|
||||
if let Some(api_key) = api_key {
|
||||
builder = builder.api_key(&api_key);
|
||||
}
|
||||
if let Some(region) = region {
|
||||
builder = builder.region(®ion);
|
||||
}
|
||||
if let Some(host_override) = host_override {
|
||||
builder = builder.host_override(&host_override);
|
||||
}
|
||||
if let Some(read_consistency_interval) = read_consistency_interval {
|
||||
let read_consistency_interval = Duration::from_secs_f64(read_consistency_interval);
|
||||
builder = builder.read_consistency_interval(read_consistency_interval);
|
||||
}
|
||||
Ok(Connection {
|
||||
inner: builder.execute().await.infer_error()?,
|
||||
})
|
||||
})
|
||||
}
|
||||
61
python/src/error.rs
Normal file
61
python/src/error.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright 2024 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use pyo3::{
|
||||
exceptions::{PyOSError, PyRuntimeError, PyValueError},
|
||||
PyResult,
|
||||
};
|
||||
|
||||
use lancedb::error::Error as LanceError;
|
||||
|
||||
pub trait PythonErrorExt<T> {
|
||||
/// Convert to a python error based on the Lance error type
|
||||
fn infer_error(self) -> PyResult<T>;
|
||||
/// Convert to OSError
|
||||
fn os_error(self) -> PyResult<T>;
|
||||
/// Convert to RuntimeError
|
||||
fn runtime_error(self) -> PyResult<T>;
|
||||
/// Convert to ValueError
|
||||
fn value_error(self) -> PyResult<T>;
|
||||
}
|
||||
|
||||
impl<T> PythonErrorExt<T> for std::result::Result<T, LanceError> {
|
||||
fn infer_error(self) -> PyResult<T> {
|
||||
match &self {
|
||||
Ok(_) => Ok(self.unwrap()),
|
||||
Err(err) => match err {
|
||||
LanceError::InvalidTableName { .. } => self.value_error(),
|
||||
LanceError::TableNotFound { .. } => self.value_error(),
|
||||
LanceError::TableAlreadyExists { .. } => self.runtime_error(),
|
||||
LanceError::CreateDir { .. } => self.os_error(),
|
||||
LanceError::Store { .. } => self.runtime_error(),
|
||||
LanceError::Lance { .. } => self.runtime_error(),
|
||||
LanceError::Schema { .. } => self.value_error(),
|
||||
LanceError::Runtime { .. } => self.runtime_error(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn os_error(self) -> PyResult<T> {
|
||||
self.map_err(|err| PyOSError::new_err(err.to_string()))
|
||||
}
|
||||
|
||||
fn runtime_error(self) -> PyResult<T> {
|
||||
self.map_err(|err| PyRuntimeError::new_err(err.to_string()))
|
||||
}
|
||||
|
||||
fn value_error(self) -> PyResult<T> {
|
||||
self.map_err(|err| PyValueError::new_err(err.to_string()))
|
||||
}
|
||||
}
|
||||
32
python/src/lib.rs
Normal file
32
python/src/lib.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
// Copyright 2024 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use connection::{connect, Connection};
|
||||
use env_logger::Env;
|
||||
use pyo3::{pymodule, types::PyModule, wrap_pyfunction, PyResult, Python};
|
||||
|
||||
pub mod connection;
|
||||
pub(crate) mod error;
|
||||
|
||||
#[pymodule]
|
||||
pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
let env = Env::new()
|
||||
.filter_or("LANCEDB_LOG", "warn")
|
||||
.write_style("LANCEDB_LOG_STYLE");
|
||||
env_logger::init_from_env(env);
|
||||
m.add_class::<Connection>()?;
|
||||
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user