mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-10 22:02:58 +00:00
Merge branch 'main' into add-data-to-index-stats
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.6.7
|
||||
current_version = 0.6.10
|
||||
commit = True
|
||||
message = [python] Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
@@ -41,7 +41,7 @@ To build the python package you can use maturin:
|
||||
```bash
|
||||
# This will build the rust bindings and place them in the appropriate place
|
||||
# in your venv or conda environment
|
||||
matruin develop
|
||||
maturin develop
|
||||
```
|
||||
|
||||
To run the unit tests:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
[project]
|
||||
name = "lancedb"
|
||||
version = "0.6.7"
|
||||
version = "0.6.10"
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.10.9",
|
||||
"pylance==0.10.12",
|
||||
"ratelimiter~=1.0",
|
||||
"requests>=2.31.0",
|
||||
"retry>=0.9.2",
|
||||
@@ -49,6 +49,7 @@ repository = "https://github.com/lancedb/lancedb"
|
||||
[project.optional-dependencies]
|
||||
tests = [
|
||||
"aiohttp",
|
||||
"boto3",
|
||||
"pandas>=1.4",
|
||||
"pytest",
|
||||
"pytest-mock",
|
||||
@@ -56,6 +57,7 @@ tests = [
|
||||
"duckdb",
|
||||
"pytz",
|
||||
"polars>=0.19",
|
||||
"tantivy"
|
||||
]
|
||||
dev = ["ruff", "pre-commit"]
|
||||
docs = [
|
||||
@@ -63,7 +65,6 @@ docs = [
|
||||
"mkdocs-jupyter",
|
||||
"mkdocs-material",
|
||||
"mkdocstrings[python]",
|
||||
"mkdocs-ultralytics-plugin==0.0.44",
|
||||
]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
embeddings = [
|
||||
@@ -98,4 +99,5 @@ addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py"
|
||||
markers = [
|
||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||
"asyncio",
|
||||
"s3_test"
|
||||
]
|
||||
|
||||
@@ -15,7 +15,7 @@ import importlib.metadata
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from typing import Optional, Union
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
__version__ = importlib.metadata.version("lancedb")
|
||||
|
||||
@@ -83,7 +83,7 @@ def connect(
|
||||
|
||||
>>> db = lancedb.connect("s3://my-bucket/lancedb")
|
||||
|
||||
Connect to LancdDB cloud:
|
||||
Connect to LanceDB cloud:
|
||||
|
||||
>>> db = lancedb.connect("db://my_database", api_key="ldb_...")
|
||||
|
||||
@@ -118,6 +118,7 @@ async def connect_async(
|
||||
host_override: Optional[str] = None,
|
||||
read_consistency_interval: Optional[timedelta] = None,
|
||||
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
) -> AsyncConnection:
|
||||
"""Connect to a LanceDB database.
|
||||
|
||||
@@ -144,6 +145,9 @@ async def connect_async(
|
||||
the last check, then the table will be checked for updates. Note: this
|
||||
consistency only applies to read operations. Write operations are
|
||||
always consistent.
|
||||
storage_options: dict, optional
|
||||
Additional options for the storage backend. See available options at
|
||||
https://lancedb.github.io/lancedb/guides/storage/
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -172,6 +176,7 @@ async def connect_async(
|
||||
region,
|
||||
host_override,
|
||||
read_consistency_interval_secs,
|
||||
storage_options,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -19,10 +19,18 @@ class Connection(object):
|
||||
self, start_after: Optional[str], limit: Optional[int]
|
||||
) -> list[str]: ...
|
||||
async def create_table(
|
||||
self, name: str, mode: str, data: pa.RecordBatchReader
|
||||
self,
|
||||
name: str,
|
||||
mode: str,
|
||||
data: pa.RecordBatchReader,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
) -> Table: ...
|
||||
async def create_empty_table(
|
||||
self, name: str, mode: str, schema: pa.Schema
|
||||
self,
|
||||
name: str,
|
||||
mode: str,
|
||||
schema: pa.Schema,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
) -> Table: ...
|
||||
|
||||
class Table:
|
||||
|
||||
@@ -18,7 +18,7 @@ import inspect
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Iterable, List, Literal, Optional, Union
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Union
|
||||
|
||||
import pyarrow as pa
|
||||
from overrides import EnforceOverrides, override
|
||||
@@ -533,6 +533,7 @@ class AsyncConnection(object):
|
||||
exist_ok: Optional[bool] = None,
|
||||
on_bad_vectors: Optional[str] = None,
|
||||
fill_value: Optional[float] = None,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
) -> AsyncTable:
|
||||
"""Create an [AsyncTable][lancedb.table.AsyncTable] in the database.
|
||||
|
||||
@@ -570,6 +571,12 @@ class AsyncConnection(object):
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
storage_options: dict, optional
|
||||
Additional options for the storage backend. Options already set on the
|
||||
connection will be inherited by the table, but can be overridden here.
|
||||
See available options at
|
||||
https://lancedb.github.io/lancedb/guides/storage/
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -729,30 +736,40 @@ class AsyncConnection(object):
|
||||
mode = "exist_ok"
|
||||
|
||||
if data is None:
|
||||
new_table = await self._inner.create_empty_table(name, mode, schema)
|
||||
new_table = await self._inner.create_empty_table(
|
||||
name, mode, schema, storage_options=storage_options
|
||||
)
|
||||
else:
|
||||
data = data_to_reader(data, schema)
|
||||
new_table = await self._inner.create_table(
|
||||
name,
|
||||
mode,
|
||||
data,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
|
||||
return AsyncTable(new_table)
|
||||
|
||||
async def open_table(self, name: str) -> Table:
|
||||
async def open_table(
|
||||
self, name: str, storage_options: Optional[Dict[str, str]] = None
|
||||
) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
storage_options: dict, optional
|
||||
Additional options for the storage backend. Options already set on the
|
||||
connection will be inherited by the table, but can be overridden here.
|
||||
See available options at
|
||||
https://lancedb.github.io/lancedb/guides/storage/
|
||||
|
||||
Returns
|
||||
-------
|
||||
A LanceTable object representing the table.
|
||||
"""
|
||||
table = await self._inner.open_table(name)
|
||||
table = await self._inner.open_table(name, storage_options)
|
||||
return AsyncTable(table)
|
||||
|
||||
async def drop_table(self, name: str):
|
||||
|
||||
@@ -78,6 +78,9 @@ class BedRockText(TextEmbeddingFunction):
|
||||
|
||||
class Config:
|
||||
keep_untouched = (cached_property,)
|
||||
else:
|
||||
model_config = dict()
|
||||
model_config["ignored_types"] = (cached_property,)
|
||||
|
||||
def ndims(self):
|
||||
# return len(self._generate_embedding("test"))
|
||||
|
||||
@@ -94,6 +94,9 @@ class GeminiText(TextEmbeddingFunction):
|
||||
|
||||
class Config:
|
||||
keep_untouched = (cached_property,)
|
||||
else:
|
||||
model_config = dict()
|
||||
model_config["ignored_types"] = (cached_property,)
|
||||
|
||||
def ndims(self):
|
||||
# TODO: fix hardcoding
|
||||
|
||||
@@ -22,6 +22,8 @@ from .base import EmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import AUDIO, IMAGES, TEXT
|
||||
|
||||
from lancedb.pydantic import PYDANTIC_VERSION
|
||||
|
||||
|
||||
@register("imagebind")
|
||||
class ImageBindEmbeddings(EmbeddingFunction):
|
||||
@@ -38,8 +40,13 @@ class ImageBindEmbeddings(EmbeddingFunction):
|
||||
device: str = "cpu"
|
||||
normalize: bool = False
|
||||
|
||||
class Config:
|
||||
keep_untouched = (cached_property,)
|
||||
if PYDANTIC_VERSION < (2, 0): # Pydantic 1.x compat
|
||||
|
||||
class Config:
|
||||
keep_untouched = (cached_property,)
|
||||
else:
|
||||
model_config = dict()
|
||||
model_config["ignored_types"] = (cached_property,)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -17,6 +17,7 @@ from typing import List, Any
|
||||
import numpy as np
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
from lancedb.pydantic import PYDANTIC_VERSION
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import EmbeddingFunction
|
||||
@@ -53,8 +54,13 @@ class TransformersEmbeddingFunction(EmbeddingFunction):
|
||||
self._tokenizer = transformers.AutoTokenizer.from_pretrained(self.name)
|
||||
self._model = transformers.AutoModel.from_pretrained(self.name)
|
||||
|
||||
class Config:
|
||||
keep_untouched = (cached_property,)
|
||||
if PYDANTIC_VERSION < (2, 0): # Pydantic 1.x compat
|
||||
|
||||
class Config:
|
||||
keep_untouched = (cached_property,)
|
||||
else:
|
||||
model_config = dict()
|
||||
model_config["ignored_types"] = (cached_property,)
|
||||
|
||||
def ndims(self):
|
||||
self._ndims = self._model.config.hidden_size
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import semver
|
||||
from functools import cached_property
|
||||
from typing import Union
|
||||
|
||||
@@ -42,6 +43,14 @@ class CohereReranker(Reranker):
|
||||
@cached_property
|
||||
def _client(self):
|
||||
cohere = attempt_import_or_raise("cohere")
|
||||
# ensure version is at least 0.5.0
|
||||
if (
|
||||
hasattr(cohere, "__version__")
|
||||
and semver.compare(cohere.__version__, "5.0.0") < 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"cohere version must be at least 0.5.0, found {cohere.__version__}"
|
||||
)
|
||||
if os.environ.get("COHERE_API_KEY") is None and self.api_key is None:
|
||||
raise ValueError(
|
||||
"COHERE_API_KEY not set. Either set it in your environment or \
|
||||
@@ -51,11 +60,14 @@ class CohereReranker(Reranker):
|
||||
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
docs = result_set[self.column].to_pylist()
|
||||
results = self._client.rerank(
|
||||
response = self._client.rerank(
|
||||
query=query,
|
||||
documents=docs,
|
||||
top_n=self.top_n,
|
||||
model=self.model_name,
|
||||
)
|
||||
results = (
|
||||
response.results
|
||||
) # returns list (text, idx, relevance) attributes sorted descending by score
|
||||
indices, scores = list(
|
||||
zip(*[(result.index, result.relevance_score) for result in results])
|
||||
|
||||
@@ -95,6 +95,9 @@ def _sanitize_data(
|
||||
data.data.to_batches(), schema, metadata, on_bad_vectors, fill_value
|
||||
)
|
||||
|
||||
if isinstance(data, LanceModel):
|
||||
raise ValueError("Cannot add a single LanceModel to a table. Use a list.")
|
||||
|
||||
if isinstance(data, list):
|
||||
# convert to list of dict if data is a bunch of LanceModels
|
||||
if isinstance(data[0], LanceModel):
|
||||
@@ -1403,7 +1406,14 @@ class LanceTable(Table):
|
||||
vector and the returned vector.
|
||||
"""
|
||||
if vector_column_name is None and query is not None:
|
||||
vector_column_name = inf_vector_column_query(self.schema)
|
||||
try:
|
||||
vector_column_name = inf_vector_column_query(self.schema)
|
||||
except Exception as e:
|
||||
if query_type == "fts":
|
||||
vector_column_name = ""
|
||||
else:
|
||||
raise e
|
||||
|
||||
return LanceQueryBuilder.create(
|
||||
self,
|
||||
query,
|
||||
|
||||
@@ -28,13 +28,25 @@ def test_basic(tmp_path):
|
||||
assert db.uri == str(tmp_path)
|
||||
assert db.table_names() == []
|
||||
|
||||
class SimpleModel(LanceModel):
|
||||
item: str
|
||||
price: float
|
||||
vector: Vector(2)
|
||||
|
||||
table = db.create_table(
|
||||
"test",
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||
],
|
||||
schema=SimpleModel,
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Cannot add a single LanceModel to a table. Use a list."
|
||||
):
|
||||
table.add(SimpleModel(item="baz", price=30.0, vector=[1.0, 2.0]))
|
||||
|
||||
rs = table.search([100, 100]).limit(1).to_pandas()
|
||||
assert len(rs) == 1
|
||||
assert rs["item"].iloc[0] == "bar"
|
||||
@@ -43,6 +55,11 @@ def test_basic(tmp_path):
|
||||
assert len(rs) == 1
|
||||
assert rs["item"].iloc[0] == "foo"
|
||||
|
||||
table.create_fts_index(["item"])
|
||||
rs = table.search("bar", query_type="fts").to_pandas()
|
||||
assert len(rs) == 1
|
||||
assert rs["item"].iloc[0] == "bar"
|
||||
|
||||
assert db.table_names() == ["test"]
|
||||
assert "test" in db
|
||||
assert len(db) == 1
|
||||
|
||||
158
python/python/tests/test_s3.py
Normal file
158
python/python/tests/test_s3.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# Copyright 2024 Lance Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
import pyarrow as pa
|
||||
import lancedb
|
||||
|
||||
|
||||
# These are all keys that are accepted by storage_options
|
||||
CONFIG = {
|
||||
"allow_http": "true",
|
||||
"aws_access_key_id": "ACCESSKEY",
|
||||
"aws_secret_access_key": "SECRETKEY",
|
||||
"aws_endpoint": "http://localhost:4566",
|
||||
"aws_region": "us-east-1",
|
||||
}
|
||||
|
||||
|
||||
def get_boto3_client(*args, **kwargs):
|
||||
import boto3
|
||||
|
||||
return boto3.client(
|
||||
*args,
|
||||
region_name=CONFIG["aws_region"],
|
||||
aws_access_key_id=CONFIG["aws_access_key_id"],
|
||||
aws_secret_access_key=CONFIG["aws_secret_access_key"],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def s3_bucket():
|
||||
s3 = get_boto3_client("s3", endpoint_url=CONFIG["aws_endpoint"])
|
||||
bucket_name = "lance-integtest"
|
||||
# if bucket exists, delete it
|
||||
try:
|
||||
delete_bucket(s3, bucket_name)
|
||||
except s3.exceptions.NoSuchBucket:
|
||||
pass
|
||||
s3.create_bucket(Bucket=bucket_name)
|
||||
yield bucket_name
|
||||
|
||||
delete_bucket(s3, bucket_name)
|
||||
|
||||
|
||||
def delete_bucket(s3, bucket_name):
|
||||
# Delete all objects first
|
||||
for obj in s3.list_objects(Bucket=bucket_name).get("Contents", []):
|
||||
s3.delete_object(Bucket=bucket_name, Key=obj["Key"])
|
||||
s3.delete_bucket(Bucket=bucket_name)
|
||||
|
||||
|
||||
@pytest.mark.s3_test
|
||||
def test_s3_lifecycle(s3_bucket: str):
|
||||
storage_options = copy.copy(CONFIG)
|
||||
|
||||
uri = f"s3://{s3_bucket}/test_lifecycle"
|
||||
data = pa.table({"x": [1, 2, 3]})
|
||||
|
||||
async def test():
|
||||
db = await lancedb.connect_async(uri, storage_options=storage_options)
|
||||
|
||||
table = await db.create_table("test", schema=data.schema)
|
||||
assert await table.count_rows() == 0
|
||||
|
||||
table = await db.create_table("test", data, mode="overwrite")
|
||||
assert await table.count_rows() == 3
|
||||
|
||||
await table.add(data, mode="append")
|
||||
assert await table.count_rows() == 6
|
||||
|
||||
table = await db.open_table("test")
|
||||
assert await table.count_rows() == 6
|
||||
|
||||
await db.drop_table("test")
|
||||
|
||||
await db.drop_database()
|
||||
|
||||
asyncio.run(test())
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def kms_key():
|
||||
kms = get_boto3_client("kms", endpoint_url=CONFIG["aws_endpoint"])
|
||||
key_id = kms.create_key()["KeyMetadata"]["KeyId"]
|
||||
yield key_id
|
||||
kms.schedule_key_deletion(KeyId=key_id, PendingWindowInDays=7)
|
||||
|
||||
|
||||
def validate_objects_encrypted(bucket: str, path: str, kms_key: str):
|
||||
s3 = get_boto3_client("s3", endpoint_url=CONFIG["aws_endpoint"])
|
||||
objects = s3.list_objects_v2(Bucket=bucket, Prefix=path)["Contents"]
|
||||
for obj in objects:
|
||||
info = s3.head_object(Bucket=bucket, Key=obj["Key"])
|
||||
assert info["ServerSideEncryption"] == "aws:kms", (
|
||||
"object %s not encrypted" % obj["Key"]
|
||||
)
|
||||
assert info["SSEKMSKeyId"].endswith(kms_key), (
|
||||
"object %s not encrypted with correct key" % obj["Key"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.s3_test
|
||||
def test_s3_sse(s3_bucket: str, kms_key: str):
|
||||
storage_options = copy.copy(CONFIG)
|
||||
|
||||
uri = f"s3://{s3_bucket}/test_lifecycle"
|
||||
data = pa.table({"x": [1, 2, 3]})
|
||||
|
||||
async def test():
|
||||
# Create a table with SSE
|
||||
db = await lancedb.connect_async(uri, storage_options=storage_options)
|
||||
|
||||
table = await db.create_table(
|
||||
"table1",
|
||||
schema=data.schema,
|
||||
storage_options={
|
||||
"aws_server_side_encryption": "aws:kms",
|
||||
"aws_sse_kms_key_id": kms_key,
|
||||
},
|
||||
)
|
||||
await table.add(data)
|
||||
await table.update({"x": "1"})
|
||||
|
||||
path = "test_lifecycle/table1.lance"
|
||||
validate_objects_encrypted(s3_bucket, path, kms_key)
|
||||
|
||||
# Test we can set encryption at connection level too.
|
||||
db = await lancedb.connect_async(
|
||||
uri,
|
||||
storage_options=dict(
|
||||
aws_server_side_encryption="aws:kms",
|
||||
aws_sse_kms_key_id=kms_key,
|
||||
**storage_options,
|
||||
),
|
||||
)
|
||||
|
||||
table = await db.create_table("table2", schema=data.schema)
|
||||
await table.add(data)
|
||||
await table.update({"x": "1"})
|
||||
|
||||
path = "test_lifecycle/table2.lance"
|
||||
validate_objects_encrypted(s3_bucket, path, kms_key)
|
||||
|
||||
asyncio.run(test())
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use std::{collections::HashMap, sync::Arc, time::Duration};
|
||||
|
||||
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
|
||||
use lancedb::connection::{Connection as LanceConnection, CreateTableMode};
|
||||
@@ -90,19 +90,21 @@ impl Connection {
|
||||
name: String,
|
||||
mode: &str,
|
||||
data: &PyAny,
|
||||
storage_options: Option<HashMap<String, String>>,
|
||||
) -> PyResult<&'a PyAny> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
|
||||
let mode = Self::parse_create_mode_str(mode)?;
|
||||
|
||||
let batches = ArrowArrayStreamReader::from_pyarrow(data)?;
|
||||
let mut builder = inner.create_table(name, batches).mode(mode);
|
||||
|
||||
if let Some(storage_options) = storage_options {
|
||||
builder = builder.storage_options(storage_options);
|
||||
}
|
||||
|
||||
future_into_py(self_.py(), async move {
|
||||
let table = inner
|
||||
.create_table(name, batches)
|
||||
.mode(mode)
|
||||
.execute()
|
||||
.await
|
||||
.infer_error()?;
|
||||
let table = builder.execute().await.infer_error()?;
|
||||
Ok(Table::new(table))
|
||||
})
|
||||
}
|
||||
@@ -112,6 +114,7 @@ impl Connection {
|
||||
name: String,
|
||||
mode: &str,
|
||||
schema: &PyAny,
|
||||
storage_options: Option<HashMap<String, String>>,
|
||||
) -> PyResult<&'a PyAny> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
|
||||
@@ -119,21 +122,31 @@ impl Connection {
|
||||
|
||||
let schema = Schema::from_pyarrow(schema)?;
|
||||
|
||||
let mut builder = inner.create_empty_table(name, Arc::new(schema)).mode(mode);
|
||||
|
||||
if let Some(storage_options) = storage_options {
|
||||
builder = builder.storage_options(storage_options);
|
||||
}
|
||||
|
||||
future_into_py(self_.py(), async move {
|
||||
let table = inner
|
||||
.create_empty_table(name, Arc::new(schema))
|
||||
.mode(mode)
|
||||
.execute()
|
||||
.await
|
||||
.infer_error()?;
|
||||
let table = builder.execute().await.infer_error()?;
|
||||
Ok(Table::new(table))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn open_table(self_: PyRef<'_, Self>, name: String) -> PyResult<&PyAny> {
|
||||
#[pyo3(signature = (name, storage_options = None))]
|
||||
pub fn open_table(
|
||||
self_: PyRef<'_, Self>,
|
||||
name: String,
|
||||
storage_options: Option<HashMap<String, String>>,
|
||||
) -> PyResult<&PyAny> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
let mut builder = inner.open_table(name);
|
||||
if let Some(storage_options) = storage_options {
|
||||
builder = builder.storage_options(storage_options);
|
||||
}
|
||||
future_into_py(self_.py(), async move {
|
||||
let table = inner.open_table(&name).execute().await.infer_error()?;
|
||||
let table = builder.execute().await.infer_error()?;
|
||||
Ok(Table::new(table))
|
||||
})
|
||||
}
|
||||
@@ -162,6 +175,7 @@ pub fn connect(
|
||||
region: Option<String>,
|
||||
host_override: Option<String>,
|
||||
read_consistency_interval: Option<f64>,
|
||||
storage_options: Option<HashMap<String, String>>,
|
||||
) -> PyResult<&PyAny> {
|
||||
future_into_py(py, async move {
|
||||
let mut builder = lancedb::connect(&uri);
|
||||
@@ -178,6 +192,9 @@ pub fn connect(
|
||||
let read_consistency_interval = Duration::from_secs_f64(read_consistency_interval);
|
||||
builder = builder.read_consistency_interval(read_consistency_interval);
|
||||
}
|
||||
if let Some(storage_options) = storage_options {
|
||||
builder = builder.storage_options(storage_options);
|
||||
}
|
||||
Ok(Connection::new(builder.execute().await.infer_error()?))
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user