Merge branch 'main' into add-data-to-index-stats

This commit is contained in:
qzhu
2024-04-22 13:54:02 -07:00
63 changed files with 3870 additions and 391 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.6.7
current_version = 0.6.10
commit = True
message = [python] Bump version: {current_version} → {new_version}
tag = True

View File

@@ -41,7 +41,7 @@ To build the python package you can use maturin:
```bash
# This will build the rust bindings and place them in the appropriate place
# in your venv or conda environment
matruin develop
maturin develop
```
To run the unit tests:

View File

@@ -1,9 +1,9 @@
[project]
name = "lancedb"
version = "0.6.7"
version = "0.6.10"
dependencies = [
"deprecation",
"pylance==0.10.9",
"pylance==0.10.12",
"ratelimiter~=1.0",
"requests>=2.31.0",
"retry>=0.9.2",
@@ -49,6 +49,7 @@ repository = "https://github.com/lancedb/lancedb"
[project.optional-dependencies]
tests = [
"aiohttp",
"boto3",
"pandas>=1.4",
"pytest",
"pytest-mock",
@@ -56,6 +57,7 @@ tests = [
"duckdb",
"pytz",
"polars>=0.19",
"tantivy"
]
dev = ["ruff", "pre-commit"]
docs = [
@@ -63,7 +65,6 @@ docs = [
"mkdocs-jupyter",
"mkdocs-material",
"mkdocstrings[python]",
"mkdocs-ultralytics-plugin==0.0.44",
]
clip = ["torch", "pillow", "open-clip"]
embeddings = [
@@ -98,4 +99,5 @@ addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py"
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"asyncio",
"s3_test"
]

View File

@@ -15,7 +15,7 @@ import importlib.metadata
import os
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import Optional, Union
from typing import Dict, Optional, Union
__version__ = importlib.metadata.version("lancedb")
@@ -83,7 +83,7 @@ def connect(
>>> db = lancedb.connect("s3://my-bucket/lancedb")
Connect to LancdDB cloud:
Connect to LanceDB cloud:
>>> db = lancedb.connect("db://my_database", api_key="ldb_...")
@@ -118,6 +118,7 @@ async def connect_async(
host_override: Optional[str] = None,
read_consistency_interval: Optional[timedelta] = None,
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
storage_options: Optional[Dict[str, str]] = None,
) -> AsyncConnection:
"""Connect to a LanceDB database.
@@ -144,6 +145,9 @@ async def connect_async(
the last check, then the table will be checked for updates. Note: this
consistency only applies to read operations. Write operations are
always consistent.
storage_options: dict, optional
Additional options for the storage backend. See available options at
https://lancedb.github.io/lancedb/guides/storage/
Examples
--------
@@ -172,6 +176,7 @@ async def connect_async(
region,
host_override,
read_consistency_interval_secs,
storage_options,
)
)

View File

@@ -19,10 +19,18 @@ class Connection(object):
self, start_after: Optional[str], limit: Optional[int]
) -> list[str]: ...
async def create_table(
self, name: str, mode: str, data: pa.RecordBatchReader
self,
name: str,
mode: str,
data: pa.RecordBatchReader,
storage_options: Optional[Dict[str, str]] = None,
) -> Table: ...
async def create_empty_table(
self, name: str, mode: str, schema: pa.Schema
self,
name: str,
mode: str,
schema: pa.Schema,
storage_options: Optional[Dict[str, str]] = None,
) -> Table: ...
class Table:

View File

@@ -18,7 +18,7 @@ import inspect
import os
from abc import abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Iterable, List, Literal, Optional, Union
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Union
import pyarrow as pa
from overrides import EnforceOverrides, override
@@ -533,6 +533,7 @@ class AsyncConnection(object):
exist_ok: Optional[bool] = None,
on_bad_vectors: Optional[str] = None,
fill_value: Optional[float] = None,
storage_options: Optional[Dict[str, str]] = None,
) -> AsyncTable:
"""Create an [AsyncTable][lancedb.table.AsyncTable] in the database.
@@ -570,6 +571,12 @@ class AsyncConnection(object):
One of "error", "drop", "fill".
fill_value: float
The value to use when filling vectors. Only used if on_bad_vectors="fill".
storage_options: dict, optional
Additional options for the storage backend. Options already set on the
connection will be inherited by the table, but can be overridden here.
See available options at
https://lancedb.github.io/lancedb/guides/storage/
Returns
-------
@@ -729,30 +736,40 @@ class AsyncConnection(object):
mode = "exist_ok"
if data is None:
new_table = await self._inner.create_empty_table(name, mode, schema)
new_table = await self._inner.create_empty_table(
name, mode, schema, storage_options=storage_options
)
else:
data = data_to_reader(data, schema)
new_table = await self._inner.create_table(
name,
mode,
data,
storage_options=storage_options,
)
return AsyncTable(new_table)
async def open_table(self, name: str) -> Table:
async def open_table(
self, name: str, storage_options: Optional[Dict[str, str]] = None
) -> Table:
"""Open a Lance Table in the database.
Parameters
----------
name: str
The name of the table.
storage_options: dict, optional
Additional options for the storage backend. Options already set on the
connection will be inherited by the table, but can be overridden here.
See available options at
https://lancedb.github.io/lancedb/guides/storage/
Returns
-------
A LanceTable object representing the table.
"""
table = await self._inner.open_table(name)
table = await self._inner.open_table(name, storage_options)
return AsyncTable(table)
async def drop_table(self, name: str):

View File

@@ -78,6 +78,9 @@ class BedRockText(TextEmbeddingFunction):
class Config:
keep_untouched = (cached_property,)
else:
model_config = dict()
model_config["ignored_types"] = (cached_property,)
def ndims(self):
# return len(self._generate_embedding("test"))

View File

@@ -94,6 +94,9 @@ class GeminiText(TextEmbeddingFunction):
class Config:
keep_untouched = (cached_property,)
else:
model_config = dict()
model_config["ignored_types"] = (cached_property,)
def ndims(self):
# TODO: fix hardcoding

View File

@@ -22,6 +22,8 @@ from .base import EmbeddingFunction
from .registry import register
from .utils import AUDIO, IMAGES, TEXT
from lancedb.pydantic import PYDANTIC_VERSION
@register("imagebind")
class ImageBindEmbeddings(EmbeddingFunction):
@@ -38,8 +40,13 @@ class ImageBindEmbeddings(EmbeddingFunction):
device: str = "cpu"
normalize: bool = False
class Config:
keep_untouched = (cached_property,)
if PYDANTIC_VERSION < (2, 0): # Pydantic 1.x compat
class Config:
keep_untouched = (cached_property,)
else:
model_config = dict()
model_config["ignored_types"] = (cached_property,)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

View File

@@ -17,6 +17,7 @@ from typing import List, Any
import numpy as np
from pydantic import PrivateAttr
from lancedb.pydantic import PYDANTIC_VERSION
from ..util import attempt_import_or_raise
from .base import EmbeddingFunction
@@ -53,8 +54,13 @@ class TransformersEmbeddingFunction(EmbeddingFunction):
self._tokenizer = transformers.AutoTokenizer.from_pretrained(self.name)
self._model = transformers.AutoModel.from_pretrained(self.name)
class Config:
keep_untouched = (cached_property,)
if PYDANTIC_VERSION < (2, 0): # Pydantic 1.x compat
class Config:
keep_untouched = (cached_property,)
else:
model_config = dict()
model_config["ignored_types"] = (cached_property,)
def ndims(self):
self._ndims = self._model.config.hidden_size

View File

@@ -1,4 +1,5 @@
import os
import semver
from functools import cached_property
from typing import Union
@@ -42,6 +43,14 @@ class CohereReranker(Reranker):
@cached_property
def _client(self):
cohere = attempt_import_or_raise("cohere")
# ensure version is at least 0.5.0
if (
hasattr(cohere, "__version__")
and semver.compare(cohere.__version__, "5.0.0") < 0
):
raise ValueError(
f"cohere version must be at least 0.5.0, found {cohere.__version__}"
)
if os.environ.get("COHERE_API_KEY") is None and self.api_key is None:
raise ValueError(
"COHERE_API_KEY not set. Either set it in your environment or \
@@ -51,11 +60,14 @@ class CohereReranker(Reranker):
def _rerank(self, result_set: pa.Table, query: str):
docs = result_set[self.column].to_pylist()
results = self._client.rerank(
response = self._client.rerank(
query=query,
documents=docs,
top_n=self.top_n,
model=self.model_name,
)
results = (
response.results
) # returns list (text, idx, relevance) attributes sorted descending by score
indices, scores = list(
zip(*[(result.index, result.relevance_score) for result in results])

View File

@@ -95,6 +95,9 @@ def _sanitize_data(
data.data.to_batches(), schema, metadata, on_bad_vectors, fill_value
)
if isinstance(data, LanceModel):
raise ValueError("Cannot add a single LanceModel to a table. Use a list.")
if isinstance(data, list):
# convert to list of dict if data is a bunch of LanceModels
if isinstance(data[0], LanceModel):
@@ -1403,7 +1406,14 @@ class LanceTable(Table):
vector and the returned vector.
"""
if vector_column_name is None and query is not None:
vector_column_name = inf_vector_column_query(self.schema)
try:
vector_column_name = inf_vector_column_query(self.schema)
except Exception as e:
if query_type == "fts":
vector_column_name = ""
else:
raise e
return LanceQueryBuilder.create(
self,
query,

View File

@@ -28,13 +28,25 @@ def test_basic(tmp_path):
assert db.uri == str(tmp_path)
assert db.table_names() == []
class SimpleModel(LanceModel):
item: str
price: float
vector: Vector(2)
table = db.create_table(
"test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
schema=SimpleModel,
)
with pytest.raises(
ValueError, match="Cannot add a single LanceModel to a table. Use a list."
):
table.add(SimpleModel(item="baz", price=30.0, vector=[1.0, 2.0]))
rs = table.search([100, 100]).limit(1).to_pandas()
assert len(rs) == 1
assert rs["item"].iloc[0] == "bar"
@@ -43,6 +55,11 @@ def test_basic(tmp_path):
assert len(rs) == 1
assert rs["item"].iloc[0] == "foo"
table.create_fts_index(["item"])
rs = table.search("bar", query_type="fts").to_pandas()
assert len(rs) == 1
assert rs["item"].iloc[0] == "bar"
assert db.table_names() == ["test"]
assert "test" in db
assert len(db) == 1

View File

@@ -0,0 +1,158 @@
# Copyright 2024 Lance Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import copy
import pytest
import pyarrow as pa
import lancedb
# These are all keys that are accepted by storage_options
CONFIG = {
"allow_http": "true",
"aws_access_key_id": "ACCESSKEY",
"aws_secret_access_key": "SECRETKEY",
"aws_endpoint": "http://localhost:4566",
"aws_region": "us-east-1",
}
def get_boto3_client(*args, **kwargs):
import boto3
return boto3.client(
*args,
region_name=CONFIG["aws_region"],
aws_access_key_id=CONFIG["aws_access_key_id"],
aws_secret_access_key=CONFIG["aws_secret_access_key"],
**kwargs,
)
@pytest.fixture(scope="module")
def s3_bucket():
s3 = get_boto3_client("s3", endpoint_url=CONFIG["aws_endpoint"])
bucket_name = "lance-integtest"
# if bucket exists, delete it
try:
delete_bucket(s3, bucket_name)
except s3.exceptions.NoSuchBucket:
pass
s3.create_bucket(Bucket=bucket_name)
yield bucket_name
delete_bucket(s3, bucket_name)
def delete_bucket(s3, bucket_name):
# Delete all objects first
for obj in s3.list_objects(Bucket=bucket_name).get("Contents", []):
s3.delete_object(Bucket=bucket_name, Key=obj["Key"])
s3.delete_bucket(Bucket=bucket_name)
@pytest.mark.s3_test
def test_s3_lifecycle(s3_bucket: str):
storage_options = copy.copy(CONFIG)
uri = f"s3://{s3_bucket}/test_lifecycle"
data = pa.table({"x": [1, 2, 3]})
async def test():
db = await lancedb.connect_async(uri, storage_options=storage_options)
table = await db.create_table("test", schema=data.schema)
assert await table.count_rows() == 0
table = await db.create_table("test", data, mode="overwrite")
assert await table.count_rows() == 3
await table.add(data, mode="append")
assert await table.count_rows() == 6
table = await db.open_table("test")
assert await table.count_rows() == 6
await db.drop_table("test")
await db.drop_database()
asyncio.run(test())
@pytest.fixture()
def kms_key():
kms = get_boto3_client("kms", endpoint_url=CONFIG["aws_endpoint"])
key_id = kms.create_key()["KeyMetadata"]["KeyId"]
yield key_id
kms.schedule_key_deletion(KeyId=key_id, PendingWindowInDays=7)
def validate_objects_encrypted(bucket: str, path: str, kms_key: str):
s3 = get_boto3_client("s3", endpoint_url=CONFIG["aws_endpoint"])
objects = s3.list_objects_v2(Bucket=bucket, Prefix=path)["Contents"]
for obj in objects:
info = s3.head_object(Bucket=bucket, Key=obj["Key"])
assert info["ServerSideEncryption"] == "aws:kms", (
"object %s not encrypted" % obj["Key"]
)
assert info["SSEKMSKeyId"].endswith(kms_key), (
"object %s not encrypted with correct key" % obj["Key"]
)
@pytest.mark.s3_test
def test_s3_sse(s3_bucket: str, kms_key: str):
storage_options = copy.copy(CONFIG)
uri = f"s3://{s3_bucket}/test_lifecycle"
data = pa.table({"x": [1, 2, 3]})
async def test():
# Create a table with SSE
db = await lancedb.connect_async(uri, storage_options=storage_options)
table = await db.create_table(
"table1",
schema=data.schema,
storage_options={
"aws_server_side_encryption": "aws:kms",
"aws_sse_kms_key_id": kms_key,
},
)
await table.add(data)
await table.update({"x": "1"})
path = "test_lifecycle/table1.lance"
validate_objects_encrypted(s3_bucket, path, kms_key)
# Test we can set encryption at connection level too.
db = await lancedb.connect_async(
uri,
storage_options=dict(
aws_server_side_encryption="aws:kms",
aws_sse_kms_key_id=kms_key,
**storage_options,
),
)
table = await db.create_table("table2", schema=data.schema)
await table.add(data)
await table.update({"x": "1"})
path = "test_lifecycle/table2.lance"
validate_objects_encrypted(s3_bucket, path, kms_key)
asyncio.run(test())

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{sync::Arc, time::Duration};
use std::{collections::HashMap, sync::Arc, time::Duration};
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
use lancedb::connection::{Connection as LanceConnection, CreateTableMode};
@@ -90,19 +90,21 @@ impl Connection {
name: String,
mode: &str,
data: &PyAny,
storage_options: Option<HashMap<String, String>>,
) -> PyResult<&'a PyAny> {
let inner = self_.get_inner()?.clone();
let mode = Self::parse_create_mode_str(mode)?;
let batches = ArrowArrayStreamReader::from_pyarrow(data)?;
let mut builder = inner.create_table(name, batches).mode(mode);
if let Some(storage_options) = storage_options {
builder = builder.storage_options(storage_options);
}
future_into_py(self_.py(), async move {
let table = inner
.create_table(name, batches)
.mode(mode)
.execute()
.await
.infer_error()?;
let table = builder.execute().await.infer_error()?;
Ok(Table::new(table))
})
}
@@ -112,6 +114,7 @@ impl Connection {
name: String,
mode: &str,
schema: &PyAny,
storage_options: Option<HashMap<String, String>>,
) -> PyResult<&'a PyAny> {
let inner = self_.get_inner()?.clone();
@@ -119,21 +122,31 @@ impl Connection {
let schema = Schema::from_pyarrow(schema)?;
let mut builder = inner.create_empty_table(name, Arc::new(schema)).mode(mode);
if let Some(storage_options) = storage_options {
builder = builder.storage_options(storage_options);
}
future_into_py(self_.py(), async move {
let table = inner
.create_empty_table(name, Arc::new(schema))
.mode(mode)
.execute()
.await
.infer_error()?;
let table = builder.execute().await.infer_error()?;
Ok(Table::new(table))
})
}
pub fn open_table(self_: PyRef<'_, Self>, name: String) -> PyResult<&PyAny> {
#[pyo3(signature = (name, storage_options = None))]
pub fn open_table(
self_: PyRef<'_, Self>,
name: String,
storage_options: Option<HashMap<String, String>>,
) -> PyResult<&PyAny> {
let inner = self_.get_inner()?.clone();
let mut builder = inner.open_table(name);
if let Some(storage_options) = storage_options {
builder = builder.storage_options(storage_options);
}
future_into_py(self_.py(), async move {
let table = inner.open_table(&name).execute().await.infer_error()?;
let table = builder.execute().await.infer_error()?;
Ok(Table::new(table))
})
}
@@ -162,6 +175,7 @@ pub fn connect(
region: Option<String>,
host_override: Option<String>,
read_consistency_interval: Option<f64>,
storage_options: Option<HashMap<String, String>>,
) -> PyResult<&PyAny> {
future_into_py(py, async move {
let mut builder = lancedb::connect(&uri);
@@ -178,6 +192,9 @@ pub fn connect(
let read_consistency_interval = Duration::from_secs_f64(read_consistency_interval);
builder = builder.read_consistency_interval(read_consistency_interval);
}
if let Some(storage_options) = storage_options {
builder = builder.storage_options(storage_options);
}
Ok(Connection::new(builder.execute().await.infer_error()?))
})
}