mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 06:39:57 +00:00
This switches the default FTS to native lance FTS for Python sync table API, the other APIs have switched to native implementation already <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - The default behavior for creating a full-text search index now uses the new implementation rather than the legacy one. - **Bug Fixes** - Improved handling and error messages for phrase queries in full-text search. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: BubbleCal <bubble-cal@outlook.com>
279 lines
8.0 KiB
Python
279 lines
8.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
|
|
|
|
import asyncio
|
|
import copy
|
|
from datetime import timedelta
|
|
import threading
|
|
|
|
import pytest
|
|
import pyarrow as pa
|
|
import lancedb
|
|
|
|
|
|
# These are all keys that are accepted by storage_options
|
|
CONFIG = {
|
|
"allow_http": "true",
|
|
"aws_access_key_id": "ACCESSKEY",
|
|
"aws_secret_access_key": "SECRETKEY",
|
|
"aws_endpoint": "http://localhost:4566",
|
|
"dynamodb_endpoint": "http://localhost:4566",
|
|
"aws_region": "us-east-1",
|
|
}
|
|
|
|
|
|
def get_boto3_client(*args, **kwargs):
|
|
import boto3
|
|
|
|
return boto3.client(
|
|
*args,
|
|
region_name=CONFIG["aws_region"],
|
|
aws_access_key_id=CONFIG["aws_access_key_id"],
|
|
aws_secret_access_key=CONFIG["aws_secret_access_key"],
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def s3_bucket():
|
|
s3 = get_boto3_client("s3", endpoint_url=CONFIG["aws_endpoint"])
|
|
bucket_name = "lance-integtest"
|
|
# if bucket exists, delete it
|
|
try:
|
|
delete_bucket(s3, bucket_name)
|
|
except s3.exceptions.NoSuchBucket:
|
|
pass
|
|
s3.create_bucket(Bucket=bucket_name)
|
|
yield bucket_name
|
|
|
|
delete_bucket(s3, bucket_name)
|
|
|
|
|
|
def delete_bucket(s3, bucket_name):
|
|
# Delete all objects first
|
|
for obj in s3.list_objects(Bucket=bucket_name).get("Contents", []):
|
|
s3.delete_object(Bucket=bucket_name, Key=obj["Key"])
|
|
s3.delete_bucket(Bucket=bucket_name)
|
|
|
|
|
|
@pytest.mark.s3_test
|
|
def test_s3_lifecycle(s3_bucket: str):
|
|
storage_options = copy.copy(CONFIG)
|
|
|
|
uri = f"s3://{s3_bucket}/test_lifecycle"
|
|
data = pa.table({"x": [1, 2, 3]})
|
|
|
|
async def test():
|
|
db = await lancedb.connect_async(uri, storage_options=storage_options)
|
|
|
|
table = await db.create_table("test", schema=data.schema)
|
|
assert await table.count_rows() == 0
|
|
|
|
table = await db.create_table("test", data, mode="overwrite")
|
|
assert await table.count_rows() == 3
|
|
|
|
await table.add(data, mode="append")
|
|
assert await table.count_rows() == 6
|
|
|
|
table = await db.open_table("test")
|
|
assert await table.count_rows() == 6
|
|
|
|
await db.drop_table("test")
|
|
|
|
await db.drop_database()
|
|
|
|
asyncio.run(test())
|
|
|
|
|
|
@pytest.fixture()
|
|
def kms_key():
|
|
kms = get_boto3_client("kms", endpoint_url=CONFIG["aws_endpoint"])
|
|
key_id = kms.create_key()["KeyMetadata"]["KeyId"]
|
|
yield key_id
|
|
kms.schedule_key_deletion(KeyId=key_id, PendingWindowInDays=7)
|
|
|
|
|
|
def validate_objects_encrypted(bucket: str, path: str, kms_key: str):
|
|
s3 = get_boto3_client("s3", endpoint_url=CONFIG["aws_endpoint"])
|
|
objects = s3.list_objects_v2(Bucket=bucket, Prefix=path)["Contents"]
|
|
for obj in objects:
|
|
info = s3.head_object(Bucket=bucket, Key=obj["Key"])
|
|
assert info["ServerSideEncryption"] == "aws:kms", (
|
|
"object %s not encrypted" % obj["Key"]
|
|
)
|
|
assert info["SSEKMSKeyId"].endswith(kms_key), (
|
|
"object %s not encrypted with correct key" % obj["Key"]
|
|
)
|
|
|
|
|
|
@pytest.mark.s3_test
|
|
def test_s3_sse(s3_bucket: str, kms_key: str):
|
|
storage_options = copy.copy(CONFIG)
|
|
|
|
uri = f"s3://{s3_bucket}/test_lifecycle"
|
|
data = pa.table({"x": [1, 2, 3]})
|
|
|
|
async def test():
|
|
# Create a table with SSE
|
|
db = await lancedb.connect_async(uri, storage_options=storage_options)
|
|
|
|
table = await db.create_table(
|
|
"table1",
|
|
schema=data.schema,
|
|
storage_options={
|
|
"aws_server_side_encryption": "aws:kms",
|
|
"aws_sse_kms_key_id": kms_key,
|
|
},
|
|
)
|
|
await table.add(data)
|
|
await table.update({"x": "1"})
|
|
|
|
path = "test_lifecycle/table1.lance"
|
|
validate_objects_encrypted(s3_bucket, path, kms_key)
|
|
|
|
# Test we can set encryption at connection level too.
|
|
db = await lancedb.connect_async(
|
|
uri,
|
|
storage_options=dict(
|
|
aws_server_side_encryption="aws:kms",
|
|
aws_sse_kms_key_id=kms_key,
|
|
**storage_options,
|
|
),
|
|
)
|
|
|
|
table = await db.create_table("table2", schema=data.schema)
|
|
await table.add(data)
|
|
await table.update({"x": "1"})
|
|
|
|
path = "test_lifecycle/table2.lance"
|
|
validate_objects_encrypted(s3_bucket, path, kms_key)
|
|
|
|
asyncio.run(test())
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def commit_table():
|
|
ddb = get_boto3_client("dynamodb", endpoint_url=CONFIG["dynamodb_endpoint"])
|
|
table_name = "lance-integtest"
|
|
try:
|
|
ddb.delete_table(TableName=table_name)
|
|
except ddb.exceptions.ResourceNotFoundException:
|
|
pass
|
|
ddb.create_table(
|
|
TableName=table_name,
|
|
KeySchema=[
|
|
{"AttributeName": "base_uri", "KeyType": "HASH"},
|
|
{"AttributeName": "version", "KeyType": "RANGE"},
|
|
],
|
|
AttributeDefinitions=[
|
|
{"AttributeName": "base_uri", "AttributeType": "S"},
|
|
{"AttributeName": "version", "AttributeType": "N"},
|
|
],
|
|
ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1},
|
|
)
|
|
yield table_name
|
|
ddb.delete_table(TableName=table_name)
|
|
|
|
|
|
@pytest.mark.s3_test
|
|
def test_s3_dynamodb(s3_bucket: str, commit_table: str):
|
|
storage_options = copy.copy(CONFIG)
|
|
|
|
uri = f"s3+ddb://{s3_bucket}/test?ddbTableName={commit_table}"
|
|
data = pa.table({"x": [1, 2, 3]})
|
|
|
|
async def test():
|
|
db = await lancedb.connect_async(
|
|
uri,
|
|
storage_options=storage_options,
|
|
read_consistency_interval=timedelta(0),
|
|
)
|
|
|
|
table = await db.create_table("test", data)
|
|
|
|
# Five concurrent writers
|
|
async def insert():
|
|
# independent table refs for true concurrent writes.
|
|
table = await db.open_table("test")
|
|
await table.add(data, mode="append")
|
|
|
|
tasks = [insert() for _ in range(5)]
|
|
await asyncio.gather(*tasks)
|
|
|
|
row_count = await table.count_rows()
|
|
assert row_count == 3 * 6
|
|
|
|
asyncio.run(test())
|
|
|
|
|
|
@pytest.mark.s3_test
|
|
def test_s3_dynamodb_sync(s3_bucket: str, commit_table: str, monkeypatch):
|
|
# Sync API doesn't support storage_options, so we have to provide as env vars
|
|
for key, value in CONFIG.items():
|
|
monkeypatch.setenv(key.upper(), value)
|
|
|
|
uri = f"s3+ddb://{s3_bucket}/test2?ddbTableName={commit_table}"
|
|
data = pa.table({"x": ["a", "b", "c"]})
|
|
|
|
db = lancedb.connect(
|
|
uri,
|
|
read_consistency_interval=timedelta(0),
|
|
)
|
|
|
|
table = db.create_table("test_ddb_sync", data)
|
|
|
|
# Five concurrent writers
|
|
def insert():
|
|
table = db.open_table("test_ddb_sync")
|
|
table.add(data, mode="append")
|
|
|
|
threads = []
|
|
for _ in range(5):
|
|
thread = threading.Thread(target=insert)
|
|
threads.append(thread)
|
|
thread.start()
|
|
|
|
for thread in threads:
|
|
thread.join()
|
|
|
|
row_count = table.count_rows()
|
|
assert row_count == 3 * 6
|
|
|
|
# FTS indices should error since they are not supported yet.
|
|
with pytest.raises(
|
|
NotImplementedError,
|
|
match="Full-text search is only supported on the local filesystem",
|
|
):
|
|
table.create_fts_index("x", use_tantivy=True)
|
|
|
|
# make sure list tables still works
|
|
assert db.table_names() == ["test_ddb_sync"]
|
|
db.drop_table("test_ddb_sync")
|
|
assert db.table_names() == []
|
|
db.drop_database()
|
|
|
|
|
|
@pytest.mark.s3_test
|
|
def test_s3_dynamodb_drop_all_tables(s3_bucket: str, commit_table: str, monkeypatch):
|
|
for key, value in CONFIG.items():
|
|
monkeypatch.setenv(key.upper(), value)
|
|
|
|
uri = f"s3+ddb://{s3_bucket}/test2?ddbTableName={commit_table}"
|
|
db = lancedb.connect(uri, read_consistency_interval=timedelta(0))
|
|
data = pa.table({"x": ["a", "b", "c"]})
|
|
|
|
db.create_table("foo", data)
|
|
db.create_table("bar", data)
|
|
assert db.table_names() == ["bar", "foo"]
|
|
|
|
# dropping all tables should clear multiple tables
|
|
db.drop_all_tables()
|
|
assert db.table_names() == []
|
|
|
|
# create a new table with the same name to ensure DDB is clean
|
|
db.create_table("foo", data)
|
|
assert db.table_names() == ["foo"]
|
|
|
|
db.drop_all_tables()
|