feat: expose storage options in LanceDB (#1204)

Exposes `storage_options` in LanceDB. This is provided for Python async,
Node `lancedb`, and Node `vectordb` (and Rust of course). Python
synchronous is omitted because it's not compatible with the PyArrow
filesystems we use there currently. In the future, we will move the sync
API to wrap the async one, and then it will get support for
`storage_options`.

1. Fixes #1168
2. Closes #1165
3. Closes #1082
4. Closes #439
5. Closes #897
6. Closes #642
7. Closes #281
8. Closes #114
9. Closes #990
10. Deprecating `awsCredentials` and `awsRegion`. Users are encouraged
to use `storageOptions` instead.
This commit is contained in:
Will Jones
2024-04-10 10:12:04 -07:00
committed by GitHub
parent 25dea4e859
commit 1d23af213b
31 changed files with 3128 additions and 262 deletions

View File

@@ -107,6 +107,7 @@ jobs:
AWS_ENDPOINT: http://localhost:4566
# this one is for dynamodb
DYNAMODB_ENDPOINT: http://localhost:4566
ALLOW_HTTP: true
steps:
- uses: actions/checkout@v4
with:

View File

@@ -85,7 +85,12 @@ jobs:
run: |
npm ci
npm run build
- name: Setup localstack
working-directory: .
run: docker compose up --detach --wait
- name: Test
env:
S3_TEST: "1"
run: npm run test
macos:
timeout-minutes: 30

View File

@@ -99,6 +99,8 @@ jobs:
workspaces: python
- uses: ./.github/workflows/build_linux_wheel
- uses: ./.github/workflows/run_tests
with:
integration: true
# Make sure wheels are not included in the Rust cache
- name: Delete wheels
run: rm -rf target/wheels
@@ -190,4 +192,4 @@ jobs:
pip install -e .[tests]
pip install tantivy
- name: Run tests
run: pytest -m "not slow" -x -v --durations=30 python/tests
run: pytest -m "not slow and not s3_test" -x -v --durations=30 python/tests

View File

@@ -5,6 +5,10 @@ inputs:
python-minor-version:
required: true
description: "8 9 10 11 12"
integration:
required: false
description: "Run integration tests"
default: "false"
runs:
using: "composite"
steps:
@@ -12,6 +16,16 @@ runs:
shell: bash
run: |
pip3 install $(ls target/wheels/lancedb-*.whl)[tests,dev]
- name: pytest
- name: Setup localstack for integration tests
if: ${{ inputs.integration == 'true' }}
shell: bash
working-directory: .
run: docker compose up --detach --wait
- name: pytest (with integration)
shell: bash
if: ${{ inputs.integration == 'true' }}
run: pytest -m "not slow" -x -v --durations=30 python/python/tests
- name: pytest (no integration tests)
shell: bash
if: ${{ inputs.integration != 'true' }}
run: pytest -m "not slow and not s3_test" -x -v --durations=30 python/python/tests

View File

@@ -76,6 +76,9 @@ jobs:
sudo apt install -y protobuf-compiler libssl-dev
- name: Build
run: cargo build --all-features
- name: Start S3 integration test environment
working-directory: .
run: docker compose up --detach --wait
- name: Run tests
run: cargo test --all-features
- name: Run examples
@@ -105,7 +108,8 @@ jobs:
- name: Build
run: cargo build --all-features
- name: Run tests
run: cargo test --all-features
# Run with everything except the integration tests.
run: cargo test --features remote,fp16kernels
windows:
runs-on: windows-2022
steps:

View File

@@ -14,10 +14,10 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
categories = ["database-implementations"]
[workspace.dependencies]
lance = { "version" = "=0.10.9", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.10.9" }
lance-linalg = { "version" = "=0.10.9" }
lance-testing = { "version" = "=0.10.9" }
lance = { "version" = "=0.10.10", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.10.10" }
lance-linalg = { "version" = "=0.10.10" }
lance-testing = { "version" = "=0.10.10" }
# Note that this one does not include pyarrow
arrow = { version = "50.0", optional = false }
arrow-array = "50.0"

View File

@@ -1,18 +1,18 @@
version: "3.9"
services:
localstack:
image: localstack/localstack:0.14
image: localstack/localstack:3.3
ports:
- 4566:4566
environment:
- SERVICES=s3,dynamodb
- SERVICES=s3,dynamodb,kms
- DEBUG=1
- LS_LOG=trace
- DOCKER_HOST=unix:///var/run/docker.sock
- AWS_ACCESS_KEY_ID=ACCESSKEY
- AWS_SECRET_ACCESS_KEY=SECRETKEY
healthcheck:
test: [ "CMD", "curl", "-f", "http://localhost:4566/health" ]
test: [ "CMD", "curl", "-s", "http://localhost:4566/_localstack/health" ]
interval: 5s
retries: 3
start_period: 10s

View File

@@ -55,18 +55,139 @@ LanceDB OSS supports object stores such as AWS S3 (and compatible stores), Azure
const db = await lancedb.connect("az://bucket/path");
```
In most cases, when running in the respective cloud and permissions are set up correctly, no additional configuration is required. When running outside of the respective cloud, authentication credentials must be provided using environment variables. In general, these environment variables are the same as those used by the respective cloud SDKs. The sections below describe the environment variables that can be used to configure each object store.
In most cases, when running in the respective cloud and permissions are set up correctly, no additional configuration is required. When running outside of the respective cloud, authentication credentials must be provided. Credentials and other configuration options can be set in two ways: first, by setting environment variables. And second, by passing a `storage_options` object to the `connect` function. For example, to increase the request timeout to 60 seconds, you can set the `TIMEOUT` environment variable to `60s`:
LanceDB OSS uses the [object-store](https://docs.rs/object_store/latest/object_store/) Rust crate for object store access. There are general environment variables that can be used to configure the object store, such as the request timeout and proxy configuration. See the [object_store ClientConfigKey](https://docs.rs/object_store/latest/object_store/enum.ClientConfigKey.html) doc for available configuration options. The environment variables that can be set are the snake-cased versions of these variable names. For example, to set `ProxyUrl` use the environment variable `PROXY_URL`. (Don't let the Rust docs intimidate you! We link to them so you can see an up-to-date list of the available options.)
```bash
export TIMEOUT=60s
```
!!! note "`storage_options` availability"
The `storage_options` parameter is only available in Python *async* API and JavaScript API.
It is not yet supported in the Python synchronous API.
If you only want this to apply to one particular connection, you can pass the `storage_options` argument when opening the connection:
=== "Python"
```python
import lancedb
db = await lancedb.connect_async(
"s3://bucket/path",
storage_options={"timeout": "60s"}
)
```
=== "JavaScript"
```javascript
const lancedb = require("lancedb");
const db = await lancedb.connect("s3://bucket/path",
{storageOptions: {timeout: "60s"}});
```
Getting even more specific, you can set the `timeout` for only a particular table:
=== "Python"
<!-- skip-test -->
```python
import lancedb
db = await lancedb.connect_async("s3://bucket/path")
table = await db.create_table(
"table",
[{"a": 1, "b": 2}],
storage_options={"timeout": "60s"}
)
```
=== "JavaScript"
<!-- skip-test -->
```javascript
const lancedb = require("lancedb");
const db = await lancedb.connect("s3://bucket/path");
const table = db.createTable(
"table",
[{ a: 1, b: 2}],
{storageOptions: {timeout: "60s"}}
);
```
!!! info "Storage option casing"
The storage option keys are case-insensitive. So `connect_timeout` and `CONNECT_TIMEOUT` are the same setting. Usually lowercase is used in the `storage_options` argument and uppercase is used for environment variables. In the `lancedb` Node package, the keys can also be provided in `camelCase` capitalization. For example, `connectTimeout` is equivalent to `connect_timeout`.
### General configuration
There are several options that can be set for all object stores, mostly related to network client configuration.
<!-- from here: https://docs.rs/object_store/latest/object_store/enum.ClientConfigKey.html -->
| Key | Description |
|----------------------------|--------------------------------------------------------------------------------------------------|
| `allow_http` | Allow non-TLS, i.e. non-HTTPS connections. Default: `False`. |
| `allow_invalid_certificates`| Skip certificate validation on HTTPS connections. Default: `False`. |
| `connect_timeout` | Timeout for only the connect phase of a Client. Default: `5s`. |
| `timeout` | Timeout for the entire request, from connection until the response body has finished. Default: `30s`. |
| `user_agent` | User agent string to use in requests. |
| `proxy_url` | URL of a proxy server to use for requests. Default: `None`. |
| `proxy_ca_certificate` | PEM-formatted CA certificate for proxy connections. |
| `proxy_excludes` | List of hosts that bypass the proxy. This is a comma-separated list of domains and IP masks. Any subdomain of the provided domain will be bypassed. For example, `example.com, 192.168.1.0/24` would bypass `https://api.example.com`, `https://www.example.com`, and any IP in the range `192.168.1.0/24`. |
### AWS S3
To configure credentials for AWS S3, you can use the `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, and `AWS_SESSION_TOKEN` environment variables.
To configure credentials for AWS S3, you can use the `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, and `AWS_SESSION_TOKEN` keys. Region can also be set, but it is not mandatory when using AWS.
These can be set as environment variables or passed in the `storage_options` parameter:
=== "Python"
```python
import lancedb
db = await lancedb.connect_async(
"s3://bucket/path",
storage_options={
"aws_access_key_id": "my-access-key",
"aws_secret_access_key": "my-secret-key",
"aws_session_token": "my-session-token",
}
)
```
=== "JavaScript"
```javascript
const lancedb = require("lancedb");
const db = await lancedb.connect(
"s3://bucket/path",
{
storageOptions: {
awsAccessKeyId: "my-access-key",
awsSecretAccessKey: "my-secret-key",
awsSessionToken: "my-session-token",
}
}
);
```
Alternatively, if you are using AWS SSO, you can use the `AWS_PROFILE` and `AWS_DEFAULT_REGION` environment variables.
You can see a full list of environment variables [here](https://docs.rs/object_store/latest/object_store/aws/struct.AmazonS3Builder.html#method.from_env).
The following keys can be used as both environment variables or keys in the `storage_options` parameter:
| Key | Description |
|------------------------------------|------------------------------------------------------------------------------------------------------|
| `aws_region` / `region` | The AWS region the bucket is in. This can be automatically detected when using AWS S3, but must be specified for S3-compatible stores. |
| `aws_access_key_id` / `access_key_id` | The AWS access key ID to use. |
| `aws_secret_access_key` / `secret_access_key` | The AWS secret access key to use. |
| `aws_session_token` / `session_token` | The AWS session token to use. |
| `aws_endpoint` / `endpoint` | The endpoint to use for S3-compatible stores. |
| `aws_virtual_hosted_style_request` / `virtual_hosted_style_request` | Whether to use virtual hosted-style requests, where the bucket name is part of the endpoint. Meant to be used with `aws_endpoint`. Default: `False`. |
| `aws_s3_express` / `s3_express` | Whether to use S3 Express One Zone endpoints. Default: `False`. See more details below. |
| `aws_server_side_encryption` | The server-side encryption algorithm to use. Must be one of `"AES256"`, `"aws:kms"`, or `"aws:kms:dsse"`. Default: `None`. |
| `aws_sse_kms_key_id` | The KMS key ID to use for server-side encryption. If set, `aws_server_side_encryption` must be `"aws:kms"` or `"aws:kms:dsse"`. |
| `aws_sse_bucket_key_enabled` | Whether to use bucket keys for server-side encryption. |
!!! tip "Automatic cleanup for failed writes"
@@ -146,22 +267,174 @@ For **read-only access**, LanceDB will need a policy such as:
#### S3-compatible stores
LanceDB can also connect to S3-compatible stores, such as MinIO. To do so, you must specify two environment variables: `AWS_ENDPOINT` and `AWS_DEFAULT_REGION`. `AWS_ENDPOINT` should be the URL of the S3-compatible store, and `AWS_DEFAULT_REGION` should be the region to use.
LanceDB can also connect to S3-compatible stores, such as MinIO. To do so, you must specify both region and endpoint:
=== "Python"
```python
import lancedb
db = await lancedb.connect_async(
"s3://bucket/path",
storage_options={
"region": "us-east-1",
"endpoint": "http://minio:9000",
}
)
```
=== "JavaScript"
```javascript
const lancedb = require("lancedb");
const db = await lancedb.connect(
"s3://bucket/path",
{
storageOptions: {
region: "us-east-1",
endpoint: "http://minio:9000",
}
}
);
```
This can also be done with the ``AWS_ENDPOINT`` and ``AWS_DEFAULT_REGION`` environment variables.
#### S3 Express
LanceDB supports [S3 Express One Zone](https://aws.amazon.com/s3/storage-classes/express-one-zone/) endpoints, but requires additional configuration. Also, S3 Express endpoints only support connecting from an EC2 instance within the same region.
To configure LanceDB to use an S3 Express endpoint, you must set the storage option `s3_express`. The bucket name in your table URI should **include the suffix**.
=== "Python"
```python
import lancedb
db = await lancedb.connect_async(
"s3://my-bucket--use1-az4--x-s3/path",
storage_options={
"region": "us-east-1",
"s3_express": "true",
}
)
```
=== "JavaScript"
```javascript
const lancedb = require("lancedb");
const db = await lancedb.connect(
"s3://my-bucket--use1-az4--x-s3/path",
{
storageOptions: {
region: "us-east-1",
s3Express: "true",
}
}
);
```
<!-- TODO: we should also document the use of S3 Express once we fully support it -->
### Google Cloud Storage
GCS credentials are configured by setting the `GOOGLE_SERVICE_ACCOUNT` environment variable to the path of a JSON file containing the service account credentials. There are several aliases for this environment variable, documented [here](https://docs.rs/object_store/latest/object_store/gcp/struct.GoogleCloudStorageBuilder.html#method.from_env).
GCS credentials are configured by setting the `GOOGLE_SERVICE_ACCOUNT` environment variable to the path of a JSON file containing the service account credentials. Alternatively, you can pass the path to the JSON file in the `storage_options`:
=== "Python"
<!-- skip-test -->
```python
import lancedb
db = await lancedb.connect_async(
"gs://my-bucket/my-database",
storage_options={
"service_account": "path/to/service-account.json",
}
)
```
=== "JavaScript"
```javascript
const lancedb = require("lancedb");
const db = await lancedb.connect(
"gs://my-bucket/my-database",
{
storageOptions: {
serviceAccount: "path/to/service-account.json",
}
}
);
```
!!! info "HTTP/2 support"
By default, GCS uses HTTP/1 for communication, as opposed to HTTP/2. This improves maximum throughput significantly. However, if you wish to use HTTP/2 for some reason, you can set the environment variable `HTTP1_ONLY` to `false`.
The following keys can be used as both environment variables or keys in the `storage_options` parameter:
<!-- source: https://docs.rs/object_store/latest/object_store/gcp/enum.GoogleConfigKey.html -->
| Key | Description |
|---------------------------------------|----------------------------------------------|
| ``google_service_account`` / `service_account` | Path to the service account JSON file. |
| ``google_service_account_key`` | The serialized service account key. |
| ``google_application_credentials`` | Path to the application credentials. |
### Azure Blob Storage
Azure Blob Storage credentials can be configured by setting the `AZURE_STORAGE_ACCOUNT_NAME` and ``AZURE_STORAGE_ACCOUNT_KEY`` environment variables. The full list of environment variables that can be set are documented [here](https://docs.rs/object_store/latest/object_store/azure/struct.MicrosoftAzureBuilder.html#method.from_env).
Azure Blob Storage credentials can be configured by setting the `AZURE_STORAGE_ACCOUNT_NAME`and `AZURE_STORAGE_ACCOUNT_KEY` environment variables. Alternatively, you can pass the account name and key in the `storage_options` parameter:
=== "Python"
<!-- skip-test -->
```python
import lancedb
db = await lancedb.connect_async(
"az://my-container/my-database",
storage_options={
account_name: "some-account",
account_key: "some-key",
}
)
```
=== "JavaScript"
```javascript
const lancedb = require("lancedb");
const db = await lancedb.connect(
"az://my-container/my-database",
{
storageOptions: {
accountName: "some-account",
accountKey: "some-key",
}
}
);
```
These keys can be used as both environment variables or keys in the `storage_options` parameter:
<!-- source: https://docs.rs/object_store/latest/object_store/azure/enum.AzureConfigKey.html -->
| Key | Description |
|---------------------------------------|--------------------------------------------------------------------------------------------------|
| ``azure_storage_account_name`` | The name of the azure storage account. |
| ``azure_storage_account_key`` | The serialized service account key. |
| ``azure_client_id`` | Service principal client id for authorizing requests. |
| ``azure_client_secret`` | Service principal client secret for authorizing requests. |
| ``azure_tenant_id`` | Tenant id used in oauth flows. |
| ``azure_storage_sas_key`` | Shared access signature. The signature is expected to be percent-encoded, much like they are provided in the azure storage explorer or azure portal. |
| ``azure_storage_token`` | Bearer token. |
| ``azure_storage_use_emulator`` | Use object store with azurite storage emulator. |
| ``azure_endpoint`` | Override the endpoint used to communicate with blob storage. |
| ``azure_use_fabric_endpoint`` | Use object store with url scheme account.dfs.fabric.microsoft.com. |
| ``azure_msi_endpoint`` | Endpoint to request a imds managed identity token. |
| ``azure_object_id`` | Object id for use with managed identity authentication. |
| ``azure_msi_resource_id`` | Msi resource id for use with managed identity authentication. |
| ``azure_federated_token_file`` | File containing token for Azure AD workload identity federation. |
| ``azure_use_azure_cli`` | Use azure cli for acquiring access token. |
| ``azure_disable_tagging`` | Disables tagging objects. This can be desirable if not supported by the backing store. |
<!-- TODO: demonstrate how to configure networked file systems for optimal performance -->

View File

@@ -1,5 +1,5 @@
import glob
from typing import Iterator
from typing import Iterator, List
from pathlib import Path
glob_string = "../src/**/*.md"
@@ -50,11 +50,24 @@ def yield_lines(lines: Iterator[str], prefix: str, suffix: str):
yield line[strip_length:]
def wrap_async(lines: List[str]) -> List[str]:
# Indent all the lines
lines = [" " + line for line in lines]
# Put all lines in `async def main():`
lines = ["async def main():\n"] + lines
# Put `import asyncio\n asyncio.run(main())` at the end
lines = lines + ["\n", "import asyncio\n", "asyncio.run(main())\n"]
return lines
for file in filter(lambda file: file not in excluded_files, files):
with open(file, "r") as f:
lines = list(yield_lines(iter(f), "```", "```"))
if len(lines) > 0:
if any("await" in line for line in lines):
lines = wrap_async(lines)
print(lines)
out_path = (
Path(python_folder)

View File

@@ -78,12 +78,25 @@ export interface ConnectionOptions {
/** User provided AWS crednetials.
*
* If not provided, LanceDB will use the default credentials provider chain.
*
* @deprecated Pass `aws_access_key_id`, `aws_secret_access_key`, and `aws_session_token`
* through `storageOptions` instead.
*/
awsCredentials?: AwsCredentials
/** AWS region to connect to. Default is {@link defaultAwsRegion}. */
/** AWS region to connect to. Default is {@link defaultAwsRegion}
*
* @deprecated Pass `region` through `storageOptions` instead.
*/
awsRegion?: string
/**
* User provided options for object storage. For example, S3 credentials or request timeouts.
*
* The various options are described at https://lancedb.github.io/lancedb/guides/storage/
*/
storageOptions?: Record<string, string>
/**
* API key for the remote connections
*
@@ -176,7 +189,6 @@ export async function connect (
if (typeof arg === 'string') {
opts = { uri: arg }
} else {
// opts = { uri: arg.uri, awsCredentials = arg.awsCredentials }
const keys = Object.keys(arg)
if (keys.length === 1 && keys[0] === 'uri' && typeof arg.uri === 'string') {
opts = { uri: arg.uri }
@@ -198,12 +210,26 @@ export async function connect (
// Remote connection
return new RemoteConnection(opts)
}
const storageOptions = opts.storageOptions ?? {};
if (opts.awsCredentials?.accessKeyId !== undefined) {
storageOptions.aws_access_key_id = opts.awsCredentials.accessKeyId
}
if (opts.awsCredentials?.secretKey !== undefined) {
storageOptions.aws_secret_access_key = opts.awsCredentials.secretKey
}
if (opts.awsCredentials?.sessionToken !== undefined) {
storageOptions.aws_session_token = opts.awsCredentials.sessionToken
}
if (opts.awsRegion !== undefined) {
storageOptions.region = opts.awsRegion
}
// It's a pain to pass a record to Rust, so we convert it to an array of key-value pairs
const storageOptionsArr = Object.entries(storageOptions);
const db = await databaseNew(
opts.uri,
opts.awsCredentials?.accessKeyId,
opts.awsCredentials?.secretKey,
opts.awsCredentials?.sessionToken,
opts.awsRegion,
storageOptionsArr,
opts.readConsistencyInterval
)
return new LocalConnection(db, opts)
@@ -720,7 +746,6 @@ export class LocalConnection implements Connection {
const tbl = await databaseOpenTable.call(
this._db,
name,
...getAwsArgs(this._options())
)
if (embeddings !== undefined) {
return new LocalTable(tbl, name, this._options(), embeddings)

View File

@@ -75,6 +75,19 @@ describe('LanceDB client', function () {
assert.equal(con.uri, uri)
})
it('should accept custom storage options', async function () {
const uri = await createTestDB()
const storageOptions = {
region: 'us-west-2',
timeout: '30s'
};
const con = await lancedb.connect({
uri,
storageOptions
})
assert.equal(con.uri, uri)
})
it('should return the existing table names', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)

View File

@@ -0,0 +1,219 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* eslint-disable @typescript-eslint/naming-convention */
import { connect } from "../dist";
import {
CreateBucketCommand,
DeleteBucketCommand,
DeleteObjectCommand,
HeadObjectCommand,
ListObjectsV2Command,
S3Client,
} from "@aws-sdk/client-s3";
import {
CreateKeyCommand,
ScheduleKeyDeletionCommand,
KMSClient,
} from "@aws-sdk/client-kms";
// Skip these tests unless the S3_TEST environment variable is set
const maybeDescribe = process.env.S3_TEST ? describe : describe.skip;
// These are all keys that are accepted by storage_options
const CONFIG = {
allowHttp: "true",
awsAccessKeyId: "ACCESSKEY",
awsSecretAccessKey: "SECRETKEY",
awsEndpoint: "http://127.0.0.1:4566",
awsRegion: "us-east-1",
};
class S3Bucket {
name: string;
constructor(name: string) {
this.name = name;
}
static s3Client() {
return new S3Client({
region: CONFIG.awsRegion,
credentials: {
accessKeyId: CONFIG.awsAccessKeyId,
secretAccessKey: CONFIG.awsSecretAccessKey,
},
endpoint: CONFIG.awsEndpoint,
});
}
public static async create(name: string): Promise<S3Bucket> {
const client = this.s3Client();
// Delete the bucket if it already exists
try {
await this.deleteBucket(client, name);
} catch (e) {
// It's fine if the bucket doesn't exist
}
await client.send(new CreateBucketCommand({ Bucket: name }));
return new S3Bucket(name);
}
public async delete() {
const client = S3Bucket.s3Client();
await S3Bucket.deleteBucket(client, this.name);
}
static async deleteBucket(client: S3Client, name: string) {
// Must delete all objects before we can delete the bucket
const objects = await client.send(
new ListObjectsV2Command({ Bucket: name }),
);
if (objects.Contents) {
for (const object of objects.Contents) {
await client.send(
new DeleteObjectCommand({ Bucket: name, Key: object.Key }),
);
}
}
await client.send(new DeleteBucketCommand({ Bucket: name }));
}
public async assertAllEncrypted(path: string, keyId: string) {
const client = S3Bucket.s3Client();
const objects = await client.send(
new ListObjectsV2Command({ Bucket: this.name, Prefix: path }),
);
if (objects.Contents) {
for (const object of objects.Contents) {
const metadata = await client.send(
new HeadObjectCommand({ Bucket: this.name, Key: object.Key }),
);
expect(metadata.ServerSideEncryption).toBe("aws:kms");
expect(metadata.SSEKMSKeyId).toContain(keyId);
}
}
}
}
class KmsKey {
keyId: string;
constructor(keyId: string) {
this.keyId = keyId;
}
static kmsClient() {
return new KMSClient({
region: CONFIG.awsRegion,
credentials: {
accessKeyId: CONFIG.awsAccessKeyId,
secretAccessKey: CONFIG.awsSecretAccessKey,
},
endpoint: CONFIG.awsEndpoint,
});
}
public static async create(): Promise<KmsKey> {
const client = this.kmsClient();
const key = await client.send(new CreateKeyCommand({}));
const keyId = key?.KeyMetadata?.KeyId;
if (!keyId) {
throw new Error("Failed to create KMS key");
}
return new KmsKey(keyId);
}
public async delete() {
const client = KmsKey.kmsClient();
await client.send(new ScheduleKeyDeletionCommand({ KeyId: this.keyId }));
}
}
maybeDescribe("storage_options", () => {
let bucket: S3Bucket;
let kmsKey: KmsKey;
beforeAll(async () => {
bucket = await S3Bucket.create("lancedb");
kmsKey = await KmsKey.create();
});
afterAll(async () => {
await kmsKey.delete();
await bucket.delete();
});
it("can be used to configure auth and endpoints", async () => {
const uri = `s3://${bucket.name}/test`;
const db = await connect(uri, { storageOptions: CONFIG });
let table = await db.createTable("test", [{ a: 1, b: 2 }]);
let rowCount = await table.countRows();
expect(rowCount).toBe(1);
let tableNames = await db.tableNames();
expect(tableNames).toEqual(["test"]);
table = await db.openTable("test");
rowCount = await table.countRows();
expect(rowCount).toBe(1);
await table.add([
{ a: 2, b: 3 },
{ a: 3, b: 4 },
]);
rowCount = await table.countRows();
expect(rowCount).toBe(3);
await db.dropTable("test");
tableNames = await db.tableNames();
expect(tableNames).toEqual([]);
});
it("can configure encryption at connection and table level", async () => {
const uri = `s3://${bucket.name}/test`;
let db = await connect(uri, { storageOptions: CONFIG });
let table = await db.createTable("table1", [{ a: 1, b: 2 }], {
storageOptions: {
awsServerSideEncryption: "aws:kms",
awsSseKmsKeyId: kmsKey.keyId,
},
});
let rowCount = await table.countRows();
expect(rowCount).toBe(1);
await table.add([{ a: 2, b: 3 }]);
await bucket.assertAllEncrypted("test/table1.lance", kmsKey.keyId);
// Now with encryption settings at connection level
db = await connect(uri, {
storageOptions: {
...CONFIG,
awsServerSideEncryption: "aws:kms",
awsSseKmsKeyId: kmsKey.keyId,
},
});
table = await db.createTable("table2", [{ a: 1, b: 2 }]);
rowCount = await table.countRows();
expect(rowCount).toBe(1);
await table.add([{ a: 2, b: 3 }]);
await bucket.assertAllEncrypted("test/table2.lance", kmsKey.keyId);
});
});

View File

@@ -13,10 +13,32 @@
// limitations under the License.
import { fromTableToBuffer, makeArrowTable, makeEmptyTable } from "./arrow";
import { Connection as LanceDbConnection } from "./native";
import { ConnectionOptions, Connection as LanceDbConnection } from "./native";
import { Table } from "./table";
import { Table as ArrowTable, Schema } from "apache-arrow";
/**
* Connect to a LanceDB instance at the given URI.
*
* Accpeted formats:
*
* - `/path/to/database` - local database
* - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud storage
* - `db://host:port` - remote database (LanceDB cloud)
* @param {string} uri - The uri of the database. If the database uri starts
* with `db://` then it connects to a remote database.
* @see {@link ConnectionOptions} for more details on the URI format.
*/
export async function connect(
uri: string,
opts?: Partial<ConnectionOptions>,
): Promise<Connection> {
opts = opts ?? {};
opts.storageOptions = cleanseStorageOptions(opts.storageOptions);
const nativeConn = await LanceDbConnection.new(uri, opts);
return new Connection(nativeConn);
}
export interface CreateTableOptions {
/**
* The mode to use when creating the table.
@@ -33,6 +55,28 @@ export interface CreateTableOptions {
* then no error will be raised.
*/
existOk: boolean;
/**
* Configuration for object storage.
*
* Options already set on the connection will be inherited by the table,
* but can be overridden here.
*
* The available options are described at https://lancedb.github.io/lancedb/guides/storage/
*/
storageOptions?: Record<string, string>;
}
export interface OpenTableOptions {
/**
* Configuration for object storage.
*
* Options already set on the connection will be inherited by the table,
* but can be overridden here.
*
* The available options are described at https://lancedb.github.io/lancedb/guides/storage/
*/
storageOptions?: Record<string, string>;
}
export interface TableNamesOptions {
@@ -109,8 +153,14 @@ export class Connection {
* Open a table in the database.
* @param {string} name - The name of the table
*/
async openTable(name: string): Promise<Table> {
const innerTable = await this.inner.openTable(name);
async openTable(
name: string,
options?: Partial<OpenTableOptions>,
): Promise<Table> {
const innerTable = await this.inner.openTable(
name,
cleanseStorageOptions(options?.storageOptions),
);
return new Table(innerTable);
}
@@ -139,7 +189,12 @@ export class Connection {
table = makeArrowTable(data);
}
const buf = await fromTableToBuffer(table);
const innerTable = await this.inner.createTable(name, buf, mode);
const innerTable = await this.inner.createTable(
name,
buf,
mode,
cleanseStorageOptions(options?.storageOptions),
);
return new Table(innerTable);
}
@@ -162,7 +217,12 @@ export class Connection {
const table = makeEmptyTable(schema);
const buf = await fromTableToBuffer(table);
const innerTable = await this.inner.createEmptyTable(name, buf, mode);
const innerTable = await this.inner.createEmptyTable(
name,
buf,
mode,
cleanseStorageOptions(options?.storageOptions),
);
return new Table(innerTable);
}
@@ -174,3 +234,43 @@ export class Connection {
return this.inner.dropTable(name);
}
}
/**
* Takes storage options and makes all the keys snake case.
*/
function cleanseStorageOptions(
options?: Record<string, string>,
): Record<string, string> | undefined {
if (options === undefined) {
return undefined;
}
const result: Record<string, string> = {};
for (const [key, value] of Object.entries(options)) {
if (value !== undefined) {
const newKey = camelToSnakeCase(key);
result[newKey] = value;
}
}
return result;
}
/**
* Convert a string to snake case. It might already be snake case, in which case it is
* returned unchanged.
*/
function camelToSnakeCase(camel: string): string {
if (camel.includes("_")) {
// Assume if there is at least one underscore, it is already snake case
return camel;
}
if (camel.toLocaleUpperCase() === camel) {
// Assume if the string is all uppercase, it is already snake case
return camel;
}
let result = camel.replace(/[A-Z]/g, (letter) => `_${letter.toLowerCase()}`);
if (result.startsWith("_")) {
result = result.slice(1);
}
return result;
}

View File

@@ -12,12 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import { Connection } from "./connection";
import {
Connection as LanceDbConnection,
ConnectionOptions,
} from "./native.js";
export {
WriteOptions,
WriteMode,
@@ -32,6 +26,7 @@ export {
VectorColumnOptions,
} from "./arrow";
export {
connect,
Connection,
CreateTableOptions,
TableNamesOptions,
@@ -46,24 +41,3 @@ export {
export { Index, IndexOptions, IvfPqOptions } from "./indices";
export { Table, AddDataOptions, IndexConfig, UpdateOptions } from "./table";
export * as embedding from "./embedding";
/**
* Connect to a LanceDB instance at the given URI.
*
* Accpeted formats:
*
* - `/path/to/database` - local database
* - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud storage
* - `db://host:port` - remote database (LanceDB cloud)
* @param {string} uri - The uri of the database. If the database uri starts
* with `db://` then it connects to a remote database.
* @see {@link ConnectionOptions} for more details on the URI format.
*/
export async function connect(
uri: string,
opts?: Partial<ConnectionOptions>,
): Promise<Connection> {
opts = opts ?? {};
const nativeConn = await LanceDbConnection.new(uri, opts);
return new Connection(nativeConn);
}

1636
nodejs/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -18,6 +18,8 @@
},
"license": "Apache 2.0",
"devDependencies": {
"@aws-sdk/client-s3": "^3.33.0",
"@aws-sdk/client-kms": "^3.33.0",
"@napi-rs/cli": "^2.18.0",
"@types/jest": "^29.1.2",
"@types/tmp": "^0.2.6",
@@ -63,6 +65,7 @@
"lint": "eslint lancedb && eslint __test__",
"prepublishOnly": "napi prepublish -t npm",
"test": "npm run build && jest --verbose",
"integration": "S3_TEST=1 npm run test",
"universal": "napi universal",
"version": "napi version"
},

View File

@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use napi::bindgen_prelude::*;
use napi_derive::*;
@@ -64,6 +66,11 @@ impl Connection {
builder =
builder.read_consistency_interval(std::time::Duration::from_secs_f64(interval));
}
if let Some(storage_options) = options.storage_options {
for (key, value) in storage_options {
builder = builder.storage_option(key, value);
}
}
Ok(Self::inner_new(
builder
.execute()
@@ -118,14 +125,18 @@ impl Connection {
name: String,
buf: Buffer,
mode: String,
storage_options: Option<HashMap<String, String>>,
) -> napi::Result<Table> {
let batches = ipc_file_to_batches(buf.to_vec())
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
let mode = Self::parse_create_mode_str(&mode)?;
let tbl = self
.get_inner()?
.create_table(&name, batches)
.mode(mode)
let mut builder = self.get_inner()?.create_table(&name, batches).mode(mode);
if let Some(storage_options) = storage_options {
for (key, value) in storage_options {
builder = builder.storage_option(key, value);
}
}
let tbl = builder
.execute()
.await
.map_err(|e| napi::Error::from_reason(format!("{}", e)))?;
@@ -138,15 +149,22 @@ impl Connection {
name: String,
schema_buf: Buffer,
mode: String,
storage_options: Option<HashMap<String, String>>,
) -> napi::Result<Table> {
let schema = ipc_file_to_schema(schema_buf.to_vec()).map_err(|e| {
napi::Error::from_reason(format!("Failed to marshal schema from JS to Rust: {}", e))
})?;
let mode = Self::parse_create_mode_str(&mode)?;
let tbl = self
let mut builder = self
.get_inner()?
.create_empty_table(&name, schema)
.mode(mode)
.mode(mode);
if let Some(storage_options) = storage_options {
for (key, value) in storage_options {
builder = builder.storage_option(key, value);
}
}
let tbl = builder
.execute()
.await
.map_err(|e| napi::Error::from_reason(format!("{}", e)))?;
@@ -154,10 +172,18 @@ impl Connection {
}
#[napi]
pub async fn open_table(&self, name: String) -> napi::Result<Table> {
let tbl = self
.get_inner()?
.open_table(&name)
pub async fn open_table(
&self,
name: String,
storage_options: Option<HashMap<String, String>>,
) -> napi::Result<Table> {
let mut builder = self.get_inner()?.open_table(&name);
if let Some(storage_options) = storage_options {
for (key, value) in storage_options {
builder = builder.storage_option(key, value);
}
}
let tbl = builder
.execute()
.await
.map_err(|e| napi::Error::from_reason(format!("{}", e)))?;

View File

@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use connection::Connection;
use std::collections::HashMap;
use napi_derive::*;
mod connection;
@@ -38,6 +39,10 @@ pub struct ConnectionOptions {
/// Note: this consistency only applies to read operations. Write operations are
/// always consistent.
pub read_consistency_interval: Option<f64>,
/// (For LanceDB OSS only): configuration for object storage.
///
/// The available options are described at https://lancedb.github.io/lancedb/guides/storage/
pub storage_options: Option<HashMap<String, String>>,
}
/// Write mode for writing a table.
@@ -54,7 +59,7 @@ pub struct WriteOptions {
pub mode: Option<WriteMode>,
}
#[napi]
pub async fn connect(uri: String, options: ConnectionOptions) -> napi::Result<Connection> {
Connection::new(uri, options).await
#[napi(object)]
pub struct OpenTableOptions {
pub storage_options: Option<HashMap<String, String>>,
}

View File

@@ -41,7 +41,7 @@ To build the python package you can use maturin:
```bash
# This will build the rust bindings and place them in the appropriate place
# in your venv or conda environment
matruin develop
maturin develop
```
To run the unit tests:

View File

@@ -3,7 +3,7 @@ name = "lancedb"
version = "0.6.7"
dependencies = [
"deprecation",
"pylance==0.10.9",
"pylance==0.10.10",
"ratelimiter~=1.0",
"requests>=2.31.0",
"retry>=0.9.2",
@@ -49,6 +49,7 @@ repository = "https://github.com/lancedb/lancedb"
[project.optional-dependencies]
tests = [
"aiohttp",
"boto3",
"pandas>=1.4",
"pytest",
"pytest-mock",
@@ -98,4 +99,5 @@ addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py"
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"asyncio",
"s3_test"
]

View File

@@ -15,7 +15,7 @@ import importlib.metadata
import os
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import Optional, Union
from typing import Dict, Optional, Union
__version__ = importlib.metadata.version("lancedb")
@@ -118,6 +118,7 @@ async def connect_async(
host_override: Optional[str] = None,
read_consistency_interval: Optional[timedelta] = None,
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
storage_options: Optional[Dict[str, str]] = None,
) -> AsyncConnection:
"""Connect to a LanceDB database.
@@ -144,6 +145,9 @@ async def connect_async(
the last check, then the table will be checked for updates. Note: this
consistency only applies to read operations. Write operations are
always consistent.
storage_options: dict, optional
Additional options for the storage backend. See available options at
https://lancedb.github.io/lancedb/guides/storage/
Examples
--------
@@ -172,6 +176,7 @@ async def connect_async(
region,
host_override,
read_consistency_interval_secs,
storage_options,
)
)

View File

@@ -19,10 +19,18 @@ class Connection(object):
self, start_after: Optional[str], limit: Optional[int]
) -> list[str]: ...
async def create_table(
self, name: str, mode: str, data: pa.RecordBatchReader
self,
name: str,
mode: str,
data: pa.RecordBatchReader,
storage_options: Optional[Dict[str, str]] = None,
) -> Table: ...
async def create_empty_table(
self, name: str, mode: str, schema: pa.Schema
self,
name: str,
mode: str,
schema: pa.Schema,
storage_options: Optional[Dict[str, str]] = None,
) -> Table: ...
class Table:

View File

@@ -18,7 +18,7 @@ import inspect
import os
from abc import abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Iterable, List, Literal, Optional, Union
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Union
import pyarrow as pa
from overrides import EnforceOverrides, override
@@ -533,6 +533,7 @@ class AsyncConnection(object):
exist_ok: Optional[bool] = None,
on_bad_vectors: Optional[str] = None,
fill_value: Optional[float] = None,
storage_options: Optional[Dict[str, str]] = None,
) -> AsyncTable:
"""Create an [AsyncTable][lancedb.table.AsyncTable] in the database.
@@ -570,6 +571,12 @@ class AsyncConnection(object):
One of "error", "drop", "fill".
fill_value: float
The value to use when filling vectors. Only used if on_bad_vectors="fill".
storage_options: dict, optional
Additional options for the storage backend. Options already set on the
connection will be inherited by the table, but can be overridden here.
See available options at
https://lancedb.github.io/lancedb/guides/storage/
Returns
-------
@@ -729,30 +736,40 @@ class AsyncConnection(object):
mode = "exist_ok"
if data is None:
new_table = await self._inner.create_empty_table(name, mode, schema)
new_table = await self._inner.create_empty_table(
name, mode, schema, storage_options=storage_options
)
else:
data = data_to_reader(data, schema)
new_table = await self._inner.create_table(
name,
mode,
data,
storage_options=storage_options,
)
return AsyncTable(new_table)
async def open_table(self, name: str) -> Table:
async def open_table(
self, name: str, storage_options: Optional[Dict[str, str]] = None
) -> Table:
"""Open a Lance Table in the database.
Parameters
----------
name: str
The name of the table.
storage_options: dict, optional
Additional options for the storage backend. Options already set on the
connection will be inherited by the table, but can be overridden here.
See available options at
https://lancedb.github.io/lancedb/guides/storage/
Returns
-------
A LanceTable object representing the table.
"""
table = await self._inner.open_table(name)
table = await self._inner.open_table(name, storage_options)
return AsyncTable(table)
async def drop_table(self, name: str):

View File

@@ -0,0 +1,158 @@
# Copyright 2024 Lance Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import copy
import pytest
import pyarrow as pa
import lancedb
# These are all keys that are accepted by storage_options
CONFIG = {
"allow_http": "true",
"aws_access_key_id": "ACCESSKEY",
"aws_secret_access_key": "SECRETKEY",
"aws_endpoint": "http://localhost:4566",
"aws_region": "us-east-1",
}
def get_boto3_client(*args, **kwargs):
import boto3
return boto3.client(
*args,
region_name=CONFIG["aws_region"],
aws_access_key_id=CONFIG["aws_access_key_id"],
aws_secret_access_key=CONFIG["aws_secret_access_key"],
**kwargs,
)
@pytest.fixture(scope="module")
def s3_bucket():
s3 = get_boto3_client("s3", endpoint_url=CONFIG["aws_endpoint"])
bucket_name = "lance-integtest"
# if bucket exists, delete it
try:
delete_bucket(s3, bucket_name)
except s3.exceptions.NoSuchBucket:
pass
s3.create_bucket(Bucket=bucket_name)
yield bucket_name
delete_bucket(s3, bucket_name)
def delete_bucket(s3, bucket_name):
# Delete all objects first
for obj in s3.list_objects(Bucket=bucket_name).get("Contents", []):
s3.delete_object(Bucket=bucket_name, Key=obj["Key"])
s3.delete_bucket(Bucket=bucket_name)
@pytest.mark.s3_test
def test_s3_lifecycle(s3_bucket: str):
storage_options = copy.copy(CONFIG)
uri = f"s3://{s3_bucket}/test_lifecycle"
data = pa.table({"x": [1, 2, 3]})
async def test():
db = await lancedb.connect_async(uri, storage_options=storage_options)
table = await db.create_table("test", schema=data.schema)
assert await table.count_rows() == 0
table = await db.create_table("test", data, mode="overwrite")
assert await table.count_rows() == 3
await table.add(data, mode="append")
assert await table.count_rows() == 6
table = await db.open_table("test")
assert await table.count_rows() == 6
await db.drop_table("test")
await db.drop_database()
asyncio.run(test())
@pytest.fixture()
def kms_key():
kms = get_boto3_client("kms", endpoint_url=CONFIG["aws_endpoint"])
key_id = kms.create_key()["KeyMetadata"]["KeyId"]
yield key_id
kms.schedule_key_deletion(KeyId=key_id, PendingWindowInDays=7)
def validate_objects_encrypted(bucket: str, path: str, kms_key: str):
s3 = get_boto3_client("s3", endpoint_url=CONFIG["aws_endpoint"])
objects = s3.list_objects_v2(Bucket=bucket, Prefix=path)["Contents"]
for obj in objects:
info = s3.head_object(Bucket=bucket, Key=obj["Key"])
assert info["ServerSideEncryption"] == "aws:kms", (
"object %s not encrypted" % obj["Key"]
)
assert info["SSEKMSKeyId"].endswith(kms_key), (
"object %s not encrypted with correct key" % obj["Key"]
)
@pytest.mark.s3_test
def test_s3_sse(s3_bucket: str, kms_key: str):
storage_options = copy.copy(CONFIG)
uri = f"s3://{s3_bucket}/test_lifecycle"
data = pa.table({"x": [1, 2, 3]})
async def test():
# Create a table with SSE
db = await lancedb.connect_async(uri, storage_options=storage_options)
table = await db.create_table(
"table1",
schema=data.schema,
storage_options={
"aws_server_side_encryption": "aws:kms",
"aws_sse_kms_key_id": kms_key,
},
)
await table.add(data)
await table.update({"x": "1"})
path = "test_lifecycle/table1.lance"
validate_objects_encrypted(s3_bucket, path, kms_key)
# Test we can set encryption at connection level too.
db = await lancedb.connect_async(
uri,
storage_options=dict(
aws_server_side_encryption="aws:kms",
aws_sse_kms_key_id=kms_key,
**storage_options,
),
)
table = await db.create_table("table2", schema=data.schema)
await table.add(data)
await table.update({"x": "1"})
path = "test_lifecycle/table2.lance"
validate_objects_encrypted(s3_bucket, path, kms_key)
asyncio.run(test())

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{sync::Arc, time::Duration};
use std::{collections::HashMap, sync::Arc, time::Duration};
use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::FromPyArrow};
use lancedb::connection::{Connection as LanceConnection, CreateTableMode};
@@ -90,19 +90,21 @@ impl Connection {
name: String,
mode: &str,
data: &PyAny,
storage_options: Option<HashMap<String, String>>,
) -> PyResult<&'a PyAny> {
let inner = self_.get_inner()?.clone();
let mode = Self::parse_create_mode_str(mode)?;
let batches = ArrowArrayStreamReader::from_pyarrow(data)?;
let mut builder = inner.create_table(name, batches).mode(mode);
if let Some(storage_options) = storage_options {
builder = builder.storage_options(storage_options);
}
future_into_py(self_.py(), async move {
let table = inner
.create_table(name, batches)
.mode(mode)
.execute()
.await
.infer_error()?;
let table = builder.execute().await.infer_error()?;
Ok(Table::new(table))
})
}
@@ -112,6 +114,7 @@ impl Connection {
name: String,
mode: &str,
schema: &PyAny,
storage_options: Option<HashMap<String, String>>,
) -> PyResult<&'a PyAny> {
let inner = self_.get_inner()?.clone();
@@ -119,21 +122,31 @@ impl Connection {
let schema = Schema::from_pyarrow(schema)?;
let mut builder = inner.create_empty_table(name, Arc::new(schema)).mode(mode);
if let Some(storage_options) = storage_options {
builder = builder.storage_options(storage_options);
}
future_into_py(self_.py(), async move {
let table = inner
.create_empty_table(name, Arc::new(schema))
.mode(mode)
.execute()
.await
.infer_error()?;
let table = builder.execute().await.infer_error()?;
Ok(Table::new(table))
})
}
pub fn open_table(self_: PyRef<'_, Self>, name: String) -> PyResult<&PyAny> {
#[pyo3(signature = (name, storage_options = None))]
pub fn open_table(
self_: PyRef<'_, Self>,
name: String,
storage_options: Option<HashMap<String, String>>,
) -> PyResult<&PyAny> {
let inner = self_.get_inner()?.clone();
let mut builder = inner.open_table(name);
if let Some(storage_options) = storage_options {
builder = builder.storage_options(storage_options);
}
future_into_py(self_.py(), async move {
let table = inner.open_table(&name).execute().await.infer_error()?;
let table = builder.execute().await.infer_error()?;
Ok(Table::new(table))
})
}
@@ -162,6 +175,7 @@ pub fn connect(
region: Option<String>,
host_override: Option<String>,
read_consistency_interval: Option<f64>,
storage_options: Option<HashMap<String, String>>,
) -> PyResult<&PyAny> {
future_into_py(py, async move {
let mut builder = lancedb::connect(&uri);
@@ -178,6 +192,9 @@ pub fn connect(
let read_consistency_interval = Duration::from_secs_f64(read_consistency_interval);
builder = builder.read_consistency_interval(read_consistency_interval);
}
if let Some(storage_options) = storage_options {
builder = builder.storage_options(storage_options);
}
Ok(Connection::new(builder.execute().await.infer_error()?))
})
}

View File

@@ -12,19 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use async_trait::async_trait;
use lance::io::ObjectStoreParams;
use neon::prelude::*;
use object_store::aws::{AwsCredential, AwsCredentialProvider};
use object_store::CredentialProvider;
use once_cell::sync::OnceCell;
use tokio::runtime::Runtime;
use lancedb::connect;
use lancedb::connection::Connection;
use lancedb::table::ReadParams;
use crate::error::ResultExt;
use crate::query::JsQuery;
@@ -44,33 +37,6 @@ struct JsDatabase {
impl Finalize for JsDatabase {}
// TODO: object_store didn't export this type so I copied it.
// Make a request to object_store to export this type
#[derive(Debug)]
pub struct StaticCredentialProvider<T> {
credential: Arc<T>,
}
impl<T> StaticCredentialProvider<T> {
pub fn new(credential: T) -> Self {
Self {
credential: Arc::new(credential),
}
}
}
#[async_trait]
impl<T> CredentialProvider for StaticCredentialProvider<T>
where
T: std::fmt::Debug + Send + Sync,
{
type Credential = T;
async fn get_credential(&self) -> object_store::Result<Arc<T>> {
Ok(Arc::clone(&self.credential))
}
}
fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> {
static RUNTIME: OnceCell<Runtime> = OnceCell::new();
static LOG: OnceCell<()> = OnceCell::new();
@@ -82,29 +48,28 @@ fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> {
fn database_new(mut cx: FunctionContext) -> JsResult<JsPromise> {
let path = cx.argument::<JsString>(0)?.value(&mut cx);
let aws_creds = get_aws_creds(&mut cx, 1)?;
let region = get_aws_region(&mut cx, 4)?;
let read_consistency_interval = cx
.argument_opt(5)
.and_then(|arg| arg.downcast::<JsNumber, _>(&mut cx).ok())
.map(|v| v.value(&mut cx))
.map(std::time::Duration::from_secs_f64);
let storage_options_js = cx.argument::<JsArray>(1)?.to_vec(&mut cx)?;
let mut storage_options: Vec<(String, String)> = Vec::with_capacity(storage_options_js.len());
for handle in storage_options_js {
let obj = handle.downcast::<JsArray, _>(&mut cx).unwrap();
let key = obj.get::<JsString, _, _>(&mut cx, 0)?.value(&mut cx);
let value = obj.get::<JsString, _, _>(&mut cx, 0)?.value(&mut cx);
storage_options.push((key, value));
}
let rt = runtime(&mut cx)?;
let channel = cx.channel();
let (deferred, promise) = cx.promise();
let mut conn_builder = connect(&path);
if let Some(region) = region {
conn_builder = conn_builder.region(&region);
}
if let Some(aws_creds) = aws_creds {
conn_builder = conn_builder.aws_creds(AwsCredential {
key_id: aws_creds.key_id,
secret_key: aws_creds.secret_key,
token: aws_creds.token,
});
}
let mut conn_builder = connect(&path).storage_options(storage_options);
if let Some(interval) = read_consistency_interval {
conn_builder = conn_builder.read_consistency_interval(interval);
}
@@ -143,93 +108,19 @@ fn database_table_names(mut cx: FunctionContext) -> JsResult<JsPromise> {
Ok(promise)
}
/// Get AWS creds arguments from the context
/// Consumes 3 arguments
fn get_aws_creds(
cx: &mut FunctionContext,
arg_starting_location: i32,
) -> NeonResult<Option<AwsCredential>> {
let secret_key_id = cx
.argument_opt(arg_starting_location)
.filter(|arg| arg.is_a::<JsString, _>(cx))
.and_then(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
.map(|v| v.value(cx));
let secret_key = cx
.argument_opt(arg_starting_location + 1)
.filter(|arg| arg.is_a::<JsString, _>(cx))
.and_then(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
.map(|v| v.value(cx));
let temp_token = cx
.argument_opt(arg_starting_location + 2)
.filter(|arg| arg.is_a::<JsString, _>(cx))
.and_then(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
.map(|v| v.value(cx));
match (secret_key_id, secret_key, temp_token) {
(Some(key_id), Some(key), optional_token) => Ok(Some(AwsCredential {
key_id,
secret_key: key,
token: optional_token,
})),
(None, None, None) => Ok(None),
_ => cx.throw_error("Invalid credentials configuration"),
}
}
fn get_aws_credential_provider(
cx: &mut FunctionContext,
arg_starting_location: i32,
) -> NeonResult<Option<AwsCredentialProvider>> {
Ok(get_aws_creds(cx, arg_starting_location)?.map(|aws_cred| {
Arc::new(StaticCredentialProvider::new(aws_cred))
as Arc<dyn CredentialProvider<Credential = AwsCredential>>
}))
}
/// Get AWS region arguments from the context
fn get_aws_region(cx: &mut FunctionContext, arg_location: i32) -> NeonResult<Option<String>> {
let region = cx
.argument_opt(arg_location)
.filter(|arg| arg.is_a::<JsString, _>(cx))
.map(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx));
match region {
Some(Ok(region)) => Ok(Some(region.value(cx))),
None => Ok(None),
Some(Err(e)) => Err(e),
}
}
fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
let db = cx
.this()
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
let aws_creds = get_aws_credential_provider(&mut cx, 1)?;
let aws_region = get_aws_region(&mut cx, 4)?;
let params = ReadParams {
store_options: Some(ObjectStoreParams::with_aws_credentials(
aws_creds, aws_region,
)),
..ReadParams::default()
};
let rt = runtime(&mut cx)?;
let channel = cx.channel();
let database = db.database.clone();
let (deferred, promise) = cx.promise();
rt.spawn(async move {
let table_rst = database
.open_table(&table_name)
.lance_read_params(params)
.execute()
.await;
let table_rst = database.open_table(&table_name).execute().await;
deferred.settle_with(&channel, move |mut cx| {
let js_table = JsTable::from(table_rst.or_throw(&mut cx)?);

View File

@@ -17,7 +17,6 @@ use std::ops::Deref;
use arrow_array::{RecordBatch, RecordBatchIterator};
use lance::dataset::optimize::CompactionOptions;
use lance::dataset::{ColumnAlteration, NewColumnTransform, WriteMode, WriteParams};
use lance::io::ObjectStoreParams;
use lancedb::table::{OptimizeAction, WriteOptions};
use crate::arrow::{arrow_buffer_to_record_batch, record_batch_to_buffer};
@@ -26,7 +25,7 @@ use neon::prelude::*;
use neon::types::buffer::TypedArray;
use crate::error::ResultExt;
use crate::{convert, get_aws_credential_provider, get_aws_region, runtime, JsDatabase};
use crate::{convert, runtime, JsDatabase};
pub struct JsTable {
pub table: LanceDbTable,
@@ -59,6 +58,10 @@ impl JsTable {
return cx.throw_error("Table::create only supports 'overwrite' and 'create' modes")
}
};
let params = WriteParams {
mode,
..WriteParams::default()
};
let rt = runtime(&mut cx)?;
let channel = cx.channel();
@@ -66,17 +69,6 @@ impl JsTable {
let (deferred, promise) = cx.promise();
let database = db.database.clone();
let aws_creds = get_aws_credential_provider(&mut cx, 3)?;
let aws_region = get_aws_region(&mut cx, 6)?;
let params = WriteParams {
store_params: Some(ObjectStoreParams::with_aws_credentials(
aws_creds, aws_region,
)),
mode,
..WriteParams::default()
};
rt.spawn(async move {
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
let table_rst = database
@@ -112,13 +104,8 @@ impl JsTable {
"overwrite" => WriteMode::Overwrite,
s => return cx.throw_error(format!("invalid write mode {}", s)),
};
let aws_creds = get_aws_credential_provider(&mut cx, 2)?;
let aws_region = get_aws_region(&mut cx, 5)?;
let params = WriteParams {
store_params: Some(ObjectStoreParams::with_aws_credentials(
aws_creds, aws_region,
)),
mode: write_mode,
..WriteParams::default()
};

View File

@@ -46,8 +46,13 @@ tempfile = "3.5.0"
rand = { version = "0.8.3", features = ["small_rng"] }
uuid = { version = "1.7.0", features = ["v4"] }
walkdir = "2"
# For s3 integration tests (dev deps aren't allowed to be optional atm)
aws-sdk-s3 = { version = "1.0" }
aws-sdk-kms = { version = "1.0" }
aws-config = { version = "1.0" }
[features]
default = ["remote"]
remote = ["dep:reqwest"]
fp16kernels = ["lance-linalg/fp16kernels"]
fp16kernels = ["lance-linalg/fp16kernels"]
s3-test = []

View File

@@ -14,6 +14,7 @@
//! LanceDB Database
use std::collections::HashMap;
use std::fs::create_dir_all;
use std::path::Path;
use std::sync::Arc;
@@ -22,9 +23,7 @@ use arrow_array::{RecordBatchIterator, RecordBatchReader};
use arrow_schema::SchemaRef;
use lance::dataset::{ReadParams, WriteMode};
use lance::io::{ObjectStore, ObjectStoreParams, WrappingObjectStore};
use object_store::{
aws::AwsCredential, local::LocalFileSystem, CredentialProvider, StaticCredentialProvider,
};
use object_store::{aws::AwsCredential, local::LocalFileSystem};
use snafu::prelude::*;
use crate::arrow::IntoArrow;
@@ -208,6 +207,50 @@ impl<const HAS_DATA: bool, T: IntoArrow> CreateTableBuilder<HAS_DATA, T> {
self.mode = mode;
self
}
/// Set an option for the storage layer.
///
/// Options already set on the connection will be inherited by the table,
/// but can be overridden here.
///
/// See available options at <https://lancedb.github.io/lancedb/guides/storage/>
pub fn storage_option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
let store_options = self
.write_options
.lance_write_params
.get_or_insert(Default::default())
.store_params
.get_or_insert(Default::default())
.storage_options
.get_or_insert(Default::default());
store_options.insert(key.into(), value.into());
self
}
/// Set multiple options for the storage layer.
///
/// Options already set on the connection will be inherited by the table,
/// but can be overridden here.
///
/// See available options at <https://lancedb.github.io/lancedb/guides/storage/>
pub fn storage_options(
mut self,
pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
) -> Self {
let store_options = self
.write_options
.lance_write_params
.get_or_insert(Default::default())
.store_params
.get_or_insert(Default::default())
.storage_options
.get_or_insert(Default::default());
for (key, value) in pairs {
store_options.insert(key.into(), value.into());
}
self
}
}
#[derive(Clone, Debug)]
@@ -252,6 +295,48 @@ impl OpenTableBuilder {
self
}
/// Set an option for the storage layer.
///
/// Options already set on the connection will be inherited by the table,
/// but can be overridden here.
///
/// See available options at <https://lancedb.github.io/lancedb/guides/storage/>
pub fn storage_option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
let storage_options = self
.lance_read_params
.get_or_insert(Default::default())
.store_options
.get_or_insert(Default::default())
.storage_options
.get_or_insert(Default::default());
storage_options.insert(key.into(), value.into());
self
}
/// Set multiple options for the storage layer.
///
/// Options already set on the connection will be inherited by the table,
/// but can be overridden here.
///
/// See available options at <https://lancedb.github.io/lancedb/guides/storage/>
pub fn storage_options(
mut self,
pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
) -> Self {
let storage_options = self
.lance_read_params
.get_or_insert(Default::default())
.store_options
.get_or_insert(Default::default())
.storage_options
.get_or_insert(Default::default());
for (key, value) in pairs {
storage_options.insert(key.into(), value.into());
}
self
}
/// Open the table
pub async fn execute(self) -> Result<Table> {
self.parent.clone().do_open_table(self).await
@@ -385,8 +470,7 @@ pub struct ConnectBuilder {
/// LanceDB Cloud host override, only required if using an on-premises Lance Cloud instance
host_override: Option<String>,
/// User provided AWS credentials
aws_creds: Option<AwsCredential>,
storage_options: HashMap<String, String>,
/// The interval at which to check for updates from other processes.
///
@@ -409,8 +493,8 @@ impl ConnectBuilder {
api_key: None,
region: None,
host_override: None,
aws_creds: None,
read_consistency_interval: None,
storage_options: HashMap::new(),
}
}
@@ -430,8 +514,37 @@ impl ConnectBuilder {
}
/// [`AwsCredential`] to use when connecting to S3.
#[deprecated(note = "Pass through storage_options instead")]
pub fn aws_creds(mut self, aws_creds: AwsCredential) -> Self {
self.aws_creds = Some(aws_creds);
self.storage_options
.insert("aws_access_key_id".into(), aws_creds.key_id.clone());
self.storage_options
.insert("aws_secret_access_key".into(), aws_creds.secret_key.clone());
if let Some(token) = &aws_creds.token {
self.storage_options
.insert("aws_session_token".into(), token.clone());
}
self
}
/// Set an option for the storage layer.
///
/// See available options at <https://lancedb.github.io/lancedb/guides/storage/>
pub fn storage_option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.storage_options.insert(key.into(), value.into());
self
}
/// Set multiple options for the storage layer.
///
/// See available options at <https://lancedb.github.io/lancedb/guides/storage/>
pub fn storage_options(
mut self,
pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
) -> Self {
for (key, value) in pairs {
self.storage_options.insert(key.into(), value.into());
}
self
}
@@ -522,6 +635,9 @@ struct Database {
pub(crate) store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
read_consistency_interval: Option<std::time::Duration>,
// Storage options to be inherited by tables created from this connection
storage_options: HashMap<String, String>,
}
impl std::fmt::Display for Database {
@@ -604,20 +720,11 @@ impl Database {
};
let plain_uri = url.to_string();
let os_params: ObjectStoreParams = if let Some(aws_creds) = &options.aws_creds {
let credential_provider: Arc<
dyn CredentialProvider<Credential = AwsCredential>,
> = Arc::new(StaticCredentialProvider::new(AwsCredential {
key_id: aws_creds.key_id.clone(),
secret_key: aws_creds.secret_key.clone(),
token: aws_creds.token.clone(),
}));
ObjectStoreParams::with_aws_credentials(
Some(credential_provider),
options.region.clone(),
)
} else {
ObjectStoreParams::default()
let storage_options = options.storage_options.clone();
let os_params = ObjectStoreParams {
storage_options: Some(storage_options.clone()),
..Default::default()
};
let (object_store, base_path) =
ObjectStore::from_uri_and_params(&plain_uri, &os_params).await?;
@@ -641,6 +748,7 @@ impl Database {
object_store,
store_wrapper: write_store_wrapper,
read_consistency_interval: options.read_consistency_interval,
storage_options,
})
}
Err(_) => Self::open_path(uri, options.read_consistency_interval).await,
@@ -662,6 +770,7 @@ impl Database {
object_store,
store_wrapper: None,
read_consistency_interval,
storage_options: HashMap::new(),
})
}
@@ -734,11 +843,26 @@ impl ConnectionInternal for Database {
async fn do_create_table(
&self,
options: CreateTableBuilder<false, NoData>,
mut options: CreateTableBuilder<false, NoData>,
data: Box<dyn RecordBatchReader + Send>,
) -> Result<Table> {
let table_uri = self.table_uri(&options.name)?;
// Inherit storage options from the connection
let storage_options = options
.write_options
.lance_write_params
.get_or_insert_with(Default::default)
.store_params
.get_or_insert_with(Default::default)
.storage_options
.get_or_insert_with(Default::default);
for (key, value) in self.storage_options.iter() {
if !storage_options.contains_key(key) {
storage_options.insert(key.clone(), value.clone());
}
}
let mut write_params = options.write_options.lance_write_params.unwrap_or_default();
if matches!(&options.mode, CreateTableMode::Overwrite) {
write_params.mode = WriteMode::Overwrite;
@@ -768,8 +892,23 @@ impl ConnectionInternal for Database {
}
}
async fn do_open_table(&self, options: OpenTableBuilder) -> Result<Table> {
async fn do_open_table(&self, mut options: OpenTableBuilder) -> Result<Table> {
let table_uri = self.table_uri(&options.name)?;
// Inherit storage options from the connection
let storage_options = options
.lance_read_params
.get_or_insert_with(Default::default)
.store_options
.get_or_insert_with(Default::default)
.storage_options
.get_or_insert_with(Default::default);
for (key, value) in self.storage_options.iter() {
if !storage_options.contains_key(key) {
storage_options.insert(key.clone(), value.clone());
}
}
let native_table = Arc::new(
NativeTable::open_with_params(
&table_uri,
@@ -801,7 +940,10 @@ impl ConnectionInternal for Database {
}
async fn drop_db(&self) -> Result<()> {
todo!()
self.object_store
.remove_dir_all(self.base_path.clone())
.await?;
Ok(())
}
}

View File

@@ -14,6 +14,7 @@
//! LanceDB Table APIs
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
@@ -757,6 +758,8 @@ pub struct NativeTable {
// the object store wrapper to use on write path
store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
storage_options: HashMap<String, String>,
// This comes from the connection options. We store here so we can pass down
// to the dataset when we recreate it (for example, in checkout_latest).
read_consistency_interval: Option<std::time::Duration>,
@@ -822,6 +825,13 @@ impl NativeTable {
None => params,
};
let storage_options = params
.store_options
.clone()
.unwrap_or_default()
.storage_options
.unwrap_or_default();
let dataset = DatasetBuilder::from_uri(uri)
.with_read_params(params)
.load()
@@ -840,6 +850,7 @@ impl NativeTable {
uri: uri.to_string(),
dataset,
store_wrapper: write_store_wrapper,
storage_options,
read_consistency_interval,
})
}
@@ -908,6 +919,13 @@ impl NativeTable {
None => params,
};
let storage_options = params
.store_params
.clone()
.unwrap_or_default()
.storage_options
.unwrap_or_default();
let dataset = Dataset::write(batches, uri, Some(params))
.await
.map_err(|e| match e {
@@ -921,6 +939,7 @@ impl NativeTable {
uri: uri.to_string(),
dataset: DatasetConsistencyWrapper::new_latest(dataset, read_consistency_interval),
store_wrapper: write_store_wrapper,
storage_options,
read_consistency_interval,
})
}
@@ -1312,7 +1331,7 @@ impl TableInternal for NativeTable {
add: AddDataBuilder<NoData>,
data: Box<dyn RecordBatchReader + Send>,
) -> Result<()> {
let lance_params = add.write_options.lance_write_params.unwrap_or(WriteParams {
let mut lance_params = add.write_options.lance_write_params.unwrap_or(WriteParams {
mode: match add.mode {
AddDataMode::Append => WriteMode::Append,
AddDataMode::Overwrite => WriteMode::Overwrite,
@@ -1320,6 +1339,18 @@ impl TableInternal for NativeTable {
..Default::default()
});
// Bring storage options from table
let storage_options = lance_params
.store_params
.get_or_insert(Default::default())
.storage_options
.get_or_insert(Default::default());
for (key, value) in self.storage_options.iter() {
if !storage_options.contains_key(key) {
storage_options.insert(key.clone(), value.clone());
}
}
// patch the params if we have a write store wrapper
let lance_params = match self.store_wrapper.clone() {
Some(wrapper) => lance_params.patch_with_store_wrapper(wrapper)?,

View File

@@ -0,0 +1,290 @@
// Copyright 2023 LanceDB Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![cfg(feature = "s3-test")]
use std::sync::Arc;
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, StringArray};
use arrow_schema::{DataType, Field, Schema};
use aws_config::{BehaviorVersion, ConfigLoader, Region, SdkConfig};
use aws_sdk_s3::{config::Credentials, types::ServerSideEncryption, Client as S3Client};
use lancedb::Result;
const CONFIG: &[(&str, &str)] = &[
("access_key_id", "ACCESS_KEY"),
("secret_access_key", "SECRET_KEY"),
("endpoint", "http://127.0.0.1:4566"),
("allow_http", "true"),
];
async fn aws_config() -> SdkConfig {
let credentials = Credentials::new(CONFIG[0].1, CONFIG[1].1, None, None, "static");
ConfigLoader::default()
.credentials_provider(credentials)
.endpoint_url(CONFIG[2].1)
.behavior_version(BehaviorVersion::latest())
.region(Region::new("us-east-1"))
.load()
.await
}
struct S3Bucket(String);
impl S3Bucket {
async fn new(bucket: &str) -> Self {
let config = aws_config().await;
let client = S3Client::new(&config);
// In case it wasn't deleted earlier
Self::delete_bucket(client.clone(), bucket).await;
client.create_bucket().bucket(bucket).send().await.unwrap();
Self(bucket.to_string())
}
async fn delete_bucket(client: S3Client, bucket: &str) {
// Before we delete the bucket, we need to delete all objects in it
let res = client
.list_objects_v2()
.bucket(bucket)
.send()
.await
.map_err(|err| err.into_service_error());
match res {
Err(e) if e.is_no_such_bucket() => return,
Err(e) => panic!("Failed to list objects in bucket: {}", e),
_ => {}
}
let objects = res.unwrap().contents.unwrap_or_default();
for object in objects {
client
.delete_object()
.bucket(bucket)
.key(object.key.unwrap())
.send()
.await
.unwrap();
}
client.delete_bucket().bucket(bucket).send().await.unwrap();
}
}
impl Drop for S3Bucket {
fn drop(&mut self) {
let bucket_name = self.0.clone();
tokio::task::spawn(async move {
let config = aws_config().await;
let client = S3Client::new(&config);
Self::delete_bucket(client, &bucket_name).await;
});
}
}
fn test_data() -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, false),
]));
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec!["a", "b", "c"])),
],
)
.unwrap()
}
#[tokio::test]
async fn test_minio_lifecycle() -> Result<()> {
// test create, update, drop, list on localstack minio
let bucket = S3Bucket::new("test-bucket").await;
let uri = format!("s3://{}", bucket.0);
let db = lancedb::connect(&uri)
.storage_options(CONFIG.iter().cloned())
.execute()
.await?;
let data = test_data();
let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema());
let table = db.create_table("test_table", data).execute().await?;
let row_count = table.count_rows(None).await?;
assert_eq!(row_count, 3);
let table_names = db.table_names().execute().await?;
assert_eq!(table_names, vec!["test_table"]);
// Re-open the table
let table = db.open_table("test_table").execute().await?;
let row_count = table.count_rows(None).await?;
assert_eq!(row_count, 3);
let data = test_data();
let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema());
table.add(data).execute().await?;
db.drop_table("test_table").await?;
Ok(())
}
struct KMSKey(String);
impl KMSKey {
async fn new() -> Self {
let config = aws_config().await;
let client = aws_sdk_kms::Client::new(&config);
let key = client
.create_key()
.description("test key")
.send()
.await
.unwrap()
.key_metadata
.unwrap()
.key_id;
Self(key)
}
}
impl Drop for KMSKey {
fn drop(&mut self) {
let key_id = self.0.clone();
tokio::task::spawn(async move {
let config = aws_config().await;
let client = aws_sdk_kms::Client::new(&config);
client
.schedule_key_deletion()
.key_id(&key_id)
.send()
.await
.unwrap();
});
}
}
async fn validate_objects_encrypted(bucket: &str, path: &str, kms_key_id: &str) {
// Get S3 client
let config = aws_config().await;
let client = S3Client::new(&config);
// list the objects are the path
let objects = client
.list_objects_v2()
.bucket(bucket)
.prefix(path)
.send()
.await
.unwrap()
.contents
.unwrap();
let mut errors = vec![];
let mut correctly_encrypted = vec![];
// For each object, call head
for object in &objects {
let head = client
.head_object()
.bucket(bucket)
.key(object.key().unwrap())
.send()
.await
.unwrap();
// Verify the object is encrypted
if head.server_side_encryption() != Some(&ServerSideEncryption::AwsKms) {
errors.push(format!("Object {} is not encrypted", object.key().unwrap()));
continue;
}
if !(head
.ssekms_key_id()
.map(|arn| arn.ends_with(kms_key_id))
.unwrap_or(false))
{
errors.push(format!(
"Object {} has wrong key id: {:?}, vs expected: {}",
object.key().unwrap(),
head.ssekms_key_id(),
kms_key_id
));
continue;
}
correctly_encrypted.push(object.key().unwrap().to_string());
}
if !errors.is_empty() {
panic!(
"{} of {} correctly encrypted: {:?}\n{} of {} not correct: {:?}",
correctly_encrypted.len(),
objects.len(),
correctly_encrypted,
errors.len(),
objects.len(),
errors
);
}
}
#[tokio::test]
async fn test_encryption() -> Result<()> {
// test encryption on localstack minio
let bucket = S3Bucket::new("test-encryption").await;
let key = KMSKey::new().await;
let uri = format!("s3://{}", bucket.0);
let db = lancedb::connect(&uri)
.storage_options(CONFIG.iter().cloned())
.execute()
.await?;
// Create a table with encryption
let data = test_data();
let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema());
let mut builder = db.create_table("test_table", data);
for (key, value) in CONFIG {
builder = builder.storage_option(*key, *value);
}
let table = builder
.storage_option("aws_server_side_encryption", "aws:kms")
.storage_option("aws_sse_kms_key_id", &key.0)
.execute()
.await?;
validate_objects_encrypted(&bucket.0, "test_table", &key.0).await;
table.delete("a = 1").await?;
validate_objects_encrypted(&bucket.0, "test_table", &key.0).await;
// Test we can set encryption at the connection level.
let db = lancedb::connect(&uri)
.storage_options(CONFIG.iter().cloned())
.storage_option("aws_server_side_encryption", "aws:kms")
.storage_option("aws_sse_kms_key_id", &key.0)
.execute()
.await?;
let table = db.open_table("test_table").execute().await?;
let data = test_data();
let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema());
table.add(data).execute().await?;
validate_objects_encrypted(&bucket.0, "test_table", &key.0).await;
Ok(())
}