mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 13:29:57 +00:00
Compare commits
35 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b1d791a299 | ||
|
|
8da74dcb37 | ||
|
|
3c7419b392 | ||
|
|
e612686fdb | ||
|
|
e77d57a5b6 | ||
|
|
9391ad1450 | ||
|
|
79960b254e | ||
|
|
d19c64e29b | ||
|
|
06d5612443 | ||
|
|
45f96f4151 | ||
|
|
f744b785f8 | ||
|
|
2e3f745820 | ||
|
|
683aaed716 | ||
|
|
48f7b20daa | ||
|
|
4dd399ca29 | ||
|
|
e6f1da31dc | ||
|
|
a9ea785b15 | ||
|
|
cc38453391 | ||
|
|
47747287b6 | ||
|
|
0847e666a0 | ||
|
|
981f8427e6 | ||
|
|
f6846004ca | ||
|
|
faf8973624 | ||
|
|
fabe37274f | ||
|
|
6839ac3509 | ||
|
|
b88422e515 | ||
|
|
8d60685ede | ||
|
|
04285a4a4e | ||
|
|
d4a41b5663 | ||
|
|
adc3daa462 | ||
|
|
acbfa6c012 | ||
|
|
d602e9f98c | ||
|
|
ad09234d59 | ||
|
|
0c34ffb252 | ||
|
|
d9f333d828 |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.21.2"
|
||||
current_version = "0.22.1-beta.0"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
1
.github/workflows/docs_test.yml
vendored
1
.github/workflows/docs_test.yml
vendored
@@ -48,6 +48,7 @@ jobs:
|
||||
uses: swatinem/rust-cache@v2
|
||||
- name: Build Python
|
||||
working-directory: docs/test
|
||||
timeout-minutes: 60
|
||||
run:
|
||||
python -m pip install --extra-index-url https://pypi.fury.io/lancedb/ -r requirements.txt
|
||||
- name: Create test files
|
||||
|
||||
66
CLAUDE.md
66
CLAUDE.md
@@ -13,10 +13,68 @@ Project layout:
|
||||
|
||||
Common commands:
|
||||
|
||||
* Check for compiler errors: `cargo check --features remote --tests --examples`
|
||||
* Run tests: `cargo test --features remote --tests`
|
||||
* Run specific test: `cargo test --features remote -p <package_name> --test <test_name>`
|
||||
* Lint: `cargo clippy --features remote --tests --examples`
|
||||
* Check for compiler errors: `cargo check --quiet --features remote --tests --examples`
|
||||
* Run tests: `cargo test --quiet --features remote --tests`
|
||||
* Run specific test: `cargo test --quiet --features remote -p <package_name> --test <test_name>`
|
||||
* Lint: `cargo clippy --quiet --features remote --tests --examples`
|
||||
* Format: `cargo fmt --all`
|
||||
|
||||
Before committing changes, run formatting.
|
||||
|
||||
## Coding tips
|
||||
|
||||
* When writing Rust doctests for things that require a connection or table reference,
|
||||
write them as a function instead of a fully executable test. This allows type checking
|
||||
to run but avoids needing a full test environment. For example:
|
||||
```rust
|
||||
/// ```
|
||||
/// use lance_index::scalar::FullTextSearchQuery;
|
||||
/// use lancedb::query::{QueryBase, ExecutableQuery};
|
||||
///
|
||||
/// # use lancedb::Table;
|
||||
/// # async fn query(table: &Table) -> Result<(), Box<dyn std::error::Error>> {
|
||||
/// let results = table.query()
|
||||
/// .full_text_search(FullTextSearchQuery::new("hello world".into()))
|
||||
/// .execute()
|
||||
/// .await?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
```
|
||||
|
||||
## Example plan: adding a new method on Table
|
||||
|
||||
Adding a new method involves first adding it to the Rust core, then exposing it
|
||||
in the Python and TypeScript bindings. There are both local and remote tables.
|
||||
Remote tables are implemented via a HTTP API and require the `remote` cargo
|
||||
feature flag to be enabled. Python has both sync and async methods.
|
||||
|
||||
Rust core changes:
|
||||
|
||||
1. Add method on `Table` struct in `rust/lancedb/src/table.rs` (calls `BaseTable` trait).
|
||||
2. Add method to `BaseTable` trait in `rust/lancedb/src/table.rs`.
|
||||
3. Implement new trait method on `NativeTable` in `rust/lancedb/src/table.rs`.
|
||||
* Test with unit test in `rust/lancedb/src/table.rs`.
|
||||
4. Implement new trait method on `RemoteTable` in `rust/lancedb/src/remote/table.rs`.
|
||||
* Test with unit test in `rust/lancedb/src/remote/table.rs` against mocked endpoint.
|
||||
|
||||
Python bindings changes:
|
||||
|
||||
1. Add PyO3 method binding in `python/src/table.rs`. Run `make develop` to compile bindings.
|
||||
2. Add types for PyO3 method in `python/python/lancedb/_lancedb.pyi`.
|
||||
3. Add method to `AsyncTable` class in `python/python/lancedb/table.py`.
|
||||
4. Add abstract method to `Table` abstract base class in `python/python/lancedb/table.py`.
|
||||
5. Add concrete sync method to `LanceTable` class in `python/python/lancedb/table.py`.
|
||||
* Should use `LOOP.run()` to call the corresponding `AsyncTable` method.
|
||||
6. Add concrete sync method to `RemoteTable` class in `python/python/lancedb/remote/table.py`.
|
||||
7. Add unit test in `python/tests/test_table.py`.
|
||||
|
||||
TypeScript bindings changes:
|
||||
|
||||
1. Add napi-rs method binding on `Table` in `nodejs/src/table.rs`.
|
||||
2. Run `npm run build` to generate TypeScript definitions.
|
||||
3. Add typescript method on abstract class `Table` in `nodejs/src/table.ts`.
|
||||
4. Add concrete method on `LocalTable` class in `nodejs/src/native_table.ts`.
|
||||
* Note: despite the name, this class is also used for remote tables.
|
||||
5. Add test in `nodejs/__test__/table.test.ts`.
|
||||
6. Run `npm run docs` to generate TypeScript documentation.
|
||||
|
||||
1778
Cargo.lock
generated
1778
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
28
Cargo.toml
28
Cargo.toml
@@ -15,14 +15,14 @@ categories = ["database-implementations"]
|
||||
rust-version = "1.78.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.33.0", "features" = ["dynamodb"] }
|
||||
lance-io = "=0.33.0"
|
||||
lance-index = "=0.33.0"
|
||||
lance-linalg = "=0.33.0"
|
||||
lance-table = "=0.33.0"
|
||||
lance-testing = "=0.33.0"
|
||||
lance-datafusion = "=0.33.0"
|
||||
lance-encoding = "=0.33.0"
|
||||
lance = { "version" = "=0.35.0", default-features = false, "features" = ["dynamodb"], "tag" = "v0.35.0-beta.4", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-io = { "version" = "=0.35.0", default-features = false, "tag" = "v0.35.0-beta.4", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-index = { "version" = "=0.35.0", "tag" = "v0.35.0-beta.4", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-linalg = { "version" = "=0.35.0", "tag" = "v0.35.0-beta.4", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-table = { "version" = "=0.35.0", "tag" = "v0.35.0-beta.4", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-testing = { "version" = "=0.35.0", "tag" = "v0.35.0-beta.4", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-datafusion = { "version" = "=0.35.0", "tag" = "v0.35.0-beta.4", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-encoding = { "version" = "=0.35.0", "tag" = "v0.35.0-beta.4", "git" = "https://github.com/lancedb/lance.git" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "55.1", optional = false }
|
||||
arrow-array = "55.1"
|
||||
@@ -33,12 +33,12 @@ arrow-schema = "55.1"
|
||||
arrow-arith = "55.1"
|
||||
arrow-cast = "55.1"
|
||||
async-trait = "0"
|
||||
datafusion = { version = "48.0", default-features = false }
|
||||
datafusion-catalog = "48.0"
|
||||
datafusion-common = { version = "48.0", default-features = false }
|
||||
datafusion-execution = "48.0"
|
||||
datafusion-expr = "48.0"
|
||||
datafusion-physical-plan = "48.0"
|
||||
datafusion = { version = "49.0", default-features = false }
|
||||
datafusion-catalog = "49.0"
|
||||
datafusion-common = { version = "49.0", default-features = false }
|
||||
datafusion-execution = "49.0"
|
||||
datafusion-expr = "49.0"
|
||||
datafusion-physical-plan = "49.0"
|
||||
env_logger = "0.11"
|
||||
half = { "version" = "2.6.0", default-features = false, features = [
|
||||
"num-traits",
|
||||
|
||||
@@ -54,6 +54,52 @@ def extract_features(line: str) -> list:
|
||||
return []
|
||||
|
||||
|
||||
def extract_default_features(line: str) -> bool:
|
||||
"""
|
||||
Checks if default-features = false is present in a line in Cargo.toml.
|
||||
Example: 'lance = { "version" = "=0.29.0", default-features = false, "features" = ["dynamodb"] }'
|
||||
Returns: True if default-features = false is present, False otherwise
|
||||
"""
|
||||
import re
|
||||
|
||||
match = re.search(r'default-features\s*=\s*false', line)
|
||||
return match is not None
|
||||
|
||||
|
||||
def dict_to_toml_line(package_name: str, config: dict) -> str:
|
||||
"""
|
||||
Converts a configuration dictionary to a TOML dependency line.
|
||||
Dictionary insertion order is preserved (Python 3.7+), so the caller
|
||||
controls the order of fields in the output.
|
||||
|
||||
Args:
|
||||
package_name: The name of the package (e.g., "lance", "lance-io")
|
||||
config: Dictionary with keys like "version", "path", "git", "tag", "features", "default-features"
|
||||
The order of keys in this dict determines the order in the output.
|
||||
|
||||
Returns:
|
||||
A properly formatted TOML line with a trailing newline
|
||||
"""
|
||||
# If only version is specified, use simple format
|
||||
if len(config) == 1 and "version" in config:
|
||||
return f'{package_name} = "{config["version"]}"\n'
|
||||
|
||||
# Otherwise, use inline table format
|
||||
parts = []
|
||||
for key, value in config.items():
|
||||
if key == "default-features" and not value:
|
||||
parts.append("default-features = false")
|
||||
elif key == "features":
|
||||
parts.append(f'"features" = {json.dumps(value)}')
|
||||
elif isinstance(value, str):
|
||||
parts.append(f'"{key}" = "{value}"')
|
||||
else:
|
||||
# This shouldn't happen with our current usage
|
||||
parts.append(f'"{key}" = {json.dumps(value)}')
|
||||
|
||||
return f'{package_name} = {{ {", ".join(parts)} }}\n'
|
||||
|
||||
|
||||
def update_cargo_toml(line_updater):
|
||||
"""
|
||||
Updates the Cargo.toml file by applying the line_updater function to each line.
|
||||
@@ -67,20 +113,27 @@ def update_cargo_toml(line_updater):
|
||||
is_parsing_lance_line = False
|
||||
for line in lines:
|
||||
if line.startswith("lance"):
|
||||
# Update the line using the provided function
|
||||
if line.strip().endswith("}"):
|
||||
# Check if this is a single-line or multi-line entry
|
||||
# Single-line entries either:
|
||||
# 1. End with } (complete inline table)
|
||||
# 2. End with " (simple version string)
|
||||
# Multi-line entries start with { but don't end with }
|
||||
if line.strip().endswith("}") or line.strip().endswith('"'):
|
||||
# Single-line entry - process immediately
|
||||
new_lines.append(line_updater(line))
|
||||
else:
|
||||
elif "{" in line and not line.strip().endswith("}"):
|
||||
# Multi-line entry - start accumulating
|
||||
lance_line = line
|
||||
is_parsing_lance_line = True
|
||||
else:
|
||||
# Single-line entry without quotes or braces (shouldn't happen but handle it)
|
||||
new_lines.append(line_updater(line))
|
||||
elif is_parsing_lance_line:
|
||||
lance_line += line
|
||||
if line.strip().endswith("}"):
|
||||
new_lines.append(line_updater(lance_line))
|
||||
lance_line = ""
|
||||
is_parsing_lance_line = False
|
||||
else:
|
||||
print("doesn't end with }:", line)
|
||||
else:
|
||||
# Keep the line unchanged
|
||||
new_lines.append(line)
|
||||
@@ -92,18 +145,25 @@ def update_cargo_toml(line_updater):
|
||||
def set_stable_version(version: str):
|
||||
"""
|
||||
Sets lines to
|
||||
lance = { "version" = "=0.29.0", "features" = ["dynamodb"] }
|
||||
lance-io = "=0.29.0"
|
||||
lance = { "version" = "=0.29.0", default-features = false, "features" = ["dynamodb"] }
|
||||
lance-io = { "version" = "=0.29.0", default-features = false }
|
||||
...
|
||||
"""
|
||||
|
||||
def line_updater(line: str) -> str:
|
||||
package_name = line.split("=", maxsplit=1)[0].strip()
|
||||
|
||||
# Build config in desired order: version, default-features, features
|
||||
config = {"version": f"={version}"}
|
||||
|
||||
if extract_default_features(line):
|
||||
config["default-features"] = False
|
||||
|
||||
features = extract_features(line)
|
||||
if features:
|
||||
return f'{package_name} = {{ "version" = "={version}", "features" = {json.dumps(features)} }}\n'
|
||||
else:
|
||||
return f'{package_name} = "={version}"\n'
|
||||
config["features"] = features
|
||||
|
||||
return dict_to_toml_line(package_name, config)
|
||||
|
||||
update_cargo_toml(line_updater)
|
||||
|
||||
@@ -111,19 +171,29 @@ def set_stable_version(version: str):
|
||||
def set_preview_version(version: str):
|
||||
"""
|
||||
Sets lines to
|
||||
lance = { "version" = "=0.29.0", "features" = ["dynamodb"], tag = "v0.29.0-beta.2", git="https://github.com/lancedb/lance.git" }
|
||||
lance-io = { version = "=0.29.0", tag = "v0.29.0-beta.2", git="https://github.com/lancedb/lance.git" }
|
||||
lance = { "version" = "=0.29.0", default-features = false, "features" = ["dynamodb"], "tag" = "v0.29.0-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-io = { "version" = "=0.29.0", default-features = false, "tag" = "v0.29.0-beta.2", "git" = "https://github.com/lancedb/lance.git" }
|
||||
...
|
||||
"""
|
||||
|
||||
def line_updater(line: str) -> str:
|
||||
package_name = line.split("=", maxsplit=1)[0].strip()
|
||||
features = extract_features(line)
|
||||
base_version = version.split("-")[0] # Get the base version without beta suffix
|
||||
|
||||
# Build config in desired order: version, default-features, features, tag, git
|
||||
config = {"version": f"={base_version}"}
|
||||
|
||||
if extract_default_features(line):
|
||||
config["default-features"] = False
|
||||
|
||||
features = extract_features(line)
|
||||
if features:
|
||||
return f'{package_name} = {{ "version" = "={base_version}", "features" = {json.dumps(features)}, "tag" = "v{version}", "git" = "https://github.com/lancedb/lance.git" }}\n'
|
||||
else:
|
||||
return f'{package_name} = {{ "version" = "={base_version}", "tag" = "v{version}", "git" = "https://github.com/lancedb/lance.git" }}\n'
|
||||
config["features"] = features
|
||||
|
||||
config["tag"] = f"v{version}"
|
||||
config["git"] = "https://github.com/lancedb/lance.git"
|
||||
|
||||
return dict_to_toml_line(package_name, config)
|
||||
|
||||
update_cargo_toml(line_updater)
|
||||
|
||||
@@ -131,18 +201,25 @@ def set_preview_version(version: str):
|
||||
def set_local_version():
|
||||
"""
|
||||
Sets lines to
|
||||
lance = { path = "../lance/rust/lance", features = ["dynamodb"] }
|
||||
lance-io = { path = "../lance/rust/lance-io" }
|
||||
lance = { "path" = "../lance/rust/lance", default-features = false, "features" = ["dynamodb"] }
|
||||
lance-io = { "path" = "../lance/rust/lance-io", default-features = false }
|
||||
...
|
||||
"""
|
||||
|
||||
def line_updater(line: str) -> str:
|
||||
package_name = line.split("=", maxsplit=1)[0].strip()
|
||||
|
||||
# Build config in desired order: path, default-features, features
|
||||
config = {"path": f"../lance/rust/{package_name}"}
|
||||
|
||||
if extract_default_features(line):
|
||||
config["default-features"] = False
|
||||
|
||||
features = extract_features(line)
|
||||
if features:
|
||||
return f'{package_name} = {{ "path" = "../lance/rust/{package_name}", "features" = {json.dumps(features)} }}\n'
|
||||
else:
|
||||
return f'{package_name} = {{ "path" = "../lance/rust/{package_name}" }}\n'
|
||||
config["features"] = features
|
||||
|
||||
return dict_to_toml_line(package_name, config)
|
||||
|
||||
update_cargo_toml(line_updater)
|
||||
|
||||
|
||||
@@ -45,6 +45,8 @@ Any attempt to use the connection after it is closed will result in an error.
|
||||
|
||||
### createEmptyTable()
|
||||
|
||||
#### createEmptyTable(name, schema, options)
|
||||
|
||||
```ts
|
||||
abstract createEmptyTable(
|
||||
name,
|
||||
@@ -54,7 +56,7 @@ abstract createEmptyTable(
|
||||
|
||||
Creates a new empty Table
|
||||
|
||||
#### Parameters
|
||||
##### Parameters
|
||||
|
||||
* **name**: `string`
|
||||
The name of the table.
|
||||
@@ -63,8 +65,39 @@ Creates a new empty Table
|
||||
The schema of the table
|
||||
|
||||
* **options?**: `Partial`<[`CreateTableOptions`](../interfaces/CreateTableOptions.md)>
|
||||
Additional options (backwards compatibility)
|
||||
|
||||
#### Returns
|
||||
##### Returns
|
||||
|
||||
`Promise`<[`Table`](Table.md)>
|
||||
|
||||
#### createEmptyTable(name, schema, namespace, options)
|
||||
|
||||
```ts
|
||||
abstract createEmptyTable(
|
||||
name,
|
||||
schema,
|
||||
namespace?,
|
||||
options?): Promise<Table>
|
||||
```
|
||||
|
||||
Creates a new empty Table
|
||||
|
||||
##### Parameters
|
||||
|
||||
* **name**: `string`
|
||||
The name of the table.
|
||||
|
||||
* **schema**: [`SchemaLike`](../type-aliases/SchemaLike.md)
|
||||
The schema of the table
|
||||
|
||||
* **namespace?**: `string`[]
|
||||
The namespace to create the table in (defaults to root namespace)
|
||||
|
||||
* **options?**: `Partial`<[`CreateTableOptions`](../interfaces/CreateTableOptions.md)>
|
||||
Additional options
|
||||
|
||||
##### Returns
|
||||
|
||||
`Promise`<[`Table`](Table.md)>
|
||||
|
||||
@@ -72,10 +105,10 @@ Creates a new empty Table
|
||||
|
||||
### createTable()
|
||||
|
||||
#### createTable(options)
|
||||
#### createTable(options, namespace)
|
||||
|
||||
```ts
|
||||
abstract createTable(options): Promise<Table>
|
||||
abstract createTable(options, namespace?): Promise<Table>
|
||||
```
|
||||
|
||||
Creates a new Table and initialize it with new data.
|
||||
@@ -85,6 +118,9 @@ Creates a new Table and initialize it with new data.
|
||||
* **options**: `object` & `Partial`<[`CreateTableOptions`](../interfaces/CreateTableOptions.md)>
|
||||
The options object.
|
||||
|
||||
* **namespace?**: `string`[]
|
||||
The namespace to create the table in (defaults to root namespace)
|
||||
|
||||
##### Returns
|
||||
|
||||
`Promise`<[`Table`](Table.md)>
|
||||
@@ -110,6 +146,38 @@ Creates a new Table and initialize it with new data.
|
||||
to be inserted into the table
|
||||
|
||||
* **options?**: `Partial`<[`CreateTableOptions`](../interfaces/CreateTableOptions.md)>
|
||||
Additional options (backwards compatibility)
|
||||
|
||||
##### Returns
|
||||
|
||||
`Promise`<[`Table`](Table.md)>
|
||||
|
||||
#### createTable(name, data, namespace, options)
|
||||
|
||||
```ts
|
||||
abstract createTable(
|
||||
name,
|
||||
data,
|
||||
namespace?,
|
||||
options?): Promise<Table>
|
||||
```
|
||||
|
||||
Creates a new Table and initialize it with new data.
|
||||
|
||||
##### Parameters
|
||||
|
||||
* **name**: `string`
|
||||
The name of the table.
|
||||
|
||||
* **data**: [`TableLike`](../type-aliases/TableLike.md) \| `Record`<`string`, `unknown`>[]
|
||||
Non-empty Array of Records
|
||||
to be inserted into the table
|
||||
|
||||
* **namespace?**: `string`[]
|
||||
The namespace to create the table in (defaults to root namespace)
|
||||
|
||||
* **options?**: `Partial`<[`CreateTableOptions`](../interfaces/CreateTableOptions.md)>
|
||||
Additional options
|
||||
|
||||
##### Returns
|
||||
|
||||
@@ -134,11 +202,16 @@ Return a brief description of the connection
|
||||
### dropAllTables()
|
||||
|
||||
```ts
|
||||
abstract dropAllTables(): Promise<void>
|
||||
abstract dropAllTables(namespace?): Promise<void>
|
||||
```
|
||||
|
||||
Drop all tables in the database.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **namespace?**: `string`[]
|
||||
The namespace to drop tables from (defaults to root namespace).
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`void`>
|
||||
@@ -148,7 +221,7 @@ Drop all tables in the database.
|
||||
### dropTable()
|
||||
|
||||
```ts
|
||||
abstract dropTable(name): Promise<void>
|
||||
abstract dropTable(name, namespace?): Promise<void>
|
||||
```
|
||||
|
||||
Drop an existing table.
|
||||
@@ -158,6 +231,9 @@ Drop an existing table.
|
||||
* **name**: `string`
|
||||
The name of the table to drop.
|
||||
|
||||
* **namespace?**: `string`[]
|
||||
The namespace of the table (defaults to root namespace).
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`void`>
|
||||
@@ -181,7 +257,10 @@ Return true if the connection has not been closed
|
||||
### openTable()
|
||||
|
||||
```ts
|
||||
abstract openTable(name, options?): Promise<Table>
|
||||
abstract openTable(
|
||||
name,
|
||||
namespace?,
|
||||
options?): Promise<Table>
|
||||
```
|
||||
|
||||
Open a table in the database.
|
||||
@@ -191,7 +270,11 @@ Open a table in the database.
|
||||
* **name**: `string`
|
||||
The name of the table
|
||||
|
||||
* **namespace?**: `string`[]
|
||||
The namespace of the table (defaults to root namespace)
|
||||
|
||||
* **options?**: `Partial`<[`OpenTableOptions`](../interfaces/OpenTableOptions.md)>
|
||||
Additional options
|
||||
|
||||
#### Returns
|
||||
|
||||
@@ -201,6 +284,8 @@ Open a table in the database.
|
||||
|
||||
### tableNames()
|
||||
|
||||
#### tableNames(options)
|
||||
|
||||
```ts
|
||||
abstract tableNames(options?): Promise<string[]>
|
||||
```
|
||||
@@ -209,12 +294,35 @@ List all the table names in this database.
|
||||
|
||||
Tables will be returned in lexicographical order.
|
||||
|
||||
#### Parameters
|
||||
##### Parameters
|
||||
|
||||
* **options?**: `Partial`<[`TableNamesOptions`](../interfaces/TableNamesOptions.md)>
|
||||
options to control the
|
||||
paging / start point (backwards compatibility)
|
||||
|
||||
##### Returns
|
||||
|
||||
`Promise`<`string`[]>
|
||||
|
||||
#### tableNames(namespace, options)
|
||||
|
||||
```ts
|
||||
abstract tableNames(namespace?, options?): Promise<string[]>
|
||||
```
|
||||
|
||||
List all the table names in this database.
|
||||
|
||||
Tables will be returned in lexicographical order.
|
||||
|
||||
##### Parameters
|
||||
|
||||
* **namespace?**: `string`[]
|
||||
The namespace to list tables from (defaults to root namespace)
|
||||
|
||||
* **options?**: `Partial`<[`TableNamesOptions`](../interfaces/TableNamesOptions.md)>
|
||||
options to control the
|
||||
paging / start point
|
||||
|
||||
#### Returns
|
||||
##### Returns
|
||||
|
||||
`Promise`<`string`[]>
|
||||
|
||||
85
docs/src/js/classes/HeaderProvider.md
Normal file
85
docs/src/js/classes/HeaderProvider.md
Normal file
@@ -0,0 +1,85 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / HeaderProvider
|
||||
|
||||
# Class: `abstract` HeaderProvider
|
||||
|
||||
Abstract base class for providing custom headers for each request.
|
||||
|
||||
Users can implement this interface to provide dynamic headers for various purposes
|
||||
such as authentication (OAuth tokens, API keys), request tracking (correlation IDs),
|
||||
custom metadata, or any other header-based requirements. The provider is called
|
||||
before each request to ensure fresh header values are always used.
|
||||
|
||||
## Examples
|
||||
|
||||
Simple JWT token provider:
|
||||
```typescript
|
||||
class JWTProvider extends HeaderProvider {
|
||||
constructor(private token: string) {
|
||||
super();
|
||||
}
|
||||
|
||||
getHeaders(): Record<string, string> {
|
||||
return { authorization: `Bearer ${this.token}` };
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Provider with request tracking:
|
||||
```typescript
|
||||
class RequestTrackingProvider extends HeaderProvider {
|
||||
constructor(private sessionId: string) {
|
||||
super();
|
||||
}
|
||||
|
||||
getHeaders(): Record<string, string> {
|
||||
return {
|
||||
"X-Session-Id": this.sessionId,
|
||||
"X-Request-Id": `req-${Date.now()}`
|
||||
};
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Extended by
|
||||
|
||||
- [`StaticHeaderProvider`](StaticHeaderProvider.md)
|
||||
- [`OAuthHeaderProvider`](OAuthHeaderProvider.md)
|
||||
|
||||
## Constructors
|
||||
|
||||
### new HeaderProvider()
|
||||
|
||||
```ts
|
||||
new HeaderProvider(): HeaderProvider
|
||||
```
|
||||
|
||||
#### Returns
|
||||
|
||||
[`HeaderProvider`](HeaderProvider.md)
|
||||
|
||||
## Methods
|
||||
|
||||
### getHeaders()
|
||||
|
||||
```ts
|
||||
abstract getHeaders(): Record<string, string>
|
||||
```
|
||||
|
||||
Get the latest headers to be added to requests.
|
||||
|
||||
This method is called before each request to the remote LanceDB server.
|
||||
Implementations should return headers that will be merged with existing headers.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Record`<`string`, `string`>
|
||||
|
||||
Dictionary of header names to values to add to the request.
|
||||
|
||||
#### Throws
|
||||
|
||||
If unable to fetch headers, the exception will be propagated and the request will fail.
|
||||
29
docs/src/js/classes/NativeJsHeaderProvider.md
Normal file
29
docs/src/js/classes/NativeJsHeaderProvider.md
Normal file
@@ -0,0 +1,29 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / NativeJsHeaderProvider
|
||||
|
||||
# Class: NativeJsHeaderProvider
|
||||
|
||||
JavaScript HeaderProvider implementation that wraps a JavaScript callback.
|
||||
This is the only native header provider - all header provider implementations
|
||||
should provide a JavaScript function that returns headers.
|
||||
|
||||
## Constructors
|
||||
|
||||
### new NativeJsHeaderProvider()
|
||||
|
||||
```ts
|
||||
new NativeJsHeaderProvider(getHeadersCallback): NativeJsHeaderProvider
|
||||
```
|
||||
|
||||
Create a new JsHeaderProvider from a JavaScript callback
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **getHeadersCallback**
|
||||
|
||||
#### Returns
|
||||
|
||||
[`NativeJsHeaderProvider`](NativeJsHeaderProvider.md)
|
||||
108
docs/src/js/classes/OAuthHeaderProvider.md
Normal file
108
docs/src/js/classes/OAuthHeaderProvider.md
Normal file
@@ -0,0 +1,108 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / OAuthHeaderProvider
|
||||
|
||||
# Class: OAuthHeaderProvider
|
||||
|
||||
Example implementation: OAuth token provider with automatic refresh.
|
||||
|
||||
This is an example implementation showing how to manage OAuth tokens
|
||||
with automatic refresh when they expire.
|
||||
|
||||
## Example
|
||||
|
||||
```typescript
|
||||
async function fetchToken(): Promise<TokenResponse> {
|
||||
const response = await fetch("https://oauth.example.com/token", {
|
||||
method: "POST",
|
||||
body: JSON.stringify({
|
||||
grant_type: "client_credentials",
|
||||
client_id: "your-client-id",
|
||||
client_secret: "your-client-secret"
|
||||
}),
|
||||
headers: { "Content-Type": "application/json" }
|
||||
});
|
||||
const data = await response.json();
|
||||
return {
|
||||
accessToken: data.access_token,
|
||||
expiresIn: data.expires_in
|
||||
};
|
||||
}
|
||||
|
||||
const provider = new OAuthHeaderProvider(fetchToken);
|
||||
const headers = provider.getHeaders();
|
||||
// Returns: {"authorization": "Bearer <your-token>"}
|
||||
```
|
||||
|
||||
## Extends
|
||||
|
||||
- [`HeaderProvider`](HeaderProvider.md)
|
||||
|
||||
## Constructors
|
||||
|
||||
### new OAuthHeaderProvider()
|
||||
|
||||
```ts
|
||||
new OAuthHeaderProvider(tokenFetcher, refreshBufferSeconds): OAuthHeaderProvider
|
||||
```
|
||||
|
||||
Initialize the OAuth provider.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **tokenFetcher**
|
||||
Function to fetch new tokens. Should return object with 'accessToken' and optionally 'expiresIn'.
|
||||
|
||||
* **refreshBufferSeconds**: `number` = `300`
|
||||
Seconds before expiry to refresh token. Default 300 (5 minutes).
|
||||
|
||||
#### Returns
|
||||
|
||||
[`OAuthHeaderProvider`](OAuthHeaderProvider.md)
|
||||
|
||||
#### Overrides
|
||||
|
||||
[`HeaderProvider`](HeaderProvider.md).[`constructor`](HeaderProvider.md#constructors)
|
||||
|
||||
## Methods
|
||||
|
||||
### getHeaders()
|
||||
|
||||
```ts
|
||||
getHeaders(): Record<string, string>
|
||||
```
|
||||
|
||||
Get OAuth headers, refreshing token if needed.
|
||||
Note: This is synchronous for now as the Rust implementation expects sync.
|
||||
In a real implementation, this would need to handle async properly.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Record`<`string`, `string`>
|
||||
|
||||
Headers with Bearer token authorization.
|
||||
|
||||
#### Throws
|
||||
|
||||
If unable to fetch or refresh token.
|
||||
|
||||
#### Overrides
|
||||
|
||||
[`HeaderProvider`](HeaderProvider.md).[`getHeaders`](HeaderProvider.md#getheaders)
|
||||
|
||||
***
|
||||
|
||||
### refreshToken()
|
||||
|
||||
```ts
|
||||
refreshToken(): Promise<void>
|
||||
```
|
||||
|
||||
Manually refresh the token.
|
||||
Call this before using getHeaders() to ensure token is available.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`void`>
|
||||
70
docs/src/js/classes/StaticHeaderProvider.md
Normal file
70
docs/src/js/classes/StaticHeaderProvider.md
Normal file
@@ -0,0 +1,70 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / StaticHeaderProvider
|
||||
|
||||
# Class: StaticHeaderProvider
|
||||
|
||||
Example implementation: A simple header provider that returns static headers.
|
||||
|
||||
This is an example implementation showing how to create a HeaderProvider
|
||||
for cases where headers don't change during the session.
|
||||
|
||||
## Example
|
||||
|
||||
```typescript
|
||||
const provider = new StaticHeaderProvider({
|
||||
authorization: "Bearer my-token",
|
||||
"X-Custom-Header": "custom-value"
|
||||
});
|
||||
const headers = provider.getHeaders();
|
||||
// Returns: {authorization: 'Bearer my-token', 'X-Custom-Header': 'custom-value'}
|
||||
```
|
||||
|
||||
## Extends
|
||||
|
||||
- [`HeaderProvider`](HeaderProvider.md)
|
||||
|
||||
## Constructors
|
||||
|
||||
### new StaticHeaderProvider()
|
||||
|
||||
```ts
|
||||
new StaticHeaderProvider(headers): StaticHeaderProvider
|
||||
```
|
||||
|
||||
Initialize with static headers.
|
||||
|
||||
#### Parameters
|
||||
|
||||
* **headers**: `Record`<`string`, `string`>
|
||||
Headers to return for every request.
|
||||
|
||||
#### Returns
|
||||
|
||||
[`StaticHeaderProvider`](StaticHeaderProvider.md)
|
||||
|
||||
#### Overrides
|
||||
|
||||
[`HeaderProvider`](HeaderProvider.md).[`constructor`](HeaderProvider.md#constructors)
|
||||
|
||||
## Methods
|
||||
|
||||
### getHeaders()
|
||||
|
||||
```ts
|
||||
getHeaders(): Record<string, string>
|
||||
```
|
||||
|
||||
Return the static headers.
|
||||
|
||||
#### Returns
|
||||
|
||||
`Record`<`string`, `string`>
|
||||
|
||||
Copy of the static headers.
|
||||
|
||||
#### Overrides
|
||||
|
||||
[`HeaderProvider`](HeaderProvider.md).[`getHeaders`](HeaderProvider.md#getheaders)
|
||||
@@ -6,13 +6,14 @@
|
||||
|
||||
# Function: connect()
|
||||
|
||||
## connect(uri, options, session)
|
||||
## connect(uri, options, session, headerProvider)
|
||||
|
||||
```ts
|
||||
function connect(
|
||||
uri,
|
||||
options?,
|
||||
session?): Promise<Connection>
|
||||
session?,
|
||||
headerProvider?): Promise<Connection>
|
||||
```
|
||||
|
||||
Connect to a LanceDB instance at the given URI.
|
||||
@@ -34,6 +35,8 @@ Accepted formats:
|
||||
|
||||
* **session?**: [`Session`](../classes/Session.md)
|
||||
|
||||
* **headerProvider?**: [`HeaderProvider`](../classes/HeaderProvider.md) \| () => `Record`<`string`, `string`> \| () => `Promise`<`Record`<`string`, `string`>>
|
||||
|
||||
### Returns
|
||||
|
||||
`Promise`<[`Connection`](../classes/Connection.md)>
|
||||
@@ -55,6 +58,18 @@ const conn = await connect(
|
||||
});
|
||||
```
|
||||
|
||||
Using with a header provider for per-request authentication:
|
||||
```ts
|
||||
const provider = new StaticHeaderProvider({
|
||||
"X-API-Key": "my-key"
|
||||
});
|
||||
const conn = await connectWithHeaderProvider(
|
||||
"db://host:port",
|
||||
options,
|
||||
provider
|
||||
);
|
||||
```
|
||||
|
||||
## connect(options)
|
||||
|
||||
```ts
|
||||
|
||||
@@ -20,16 +20,20 @@
|
||||
- [BooleanQuery](classes/BooleanQuery.md)
|
||||
- [BoostQuery](classes/BoostQuery.md)
|
||||
- [Connection](classes/Connection.md)
|
||||
- [HeaderProvider](classes/HeaderProvider.md)
|
||||
- [Index](classes/Index.md)
|
||||
- [MakeArrowTableOptions](classes/MakeArrowTableOptions.md)
|
||||
- [MatchQuery](classes/MatchQuery.md)
|
||||
- [MergeInsertBuilder](classes/MergeInsertBuilder.md)
|
||||
- [MultiMatchQuery](classes/MultiMatchQuery.md)
|
||||
- [NativeJsHeaderProvider](classes/NativeJsHeaderProvider.md)
|
||||
- [OAuthHeaderProvider](classes/OAuthHeaderProvider.md)
|
||||
- [PhraseQuery](classes/PhraseQuery.md)
|
||||
- [Query](classes/Query.md)
|
||||
- [QueryBase](classes/QueryBase.md)
|
||||
- [RecordBatchIterator](classes/RecordBatchIterator.md)
|
||||
- [Session](classes/Session.md)
|
||||
- [StaticHeaderProvider](classes/StaticHeaderProvider.md)
|
||||
- [Table](classes/Table.md)
|
||||
- [TagContents](classes/TagContents.md)
|
||||
- [Tags](classes/Tags.md)
|
||||
@@ -74,6 +78,7 @@
|
||||
- [TableNamesOptions](interfaces/TableNamesOptions.md)
|
||||
- [TableStatistics](interfaces/TableStatistics.md)
|
||||
- [TimeoutConfig](interfaces/TimeoutConfig.md)
|
||||
- [TokenResponse](interfaces/TokenResponse.md)
|
||||
- [UpdateOptions](interfaces/UpdateOptions.md)
|
||||
- [UpdateResult](interfaces/UpdateResult.md)
|
||||
- [Version](interfaces/Version.md)
|
||||
|
||||
@@ -16,6 +16,14 @@ optional extraHeaders: Record<string, string>;
|
||||
|
||||
***
|
||||
|
||||
### idDelimiter?
|
||||
|
||||
```ts
|
||||
optional idDelimiter: string;
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### retryConfig?
|
||||
|
||||
```ts
|
||||
|
||||
@@ -26,6 +26,18 @@ will be used to determine the most useful kind of index to create.
|
||||
|
||||
***
|
||||
|
||||
### name?
|
||||
|
||||
```ts
|
||||
optional name: string;
|
||||
```
|
||||
|
||||
Optional custom name for the index.
|
||||
|
||||
If not provided, a default name will be generated based on the column name.
|
||||
|
||||
***
|
||||
|
||||
### replace?
|
||||
|
||||
```ts
|
||||
@@ -42,8 +54,27 @@ The default is true
|
||||
|
||||
***
|
||||
|
||||
### train?
|
||||
|
||||
```ts
|
||||
optional train: boolean;
|
||||
```
|
||||
|
||||
Whether to train the index with existing data.
|
||||
|
||||
If true (default), the index will be trained with existing data in the table.
|
||||
If false, the index will be created empty and populated as new data is added.
|
||||
|
||||
Note: This option is only supported for scalar indices. Vector indices always train.
|
||||
|
||||
***
|
||||
|
||||
### waitTimeoutSeconds?
|
||||
|
||||
```ts
|
||||
optional waitTimeoutSeconds: number;
|
||||
```
|
||||
|
||||
Timeout in seconds to wait for index creation to complete.
|
||||
|
||||
If not specified, the method will return immediately after starting the index creation.
|
||||
|
||||
25
docs/src/js/interfaces/TokenResponse.md
Normal file
25
docs/src/js/interfaces/TokenResponse.md
Normal file
@@ -0,0 +1,25 @@
|
||||
[**@lancedb/lancedb**](../README.md) • **Docs**
|
||||
|
||||
***
|
||||
|
||||
[@lancedb/lancedb](../globals.md) / TokenResponse
|
||||
|
||||
# Interface: TokenResponse
|
||||
|
||||
Token response from OAuth provider.
|
||||
|
||||
## Properties
|
||||
|
||||
### accessToken
|
||||
|
||||
```ts
|
||||
accessToken: string;
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
### expiresIn?
|
||||
|
||||
```ts
|
||||
optional expiresIn: number;
|
||||
```
|
||||
@@ -15,7 +15,7 @@ publish = false
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
lancedb = { path = "../../../rust/lancedb" }
|
||||
lancedb = { path = "../../../rust/lancedb", default-features = false }
|
||||
lance = { workspace = true }
|
||||
arrow = { workspace = true, features = ["ffi"] }
|
||||
arrow-schema.workspace = true
|
||||
@@ -25,3 +25,6 @@ snafu.workspace = true
|
||||
lazy_static.workspace = true
|
||||
serde = { version = "^1" }
|
||||
serde_json = { version = "1" }
|
||||
|
||||
[features]
|
||||
default = ["lancedb/default"]
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.21.2-final.0</version>
|
||||
<version>0.22.1-beta.0</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
<parent>
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.21.2-final.0</version>
|
||||
<version>0.22.1-beta.0</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
<groupId>com.lancedb</groupId>
|
||||
<artifactId>lancedb-parent</artifactId>
|
||||
<version>0.21.2-final.0</version>
|
||||
<version>0.22.1-beta.0</version>
|
||||
<packaging>pom</packaging>
|
||||
<name>${project.artifactId}</name>
|
||||
<description>LanceDB Java SDK Parent POM</description>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "lancedb-nodejs"
|
||||
edition.workspace = true
|
||||
version = "0.21.2"
|
||||
version = "0.22.1-beta.0"
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
@@ -18,7 +18,7 @@ arrow-array.workspace = true
|
||||
arrow-schema.workspace = true
|
||||
env_logger.workspace = true
|
||||
futures.workspace = true
|
||||
lancedb = { path = "../rust/lancedb" }
|
||||
lancedb = { path = "../rust/lancedb", default-features = false }
|
||||
napi = { version = "2.16.8", default-features = false, features = [
|
||||
"napi9",
|
||||
"async"
|
||||
@@ -36,6 +36,6 @@ aws-lc-rs = "=1.13.0"
|
||||
napi-build = "2.1"
|
||||
|
||||
[features]
|
||||
default = ["remote"]
|
||||
default = ["remote", "lancedb/default"]
|
||||
fp16kernels = ["lancedb/fp16kernels"]
|
||||
remote = ["lancedb/remote"]
|
||||
|
||||
@@ -3,7 +3,50 @@
|
||||
|
||||
import * as http from "http";
|
||||
import { RequestListener } from "http";
|
||||
import { Connection, ConnectionOptions, connect } from "../lancedb";
|
||||
import {
|
||||
ClientConfig,
|
||||
Connection,
|
||||
ConnectionOptions,
|
||||
NativeJsHeaderProvider,
|
||||
TlsConfig,
|
||||
connect,
|
||||
} from "../lancedb";
|
||||
import {
|
||||
HeaderProvider,
|
||||
OAuthHeaderProvider,
|
||||
StaticHeaderProvider,
|
||||
} from "../lancedb/header";
|
||||
|
||||
// Test-only header providers
|
||||
class CustomProvider extends HeaderProvider {
|
||||
getHeaders(): Record<string, string> {
|
||||
return { "X-Custom": "custom-value" };
|
||||
}
|
||||
}
|
||||
|
||||
class ErrorProvider extends HeaderProvider {
|
||||
private errorMessage: string;
|
||||
public callCount: number = 0;
|
||||
|
||||
constructor(errorMessage: string = "Test error") {
|
||||
super();
|
||||
this.errorMessage = errorMessage;
|
||||
}
|
||||
|
||||
getHeaders(): Record<string, string> {
|
||||
this.callCount++;
|
||||
throw new Error(this.errorMessage);
|
||||
}
|
||||
}
|
||||
|
||||
class ConcurrentProvider extends HeaderProvider {
|
||||
private counter: number = 0;
|
||||
|
||||
getHeaders(): Record<string, string> {
|
||||
this.counter++;
|
||||
return { "X-Request-Id": String(this.counter) };
|
||||
}
|
||||
}
|
||||
|
||||
async function withMockDatabase(
|
||||
listener: RequestListener,
|
||||
@@ -148,4 +191,431 @@ describe("remote connection", () => {
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
describe("TlsConfig", () => {
|
||||
it("should create TlsConfig with all fields", () => {
|
||||
const tlsConfig: TlsConfig = {
|
||||
certFile: "/path/to/cert.pem",
|
||||
keyFile: "/path/to/key.pem",
|
||||
sslCaCert: "/path/to/ca.pem",
|
||||
assertHostname: false,
|
||||
};
|
||||
|
||||
expect(tlsConfig.certFile).toBe("/path/to/cert.pem");
|
||||
expect(tlsConfig.keyFile).toBe("/path/to/key.pem");
|
||||
expect(tlsConfig.sslCaCert).toBe("/path/to/ca.pem");
|
||||
expect(tlsConfig.assertHostname).toBe(false);
|
||||
});
|
||||
|
||||
it("should create TlsConfig with partial fields", () => {
|
||||
const tlsConfig: TlsConfig = {
|
||||
certFile: "/path/to/cert.pem",
|
||||
keyFile: "/path/to/key.pem",
|
||||
};
|
||||
|
||||
expect(tlsConfig.certFile).toBe("/path/to/cert.pem");
|
||||
expect(tlsConfig.keyFile).toBe("/path/to/key.pem");
|
||||
expect(tlsConfig.sslCaCert).toBeUndefined();
|
||||
expect(tlsConfig.assertHostname).toBeUndefined();
|
||||
});
|
||||
|
||||
it("should create ClientConfig with TlsConfig", () => {
|
||||
const tlsConfig: TlsConfig = {
|
||||
certFile: "/path/to/cert.pem",
|
||||
keyFile: "/path/to/key.pem",
|
||||
sslCaCert: "/path/to/ca.pem",
|
||||
assertHostname: true,
|
||||
};
|
||||
|
||||
const clientConfig: ClientConfig = {
|
||||
userAgent: "test-agent",
|
||||
tlsConfig: tlsConfig,
|
||||
};
|
||||
|
||||
expect(clientConfig.userAgent).toBe("test-agent");
|
||||
expect(clientConfig.tlsConfig).toBeDefined();
|
||||
expect(clientConfig.tlsConfig?.certFile).toBe("/path/to/cert.pem");
|
||||
expect(clientConfig.tlsConfig?.keyFile).toBe("/path/to/key.pem");
|
||||
expect(clientConfig.tlsConfig?.sslCaCert).toBe("/path/to/ca.pem");
|
||||
expect(clientConfig.tlsConfig?.assertHostname).toBe(true);
|
||||
});
|
||||
|
||||
it("should handle empty TlsConfig", () => {
|
||||
const tlsConfig: TlsConfig = {};
|
||||
|
||||
expect(tlsConfig.certFile).toBeUndefined();
|
||||
expect(tlsConfig.keyFile).toBeUndefined();
|
||||
expect(tlsConfig.sslCaCert).toBeUndefined();
|
||||
expect(tlsConfig.assertHostname).toBeUndefined();
|
||||
});
|
||||
|
||||
it("should accept TlsConfig in connection options", () => {
|
||||
const tlsConfig: TlsConfig = {
|
||||
certFile: "/path/to/cert.pem",
|
||||
keyFile: "/path/to/key.pem",
|
||||
sslCaCert: "/path/to/ca.pem",
|
||||
assertHostname: false,
|
||||
};
|
||||
|
||||
// Just verify that the ClientConfig accepts the TlsConfig
|
||||
const clientConfig: ClientConfig = {
|
||||
tlsConfig: tlsConfig,
|
||||
};
|
||||
|
||||
const connectionOptions: ConnectionOptions = {
|
||||
apiKey: "fake",
|
||||
clientConfig: clientConfig,
|
||||
};
|
||||
|
||||
// Verify the configuration structure is correct
|
||||
expect(connectionOptions.clientConfig).toBeDefined();
|
||||
expect(connectionOptions.clientConfig?.tlsConfig).toBeDefined();
|
||||
expect(connectionOptions.clientConfig?.tlsConfig?.certFile).toBe(
|
||||
"/path/to/cert.pem",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("header providers", () => {
|
||||
it("should work with StaticHeaderProvider", async () => {
|
||||
const provider = new StaticHeaderProvider({
|
||||
authorization: "Bearer test-token",
|
||||
"X-Custom": "value",
|
||||
});
|
||||
|
||||
const headers = provider.getHeaders();
|
||||
expect(headers).toEqual({
|
||||
authorization: "Bearer test-token",
|
||||
"X-Custom": "value",
|
||||
});
|
||||
|
||||
// Test that it returns a copy
|
||||
headers["X-Modified"] = "modified";
|
||||
const headers2 = provider.getHeaders();
|
||||
expect(headers2).not.toHaveProperty("X-Modified");
|
||||
});
|
||||
|
||||
it("should pass headers from StaticHeaderProvider to requests", async () => {
|
||||
const provider = new StaticHeaderProvider({
|
||||
"X-Custom-Auth": "secret-token",
|
||||
"X-Request-Source": "test-suite",
|
||||
});
|
||||
|
||||
await withMockDatabase(
|
||||
(req, res) => {
|
||||
expect(req.headers["x-custom-auth"]).toEqual("secret-token");
|
||||
expect(req.headers["x-request-source"]).toEqual("test-suite");
|
||||
|
||||
const body = JSON.stringify({ tables: [] });
|
||||
res.writeHead(200, { "Content-Type": "application/json" }).end(body);
|
||||
},
|
||||
async () => {
|
||||
// Use actual header provider mechanism instead of extraHeaders
|
||||
const conn = await connect(
|
||||
"db://dev",
|
||||
{
|
||||
apiKey: "fake",
|
||||
hostOverride: "http://localhost:8000",
|
||||
},
|
||||
undefined, // session
|
||||
provider, // headerProvider
|
||||
);
|
||||
|
||||
const tableNames = await conn.tableNames();
|
||||
expect(tableNames).toEqual([]);
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
it("should work with CustomProvider", () => {
|
||||
const provider = new CustomProvider();
|
||||
const headers = provider.getHeaders();
|
||||
expect(headers).toEqual({ "X-Custom": "custom-value" });
|
||||
});
|
||||
|
||||
it("should handle ErrorProvider errors", () => {
|
||||
const provider = new ErrorProvider("Authentication failed");
|
||||
|
||||
expect(() => provider.getHeaders()).toThrow("Authentication failed");
|
||||
expect(provider.callCount).toBe(1);
|
||||
|
||||
// Test that error is thrown each time
|
||||
expect(() => provider.getHeaders()).toThrow("Authentication failed");
|
||||
expect(provider.callCount).toBe(2);
|
||||
});
|
||||
|
||||
it("should work with ConcurrentProvider", () => {
|
||||
const provider = new ConcurrentProvider();
|
||||
|
||||
const headers1 = provider.getHeaders();
|
||||
const headers2 = provider.getHeaders();
|
||||
const headers3 = provider.getHeaders();
|
||||
|
||||
expect(headers1).toEqual({ "X-Request-Id": "1" });
|
||||
expect(headers2).toEqual({ "X-Request-Id": "2" });
|
||||
expect(headers3).toEqual({ "X-Request-Id": "3" });
|
||||
});
|
||||
|
||||
describe("OAuthHeaderProvider", () => {
|
||||
it("should initialize correctly", () => {
|
||||
const fetcher = () => ({
|
||||
accessToken: "token123",
|
||||
expiresIn: 3600,
|
||||
});
|
||||
|
||||
const provider = new OAuthHeaderProvider(fetcher);
|
||||
expect(provider).toBeInstanceOf(HeaderProvider);
|
||||
});
|
||||
|
||||
it("should fetch token on first use", async () => {
|
||||
let callCount = 0;
|
||||
const fetcher = () => {
|
||||
callCount++;
|
||||
return {
|
||||
accessToken: "token123",
|
||||
expiresIn: 3600,
|
||||
};
|
||||
};
|
||||
|
||||
const provider = new OAuthHeaderProvider(fetcher);
|
||||
|
||||
// Need to manually refresh first due to sync limitation
|
||||
await provider.refreshToken();
|
||||
|
||||
const headers = provider.getHeaders();
|
||||
expect(headers).toEqual({ authorization: "Bearer token123" });
|
||||
expect(callCount).toBe(1);
|
||||
|
||||
// Second call should not fetch again
|
||||
const headers2 = provider.getHeaders();
|
||||
expect(headers2).toEqual({ authorization: "Bearer token123" });
|
||||
expect(callCount).toBe(1);
|
||||
});
|
||||
|
||||
it("should handle tokens without expiry", async () => {
|
||||
const fetcher = () => ({
|
||||
accessToken: "permanent_token",
|
||||
});
|
||||
|
||||
const provider = new OAuthHeaderProvider(fetcher);
|
||||
await provider.refreshToken();
|
||||
|
||||
const headers = provider.getHeaders();
|
||||
expect(headers).toEqual({ authorization: "Bearer permanent_token" });
|
||||
});
|
||||
|
||||
it("should throw error when access_token is missing", async () => {
|
||||
const fetcher = () =>
|
||||
({
|
||||
expiresIn: 3600,
|
||||
}) as { accessToken?: string; expiresIn?: number };
|
||||
|
||||
const provider = new OAuthHeaderProvider(
|
||||
fetcher as () => {
|
||||
accessToken: string;
|
||||
expiresIn?: number;
|
||||
},
|
||||
);
|
||||
|
||||
await expect(provider.refreshToken()).rejects.toThrow(
|
||||
"Token fetcher did not return 'accessToken'",
|
||||
);
|
||||
});
|
||||
|
||||
it("should handle async token fetchers", async () => {
|
||||
const fetcher = async () => {
|
||||
// Simulate async operation
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
return {
|
||||
accessToken: "async_token",
|
||||
expiresIn: 3600,
|
||||
};
|
||||
};
|
||||
|
||||
const provider = new OAuthHeaderProvider(fetcher);
|
||||
await provider.refreshToken();
|
||||
|
||||
const headers = provider.getHeaders();
|
||||
expect(headers).toEqual({ authorization: "Bearer async_token" });
|
||||
});
|
||||
});
|
||||
|
||||
it("should merge header provider headers with extra headers", async () => {
|
||||
const provider = new StaticHeaderProvider({
|
||||
"X-From-Provider": "provider-value",
|
||||
});
|
||||
|
||||
await withMockDatabase(
|
||||
(req, res) => {
|
||||
expect(req.headers["x-from-provider"]).toEqual("provider-value");
|
||||
expect(req.headers["x-extra-header"]).toEqual("extra-value");
|
||||
|
||||
const body = JSON.stringify({ tables: [] });
|
||||
res.writeHead(200, { "Content-Type": "application/json" }).end(body);
|
||||
},
|
||||
async () => {
|
||||
// Use header provider with additional extraHeaders
|
||||
const conn = await connect(
|
||||
"db://dev",
|
||||
{
|
||||
apiKey: "fake",
|
||||
hostOverride: "http://localhost:8000",
|
||||
clientConfig: {
|
||||
extraHeaders: {
|
||||
"X-Extra-Header": "extra-value",
|
||||
},
|
||||
},
|
||||
},
|
||||
undefined, // session
|
||||
provider, // headerProvider
|
||||
);
|
||||
|
||||
const tableNames = await conn.tableNames();
|
||||
expect(tableNames).toEqual([]);
|
||||
},
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("header provider integration", () => {
|
||||
it("should work with TypeScript StaticHeaderProvider", async () => {
|
||||
let requestCount = 0;
|
||||
|
||||
await withMockDatabase(
|
||||
(req, res) => {
|
||||
requestCount++;
|
||||
|
||||
// Check headers are present on each request
|
||||
expect(req.headers["authorization"]).toEqual("Bearer test-token-123");
|
||||
expect(req.headers["x-custom"]).toEqual("custom-value");
|
||||
|
||||
// Return different responses based on the endpoint
|
||||
if (req.url === "/v1/table/test_table/describe/") {
|
||||
const body = JSON.stringify({
|
||||
name: "test_table",
|
||||
schema: { fields: [] },
|
||||
});
|
||||
res
|
||||
.writeHead(200, { "Content-Type": "application/json" })
|
||||
.end(body);
|
||||
} else {
|
||||
const body = JSON.stringify({ tables: ["test_table"] });
|
||||
res
|
||||
.writeHead(200, { "Content-Type": "application/json" })
|
||||
.end(body);
|
||||
}
|
||||
},
|
||||
async () => {
|
||||
// Create provider with static headers
|
||||
const provider = new StaticHeaderProvider({
|
||||
authorization: "Bearer test-token-123",
|
||||
"X-Custom": "custom-value",
|
||||
});
|
||||
|
||||
// Connect with the provider
|
||||
const conn = await connect(
|
||||
"db://dev",
|
||||
{
|
||||
apiKey: "fake",
|
||||
hostOverride: "http://localhost:8000",
|
||||
},
|
||||
undefined, // session
|
||||
provider, // headerProvider
|
||||
);
|
||||
|
||||
// Make multiple requests to verify headers are sent each time
|
||||
const tables1 = await conn.tableNames();
|
||||
expect(tables1).toEqual(["test_table"]);
|
||||
|
||||
const tables2 = await conn.tableNames();
|
||||
expect(tables2).toEqual(["test_table"]);
|
||||
|
||||
// Verify headers were sent with each request
|
||||
expect(requestCount).toBeGreaterThanOrEqual(2);
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
it("should work with JavaScript function provider", async () => {
|
||||
let requestId = 0;
|
||||
|
||||
await withMockDatabase(
|
||||
(req, res) => {
|
||||
// Check dynamic header is present
|
||||
expect(req.headers["x-request-id"]).toBeDefined();
|
||||
expect(req.headers["x-request-id"]).toMatch(/^req-\d+$/);
|
||||
|
||||
const body = JSON.stringify({ tables: [] });
|
||||
res.writeHead(200, { "Content-Type": "application/json" }).end(body);
|
||||
},
|
||||
async () => {
|
||||
// Create a JavaScript function that returns dynamic headers
|
||||
const getHeaders = async () => {
|
||||
requestId++;
|
||||
return {
|
||||
"X-Request-Id": `req-${requestId}`,
|
||||
"X-Timestamp": new Date().toISOString(),
|
||||
};
|
||||
};
|
||||
|
||||
// Connect with the function directly
|
||||
const conn = await connect(
|
||||
"db://dev",
|
||||
{
|
||||
apiKey: "fake",
|
||||
hostOverride: "http://localhost:8000",
|
||||
},
|
||||
undefined, // session
|
||||
getHeaders, // headerProvider
|
||||
);
|
||||
|
||||
// Make requests - each should have different headers
|
||||
const tables = await conn.tableNames();
|
||||
expect(tables).toEqual([]);
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
it("should support OAuth-like token refresh pattern", async () => {
|
||||
let tokenVersion = 0;
|
||||
|
||||
await withMockDatabase(
|
||||
(req, res) => {
|
||||
// Verify authorization header
|
||||
const authHeader = req.headers["authorization"];
|
||||
expect(authHeader).toBeDefined();
|
||||
expect(authHeader).toMatch(/^Bearer token-v\d+$/);
|
||||
|
||||
const body = JSON.stringify({ tables: [] });
|
||||
res.writeHead(200, { "Content-Type": "application/json" }).end(body);
|
||||
},
|
||||
async () => {
|
||||
// Simulate OAuth token fetcher
|
||||
const fetchToken = async () => {
|
||||
tokenVersion++;
|
||||
return {
|
||||
authorization: `Bearer token-v${tokenVersion}`,
|
||||
};
|
||||
};
|
||||
|
||||
// Connect with the function directly
|
||||
const conn = await connect(
|
||||
"db://dev",
|
||||
{
|
||||
apiKey: "fake",
|
||||
hostOverride: "http://localhost:8000",
|
||||
},
|
||||
undefined, // session
|
||||
fetchToken, // headerProvider
|
||||
);
|
||||
|
||||
// Each request will fetch a new token
|
||||
await conn.tableNames();
|
||||
|
||||
// Token should be different on next request
|
||||
await conn.tableNames();
|
||||
},
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -857,6 +857,40 @@ describe("When creating an index", () => {
|
||||
expect(stats).toBeUndefined();
|
||||
});
|
||||
|
||||
test("should support name and train parameters", async () => {
|
||||
// Test with custom name
|
||||
await tbl.createIndex("vec", {
|
||||
config: Index.ivfPq({ numPartitions: 4 }),
|
||||
name: "my_custom_vector_index",
|
||||
});
|
||||
|
||||
const indices = await tbl.listIndices();
|
||||
expect(indices).toHaveLength(1);
|
||||
expect(indices[0].name).toBe("my_custom_vector_index");
|
||||
|
||||
// Test scalar index with train=false
|
||||
await tbl.createIndex("id", {
|
||||
config: Index.btree(),
|
||||
name: "btree_empty",
|
||||
train: false,
|
||||
});
|
||||
|
||||
const allIndices = await tbl.listIndices();
|
||||
expect(allIndices).toHaveLength(2);
|
||||
expect(allIndices.some((idx) => idx.name === "btree_empty")).toBe(true);
|
||||
|
||||
// Test with both name and train=true (use tags column)
|
||||
await tbl.createIndex("tags", {
|
||||
config: Index.labelList(),
|
||||
name: "tags_trained",
|
||||
train: true,
|
||||
});
|
||||
|
||||
const finalIndices = await tbl.listIndices();
|
||||
expect(finalIndices).toHaveLength(3);
|
||||
expect(finalIndices.some((idx) => idx.name === "tags_trained")).toBe(true);
|
||||
});
|
||||
|
||||
test("create ivf_flat with binary vectors", async () => {
|
||||
const db = await connect(tmpDir.name);
|
||||
const binarySchema = new Schema([
|
||||
|
||||
@@ -159,17 +159,33 @@ export abstract class Connection {
|
||||
*
|
||||
* Tables will be returned in lexicographical order.
|
||||
* @param {Partial<TableNamesOptions>} options - options to control the
|
||||
* paging / start point
|
||||
* paging / start point (backwards compatibility)
|
||||
*
|
||||
*/
|
||||
abstract tableNames(options?: Partial<TableNamesOptions>): Promise<string[]>;
|
||||
/**
|
||||
* List all the table names in this database.
|
||||
*
|
||||
* Tables will be returned in lexicographical order.
|
||||
* @param {string[]} namespace - The namespace to list tables from (defaults to root namespace)
|
||||
* @param {Partial<TableNamesOptions>} options - options to control the
|
||||
* paging / start point
|
||||
*
|
||||
*/
|
||||
abstract tableNames(
|
||||
namespace?: string[],
|
||||
options?: Partial<TableNamesOptions>,
|
||||
): Promise<string[]>;
|
||||
|
||||
/**
|
||||
* Open a table in the database.
|
||||
* @param {string} name - The name of the table
|
||||
* @param {string[]} namespace - The namespace of the table (defaults to root namespace)
|
||||
* @param {Partial<OpenTableOptions>} options - Additional options
|
||||
*/
|
||||
abstract openTable(
|
||||
name: string,
|
||||
namespace?: string[],
|
||||
options?: Partial<OpenTableOptions>,
|
||||
): Promise<Table>;
|
||||
|
||||
@@ -178,6 +194,7 @@ export abstract class Connection {
|
||||
* @param {object} options - The options object.
|
||||
* @param {string} options.name - The name of the table.
|
||||
* @param {Data} options.data - Non-empty Array of Records to be inserted into the table
|
||||
* @param {string[]} namespace - The namespace to create the table in (defaults to root namespace)
|
||||
*
|
||||
*/
|
||||
abstract createTable(
|
||||
@@ -185,40 +202,72 @@ export abstract class Connection {
|
||||
name: string;
|
||||
data: Data;
|
||||
} & Partial<CreateTableOptions>,
|
||||
namespace?: string[],
|
||||
): Promise<Table>;
|
||||
/**
|
||||
* Creates a new Table and initialize it with new data.
|
||||
* @param {string} name - The name of the table.
|
||||
* @param {Record<string, unknown>[] | TableLike} data - Non-empty Array of Records
|
||||
* to be inserted into the table
|
||||
* @param {Partial<CreateTableOptions>} options - Additional options (backwards compatibility)
|
||||
*/
|
||||
abstract createTable(
|
||||
name: string,
|
||||
data: Record<string, unknown>[] | TableLike,
|
||||
options?: Partial<CreateTableOptions>,
|
||||
): Promise<Table>;
|
||||
/**
|
||||
* Creates a new Table and initialize it with new data.
|
||||
* @param {string} name - The name of the table.
|
||||
* @param {Record<string, unknown>[] | TableLike} data - Non-empty Array of Records
|
||||
* to be inserted into the table
|
||||
* @param {string[]} namespace - The namespace to create the table in (defaults to root namespace)
|
||||
* @param {Partial<CreateTableOptions>} options - Additional options
|
||||
*/
|
||||
abstract createTable(
|
||||
name: string,
|
||||
data: Record<string, unknown>[] | TableLike,
|
||||
namespace?: string[],
|
||||
options?: Partial<CreateTableOptions>,
|
||||
): Promise<Table>;
|
||||
|
||||
/**
|
||||
* Creates a new empty Table
|
||||
* @param {string} name - The name of the table.
|
||||
* @param {Schema} schema - The schema of the table
|
||||
* @param {Partial<CreateTableOptions>} options - Additional options (backwards compatibility)
|
||||
*/
|
||||
abstract createEmptyTable(
|
||||
name: string,
|
||||
schema: import("./arrow").SchemaLike,
|
||||
options?: Partial<CreateTableOptions>,
|
||||
): Promise<Table>;
|
||||
/**
|
||||
* Creates a new empty Table
|
||||
* @param {string} name - The name of the table.
|
||||
* @param {Schema} schema - The schema of the table
|
||||
* @param {string[]} namespace - The namespace to create the table in (defaults to root namespace)
|
||||
* @param {Partial<CreateTableOptions>} options - Additional options
|
||||
*/
|
||||
abstract createEmptyTable(
|
||||
name: string,
|
||||
schema: import("./arrow").SchemaLike,
|
||||
namespace?: string[],
|
||||
options?: Partial<CreateTableOptions>,
|
||||
): Promise<Table>;
|
||||
|
||||
/**
|
||||
* Drop an existing table.
|
||||
* @param {string} name The name of the table to drop.
|
||||
* @param {string[]} namespace The namespace of the table (defaults to root namespace).
|
||||
*/
|
||||
abstract dropTable(name: string): Promise<void>;
|
||||
abstract dropTable(name: string, namespace?: string[]): Promise<void>;
|
||||
|
||||
/**
|
||||
* Drop all tables in the database.
|
||||
* @param {string[]} namespace The namespace to drop tables from (defaults to root namespace).
|
||||
*/
|
||||
abstract dropAllTables(): Promise<void>;
|
||||
abstract dropAllTables(namespace?: string[]): Promise<void>;
|
||||
}
|
||||
|
||||
/** @hideconstructor */
|
||||
@@ -243,16 +292,39 @@ export class LocalConnection extends Connection {
|
||||
return this.inner.display();
|
||||
}
|
||||
|
||||
async tableNames(options?: Partial<TableNamesOptions>): Promise<string[]> {
|
||||
return this.inner.tableNames(options?.startAfter, options?.limit);
|
||||
async tableNames(
|
||||
namespaceOrOptions?: string[] | Partial<TableNamesOptions>,
|
||||
options?: Partial<TableNamesOptions>,
|
||||
): Promise<string[]> {
|
||||
// Detect if first argument is namespace array or options object
|
||||
let namespace: string[] | undefined;
|
||||
let tableNamesOptions: Partial<TableNamesOptions> | undefined;
|
||||
|
||||
if (Array.isArray(namespaceOrOptions)) {
|
||||
// First argument is namespace array
|
||||
namespace = namespaceOrOptions;
|
||||
tableNamesOptions = options;
|
||||
} else {
|
||||
// First argument is options object (backwards compatibility)
|
||||
namespace = undefined;
|
||||
tableNamesOptions = namespaceOrOptions;
|
||||
}
|
||||
|
||||
return this.inner.tableNames(
|
||||
namespace ?? [],
|
||||
tableNamesOptions?.startAfter,
|
||||
tableNamesOptions?.limit,
|
||||
);
|
||||
}
|
||||
|
||||
async openTable(
|
||||
name: string,
|
||||
namespace?: string[],
|
||||
options?: Partial<OpenTableOptions>,
|
||||
): Promise<Table> {
|
||||
const innerTable = await this.inner.openTable(
|
||||
name,
|
||||
namespace ?? [],
|
||||
cleanseStorageOptions(options?.storageOptions),
|
||||
options?.indexCacheSize,
|
||||
);
|
||||
@@ -286,14 +358,44 @@ export class LocalConnection extends Connection {
|
||||
nameOrOptions:
|
||||
| string
|
||||
| ({ name: string; data: Data } & Partial<CreateTableOptions>),
|
||||
data?: Record<string, unknown>[] | TableLike,
|
||||
dataOrNamespace?: Record<string, unknown>[] | TableLike | string[],
|
||||
namespaceOrOptions?: string[] | Partial<CreateTableOptions>,
|
||||
options?: Partial<CreateTableOptions>,
|
||||
): Promise<Table> {
|
||||
if (typeof nameOrOptions !== "string" && "name" in nameOrOptions) {
|
||||
const { name, data, ...options } = nameOrOptions;
|
||||
|
||||
return this.createTable(name, data, options);
|
||||
// First overload: createTable(options, namespace?)
|
||||
const { name, data, ...createOptions } = nameOrOptions;
|
||||
const namespace = dataOrNamespace as string[] | undefined;
|
||||
return this._createTableImpl(name, data, namespace, createOptions);
|
||||
}
|
||||
|
||||
// Second overload: createTable(name, data, namespace?, options?)
|
||||
const name = nameOrOptions;
|
||||
const data = dataOrNamespace as Record<string, unknown>[] | TableLike;
|
||||
|
||||
// Detect if third argument is namespace array or options object
|
||||
let namespace: string[] | undefined;
|
||||
let createOptions: Partial<CreateTableOptions> | undefined;
|
||||
|
||||
if (Array.isArray(namespaceOrOptions)) {
|
||||
// Third argument is namespace array
|
||||
namespace = namespaceOrOptions;
|
||||
createOptions = options;
|
||||
} else {
|
||||
// Third argument is options object (backwards compatibility)
|
||||
namespace = undefined;
|
||||
createOptions = namespaceOrOptions;
|
||||
}
|
||||
|
||||
return this._createTableImpl(name, data, namespace, createOptions);
|
||||
}
|
||||
|
||||
private async _createTableImpl(
|
||||
name: string,
|
||||
data: Data,
|
||||
namespace?: string[],
|
||||
options?: Partial<CreateTableOptions>,
|
||||
): Promise<Table> {
|
||||
if (data === undefined) {
|
||||
throw new Error("data is required");
|
||||
}
|
||||
@@ -302,9 +404,10 @@ export class LocalConnection extends Connection {
|
||||
const storageOptions = this.getStorageOptions(options);
|
||||
|
||||
const innerTable = await this.inner.createTable(
|
||||
nameOrOptions,
|
||||
name,
|
||||
buf,
|
||||
mode,
|
||||
namespace ?? [],
|
||||
storageOptions,
|
||||
);
|
||||
|
||||
@@ -314,39 +417,55 @@ export class LocalConnection extends Connection {
|
||||
async createEmptyTable(
|
||||
name: string,
|
||||
schema: import("./arrow").SchemaLike,
|
||||
namespaceOrOptions?: string[] | Partial<CreateTableOptions>,
|
||||
options?: Partial<CreateTableOptions>,
|
||||
): Promise<Table> {
|
||||
let mode: string = options?.mode ?? "create";
|
||||
const existOk = options?.existOk ?? false;
|
||||
// Detect if third argument is namespace array or options object
|
||||
let namespace: string[] | undefined;
|
||||
let createOptions: Partial<CreateTableOptions> | undefined;
|
||||
|
||||
if (Array.isArray(namespaceOrOptions)) {
|
||||
// Third argument is namespace array
|
||||
namespace = namespaceOrOptions;
|
||||
createOptions = options;
|
||||
} else {
|
||||
// Third argument is options object (backwards compatibility)
|
||||
namespace = undefined;
|
||||
createOptions = namespaceOrOptions;
|
||||
}
|
||||
|
||||
let mode: string = createOptions?.mode ?? "create";
|
||||
const existOk = createOptions?.existOk ?? false;
|
||||
|
||||
if (mode === "create" && existOk) {
|
||||
mode = "exist_ok";
|
||||
}
|
||||
let metadata: Map<string, string> | undefined = undefined;
|
||||
if (options?.embeddingFunction !== undefined) {
|
||||
const embeddingFunction = options.embeddingFunction;
|
||||
if (createOptions?.embeddingFunction !== undefined) {
|
||||
const embeddingFunction = createOptions.embeddingFunction;
|
||||
const registry = getRegistry();
|
||||
metadata = registry.getTableMetadata([embeddingFunction]);
|
||||
}
|
||||
|
||||
const storageOptions = this.getStorageOptions(options);
|
||||
const storageOptions = this.getStorageOptions(createOptions);
|
||||
const table = makeEmptyTable(schema, metadata);
|
||||
const buf = await fromTableToBuffer(table);
|
||||
const innerTable = await this.inner.createEmptyTable(
|
||||
name,
|
||||
buf,
|
||||
mode,
|
||||
namespace ?? [],
|
||||
storageOptions,
|
||||
);
|
||||
return new LocalTable(innerTable);
|
||||
}
|
||||
|
||||
async dropTable(name: string): Promise<void> {
|
||||
return this.inner.dropTable(name);
|
||||
async dropTable(name: string, namespace?: string[]): Promise<void> {
|
||||
return this.inner.dropTable(name, namespace ?? []);
|
||||
}
|
||||
|
||||
async dropAllTables(): Promise<void> {
|
||||
return this.inner.dropAllTables();
|
||||
async dropAllTables(namespace?: string[]): Promise<void> {
|
||||
return this.inner.dropAllTables(namespace ?? []);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
253
nodejs/lancedb/header.ts
Normal file
253
nodejs/lancedb/header.ts
Normal file
@@ -0,0 +1,253 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
/**
|
||||
* Header providers for LanceDB remote connections.
|
||||
*
|
||||
* This module provides a flexible header management framework for LanceDB remote
|
||||
* connections, allowing users to implement custom header strategies for
|
||||
* authentication, request tracking, custom metadata, or any other header-based
|
||||
* requirements.
|
||||
*
|
||||
* @module header
|
||||
*/
|
||||
|
||||
/**
|
||||
* Abstract base class for providing custom headers for each request.
|
||||
*
|
||||
* Users can implement this interface to provide dynamic headers for various purposes
|
||||
* such as authentication (OAuth tokens, API keys), request tracking (correlation IDs),
|
||||
* custom metadata, or any other header-based requirements. The provider is called
|
||||
* before each request to ensure fresh header values are always used.
|
||||
*
|
||||
* @example
|
||||
* Simple JWT token provider:
|
||||
* ```typescript
|
||||
* class JWTProvider extends HeaderProvider {
|
||||
* constructor(private token: string) {
|
||||
* super();
|
||||
* }
|
||||
*
|
||||
* getHeaders(): Record<string, string> {
|
||||
* return { authorization: `Bearer ${this.token}` };
|
||||
* }
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* @example
|
||||
* Provider with request tracking:
|
||||
* ```typescript
|
||||
* class RequestTrackingProvider extends HeaderProvider {
|
||||
* constructor(private sessionId: string) {
|
||||
* super();
|
||||
* }
|
||||
*
|
||||
* getHeaders(): Record<string, string> {
|
||||
* return {
|
||||
* "X-Session-Id": this.sessionId,
|
||||
* "X-Request-Id": `req-${Date.now()}`
|
||||
* };
|
||||
* }
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
export abstract class HeaderProvider {
|
||||
/**
|
||||
* Get the latest headers to be added to requests.
|
||||
*
|
||||
* This method is called before each request to the remote LanceDB server.
|
||||
* Implementations should return headers that will be merged with existing headers.
|
||||
*
|
||||
* @returns Dictionary of header names to values to add to the request.
|
||||
* @throws If unable to fetch headers, the exception will be propagated and the request will fail.
|
||||
*/
|
||||
abstract getHeaders(): Record<string, string>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Example implementation: A simple header provider that returns static headers.
|
||||
*
|
||||
* This is an example implementation showing how to create a HeaderProvider
|
||||
* for cases where headers don't change during the session.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const provider = new StaticHeaderProvider({
|
||||
* authorization: "Bearer my-token",
|
||||
* "X-Custom-Header": "custom-value"
|
||||
* });
|
||||
* const headers = provider.getHeaders();
|
||||
* // Returns: {authorization: 'Bearer my-token', 'X-Custom-Header': 'custom-value'}
|
||||
* ```
|
||||
*/
|
||||
export class StaticHeaderProvider extends HeaderProvider {
|
||||
private _headers: Record<string, string>;
|
||||
|
||||
/**
|
||||
* Initialize with static headers.
|
||||
* @param headers - Headers to return for every request.
|
||||
*/
|
||||
constructor(headers: Record<string, string>) {
|
||||
super();
|
||||
this._headers = { ...headers };
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the static headers.
|
||||
* @returns Copy of the static headers.
|
||||
*/
|
||||
getHeaders(): Record<string, string> {
|
||||
return { ...this._headers };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Token response from OAuth provider.
|
||||
* @public
|
||||
*/
|
||||
export interface TokenResponse {
|
||||
accessToken: string;
|
||||
expiresIn?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Example implementation: OAuth token provider with automatic refresh.
|
||||
*
|
||||
* This is an example implementation showing how to manage OAuth tokens
|
||||
* with automatic refresh when they expire.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* async function fetchToken(): Promise<TokenResponse> {
|
||||
* const response = await fetch("https://oauth.example.com/token", {
|
||||
* method: "POST",
|
||||
* body: JSON.stringify({
|
||||
* grant_type: "client_credentials",
|
||||
* client_id: "your-client-id",
|
||||
* client_secret: "your-client-secret"
|
||||
* }),
|
||||
* headers: { "Content-Type": "application/json" }
|
||||
* });
|
||||
* const data = await response.json();
|
||||
* return {
|
||||
* accessToken: data.access_token,
|
||||
* expiresIn: data.expires_in
|
||||
* };
|
||||
* }
|
||||
*
|
||||
* const provider = new OAuthHeaderProvider(fetchToken);
|
||||
* const headers = provider.getHeaders();
|
||||
* // Returns: {"authorization": "Bearer <your-token>"}
|
||||
* ```
|
||||
*/
|
||||
export class OAuthHeaderProvider extends HeaderProvider {
|
||||
private _tokenFetcher: () => Promise<TokenResponse> | TokenResponse;
|
||||
private _refreshBufferSeconds: number;
|
||||
private _currentToken: string | null = null;
|
||||
private _tokenExpiresAt: number | null = null;
|
||||
private _refreshPromise: Promise<void> | null = null;
|
||||
|
||||
/**
|
||||
* Initialize the OAuth provider.
|
||||
* @param tokenFetcher - Function to fetch new tokens. Should return object with 'accessToken' and optionally 'expiresIn'.
|
||||
* @param refreshBufferSeconds - Seconds before expiry to refresh token. Default 300 (5 minutes).
|
||||
*/
|
||||
constructor(
|
||||
tokenFetcher: () => Promise<TokenResponse> | TokenResponse,
|
||||
refreshBufferSeconds: number = 300,
|
||||
) {
|
||||
super();
|
||||
this._tokenFetcher = tokenFetcher;
|
||||
this._refreshBufferSeconds = refreshBufferSeconds;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if token needs refresh.
|
||||
*/
|
||||
private _needsRefresh(): boolean {
|
||||
if (this._currentToken === null) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (this._tokenExpiresAt === null) {
|
||||
// No expiration info, assume token is valid
|
||||
return false;
|
||||
}
|
||||
|
||||
// Refresh if we're within the buffer time of expiration
|
||||
const now = Date.now() / 1000;
|
||||
return now >= this._tokenExpiresAt - this._refreshBufferSeconds;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh the token if it's expired or close to expiring.
|
||||
*/
|
||||
private async _refreshTokenIfNeeded(): Promise<void> {
|
||||
if (!this._needsRefresh()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If refresh is already in progress, wait for it
|
||||
if (this._refreshPromise) {
|
||||
await this._refreshPromise;
|
||||
return;
|
||||
}
|
||||
|
||||
// Start refresh
|
||||
this._refreshPromise = (async () => {
|
||||
try {
|
||||
const tokenData = await this._tokenFetcher();
|
||||
|
||||
this._currentToken = tokenData.accessToken;
|
||||
if (!this._currentToken) {
|
||||
throw new Error("Token fetcher did not return 'accessToken'");
|
||||
}
|
||||
|
||||
// Set expiration if provided
|
||||
if (tokenData.expiresIn) {
|
||||
this._tokenExpiresAt = Date.now() / 1000 + tokenData.expiresIn;
|
||||
} else {
|
||||
// Token doesn't expire or expiration unknown
|
||||
this._tokenExpiresAt = null;
|
||||
}
|
||||
} finally {
|
||||
this._refreshPromise = null;
|
||||
}
|
||||
})();
|
||||
|
||||
await this._refreshPromise;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get OAuth headers, refreshing token if needed.
|
||||
* Note: This is synchronous for now as the Rust implementation expects sync.
|
||||
* In a real implementation, this would need to handle async properly.
|
||||
* @returns Headers with Bearer token authorization.
|
||||
* @throws If unable to fetch or refresh token.
|
||||
*/
|
||||
getHeaders(): Record<string, string> {
|
||||
// For simplicity in this example, we assume the token is already fetched
|
||||
// In a real implementation, this would need to handle the async nature properly
|
||||
if (!this._currentToken && !this._refreshPromise) {
|
||||
// Synchronously trigger refresh - this is a limitation of the current implementation
|
||||
throw new Error(
|
||||
"Token not initialized. Call refreshToken() first or use async initialization.",
|
||||
);
|
||||
}
|
||||
|
||||
if (!this._currentToken) {
|
||||
throw new Error("Failed to obtain OAuth token");
|
||||
}
|
||||
|
||||
return { authorization: `Bearer ${this._currentToken}` };
|
||||
}
|
||||
|
||||
/**
|
||||
* Manually refresh the token.
|
||||
* Call this before using getHeaders() to ensure token is available.
|
||||
*/
|
||||
async refreshToken(): Promise<void> {
|
||||
this._currentToken = null; // Force refresh
|
||||
await this._refreshTokenIfNeeded();
|
||||
}
|
||||
}
|
||||
@@ -10,9 +10,15 @@ import {
|
||||
import {
|
||||
ConnectionOptions,
|
||||
Connection as LanceDbConnection,
|
||||
JsHeaderProvider as NativeJsHeaderProvider,
|
||||
Session,
|
||||
} from "./native.js";
|
||||
|
||||
import { HeaderProvider } from "./header";
|
||||
|
||||
// Re-export native header provider for use with connectWithHeaderProvider
|
||||
export { JsHeaderProvider as NativeJsHeaderProvider } from "./native.js";
|
||||
|
||||
export {
|
||||
AddColumnsSql,
|
||||
ConnectionOptions,
|
||||
@@ -21,6 +27,7 @@ export {
|
||||
ClientConfig,
|
||||
TimeoutConfig,
|
||||
RetryConfig,
|
||||
TlsConfig,
|
||||
OptimizeStats,
|
||||
CompactionStats,
|
||||
RemovalStats,
|
||||
@@ -93,6 +100,13 @@ export {
|
||||
ColumnAlteration,
|
||||
} from "./table";
|
||||
|
||||
export {
|
||||
HeaderProvider,
|
||||
StaticHeaderProvider,
|
||||
OAuthHeaderProvider,
|
||||
TokenResponse,
|
||||
} from "./header";
|
||||
|
||||
export { MergeInsertBuilder, WriteExecutionOptions } from "./merge";
|
||||
|
||||
export * as embedding from "./embedding";
|
||||
@@ -131,11 +145,27 @@ export { IntoSql, packBits } from "./util";
|
||||
* {storageOptions: {timeout: "60s"}
|
||||
* });
|
||||
* ```
|
||||
* @example
|
||||
* Using with a header provider for per-request authentication:
|
||||
* ```ts
|
||||
* const provider = new StaticHeaderProvider({
|
||||
* "X-API-Key": "my-key"
|
||||
* });
|
||||
* const conn = await connectWithHeaderProvider(
|
||||
* "db://host:port",
|
||||
* options,
|
||||
* provider
|
||||
* );
|
||||
* ```
|
||||
*/
|
||||
export async function connect(
|
||||
uri: string,
|
||||
options?: Partial<ConnectionOptions>,
|
||||
session?: Session,
|
||||
headerProvider?:
|
||||
| HeaderProvider
|
||||
| (() => Record<string, string>)
|
||||
| (() => Promise<Record<string, string>>),
|
||||
): Promise<Connection>;
|
||||
/**
|
||||
* Connect to a LanceDB instance at the given URI.
|
||||
@@ -169,18 +199,58 @@ export async function connect(
|
||||
): Promise<Connection>;
|
||||
export async function connect(
|
||||
uriOrOptions: string | (Partial<ConnectionOptions> & { uri: string }),
|
||||
options?: Partial<ConnectionOptions>,
|
||||
optionsOrSession?: Partial<ConnectionOptions> | Session,
|
||||
sessionOrHeaderProvider?:
|
||||
| Session
|
||||
| HeaderProvider
|
||||
| (() => Record<string, string>)
|
||||
| (() => Promise<Record<string, string>>),
|
||||
headerProvider?:
|
||||
| HeaderProvider
|
||||
| (() => Record<string, string>)
|
||||
| (() => Promise<Record<string, string>>),
|
||||
): Promise<Connection> {
|
||||
let uri: string | undefined;
|
||||
let finalOptions: Partial<ConnectionOptions> = {};
|
||||
let finalHeaderProvider:
|
||||
| HeaderProvider
|
||||
| (() => Record<string, string>)
|
||||
| (() => Promise<Record<string, string>>)
|
||||
| undefined;
|
||||
|
||||
if (typeof uriOrOptions !== "string") {
|
||||
// First overload: connect(options)
|
||||
const { uri: uri_, ...opts } = uriOrOptions;
|
||||
uri = uri_;
|
||||
finalOptions = opts;
|
||||
} else {
|
||||
// Second overload: connect(uri, options?, session?, headerProvider?)
|
||||
uri = uriOrOptions;
|
||||
finalOptions = options || {};
|
||||
|
||||
// Handle optionsOrSession parameter
|
||||
if (optionsOrSession && "inner" in optionsOrSession) {
|
||||
// Second param is session, so no options provided
|
||||
finalOptions = {};
|
||||
} else {
|
||||
// Second param is options
|
||||
finalOptions = (optionsOrSession as Partial<ConnectionOptions>) || {};
|
||||
}
|
||||
|
||||
// Handle sessionOrHeaderProvider parameter
|
||||
if (
|
||||
sessionOrHeaderProvider &&
|
||||
(typeof sessionOrHeaderProvider === "function" ||
|
||||
"getHeaders" in sessionOrHeaderProvider)
|
||||
) {
|
||||
// Third param is header provider
|
||||
finalHeaderProvider = sessionOrHeaderProvider as
|
||||
| HeaderProvider
|
||||
| (() => Record<string, string>)
|
||||
| (() => Promise<Record<string, string>>);
|
||||
} else {
|
||||
// Third param is session, header provider is fourth param
|
||||
finalHeaderProvider = headerProvider;
|
||||
}
|
||||
}
|
||||
|
||||
if (!uri) {
|
||||
@@ -191,6 +261,26 @@ export async function connect(
|
||||
(<ConnectionOptions>finalOptions).storageOptions = cleanseStorageOptions(
|
||||
(<ConnectionOptions>finalOptions).storageOptions,
|
||||
);
|
||||
const nativeConn = await LanceDbConnection.new(uri, finalOptions);
|
||||
|
||||
// Create native header provider if one was provided
|
||||
let nativeProvider: NativeJsHeaderProvider | undefined;
|
||||
if (finalHeaderProvider) {
|
||||
if (typeof finalHeaderProvider === "function") {
|
||||
nativeProvider = new NativeJsHeaderProvider(finalHeaderProvider);
|
||||
} else if (
|
||||
finalHeaderProvider &&
|
||||
typeof finalHeaderProvider.getHeaders === "function"
|
||||
) {
|
||||
nativeProvider = new NativeJsHeaderProvider(async () =>
|
||||
finalHeaderProvider.getHeaders(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const nativeConn = await LanceDbConnection.new(
|
||||
uri,
|
||||
finalOptions,
|
||||
nativeProvider,
|
||||
);
|
||||
return new LocalConnection(nativeConn);
|
||||
}
|
||||
|
||||
@@ -700,5 +700,27 @@ export interface IndexOptions {
|
||||
*/
|
||||
replace?: boolean;
|
||||
|
||||
/**
|
||||
* Timeout in seconds to wait for index creation to complete.
|
||||
*
|
||||
* If not specified, the method will return immediately after starting the index creation.
|
||||
*/
|
||||
waitTimeoutSeconds?: number;
|
||||
|
||||
/**
|
||||
* Optional custom name for the index.
|
||||
*
|
||||
* If not provided, a default name will be generated based on the column name.
|
||||
*/
|
||||
name?: string;
|
||||
|
||||
/**
|
||||
* Whether to train the index with existing data.
|
||||
*
|
||||
* If true (default), the index will be trained with existing data in the table.
|
||||
* If false, the index will be created empty and populated as new data is added.
|
||||
*
|
||||
* Note: This option is only supported for scalar indices. Vector indices always train.
|
||||
*/
|
||||
train?: boolean;
|
||||
}
|
||||
|
||||
@@ -662,6 +662,8 @@ export class LocalTable extends Table {
|
||||
column,
|
||||
options?.replace,
|
||||
options?.waitTimeoutSeconds,
|
||||
options?.name,
|
||||
options?.train,
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.21.2",
|
||||
"version": "0.22.1-beta.0",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-x64",
|
||||
"version": "0.21.2",
|
||||
"version": "0.22.1-beta.0",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.darwin-x64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.21.2",
|
||||
"version": "0.22.1-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-musl",
|
||||
"version": "0.21.2",
|
||||
"version": "0.22.1-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.21.2",
|
||||
"version": "0.22.1-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-musl",
|
||||
"version": "0.21.2",
|
||||
"version": "0.22.1-beta.0",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-musl.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-arm64-msvc",
|
||||
"version": "0.21.2",
|
||||
"version": "0.22.1-beta.0",
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.21.2",
|
||||
"version": "0.22.1-beta.0",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
228
nodejs/package-lock.json
generated
228
nodejs/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.21.2",
|
||||
"version": "0.22.1-beta.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@lancedb/lancedb",
|
||||
"version": "0.21.2",
|
||||
"version": "0.22.1-beta.0",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
@@ -5549,10 +5549,11 @@
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/brace-expansion": {
|
||||
"version": "1.1.11",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz",
|
||||
"integrity": "sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA==",
|
||||
"version": "1.1.12",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz",
|
||||
"integrity": "sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"balanced-match": "^1.0.0",
|
||||
"concat-map": "0.0.1"
|
||||
@@ -5629,6 +5630,20 @@
|
||||
"integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/call-bind-apply-helpers": {
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz",
|
||||
"integrity": "sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==",
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"es-errors": "^1.3.0",
|
||||
"function-bind": "^1.1.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/camelcase": {
|
||||
"version": "5.3.1",
|
||||
"resolved": "https://registry.npmjs.org/camelcase/-/camelcase-5.3.1.tgz",
|
||||
@@ -6032,6 +6047,21 @@
|
||||
"node": ">=6.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/dunder-proto": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/dunder-proto/-/dunder-proto-1.0.1.tgz",
|
||||
"integrity": "sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==",
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"call-bind-apply-helpers": "^1.0.1",
|
||||
"es-errors": "^1.3.0",
|
||||
"gopd": "^1.2.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/eastasianwidth": {
|
||||
"version": "0.2.0",
|
||||
"resolved": "https://registry.npmjs.org/eastasianwidth/-/eastasianwidth-0.2.0.tgz",
|
||||
@@ -6071,6 +6101,55 @@
|
||||
"is-arrayish": "^0.2.1"
|
||||
}
|
||||
},
|
||||
"node_modules/es-define-property": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.1.tgz",
|
||||
"integrity": "sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==",
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/es-errors": {
|
||||
"version": "1.3.0",
|
||||
"resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz",
|
||||
"integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==",
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/es-object-atoms": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.1.tgz",
|
||||
"integrity": "sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==",
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"es-errors": "^1.3.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/es-set-tostringtag": {
|
||||
"version": "2.1.0",
|
||||
"resolved": "https://registry.npmjs.org/es-set-tostringtag/-/es-set-tostringtag-2.1.0.tgz",
|
||||
"integrity": "sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==",
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"es-errors": "^1.3.0",
|
||||
"get-intrinsic": "^1.2.6",
|
||||
"has-tostringtag": "^1.0.2",
|
||||
"hasown": "^2.0.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/escalade": {
|
||||
"version": "3.1.1",
|
||||
"resolved": "https://registry.npmjs.org/escalade/-/escalade-3.1.1.tgz",
|
||||
@@ -6510,13 +6589,16 @@
|
||||
}
|
||||
},
|
||||
"node_modules/form-data": {
|
||||
"version": "4.0.0",
|
||||
"resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.0.tgz",
|
||||
"integrity": "sha512-ETEklSGi5t0QMZuiXoA/Q6vcnxcLQP5vdugSpuAyi6SVGi2clPPp+xgEhuMaHC+zGgn31Kd235W35f7Hykkaww==",
|
||||
"version": "4.0.4",
|
||||
"resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.4.tgz",
|
||||
"integrity": "sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow==",
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"asynckit": "^0.4.0",
|
||||
"combined-stream": "^1.0.8",
|
||||
"es-set-tostringtag": "^2.1.0",
|
||||
"hasown": "^2.0.2",
|
||||
"mime-types": "^2.1.12"
|
||||
},
|
||||
"engines": {
|
||||
@@ -6575,7 +6657,7 @@
|
||||
"version": "1.1.2",
|
||||
"resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz",
|
||||
"integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==",
|
||||
"dev": true,
|
||||
"devOptional": true,
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
@@ -6598,6 +6680,31 @@
|
||||
"node": "6.* || 8.* || >= 10.*"
|
||||
}
|
||||
},
|
||||
"node_modules/get-intrinsic": {
|
||||
"version": "1.3.0",
|
||||
"resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.3.0.tgz",
|
||||
"integrity": "sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==",
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"call-bind-apply-helpers": "^1.0.2",
|
||||
"es-define-property": "^1.0.1",
|
||||
"es-errors": "^1.3.0",
|
||||
"es-object-atoms": "^1.1.1",
|
||||
"function-bind": "^1.1.2",
|
||||
"get-proto": "^1.0.1",
|
||||
"gopd": "^1.2.0",
|
||||
"has-symbols": "^1.1.0",
|
||||
"hasown": "^2.0.2",
|
||||
"math-intrinsics": "^1.1.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/get-package-type": {
|
||||
"version": "0.1.0",
|
||||
"resolved": "https://registry.npmjs.org/get-package-type/-/get-package-type-0.1.0.tgz",
|
||||
@@ -6607,6 +6714,20 @@
|
||||
"node": ">=8.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/get-proto": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/get-proto/-/get-proto-1.0.1.tgz",
|
||||
"integrity": "sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==",
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"dunder-proto": "^1.0.1",
|
||||
"es-object-atoms": "^1.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/get-stream": {
|
||||
"version": "6.0.1",
|
||||
"resolved": "https://registry.npmjs.org/get-stream/-/get-stream-6.0.1.tgz",
|
||||
@@ -6698,6 +6819,19 @@
|
||||
"url": "https://github.com/sponsors/sindresorhus"
|
||||
}
|
||||
},
|
||||
"node_modules/gopd": {
|
||||
"version": "1.2.0",
|
||||
"resolved": "https://registry.npmjs.org/gopd/-/gopd-1.2.0.tgz",
|
||||
"integrity": "sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==",
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/graceful-fs": {
|
||||
"version": "4.2.11",
|
||||
"resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.11.tgz",
|
||||
@@ -6724,11 +6858,41 @@
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/has-symbols": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz",
|
||||
"integrity": "sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==",
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/has-tostringtag": {
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/has-tostringtag/-/has-tostringtag-1.0.2.tgz",
|
||||
"integrity": "sha512-NqADB8VjPFLM2V0VvHUewwwsw0ZWBaIdgo+ieHtK3hasLz4qeCRjYcqfB6AQrBggRKppKF8L52/VqdVsO47Dlw==",
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"has-symbols": "^1.0.3"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/hasown": {
|
||||
"version": "2.0.0",
|
||||
"resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.0.tgz",
|
||||
"integrity": "sha512-vUptKVTpIJhcczKBbgnS+RtcuYMB8+oNzPK2/Hp3hanz8JmpATdmmgLgSaadVREkDm+e2giHwY3ZRkyjSIDDFA==",
|
||||
"dev": true,
|
||||
"version": "2.0.2",
|
||||
"resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz",
|
||||
"integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==",
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"function-bind": "^1.1.2"
|
||||
},
|
||||
@@ -7943,6 +8107,16 @@
|
||||
"integrity": "sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/math-intrinsics": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz",
|
||||
"integrity": "sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==",
|
||||
"devOptional": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/md5": {
|
||||
"version": "2.3.0",
|
||||
"resolved": "https://registry.npmjs.org/md5/-/md5-2.3.0.tgz",
|
||||
@@ -8053,9 +8227,10 @@
|
||||
}
|
||||
},
|
||||
"node_modules/minizlib/node_modules/brace-expansion": {
|
||||
"version": "2.0.1",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz",
|
||||
"integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==",
|
||||
"version": "2.0.2",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz",
|
||||
"integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==",
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"balanced-match": "^1.0.0"
|
||||
@@ -9201,10 +9376,11 @@
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/tmp": {
|
||||
"version": "0.2.3",
|
||||
"resolved": "https://registry.npmjs.org/tmp/-/tmp-0.2.3.tgz",
|
||||
"integrity": "sha512-nZD7m9iCPC5g0pYmcaxogYKggSfLsdxl8of3Q/oIbqCqLLIO9IAF0GWjX1z9NZRHPiXv8Wex4yDCaZsgEw0Y8w==",
|
||||
"version": "0.2.5",
|
||||
"resolved": "https://registry.npmjs.org/tmp/-/tmp-0.2.5.tgz",
|
||||
"integrity": "sha512-voyz6MApa1rQGUxT3E+BK7/ROe8itEx7vD8/HEvt4xwXucvQ5G5oeEiHkmHZJuBO21RpOf+YYm9MOivj709jow==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=14.14"
|
||||
}
|
||||
@@ -9349,10 +9525,11 @@
|
||||
}
|
||||
},
|
||||
"node_modules/typedoc/node_modules/brace-expansion": {
|
||||
"version": "2.0.1",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz",
|
||||
"integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==",
|
||||
"version": "2.0.2",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz",
|
||||
"integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"balanced-match": "^1.0.0"
|
||||
}
|
||||
@@ -9602,10 +9779,11 @@
|
||||
}
|
||||
},
|
||||
"node_modules/typescript-eslint/node_modules/brace-expansion": {
|
||||
"version": "2.0.1",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz",
|
||||
"integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==",
|
||||
"version": "2.0.2",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz",
|
||||
"integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"balanced-match": "^1.0.0"
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
"ann"
|
||||
],
|
||||
"private": false,
|
||||
"version": "0.21.2",
|
||||
"version": "0.22.1-beta.0",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
|
||||
@@ -2,12 +2,14 @@
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use lancedb::database::CreateTableMode;
|
||||
use napi::bindgen_prelude::*;
|
||||
use napi_derive::*;
|
||||
|
||||
use crate::error::NapiErrorExt;
|
||||
use crate::header::JsHeaderProvider;
|
||||
use crate::table::Table;
|
||||
use crate::ConnectionOptions;
|
||||
use lancedb::connection::{ConnectBuilder, Connection as LanceDBConnection};
|
||||
@@ -45,7 +47,11 @@ impl Connection {
|
||||
impl Connection {
|
||||
/// Create a new Connection instance from the given URI.
|
||||
#[napi(factory)]
|
||||
pub async fn new(uri: String, options: ConnectionOptions) -> napi::Result<Self> {
|
||||
pub async fn new(
|
||||
uri: String,
|
||||
options: ConnectionOptions,
|
||||
header_provider: Option<&JsHeaderProvider>,
|
||||
) -> napi::Result<Self> {
|
||||
let mut builder = ConnectBuilder::new(&uri);
|
||||
if let Some(interval) = options.read_consistency_interval {
|
||||
builder =
|
||||
@@ -57,8 +63,16 @@ impl Connection {
|
||||
}
|
||||
}
|
||||
|
||||
// Create client config, optionally with header provider
|
||||
let client_config = options.client_config.unwrap_or_default();
|
||||
builder = builder.client_config(client_config.into());
|
||||
let mut rust_config: lancedb::remote::ClientConfig = client_config.into();
|
||||
|
||||
if let Some(provider) = header_provider {
|
||||
rust_config.header_provider =
|
||||
Some(Arc::new(provider.clone()) as Arc<dyn lancedb::remote::HeaderProvider>);
|
||||
}
|
||||
|
||||
builder = builder.client_config(rust_config);
|
||||
|
||||
if let Some(api_key) = options.api_key {
|
||||
builder = builder.api_key(&api_key);
|
||||
@@ -100,10 +114,12 @@ impl Connection {
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn table_names(
|
||||
&self,
|
||||
namespace: Vec<String>,
|
||||
start_after: Option<String>,
|
||||
limit: Option<u32>,
|
||||
) -> napi::Result<Vec<String>> {
|
||||
let mut op = self.get_inner()?.table_names();
|
||||
op = op.namespace(namespace);
|
||||
if let Some(start_after) = start_after {
|
||||
op = op.start_after(start_after);
|
||||
}
|
||||
@@ -125,6 +141,7 @@ impl Connection {
|
||||
name: String,
|
||||
buf: Buffer,
|
||||
mode: String,
|
||||
namespace: Vec<String>,
|
||||
storage_options: Option<HashMap<String, String>>,
|
||||
) -> napi::Result<Table> {
|
||||
let batches = ipc_file_to_batches(buf.to_vec())
|
||||
@@ -132,6 +149,8 @@ impl Connection {
|
||||
let mode = Self::parse_create_mode_str(&mode)?;
|
||||
let mut builder = self.get_inner()?.create_table(&name, batches).mode(mode);
|
||||
|
||||
builder = builder.namespace(namespace);
|
||||
|
||||
if let Some(storage_options) = storage_options {
|
||||
for (key, value) in storage_options {
|
||||
builder = builder.storage_option(key, value);
|
||||
@@ -147,6 +166,7 @@ impl Connection {
|
||||
name: String,
|
||||
schema_buf: Buffer,
|
||||
mode: String,
|
||||
namespace: Vec<String>,
|
||||
storage_options: Option<HashMap<String, String>>,
|
||||
) -> napi::Result<Table> {
|
||||
let schema = ipc_file_to_schema(schema_buf.to_vec()).map_err(|e| {
|
||||
@@ -157,6 +177,9 @@ impl Connection {
|
||||
.get_inner()?
|
||||
.create_empty_table(&name, schema)
|
||||
.mode(mode);
|
||||
|
||||
builder = builder.namespace(namespace);
|
||||
|
||||
if let Some(storage_options) = storage_options {
|
||||
for (key, value) in storage_options {
|
||||
builder = builder.storage_option(key, value);
|
||||
@@ -170,10 +193,14 @@ impl Connection {
|
||||
pub async fn open_table(
|
||||
&self,
|
||||
name: String,
|
||||
namespace: Vec<String>,
|
||||
storage_options: Option<HashMap<String, String>>,
|
||||
index_cache_size: Option<u32>,
|
||||
) -> napi::Result<Table> {
|
||||
let mut builder = self.get_inner()?.open_table(&name);
|
||||
|
||||
builder = builder.namespace(namespace);
|
||||
|
||||
if let Some(storage_options) = storage_options {
|
||||
for (key, value) in storage_options {
|
||||
builder = builder.storage_option(key, value);
|
||||
@@ -188,12 +215,18 @@ impl Connection {
|
||||
|
||||
/// Drop table with the name. Or raise an error if the table does not exist.
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn drop_table(&self, name: String) -> napi::Result<()> {
|
||||
self.get_inner()?.drop_table(&name).await.default_error()
|
||||
pub async fn drop_table(&self, name: String, namespace: Vec<String>) -> napi::Result<()> {
|
||||
self.get_inner()?
|
||||
.drop_table(&name, &namespace)
|
||||
.await
|
||||
.default_error()
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn drop_all_tables(&self) -> napi::Result<()> {
|
||||
self.get_inner()?.drop_all_tables().await.default_error()
|
||||
pub async fn drop_all_tables(&self, namespace: Vec<String>) -> napi::Result<()> {
|
||||
self.get_inner()?
|
||||
.drop_all_tables(&namespace)
|
||||
.await
|
||||
.default_error()
|
||||
}
|
||||
}
|
||||
|
||||
71
nodejs/src/header.rs
Normal file
71
nodejs/src/header.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use napi::{
|
||||
bindgen_prelude::*,
|
||||
threadsafe_function::{ErrorStrategy, ThreadsafeFunction},
|
||||
};
|
||||
use napi_derive::napi;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// JavaScript HeaderProvider implementation that wraps a JavaScript callback.
|
||||
/// This is the only native header provider - all header provider implementations
|
||||
/// should provide a JavaScript function that returns headers.
|
||||
#[napi]
|
||||
pub struct JsHeaderProvider {
|
||||
get_headers_fn: Arc<ThreadsafeFunction<(), ErrorStrategy::CalleeHandled>>,
|
||||
}
|
||||
|
||||
impl Clone for JsHeaderProvider {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
get_headers_fn: self.get_headers_fn.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl JsHeaderProvider {
|
||||
/// Create a new JsHeaderProvider from a JavaScript callback
|
||||
#[napi(constructor)]
|
||||
pub fn new(get_headers_callback: JsFunction) -> Result<Self> {
|
||||
let get_headers_fn = get_headers_callback
|
||||
.create_threadsafe_function(0, |ctx| Ok(vec![ctx.value]))
|
||||
.map_err(|e| {
|
||||
Error::new(
|
||||
Status::GenericFailure,
|
||||
format!("Failed to create threadsafe function: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Self {
|
||||
get_headers_fn: Arc::new(get_headers_fn),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
#[async_trait::async_trait]
|
||||
impl lancedb::remote::HeaderProvider for JsHeaderProvider {
|
||||
async fn get_headers(&self) -> lancedb::error::Result<HashMap<String, String>> {
|
||||
// Call the JavaScript function asynchronously
|
||||
let promise: Promise<HashMap<String, String>> =
|
||||
self.get_headers_fn.call_async(Ok(())).await.map_err(|e| {
|
||||
lancedb::error::Error::Runtime {
|
||||
message: format!("Failed to call JavaScript get_headers: {}", e),
|
||||
}
|
||||
})?;
|
||||
|
||||
// Await the promise result
|
||||
promise.await.map_err(|e| lancedb::error::Error::Runtime {
|
||||
message: format!("JavaScript get_headers failed: {}", e),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for JsHeaderProvider {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "JsHeaderProvider")
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ use napi_derive::*;
|
||||
|
||||
mod connection;
|
||||
mod error;
|
||||
mod header;
|
||||
mod index;
|
||||
mod iterator;
|
||||
pub mod merge;
|
||||
|
||||
@@ -480,6 +480,7 @@ impl JsFullTextQuery {
|
||||
}
|
||||
|
||||
#[napi(factory)]
|
||||
#[allow(clippy::use_self)] // NAPI doesn't allow Self here but clippy reports it
|
||||
pub fn boolean_query(queries: Vec<(String, &JsFullTextQuery)>) -> napi::Result<Self> {
|
||||
let mut sub_queries = Vec::with_capacity(queries.len());
|
||||
for (occur, q) in queries {
|
||||
|
||||
@@ -69,6 +69,20 @@ pub struct RetryConfig {
|
||||
pub statuses: Option<Vec<u16>>,
|
||||
}
|
||||
|
||||
/// TLS/mTLS configuration for the remote HTTP client.
|
||||
#[napi(object)]
|
||||
#[derive(Debug, Default)]
|
||||
pub struct TlsConfig {
|
||||
/// Path to the client certificate file (PEM format) for mTLS authentication.
|
||||
pub cert_file: Option<String>,
|
||||
/// Path to the client private key file (PEM format) for mTLS authentication.
|
||||
pub key_file: Option<String>,
|
||||
/// Path to the CA certificate file (PEM format) for server verification.
|
||||
pub ssl_ca_cert: Option<String>,
|
||||
/// Whether to verify the hostname in the server's certificate.
|
||||
pub assert_hostname: Option<bool>,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ClientConfig {
|
||||
@@ -76,6 +90,8 @@ pub struct ClientConfig {
|
||||
pub retry_config: Option<RetryConfig>,
|
||||
pub timeout_config: Option<TimeoutConfig>,
|
||||
pub extra_headers: Option<HashMap<String, String>>,
|
||||
pub id_delimiter: Option<String>,
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
}
|
||||
|
||||
impl From<TimeoutConfig> for lancedb::remote::TimeoutConfig {
|
||||
@@ -106,6 +122,17 @@ impl From<RetryConfig> for lancedb::remote::RetryConfig {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TlsConfig> for lancedb::remote::TlsConfig {
|
||||
fn from(config: TlsConfig) -> Self {
|
||||
Self {
|
||||
cert_file: config.cert_file,
|
||||
key_file: config.key_file,
|
||||
ssl_ca_cert: config.ssl_ca_cert,
|
||||
assert_hostname: config.assert_hostname.unwrap_or(true),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ClientConfig> for lancedb::remote::ClientConfig {
|
||||
fn from(config: ClientConfig) -> Self {
|
||||
Self {
|
||||
@@ -115,6 +142,9 @@ impl From<ClientConfig> for lancedb::remote::ClientConfig {
|
||||
retry_config: config.retry_config.map(Into::into).unwrap_or_default(),
|
||||
timeout_config: config.timeout_config.map(Into::into).unwrap_or_default(),
|
||||
extra_headers: config.extra_headers.unwrap_or_default(),
|
||||
id_delimiter: config.id_delimiter,
|
||||
tls_config: config.tls_config.map(Into::into),
|
||||
header_provider: None, // the header provider is set separately later
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,7 +94,7 @@ impl napi::bindgen_prelude::FromNapiValue for Session {
|
||||
env: napi::sys::napi_env,
|
||||
napi_val: napi::sys::napi_value,
|
||||
) -> napi::Result<Self> {
|
||||
let object: napi::bindgen_prelude::ClassInstance<Session> =
|
||||
let object: napi::bindgen_prelude::ClassInstance<Self> =
|
||||
napi::bindgen_prelude::ClassInstance::from_napi_value(env, napi_val)?;
|
||||
let copy = object.clone();
|
||||
Ok(copy)
|
||||
|
||||
@@ -114,6 +114,8 @@ impl Table {
|
||||
column: String,
|
||||
replace: Option<bool>,
|
||||
wait_timeout_s: Option<i64>,
|
||||
name: Option<String>,
|
||||
train: Option<bool>,
|
||||
) -> napi::Result<()> {
|
||||
let lancedb_index = if let Some(index) = index {
|
||||
index.consume()?
|
||||
@@ -128,6 +130,12 @@ impl Table {
|
||||
builder =
|
||||
builder.wait_timeout(std::time::Duration::from_secs(timeout.try_into().unwrap()));
|
||||
}
|
||||
if let Some(name) = name {
|
||||
builder = builder.name(name);
|
||||
}
|
||||
if let Some(train) = train {
|
||||
builder = builder.train(train);
|
||||
}
|
||||
builder.execute().await.default_error()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.24.3"
|
||||
current_version = "0.25.1-beta.1"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.24.3"
|
||||
version = "0.25.1-beta.1"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
@@ -15,6 +15,7 @@ crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
arrow = { version = "55.1", features = ["pyarrow"] }
|
||||
async-trait = "0.1"
|
||||
lancedb = { path = "../rust/lancedb", default-features = false }
|
||||
env_logger.workspace = true
|
||||
pyo3 = { version = "0.24", features = ["extension-module", "abi3-py39"] }
|
||||
@@ -33,6 +34,6 @@ pyo3-build-config = { version = "0.24", features = [
|
||||
] }
|
||||
|
||||
[features]
|
||||
default = ["remote"]
|
||||
default = ["remote", "lancedb/default"]
|
||||
fp16kernels = ["lancedb/fp16kernels"]
|
||||
remote = ["lancedb/remote"]
|
||||
|
||||
@@ -10,6 +10,7 @@ dependencies = [
|
||||
"pyarrow>=16",
|
||||
"pydantic>=1.10",
|
||||
"tqdm>=4.27.0",
|
||||
"lance-namespace==0.0.6"
|
||||
]
|
||||
description = "lancedb"
|
||||
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
|
||||
|
||||
@@ -19,6 +19,7 @@ from .remote.db import RemoteDBConnection
|
||||
from .schema import vector
|
||||
from .table import AsyncTable
|
||||
from ._lancedb import Session
|
||||
from .namespace import connect_namespace, LanceNamespaceDBConnection
|
||||
|
||||
|
||||
def connect(
|
||||
@@ -221,6 +222,7 @@ async def connect_async(
|
||||
__all__ = [
|
||||
"connect",
|
||||
"connect_async",
|
||||
"connect_namespace",
|
||||
"AsyncConnection",
|
||||
"AsyncTable",
|
||||
"URI",
|
||||
@@ -228,6 +230,7 @@ __all__ = [
|
||||
"vector",
|
||||
"DBConnection",
|
||||
"LanceDBConnection",
|
||||
"LanceNamespaceDBConnection",
|
||||
"RemoteDBConnection",
|
||||
"Session",
|
||||
"__version__",
|
||||
|
||||
@@ -21,14 +21,28 @@ class Session:
|
||||
|
||||
class Connection(object):
|
||||
uri: str
|
||||
async def is_open(self): ...
|
||||
async def close(self): ...
|
||||
async def list_namespaces(
|
||||
self,
|
||||
namespace: List[str],
|
||||
page_token: Optional[str],
|
||||
limit: Optional[int],
|
||||
) -> List[str]: ...
|
||||
async def create_namespace(self, namespace: List[str]) -> None: ...
|
||||
async def drop_namespace(self, namespace: List[str]) -> None: ...
|
||||
async def table_names(
|
||||
self, start_after: Optional[str], limit: Optional[int]
|
||||
self,
|
||||
namespace: List[str],
|
||||
start_after: Optional[str],
|
||||
limit: Optional[int],
|
||||
) -> list[str]: ...
|
||||
async def create_table(
|
||||
self,
|
||||
name: str,
|
||||
mode: str,
|
||||
data: pa.RecordBatchReader,
|
||||
namespace: List[str] = [],
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
) -> Table: ...
|
||||
async def create_empty_table(
|
||||
@@ -36,10 +50,25 @@ class Connection(object):
|
||||
name: str,
|
||||
mode: str,
|
||||
schema: pa.Schema,
|
||||
namespace: List[str] = [],
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
) -> Table: ...
|
||||
async def rename_table(self, old_name: str, new_name: str) -> None: ...
|
||||
async def drop_table(self, name: str) -> None: ...
|
||||
async def open_table(
|
||||
self,
|
||||
name: str,
|
||||
namespace: List[str] = [],
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
) -> Table: ...
|
||||
async def rename_table(
|
||||
self,
|
||||
cur_name: str,
|
||||
new_name: str,
|
||||
cur_namespace: List[str] = [],
|
||||
new_namespace: List[str] = [],
|
||||
) -> None: ...
|
||||
async def drop_table(self, name: str, namespace: List[str] = []) -> None: ...
|
||||
async def drop_all_tables(self, namespace: List[str] = []) -> None: ...
|
||||
|
||||
class Table:
|
||||
def name(self) -> str: ...
|
||||
@@ -59,6 +88,10 @@ class Table:
|
||||
column: str,
|
||||
index: Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS],
|
||||
replace: Optional[bool],
|
||||
wait_timeout: Optional[object],
|
||||
*,
|
||||
name: Optional[str],
|
||||
train: Optional[bool],
|
||||
): ...
|
||||
async def list_versions(self) -> List[Dict[str, Any]]: ...
|
||||
async def version(self) -> int: ...
|
||||
|
||||
@@ -43,14 +43,70 @@ if TYPE_CHECKING:
|
||||
class DBConnection(EnforceOverrides):
|
||||
"""An active LanceDB connection interface."""
|
||||
|
||||
def list_namespaces(
|
||||
self,
|
||||
namespace: List[str] = [],
|
||||
page_token: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
) -> Iterable[str]:
|
||||
"""List immediate child namespace names in the given namespace.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
namespace: List[str], default []
|
||||
The parent namespace to list namespaces in.
|
||||
Empty list represents root namespace.
|
||||
page_token: str, optional
|
||||
The token to use for pagination. If not present, start from the beginning.
|
||||
limit: int, default 10
|
||||
The size of the page to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Iterable of str
|
||||
List of immediate child namespace names
|
||||
"""
|
||||
return []
|
||||
|
||||
def create_namespace(self, namespace: List[str]) -> None:
|
||||
"""Create a new namespace.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
namespace: List[str]
|
||||
The namespace identifier to create.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Namespace operations are not supported for this connection type"
|
||||
)
|
||||
|
||||
def drop_namespace(self, namespace: List[str]) -> None:
|
||||
"""Drop a namespace.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
namespace: List[str]
|
||||
The namespace identifier to drop.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Namespace operations are not supported for this connection type"
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def table_names(
|
||||
self, page_token: Optional[str] = None, limit: int = 10
|
||||
self,
|
||||
page_token: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
*,
|
||||
namespace: List[str] = [],
|
||||
) -> Iterable[str]:
|
||||
"""List all tables in this database, in sorted order
|
||||
|
||||
Parameters
|
||||
----------
|
||||
namespace: List[str], default []
|
||||
The namespace to list tables in.
|
||||
Empty list represents root namespace.
|
||||
page_token: str, optional
|
||||
The token to use for pagination. If not present, start from the beginning.
|
||||
Typically, this token is last table name from the previous page.
|
||||
@@ -77,6 +133,7 @@ class DBConnection(EnforceOverrides):
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
*,
|
||||
namespace: List[str] = [],
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
data_storage_version: Optional[str] = None,
|
||||
enable_v2_manifest_paths: Optional[bool] = None,
|
||||
@@ -87,6 +144,9 @@ class DBConnection(EnforceOverrides):
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
namespace: List[str], default []
|
||||
The namespace to create the table in.
|
||||
Empty list represents root namespace.
|
||||
data: The data to initialize the table, *optional*
|
||||
User must provide at least one of `data` or `schema`.
|
||||
Acceptable types are:
|
||||
@@ -238,6 +298,7 @@ class DBConnection(EnforceOverrides):
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
namespace: List[str] = [],
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
) -> Table:
|
||||
@@ -247,6 +308,9 @@ class DBConnection(EnforceOverrides):
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
namespace: List[str], optional
|
||||
The namespace to open the table from.
|
||||
None or empty list represents root namespace.
|
||||
index_cache_size: int, default 256
|
||||
**Deprecated**: Use session-level cache configuration instead.
|
||||
Create a Session with custom cache sizes and pass it to lancedb.connect().
|
||||
@@ -272,17 +336,26 @@ class DBConnection(EnforceOverrides):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def drop_table(self, name: str):
|
||||
def drop_table(self, name: str, namespace: List[str] = []):
|
||||
"""Drop a table from the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
namespace: List[str], default []
|
||||
The namespace to drop the table from.
|
||||
Empty list represents root namespace.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def rename_table(self, cur_name: str, new_name: str):
|
||||
def rename_table(
|
||||
self,
|
||||
cur_name: str,
|
||||
new_name: str,
|
||||
cur_namespace: List[str] = [],
|
||||
new_namespace: List[str] = [],
|
||||
):
|
||||
"""Rename a table in the database.
|
||||
|
||||
Parameters
|
||||
@@ -291,6 +364,12 @@ class DBConnection(EnforceOverrides):
|
||||
The current name of the table.
|
||||
new_name: str
|
||||
The new name of the table.
|
||||
cur_namespace: List[str], optional
|
||||
The namespace of the current table.
|
||||
None or empty list represents root namespace.
|
||||
new_namespace: List[str], optional
|
||||
The namespace to move the table to.
|
||||
If not specified, defaults to the same as cur_namespace.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -301,9 +380,15 @@ class DBConnection(EnforceOverrides):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def drop_all_tables(self):
|
||||
def drop_all_tables(self, namespace: List[str] = []):
|
||||
"""
|
||||
Drop all tables from the database
|
||||
|
||||
Parameters
|
||||
----------
|
||||
namespace: List[str], optional
|
||||
The namespace to drop all tables from.
|
||||
None or empty list represents root namespace.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -404,18 +489,87 @@ class LanceDBConnection(DBConnection):
|
||||
conn = AsyncConnection(await lancedb_connect(self.uri))
|
||||
return await conn.table_names(start_after=start_after, limit=limit)
|
||||
|
||||
@override
|
||||
def list_namespaces(
|
||||
self,
|
||||
namespace: List[str] = [],
|
||||
page_token: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
) -> Iterable[str]:
|
||||
"""List immediate child namespace names in the given namespace.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
namespace: List[str], optional
|
||||
The parent namespace to list namespaces in.
|
||||
None or empty list represents root namespace.
|
||||
page_token: str, optional
|
||||
The token to use for pagination. If not present, start from the beginning.
|
||||
limit: int, default 10
|
||||
The size of the page to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Iterable of str
|
||||
List of immediate child namespace names
|
||||
"""
|
||||
return LOOP.run(
|
||||
self._conn.list_namespaces(
|
||||
namespace=namespace, page_token=page_token, limit=limit
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def create_namespace(self, namespace: List[str]) -> None:
|
||||
"""Create a new namespace.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
namespace: List[str]
|
||||
The namespace identifier to create.
|
||||
"""
|
||||
LOOP.run(self._conn.create_namespace(namespace=namespace))
|
||||
|
||||
@override
|
||||
def drop_namespace(self, namespace: List[str]) -> None:
|
||||
"""Drop a namespace.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
namespace: List[str]
|
||||
The namespace identifier to drop.
|
||||
"""
|
||||
return LOOP.run(self._conn.drop_namespace(namespace=namespace))
|
||||
|
||||
@override
|
||||
def table_names(
|
||||
self, page_token: Optional[str] = None, limit: int = 10
|
||||
self,
|
||||
page_token: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
*,
|
||||
namespace: List[str] = [],
|
||||
) -> Iterable[str]:
|
||||
"""Get the names of all tables in the database. The names are sorted.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
namespace: List[str], optional
|
||||
The namespace to list tables in.
|
||||
page_token: str, optional
|
||||
The token to use for pagination.
|
||||
limit: int, default 10
|
||||
The maximum number of tables to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Iterator of str.
|
||||
A list of table names.
|
||||
"""
|
||||
return LOOP.run(self._conn.table_names(start_after=page_token, limit=limit))
|
||||
return LOOP.run(
|
||||
self._conn.table_names(
|
||||
namespace=namespace, start_after=page_token, limit=limit
|
||||
)
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.table_names())
|
||||
@@ -435,12 +589,18 @@ class LanceDBConnection(DBConnection):
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
*,
|
||||
namespace: List[str] = [],
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
data_storage_version: Optional[str] = None,
|
||||
enable_v2_manifest_paths: Optional[bool] = None,
|
||||
) -> LanceTable:
|
||||
"""Create a table in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
namespace: List[str], optional
|
||||
The namespace to create the table in.
|
||||
|
||||
See
|
||||
---
|
||||
DBConnection.create_table
|
||||
@@ -459,6 +619,7 @@ class LanceDBConnection(DBConnection):
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
embedding_functions=embedding_functions,
|
||||
namespace=namespace,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
return tbl
|
||||
@@ -468,6 +629,7 @@ class LanceDBConnection(DBConnection):
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
namespace: List[str] = [],
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
) -> LanceTable:
|
||||
@@ -477,6 +639,8 @@ class LanceDBConnection(DBConnection):
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
namespace: List[str], optional
|
||||
The namespace to open the table from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -496,26 +660,68 @@ class LanceDBConnection(DBConnection):
|
||||
return LanceTable.open(
|
||||
self,
|
||||
name,
|
||||
namespace=namespace,
|
||||
storage_options=storage_options,
|
||||
index_cache_size=index_cache_size,
|
||||
)
|
||||
|
||||
@override
|
||||
def drop_table(self, name: str, ignore_missing: bool = False):
|
||||
def drop_table(
|
||||
self,
|
||||
name: str,
|
||||
namespace: List[str] = [],
|
||||
ignore_missing: bool = False,
|
||||
):
|
||||
"""Drop a table from the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
namespace: List[str], optional
|
||||
The namespace to drop the table from.
|
||||
ignore_missing: bool, default False
|
||||
If True, ignore if the table does not exist.
|
||||
"""
|
||||
LOOP.run(self._conn.drop_table(name, ignore_missing=ignore_missing))
|
||||
LOOP.run(
|
||||
self._conn.drop_table(
|
||||
name, namespace=namespace, ignore_missing=ignore_missing
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def drop_all_tables(self):
|
||||
LOOP.run(self._conn.drop_all_tables())
|
||||
def drop_all_tables(self, namespace: List[str] = []):
|
||||
LOOP.run(self._conn.drop_all_tables(namespace=namespace))
|
||||
|
||||
@override
|
||||
def rename_table(
|
||||
self,
|
||||
cur_name: str,
|
||||
new_name: str,
|
||||
cur_namespace: List[str] = [],
|
||||
new_namespace: List[str] = [],
|
||||
):
|
||||
"""Rename a table in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cur_name: str
|
||||
The current name of the table.
|
||||
new_name: str
|
||||
The new name of the table.
|
||||
cur_namespace: List[str], optional
|
||||
The namespace of the current table.
|
||||
new_namespace: List[str], optional
|
||||
The namespace to move the table to.
|
||||
"""
|
||||
LOOP.run(
|
||||
self._conn.rename_table(
|
||||
cur_name,
|
||||
new_name,
|
||||
cur_namespace=cur_namespace,
|
||||
new_namespace=new_namespace,
|
||||
)
|
||||
)
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.15.1",
|
||||
@@ -588,13 +794,67 @@ class AsyncConnection(object):
|
||||
def uri(self) -> str:
|
||||
return self._inner.uri
|
||||
|
||||
async def list_namespaces(
|
||||
self,
|
||||
namespace: List[str] = [],
|
||||
page_token: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
) -> Iterable[str]:
|
||||
"""List immediate child namespace names in the given namespace.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
namespace: List[str], optional
|
||||
The parent namespace to list namespaces in.
|
||||
None or empty list represents root namespace.
|
||||
page_token: str, optional
|
||||
The token to use for pagination. If not present, start from the beginning.
|
||||
limit: int, default 10
|
||||
The size of the page to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Iterable of str
|
||||
List of immediate child namespace names (not full paths)
|
||||
"""
|
||||
return await self._inner.list_namespaces(
|
||||
namespace=namespace, page_token=page_token, limit=limit
|
||||
)
|
||||
|
||||
async def create_namespace(self, namespace: List[str]) -> None:
|
||||
"""Create a new namespace.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
namespace: List[str]
|
||||
The namespace identifier to create.
|
||||
"""
|
||||
await self._inner.create_namespace(namespace)
|
||||
|
||||
async def drop_namespace(self, namespace: List[str]) -> None:
|
||||
"""Drop a namespace.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
namespace: List[str]
|
||||
The namespace identifier to drop.
|
||||
"""
|
||||
await self._inner.drop_namespace(namespace)
|
||||
|
||||
async def table_names(
|
||||
self, *, start_after: Optional[str] = None, limit: Optional[int] = None
|
||||
self,
|
||||
*,
|
||||
namespace: List[str] = [],
|
||||
start_after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> Iterable[str]:
|
||||
"""List all tables in this database, in sorted order
|
||||
|
||||
Parameters
|
||||
----------
|
||||
namespace: List[str], optional
|
||||
The namespace to list tables in.
|
||||
None or empty list represents root namespace.
|
||||
start_after: str, optional
|
||||
If present, only return names that come lexicographically after the supplied
|
||||
value.
|
||||
@@ -608,7 +868,9 @@ class AsyncConnection(object):
|
||||
-------
|
||||
Iterable of str
|
||||
"""
|
||||
return await self._inner.table_names(start_after=start_after, limit=limit)
|
||||
return await self._inner.table_names(
|
||||
namespace=namespace, start_after=start_after, limit=limit
|
||||
)
|
||||
|
||||
async def create_table(
|
||||
self,
|
||||
@@ -621,6 +883,7 @@ class AsyncConnection(object):
|
||||
fill_value: Optional[float] = None,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
*,
|
||||
namespace: List[str] = [],
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
) -> AsyncTable:
|
||||
"""Create an [AsyncTable][lancedb.table.AsyncTable] in the database.
|
||||
@@ -629,6 +892,9 @@ class AsyncConnection(object):
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
namespace: List[str], default []
|
||||
The namespace to create the table in.
|
||||
Empty list represents root namespace.
|
||||
data: The data to initialize the table, *optional*
|
||||
User must provide at least one of `data` or `schema`.
|
||||
Acceptable types are:
|
||||
@@ -807,6 +1073,7 @@ class AsyncConnection(object):
|
||||
name,
|
||||
mode,
|
||||
schema,
|
||||
namespace=namespace,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
else:
|
||||
@@ -815,6 +1082,7 @@ class AsyncConnection(object):
|
||||
name,
|
||||
mode,
|
||||
data,
|
||||
namespace=namespace,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
|
||||
@@ -823,6 +1091,8 @@ class AsyncConnection(object):
|
||||
async def open_table(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
namespace: List[str] = [],
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
) -> AsyncTable:
|
||||
@@ -832,6 +1102,9 @@ class AsyncConnection(object):
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
namespace: List[str], optional
|
||||
The namespace to open the table from.
|
||||
None or empty list represents root namespace.
|
||||
storage_options: dict, optional
|
||||
Additional options for the storage backend. Options already set on the
|
||||
connection will be inherited by the table, but can be overridden here.
|
||||
@@ -855,42 +1128,77 @@ class AsyncConnection(object):
|
||||
-------
|
||||
A LanceTable object representing the table.
|
||||
"""
|
||||
table = await self._inner.open_table(name, storage_options, index_cache_size)
|
||||
table = await self._inner.open_table(
|
||||
name,
|
||||
namespace=namespace,
|
||||
storage_options=storage_options,
|
||||
index_cache_size=index_cache_size,
|
||||
)
|
||||
return AsyncTable(table)
|
||||
|
||||
async def rename_table(self, old_name: str, new_name: str):
|
||||
async def rename_table(
|
||||
self,
|
||||
cur_name: str,
|
||||
new_name: str,
|
||||
cur_namespace: List[str] = [],
|
||||
new_namespace: List[str] = [],
|
||||
):
|
||||
"""Rename a table in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
old_name: str
|
||||
cur_name: str
|
||||
The current name of the table.
|
||||
new_name: str
|
||||
The new name of the table.
|
||||
cur_namespace: List[str], optional
|
||||
The namespace of the current table.
|
||||
None or empty list represents root namespace.
|
||||
new_namespace: List[str], optional
|
||||
The namespace to move the table to.
|
||||
If not specified, defaults to the same as cur_namespace.
|
||||
"""
|
||||
await self._inner.rename_table(old_name, new_name)
|
||||
await self._inner.rename_table(
|
||||
cur_name, new_name, cur_namespace=cur_namespace, new_namespace=new_namespace
|
||||
)
|
||||
|
||||
async def drop_table(self, name: str, *, ignore_missing: bool = False):
|
||||
async def drop_table(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
namespace: List[str] = [],
|
||||
ignore_missing: bool = False,
|
||||
):
|
||||
"""Drop a table from the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
namespace: List[str], default []
|
||||
The namespace to drop the table from.
|
||||
Empty list represents root namespace.
|
||||
ignore_missing: bool, default False
|
||||
If True, ignore if the table does not exist.
|
||||
"""
|
||||
try:
|
||||
await self._inner.drop_table(name)
|
||||
await self._inner.drop_table(name, namespace=namespace)
|
||||
except ValueError as e:
|
||||
if not ignore_missing:
|
||||
raise e
|
||||
if f"Table '{name}' was not found" not in str(e):
|
||||
raise e
|
||||
|
||||
async def drop_all_tables(self):
|
||||
"""Drop all tables from the database."""
|
||||
await self._inner.drop_all_tables()
|
||||
async def drop_all_tables(self, namespace: List[str] = []):
|
||||
"""Drop all tables from the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
namespace: List[str], optional
|
||||
The namespace to drop all tables from.
|
||||
None or empty list represents root namespace.
|
||||
"""
|
||||
await self._inner.drop_all_tables(namespace=namespace)
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.15.1",
|
||||
|
||||
401
python/python/lancedb/namespace.py
Normal file
401
python/python/lancedb/namespace.py
Normal file
@@ -0,0 +1,401 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
"""
|
||||
LanceDB Namespace integration module.
|
||||
|
||||
This module provides integration with lance_namespace for managing tables
|
||||
through a namespace abstraction.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Iterable, List, Optional, Union
|
||||
import os
|
||||
|
||||
from lancedb.db import DBConnection
|
||||
from lancedb.table import LanceTable, Table
|
||||
from lancedb.util import validate_table_name
|
||||
from lancedb.common import validate_schema
|
||||
from lancedb.table import sanitize_create_table
|
||||
from overrides import override
|
||||
|
||||
from lance_namespace import LanceNamespace, connect as namespace_connect
|
||||
from lance_namespace_urllib3_client.models import (
|
||||
ListTablesRequest,
|
||||
DescribeTableRequest,
|
||||
CreateTableRequest,
|
||||
DropTableRequest,
|
||||
ListNamespacesRequest,
|
||||
CreateNamespaceRequest,
|
||||
DropNamespaceRequest,
|
||||
JsonArrowSchema,
|
||||
JsonArrowField,
|
||||
JsonArrowDataType,
|
||||
)
|
||||
|
||||
import pyarrow as pa
|
||||
from datetime import timedelta
|
||||
from lancedb.pydantic import LanceModel
|
||||
from lancedb.common import DATA
|
||||
from lancedb.embeddings import EmbeddingFunctionConfig
|
||||
from ._lancedb import Session
|
||||
|
||||
|
||||
def _convert_pyarrow_type_to_json(arrow_type: pa.DataType) -> JsonArrowDataType:
|
||||
"""Convert PyArrow DataType to JsonArrowDataType."""
|
||||
if pa.types.is_null(arrow_type):
|
||||
type_name = "null"
|
||||
elif pa.types.is_boolean(arrow_type):
|
||||
type_name = "bool"
|
||||
elif pa.types.is_int8(arrow_type):
|
||||
type_name = "int8"
|
||||
elif pa.types.is_uint8(arrow_type):
|
||||
type_name = "uint8"
|
||||
elif pa.types.is_int16(arrow_type):
|
||||
type_name = "int16"
|
||||
elif pa.types.is_uint16(arrow_type):
|
||||
type_name = "uint16"
|
||||
elif pa.types.is_int32(arrow_type):
|
||||
type_name = "int32"
|
||||
elif pa.types.is_uint32(arrow_type):
|
||||
type_name = "uint32"
|
||||
elif pa.types.is_int64(arrow_type):
|
||||
type_name = "int64"
|
||||
elif pa.types.is_uint64(arrow_type):
|
||||
type_name = "uint64"
|
||||
elif pa.types.is_float32(arrow_type):
|
||||
type_name = "float32"
|
||||
elif pa.types.is_float64(arrow_type):
|
||||
type_name = "float64"
|
||||
elif pa.types.is_string(arrow_type):
|
||||
type_name = "utf8"
|
||||
elif pa.types.is_binary(arrow_type):
|
||||
type_name = "binary"
|
||||
elif pa.types.is_list(arrow_type):
|
||||
# For list types, we need more complex handling
|
||||
type_name = "list"
|
||||
elif pa.types.is_fixed_size_list(arrow_type):
|
||||
type_name = "fixed_size_list"
|
||||
else:
|
||||
# Default to string representation for unsupported types
|
||||
type_name = str(arrow_type)
|
||||
|
||||
return JsonArrowDataType(type=type_name)
|
||||
|
||||
|
||||
def _convert_pyarrow_schema_to_json(schema: pa.Schema) -> JsonArrowSchema:
|
||||
"""Convert PyArrow Schema to JsonArrowSchema."""
|
||||
fields = []
|
||||
for field in schema:
|
||||
json_field = JsonArrowField(
|
||||
name=field.name,
|
||||
type=_convert_pyarrow_type_to_json(field.type),
|
||||
nullable=field.nullable,
|
||||
metadata=field.metadata,
|
||||
)
|
||||
fields.append(json_field)
|
||||
|
||||
return JsonArrowSchema(fields=fields, metadata=schema.metadata)
|
||||
|
||||
|
||||
class LanceNamespaceDBConnection(DBConnection):
|
||||
"""
|
||||
A LanceDB connection that uses a namespace for table management.
|
||||
|
||||
This connection delegates table URI resolution to a lance_namespace instance,
|
||||
while using the standard LanceTable for actual table operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
namespace: LanceNamespace,
|
||||
*,
|
||||
read_consistency_interval: Optional[timedelta] = None,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
session: Optional[Session] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a namespace-based LanceDB connection.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
namespace : LanceNamespace
|
||||
The namespace instance to use for table management
|
||||
read_consistency_interval : Optional[timedelta]
|
||||
The interval at which to check for updates to the table from other
|
||||
processes. If None, then consistency is not checked.
|
||||
storage_options : Optional[Dict[str, str]]
|
||||
Additional options for the storage backend
|
||||
session : Optional[Session]
|
||||
A session to use for this connection
|
||||
"""
|
||||
self._ns = namespace
|
||||
self.read_consistency_interval = read_consistency_interval
|
||||
self.storage_options = storage_options or {}
|
||||
self.session = session
|
||||
|
||||
@override
|
||||
def table_names(
|
||||
self,
|
||||
page_token: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
*,
|
||||
namespace: List[str] = [],
|
||||
) -> Iterable[str]:
|
||||
request = ListTablesRequest(id=namespace, page_token=page_token, limit=limit)
|
||||
response = self._ns.list_tables(request)
|
||||
return response.tables if response.tables else []
|
||||
|
||||
@override
|
||||
def create_table(
|
||||
self,
|
||||
name: str,
|
||||
data: Optional[DATA] = None,
|
||||
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||
mode: str = "create",
|
||||
exist_ok: bool = False,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
*,
|
||||
namespace: List[str] = [],
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
data_storage_version: Optional[str] = None,
|
||||
enable_v2_manifest_paths: Optional[bool] = None,
|
||||
) -> Table:
|
||||
if mode.lower() not in ["create", "overwrite"]:
|
||||
raise ValueError("mode must be either 'create' or 'overwrite'")
|
||||
validate_table_name(name)
|
||||
|
||||
# TODO: support passing data
|
||||
if data is not None:
|
||||
raise ValueError(
|
||||
"create_table currently only supports creating empty tables (data=None)"
|
||||
)
|
||||
|
||||
# Prepare schema
|
||||
metadata = None
|
||||
if embedding_functions is not None:
|
||||
from lancedb.embeddings.registry import EmbeddingFunctionRegistry
|
||||
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
metadata = registry.get_table_metadata(embedding_functions)
|
||||
|
||||
data, schema = sanitize_create_table(
|
||||
data, schema, metadata, on_bad_vectors, fill_value
|
||||
)
|
||||
validate_schema(schema)
|
||||
|
||||
# Convert PyArrow schema to JsonArrowSchema
|
||||
json_schema = _convert_pyarrow_schema_to_json(schema)
|
||||
|
||||
# Create table request with namespace
|
||||
table_id = namespace + [name]
|
||||
request = CreateTableRequest(id=table_id, var_schema=json_schema)
|
||||
|
||||
# Create empty Arrow IPC stream bytes
|
||||
import pyarrow.ipc as ipc
|
||||
import io
|
||||
|
||||
empty_table = pa.Table.from_arrays(
|
||||
[pa.array([], type=field.type) for field in schema], schema=schema
|
||||
)
|
||||
buffer = io.BytesIO()
|
||||
with ipc.new_stream(buffer, schema) as writer:
|
||||
writer.write_table(empty_table)
|
||||
request_data = buffer.getvalue()
|
||||
|
||||
self._ns.create_table(request, request_data)
|
||||
return self.open_table(
|
||||
name, namespace=namespace, storage_options=storage_options
|
||||
)
|
||||
|
||||
@override
|
||||
def open_table(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
namespace: List[str] = [],
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
) -> Table:
|
||||
table_id = namespace + [name]
|
||||
request = DescribeTableRequest(id=table_id)
|
||||
response = self._ns.describe_table(request)
|
||||
|
||||
merged_storage_options = dict()
|
||||
if storage_options:
|
||||
merged_storage_options.update(storage_options)
|
||||
if response.storage_options:
|
||||
merged_storage_options.update(response.storage_options)
|
||||
|
||||
return self._lance_table_from_uri(
|
||||
response.location,
|
||||
storage_options=merged_storage_options,
|
||||
index_cache_size=index_cache_size,
|
||||
)
|
||||
|
||||
@override
|
||||
def drop_table(self, name: str, namespace: List[str] = []):
|
||||
# Use namespace drop_table directly
|
||||
table_id = namespace + [name]
|
||||
request = DropTableRequest(id=table_id)
|
||||
self._ns.drop_table(request)
|
||||
|
||||
@override
|
||||
def rename_table(
|
||||
self,
|
||||
cur_name: str,
|
||||
new_name: str,
|
||||
cur_namespace: List[str] = [],
|
||||
new_namespace: List[str] = [],
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"rename_table is not supported for namespace connections"
|
||||
)
|
||||
|
||||
@override
|
||||
def drop_database(self):
|
||||
raise NotImplementedError(
|
||||
"drop_database is deprecated, use drop_all_tables instead"
|
||||
)
|
||||
|
||||
@override
|
||||
def drop_all_tables(self, namespace: List[str] = []):
|
||||
for table_name in self.table_names(namespace=namespace):
|
||||
self.drop_table(table_name, namespace=namespace)
|
||||
|
||||
@override
|
||||
def list_namespaces(
|
||||
self,
|
||||
namespace: List[str] = [],
|
||||
page_token: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
) -> Iterable[str]:
|
||||
"""
|
||||
List child namespaces under the given namespace.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
namespace : Optional[List[str]]
|
||||
The parent namespace to list children from.
|
||||
If None, lists root-level namespaces.
|
||||
page_token : Optional[str]
|
||||
Pagination token for listing results.
|
||||
limit : int
|
||||
Maximum number of namespaces to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Iterable[str]
|
||||
Names of child namespaces.
|
||||
"""
|
||||
request = ListNamespacesRequest(
|
||||
id=namespace, page_token=page_token, limit=limit
|
||||
)
|
||||
response = self._ns.list_namespaces(request)
|
||||
return response.namespaces if response.namespaces else []
|
||||
|
||||
@override
|
||||
def create_namespace(self, namespace: List[str]) -> None:
|
||||
"""
|
||||
Create a new namespace.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
namespace : List[str]
|
||||
The namespace path to create.
|
||||
"""
|
||||
request = CreateNamespaceRequest(id=namespace)
|
||||
self._ns.create_namespace(request)
|
||||
|
||||
@override
|
||||
def drop_namespace(self, namespace: List[str]) -> None:
|
||||
"""
|
||||
Drop a namespace.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
namespace : List[str]
|
||||
The namespace path to drop.
|
||||
"""
|
||||
request = DropNamespaceRequest(id=namespace)
|
||||
self._ns.drop_namespace(request)
|
||||
|
||||
def _lance_table_from_uri(
|
||||
self,
|
||||
table_uri: str,
|
||||
*,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
) -> LanceTable:
|
||||
# Extract the base path and table name from the URI
|
||||
if table_uri.endswith(".lance"):
|
||||
base_path = os.path.dirname(table_uri)
|
||||
table_name = os.path.basename(table_uri)[:-6] # Remove .lance
|
||||
else:
|
||||
raise ValueError(f"Invalid table URI: {table_uri}")
|
||||
|
||||
from lancedb.db import LanceDBConnection
|
||||
|
||||
temp_conn = LanceDBConnection(
|
||||
base_path,
|
||||
read_consistency_interval=self.read_consistency_interval,
|
||||
storage_options={**self.storage_options, **(storage_options or {})},
|
||||
session=self.session,
|
||||
)
|
||||
|
||||
# Open the table using the temporary connection
|
||||
return LanceTable.open(
|
||||
temp_conn,
|
||||
table_name,
|
||||
storage_options=storage_options,
|
||||
index_cache_size=index_cache_size,
|
||||
)
|
||||
|
||||
|
||||
def connect_namespace(
|
||||
impl: str,
|
||||
properties: Dict[str, str],
|
||||
*,
|
||||
read_consistency_interval: Optional[timedelta] = None,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
session: Optional[Session] = None,
|
||||
) -> LanceNamespaceDBConnection:
|
||||
"""
|
||||
Connect to a LanceDB database through a namespace.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
impl : str
|
||||
The namespace implementation to use. For examples:
|
||||
- "dir" for DirectoryNamespace
|
||||
- "rest" for REST-based namespace
|
||||
- Full module path for custom implementations
|
||||
properties : Dict[str, str]
|
||||
Configuration properties for the namespace implementation.
|
||||
Different namespace implemenation has different config properties.
|
||||
For example, use DirectoryNamespace with {"root": "/path/to/directory"}
|
||||
read_consistency_interval : Optional[timedelta]
|
||||
The interval at which to check for updates to the table from other
|
||||
processes. If None, then consistency is not checked.
|
||||
storage_options : Optional[Dict[str, str]]
|
||||
Additional options for the storage backend
|
||||
session : Optional[Session]
|
||||
A session to use for this connection
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceNamespaceDBConnection
|
||||
A namespace-based connection to LanceDB
|
||||
"""
|
||||
namespace = namespace_connect(impl, properties)
|
||||
|
||||
# Return the namespace-based connection
|
||||
return LanceNamespaceDBConnection(
|
||||
namespace,
|
||||
read_consistency_interval=read_consistency_interval,
|
||||
storage_options=storage_options,
|
||||
session=session,
|
||||
)
|
||||
@@ -943,20 +943,22 @@ class LanceQueryBuilder(ABC):
|
||||
>>> query = [100, 100]
|
||||
>>> plan = table.search(query).analyze_plan()
|
||||
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
|
||||
AnalyzeExec verbose=true, metrics=[]
|
||||
TracedExec, metrics=[]
|
||||
ProjectionExec: expr=[...], metrics=[...]
|
||||
GlobalLimitExec: skip=0, fetch=10, metrics=[...]
|
||||
AnalyzeExec verbose=true, metrics=[], cumulative_cpu=...
|
||||
TracedExec, metrics=[], cumulative_cpu=...
|
||||
ProjectionExec: expr=[...], metrics=[...], cumulative_cpu=...
|
||||
GlobalLimitExec: skip=0, fetch=10, metrics=[...], cumulative_cpu=...
|
||||
FilterExec: _distance@2 IS NOT NULL,
|
||||
metrics=[output_rows=..., elapsed_compute=...]
|
||||
metrics=[output_rows=..., elapsed_compute=...], cumulative_cpu=...
|
||||
SortExec: TopK(fetch=10), expr=[...],
|
||||
preserve_partitioning=[...],
|
||||
metrics=[output_rows=..., elapsed_compute=..., row_replacements=...]
|
||||
metrics=[output_rows=..., elapsed_compute=..., row_replacements=...],
|
||||
cumulative_cpu=...
|
||||
KNNVectorDistance: metric=l2,
|
||||
metrics=[output_rows=..., elapsed_compute=..., output_batches=...]
|
||||
metrics=[output_rows=..., elapsed_compute=..., output_batches=...],
|
||||
cumulative_cpu=...
|
||||
LanceRead: uri=..., projection=[vector], ...
|
||||
metrics=[output_rows=..., elapsed_compute=...,
|
||||
bytes_read=..., iops=..., requests=...]
|
||||
bytes_read=..., iops=..., requests=...], cumulative_cpu=...
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
@@ -8,7 +8,15 @@ from typing import List, Optional
|
||||
|
||||
from lancedb import __version__
|
||||
|
||||
__all__ = ["TimeoutConfig", "RetryConfig", "ClientConfig"]
|
||||
from .header import HeaderProvider
|
||||
|
||||
__all__ = [
|
||||
"TimeoutConfig",
|
||||
"RetryConfig",
|
||||
"TlsConfig",
|
||||
"ClientConfig",
|
||||
"HeaderProvider",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -112,15 +120,43 @@ class RetryConfig:
|
||||
statuses: Optional[List[int]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TlsConfig:
|
||||
"""TLS/mTLS configuration for the remote HTTP client.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
cert_file: Optional[str]
|
||||
Path to the client certificate file (PEM format) for mTLS authentication.
|
||||
key_file: Optional[str]
|
||||
Path to the client private key file (PEM format) for mTLS authentication.
|
||||
ssl_ca_cert: Optional[str]
|
||||
Path to the CA certificate file (PEM format) for server verification.
|
||||
assert_hostname: bool
|
||||
Whether to verify the hostname in the server's certificate. Default is True.
|
||||
Set to False to disable hostname verification (use with caution).
|
||||
"""
|
||||
|
||||
cert_file: Optional[str] = None
|
||||
key_file: Optional[str] = None
|
||||
ssl_ca_cert: Optional[str] = None
|
||||
assert_hostname: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientConfig:
|
||||
user_agent: str = f"LanceDB-Python-Client/{__version__}"
|
||||
retry_config: RetryConfig = field(default_factory=RetryConfig)
|
||||
timeout_config: Optional[TimeoutConfig] = field(default_factory=TimeoutConfig)
|
||||
extra_headers: Optional[dict] = None
|
||||
id_delimiter: Optional[str] = None
|
||||
tls_config: Optional[TlsConfig] = None
|
||||
header_provider: Optional["HeaderProvider"] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.retry_config, dict):
|
||||
self.retry_config = RetryConfig(**self.retry_config)
|
||||
if isinstance(self.timeout_config, dict):
|
||||
self.timeout_config = TimeoutConfig(**self.timeout_config)
|
||||
if isinstance(self.tls_config, dict):
|
||||
self.tls_config = TlsConfig(**self.tls_config)
|
||||
|
||||
@@ -96,14 +96,73 @@ class RemoteDBConnection(DBConnection):
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoteConnect(name={self.db_name})"
|
||||
|
||||
@override
|
||||
def list_namespaces(
|
||||
self,
|
||||
namespace: List[str] = [],
|
||||
page_token: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
) -> Iterable[str]:
|
||||
"""List immediate child namespace names in the given namespace.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
namespace: List[str], optional
|
||||
The parent namespace to list namespaces in.
|
||||
None or empty list represents root namespace.
|
||||
page_token: str, optional
|
||||
The token to use for pagination. If not present, start from the beginning.
|
||||
limit: int, default 10
|
||||
The size of the page to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Iterable of str
|
||||
List of immediate child namespace names
|
||||
"""
|
||||
return LOOP.run(
|
||||
self._conn.list_namespaces(
|
||||
namespace=namespace, page_token=page_token, limit=limit
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def create_namespace(self, namespace: List[str]) -> None:
|
||||
"""Create a new namespace.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
namespace: List[str]
|
||||
The namespace identifier to create.
|
||||
"""
|
||||
LOOP.run(self._conn.create_namespace(namespace=namespace))
|
||||
|
||||
@override
|
||||
def drop_namespace(self, namespace: List[str]) -> None:
|
||||
"""Drop a namespace.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
namespace: List[str]
|
||||
The namespace identifier to drop.
|
||||
"""
|
||||
return LOOP.run(self._conn.drop_namespace(namespace=namespace))
|
||||
|
||||
@override
|
||||
def table_names(
|
||||
self, page_token: Optional[str] = None, limit: int = 10
|
||||
self,
|
||||
page_token: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
*,
|
||||
namespace: List[str] = [],
|
||||
) -> Iterable[str]:
|
||||
"""List the names of all tables in the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
namespace: List[str], default []
|
||||
The namespace to list tables in.
|
||||
Empty list represents root namespace.
|
||||
page_token: str
|
||||
The last token to start the new page.
|
||||
limit: int, default 10
|
||||
@@ -113,13 +172,18 @@ class RemoteDBConnection(DBConnection):
|
||||
-------
|
||||
An iterator of table names.
|
||||
"""
|
||||
return LOOP.run(self._conn.table_names(start_after=page_token, limit=limit))
|
||||
return LOOP.run(
|
||||
self._conn.table_names(
|
||||
namespace=namespace, start_after=page_token, limit=limit
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def open_table(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
namespace: List[str] = [],
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
) -> Table:
|
||||
@@ -129,6 +193,9 @@ class RemoteDBConnection(DBConnection):
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
namespace: List[str], optional
|
||||
The namespace to open the table from.
|
||||
None or empty list represents root namespace.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -142,7 +209,7 @@ class RemoteDBConnection(DBConnection):
|
||||
" (there is no local cache to configure)"
|
||||
)
|
||||
|
||||
table = LOOP.run(self._conn.open_table(name))
|
||||
table = LOOP.run(self._conn.open_table(name, namespace=namespace))
|
||||
return RemoteTable(table, self.db_name)
|
||||
|
||||
@override
|
||||
@@ -155,6 +222,8 @@ class RemoteDBConnection(DBConnection):
|
||||
fill_value: float = 0.0,
|
||||
mode: Optional[str] = None,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
*,
|
||||
namespace: List[str] = [],
|
||||
) -> Table:
|
||||
"""Create a [Table][lancedb.table.Table] in the database.
|
||||
|
||||
@@ -162,6 +231,9 @@ class RemoteDBConnection(DBConnection):
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
namespace: List[str], optional
|
||||
The namespace to create the table in.
|
||||
None or empty list represents root namespace.
|
||||
data: The data to initialize the table, *optional*
|
||||
User must provide at least one of `data` or `schema`.
|
||||
Acceptable types are:
|
||||
@@ -262,6 +334,7 @@ class RemoteDBConnection(DBConnection):
|
||||
self._conn.create_table(
|
||||
name,
|
||||
data,
|
||||
namespace=namespace,
|
||||
mode=mode,
|
||||
schema=schema,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
@@ -271,18 +344,27 @@ class RemoteDBConnection(DBConnection):
|
||||
return RemoteTable(table, self.db_name)
|
||||
|
||||
@override
|
||||
def drop_table(self, name: str):
|
||||
def drop_table(self, name: str, namespace: List[str] = []):
|
||||
"""Drop a table from the database.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
namespace: List[str], optional
|
||||
The namespace to drop the table from.
|
||||
None or empty list represents root namespace.
|
||||
"""
|
||||
LOOP.run(self._conn.drop_table(name))
|
||||
LOOP.run(self._conn.drop_table(name, namespace=namespace))
|
||||
|
||||
@override
|
||||
def rename_table(self, cur_name: str, new_name: str):
|
||||
def rename_table(
|
||||
self,
|
||||
cur_name: str,
|
||||
new_name: str,
|
||||
cur_namespace: List[str] = [],
|
||||
new_namespace: List[str] = [],
|
||||
):
|
||||
"""Rename a table in the database.
|
||||
|
||||
Parameters
|
||||
@@ -292,7 +374,14 @@ class RemoteDBConnection(DBConnection):
|
||||
new_name: str
|
||||
The new name of the table.
|
||||
"""
|
||||
LOOP.run(self._conn.rename_table(cur_name, new_name))
|
||||
LOOP.run(
|
||||
self._conn.rename_table(
|
||||
cur_name,
|
||||
new_name,
|
||||
cur_namespace=cur_namespace,
|
||||
new_namespace=new_namespace,
|
||||
)
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection to the database."""
|
||||
|
||||
180
python/python/lancedb/remote/header.py
Normal file
180
python/python/lancedb/remote/header.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
"""Header providers for LanceDB remote connections.
|
||||
|
||||
This module provides a flexible header management framework for LanceDB remote
|
||||
connections, allowing users to implement custom header strategies for
|
||||
authentication, request tracking, custom metadata, or any other header-based
|
||||
requirements.
|
||||
|
||||
The module includes the HeaderProvider abstract base class and example implementations
|
||||
(StaticHeaderProvider and OAuthProvider) that demonstrate common patterns.
|
||||
|
||||
The HeaderProvider interface is designed to be called before each request to the remote
|
||||
server, enabling dynamic header scenarios where values may need to be
|
||||
refreshed, rotated, or computed on-demand.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional, Callable, Any
|
||||
import time
|
||||
import threading
|
||||
|
||||
|
||||
class HeaderProvider(ABC):
|
||||
"""Abstract base class for providing custom headers for each request.
|
||||
|
||||
Users can implement this interface to provide dynamic headers for various purposes
|
||||
such as authentication (OAuth tokens, API keys), request tracking (correlation IDs),
|
||||
custom metadata, or any other header-based requirements. The provider is called
|
||||
before each request to ensure fresh header values are always used.
|
||||
|
||||
Error Handling
|
||||
--------------
|
||||
If get_headers() raises an exception, the request will fail. Implementations
|
||||
should handle recoverable errors internally (e.g., retry token refresh) and
|
||||
only raise exceptions for unrecoverable errors.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
"""Get the latest headers to be added to requests.
|
||||
|
||||
This method is called before each request to the remote LanceDB server.
|
||||
Implementations should return headers that will be merged with existing headers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, str]
|
||||
Dictionary of header names to values to add to the request.
|
||||
|
||||
Raises
|
||||
------
|
||||
Exception
|
||||
If unable to fetch headers, the exception will be propagated
|
||||
and the request will fail.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class StaticHeaderProvider(HeaderProvider):
|
||||
"""Example implementation: A simple header provider that returns static headers.
|
||||
|
||||
This is an example implementation showing how to create a HeaderProvider
|
||||
for cases where headers don't change during the session. Users can use this
|
||||
as a reference for implementing their own providers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
headers : Dict[str, str]
|
||||
Static headers to return for every request.
|
||||
"""
|
||||
|
||||
def __init__(self, headers: Dict[str, str]):
|
||||
"""Initialize with static headers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
headers : Dict[str, str]
|
||||
Headers to return for every request.
|
||||
"""
|
||||
self._headers = headers.copy()
|
||||
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
"""Return the static headers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, str]
|
||||
Copy of the static headers.
|
||||
"""
|
||||
return self._headers.copy()
|
||||
|
||||
|
||||
class OAuthProvider(HeaderProvider):
|
||||
"""Example implementation: OAuth token provider with automatic refresh.
|
||||
|
||||
This is an example implementation showing how to manage OAuth tokens
|
||||
with automatic refresh when they expire. Users can use this as a reference
|
||||
for implementing their own OAuth or token-based authentication providers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
token_fetcher : Callable[[], Dict[str, Any]]
|
||||
Function that fetches a new token. Should return a dict with
|
||||
'access_token' and optionally 'expires_in' (seconds until expiration).
|
||||
refresh_buffer_seconds : int, optional
|
||||
Number of seconds before expiration to trigger refresh. Default is 300
|
||||
(5 minutes).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, token_fetcher: Callable[[], Any], refresh_buffer_seconds: int = 300
|
||||
):
|
||||
"""Initialize the OAuth provider.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
token_fetcher : Callable[[], Any]
|
||||
Function to fetch new tokens. Should return dict with
|
||||
'access_token' and optionally 'expires_in'.
|
||||
refresh_buffer_seconds : int, optional
|
||||
Seconds before expiry to refresh token. Default 300.
|
||||
"""
|
||||
self._token_fetcher = token_fetcher
|
||||
self._refresh_buffer = refresh_buffer_seconds
|
||||
self._current_token: Optional[str] = None
|
||||
self._token_expires_at: Optional[float] = None
|
||||
self._refresh_lock = threading.Lock()
|
||||
|
||||
def _refresh_token_if_needed(self) -> None:
|
||||
"""Refresh the token if it's expired or close to expiring."""
|
||||
with self._refresh_lock:
|
||||
# Check again inside the lock in case another thread refreshed
|
||||
if self._needs_refresh():
|
||||
token_data = self._token_fetcher()
|
||||
|
||||
self._current_token = token_data.get("access_token")
|
||||
if not self._current_token:
|
||||
raise ValueError("Token fetcher did not return 'access_token'")
|
||||
|
||||
# Set expiration if provided
|
||||
expires_in = token_data.get("expires_in")
|
||||
if expires_in:
|
||||
self._token_expires_at = time.time() + expires_in
|
||||
else:
|
||||
# Token doesn't expire or expiration unknown
|
||||
self._token_expires_at = None
|
||||
|
||||
def _needs_refresh(self) -> bool:
|
||||
"""Check if token needs refresh."""
|
||||
if self._current_token is None:
|
||||
return True
|
||||
|
||||
if self._token_expires_at is None:
|
||||
# No expiration info, assume token is valid
|
||||
return False
|
||||
|
||||
# Refresh if we're within the buffer time of expiration
|
||||
return time.time() >= (self._token_expires_at - self._refresh_buffer)
|
||||
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
"""Get OAuth headers, refreshing token if needed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, str]
|
||||
Headers with Bearer token authorization.
|
||||
|
||||
Raises
|
||||
------
|
||||
Exception
|
||||
If unable to fetch or refresh token.
|
||||
"""
|
||||
self._refresh_token_if_needed()
|
||||
|
||||
if not self._current_token:
|
||||
raise RuntimeError("Failed to obtain OAuth token")
|
||||
|
||||
return {"Authorization": f"Bearer {self._current_token}"}
|
||||
@@ -115,6 +115,7 @@ class RemoteTable(Table):
|
||||
*,
|
||||
replace: bool = False,
|
||||
wait_timeout: timedelta = None,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Creates a scalar index
|
||||
Parameters
|
||||
@@ -139,7 +140,11 @@ class RemoteTable(Table):
|
||||
|
||||
LOOP.run(
|
||||
self._table.create_index(
|
||||
column, config=config, replace=replace, wait_timeout=wait_timeout
|
||||
column,
|
||||
config=config,
|
||||
replace=replace,
|
||||
wait_timeout=wait_timeout,
|
||||
name=name,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -161,6 +166,7 @@ class RemoteTable(Table):
|
||||
ngram_min_length: int = 3,
|
||||
ngram_max_length: int = 3,
|
||||
prefix_only: bool = False,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
config = FTS(
|
||||
with_position=with_position,
|
||||
@@ -177,7 +183,11 @@ class RemoteTable(Table):
|
||||
)
|
||||
LOOP.run(
|
||||
self._table.create_index(
|
||||
column, config=config, replace=replace, wait_timeout=wait_timeout
|
||||
column,
|
||||
config=config,
|
||||
replace=replace,
|
||||
wait_timeout=wait_timeout,
|
||||
name=name,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -194,6 +204,8 @@ class RemoteTable(Table):
|
||||
wait_timeout: Optional[timedelta] = None,
|
||||
*,
|
||||
num_bits: int = 8,
|
||||
name: Optional[str] = None,
|
||||
train: bool = True,
|
||||
):
|
||||
"""Create an index on the table.
|
||||
Currently, the only parameters that matter are
|
||||
@@ -270,7 +282,11 @@ class RemoteTable(Table):
|
||||
|
||||
LOOP.run(
|
||||
self._table.create_index(
|
||||
vector_column_name, config=config, wait_timeout=wait_timeout
|
||||
vector_column_name,
|
||||
config=config,
|
||||
wait_timeout=wait_timeout,
|
||||
name=name,
|
||||
train=train,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -689,6 +689,8 @@ class Table(ABC):
|
||||
sample_rate: int = 256,
|
||||
m: int = 20,
|
||||
ef_construction: int = 300,
|
||||
name: Optional[str] = None,
|
||||
train: bool = True,
|
||||
):
|
||||
"""Create an index on the table.
|
||||
|
||||
@@ -721,6 +723,11 @@ class Table(ABC):
|
||||
Only 4 and 8 are supported.
|
||||
wait_timeout: timedelta, optional
|
||||
The timeout to wait if indexing is asynchronous.
|
||||
name: str, optional
|
||||
The name of the index. If not provided, a default name will be generated.
|
||||
train: bool, default True
|
||||
Whether to train the index with existing data. Vector indices always train
|
||||
with existing data.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -776,6 +783,7 @@ class Table(ABC):
|
||||
replace: bool = True,
|
||||
index_type: ScalarIndexType = "BTREE",
|
||||
wait_timeout: Optional[timedelta] = None,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Create a scalar index on a column.
|
||||
|
||||
@@ -790,6 +798,8 @@ class Table(ABC):
|
||||
The type of index to create.
|
||||
wait_timeout: timedelta, optional
|
||||
The timeout to wait if indexing is asynchronous.
|
||||
name: str, optional
|
||||
The name of the index. If not provided, a default name will be generated.
|
||||
Examples
|
||||
--------
|
||||
|
||||
@@ -852,6 +862,7 @@ class Table(ABC):
|
||||
ngram_max_length: int = 3,
|
||||
prefix_only: bool = False,
|
||||
wait_timeout: Optional[timedelta] = None,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Create a full-text search index on the table.
|
||||
|
||||
@@ -916,6 +927,8 @@ class Table(ABC):
|
||||
Whether to only index the prefix of the token for ngram tokenizer.
|
||||
wait_timeout: timedelta, optional
|
||||
The timeout to wait if indexing is asynchronous.
|
||||
name: str, optional
|
||||
The name of the index. If not provided, a default name will be generated.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1106,7 +1119,9 @@ class Table(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def take_offsets(self, offsets: list[int]) -> LanceTakeQueryBuilder:
|
||||
def take_offsets(
|
||||
self, offsets: list[int], *, with_row_id: bool = False
|
||||
) -> LanceTakeQueryBuilder:
|
||||
"""
|
||||
Take a list of offsets from the table.
|
||||
|
||||
@@ -1132,8 +1147,60 @@ class Table(ABC):
|
||||
A record batch containing the rows at the given offsets.
|
||||
"""
|
||||
|
||||
def __getitems__(self, offsets: list[int]) -> pa.RecordBatch:
|
||||
"""
|
||||
Take a list of offsets from the table and return as a record batch.
|
||||
|
||||
This method uses the `take_offsets` method to take the rows. However, it
|
||||
aligns the offsets to the passed in offsets. This means the return type
|
||||
is a record batch (and so users should take care not to pass in too many
|
||||
offsets)
|
||||
|
||||
Note: this method is primarily intended to fulfill the Dataset contract
|
||||
for pytorch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
offsets: list[int]
|
||||
The offsets to take.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pa.RecordBatch
|
||||
A record batch containing the rows at the given offsets.
|
||||
"""
|
||||
# We don't know the order of the results at all. So we calculate a permutation
|
||||
# for ordering the given offsets. Then we load the data with the _rowoffset
|
||||
# column. Then we sort by _rowoffset and apply the inverse of the permutation
|
||||
# that we calculated.
|
||||
#
|
||||
# Note: this is potentially a lot of memory copy if we're operating on large
|
||||
# batches :(
|
||||
num_offsets = len(offsets)
|
||||
indices = list(range(num_offsets))
|
||||
permutation = sorted(indices, key=lambda idx: offsets[idx])
|
||||
permutation_inv = [0] * num_offsets
|
||||
for i in range(num_offsets):
|
||||
permutation_inv[permutation[i]] = i
|
||||
|
||||
columns = self.schema.names
|
||||
columns.append("_rowoffset")
|
||||
tbl = (
|
||||
self.take_offsets(offsets)
|
||||
.select(columns)
|
||||
.to_arrow()
|
||||
.sort_by("_rowoffset")
|
||||
.take(permutation_inv)
|
||||
.combine_chunks()
|
||||
.drop_columns(["_rowoffset"])
|
||||
)
|
||||
|
||||
return tbl
|
||||
|
||||
@abstractmethod
|
||||
def take_row_ids(self, row_ids: list[int]) -> LanceTakeQueryBuilder:
|
||||
def take_row_ids(
|
||||
self, row_ids: list[int], *, with_row_id: bool = False
|
||||
) -> LanceTakeQueryBuilder:
|
||||
"""
|
||||
Take a list of row ids from the table.
|
||||
|
||||
@@ -1639,13 +1706,16 @@ class LanceTable(Table):
|
||||
connection: "LanceDBConnection",
|
||||
name: str,
|
||||
*,
|
||||
namespace: List[str] = [],
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
index_cache_size: Optional[int] = None,
|
||||
):
|
||||
self._conn = connection
|
||||
self._namespace = namespace
|
||||
self._table = LOOP.run(
|
||||
connection._conn.open_table(
|
||||
name,
|
||||
namespace=namespace,
|
||||
storage_options=storage_options,
|
||||
index_cache_size=index_cache_size,
|
||||
)
|
||||
@@ -1656,8 +1726,8 @@ class LanceTable(Table):
|
||||
return self._table.name
|
||||
|
||||
@classmethod
|
||||
def open(cls, db, name, **kwargs):
|
||||
tbl = cls(db, name, **kwargs)
|
||||
def open(cls, db, name, *, namespace: List[str] = [], **kwargs):
|
||||
tbl = cls(db, name, namespace=namespace, **kwargs)
|
||||
|
||||
# check the dataset exists
|
||||
try:
|
||||
@@ -1929,6 +1999,9 @@ class LanceTable(Table):
|
||||
sample_rate: int = 256,
|
||||
m: int = 20,
|
||||
ef_construction: int = 300,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
train: bool = True,
|
||||
):
|
||||
"""Create an index on the table."""
|
||||
if accelerator is not None:
|
||||
@@ -1992,6 +2065,8 @@ class LanceTable(Table):
|
||||
vector_column_name,
|
||||
replace=replace,
|
||||
config=config,
|
||||
name=name,
|
||||
train=train,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -2036,6 +2111,7 @@ class LanceTable(Table):
|
||||
*,
|
||||
replace: bool = True,
|
||||
index_type: ScalarIndexType = "BTREE",
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
if index_type == "BTREE":
|
||||
config = BTree()
|
||||
@@ -2046,7 +2122,7 @@ class LanceTable(Table):
|
||||
else:
|
||||
raise ValueError(f"Unknown index type {index_type}")
|
||||
return LOOP.run(
|
||||
self._table.create_index(column, replace=replace, config=config)
|
||||
self._table.create_index(column, replace=replace, config=config, name=name)
|
||||
)
|
||||
|
||||
def create_fts_index(
|
||||
@@ -2070,6 +2146,7 @@ class LanceTable(Table):
|
||||
ngram_min_length: int = 3,
|
||||
ngram_max_length: int = 3,
|
||||
prefix_only: bool = False,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
if not use_tantivy:
|
||||
if not isinstance(field_names, str):
|
||||
@@ -2107,6 +2184,7 @@ class LanceTable(Table):
|
||||
field_names,
|
||||
replace=replace,
|
||||
config=config,
|
||||
name=name,
|
||||
)
|
||||
)
|
||||
return
|
||||
@@ -2473,6 +2551,7 @@ class LanceTable(Table):
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
*,
|
||||
namespace: List[str] = [],
|
||||
storage_options: Optional[Dict[str, str | bool]] = None,
|
||||
data_storage_version: Optional[str] = None,
|
||||
enable_v2_manifest_paths: Optional[bool] = None,
|
||||
@@ -2532,6 +2611,7 @@ class LanceTable(Table):
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
self._conn = db
|
||||
self._namespace = namespace
|
||||
|
||||
if data_storage_version is not None:
|
||||
warnings.warn(
|
||||
@@ -2564,6 +2644,7 @@ class LanceTable(Table):
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
embedding_functions=embedding_functions,
|
||||
namespace=namespace,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
)
|
||||
@@ -3251,6 +3332,8 @@ class AsyncTable:
|
||||
Union[IvfFlat, IvfPq, HnswPq, HnswSq, BTree, Bitmap, LabelList, FTS]
|
||||
] = None,
|
||||
wait_timeout: Optional[timedelta] = None,
|
||||
name: Optional[str] = None,
|
||||
train: bool = True,
|
||||
):
|
||||
"""Create an index to speed up queries
|
||||
|
||||
@@ -3277,6 +3360,11 @@ class AsyncTable:
|
||||
creating an index object.
|
||||
wait_timeout: timedelta, optional
|
||||
The timeout to wait if indexing is asynchronous.
|
||||
name: str, optional
|
||||
The name of the index. If not provided, a default name will be generated.
|
||||
train: bool, default True
|
||||
Whether to train the index with existing data. Vector indices always train
|
||||
with existing data.
|
||||
"""
|
||||
if config is not None:
|
||||
if not isinstance(
|
||||
@@ -3288,7 +3376,12 @@ class AsyncTable:
|
||||
)
|
||||
try:
|
||||
await self._inner.create_index(
|
||||
column, index=config, replace=replace, wait_timeout=wait_timeout
|
||||
column,
|
||||
index=config,
|
||||
replace=replace,
|
||||
wait_timeout=wait_timeout,
|
||||
name=name,
|
||||
train=train,
|
||||
)
|
||||
except ValueError as e:
|
||||
if "not support the requested language" in str(e):
|
||||
|
||||
@@ -175,6 +175,18 @@ def test_table_names(tmp_db: lancedb.DBConnection):
|
||||
tmp_db.create_table("test3", data=data)
|
||||
assert tmp_db.table_names() == ["test1", "test2", "test3"]
|
||||
|
||||
# Test that positional arguments for page_token and limit
|
||||
result = list(tmp_db.table_names("test1", 1)) # page_token="test1", limit=1
|
||||
assert result == ["test2"], f"Expected ['test2'], got {result}"
|
||||
|
||||
# Test mixed positional and keyword arguments
|
||||
result = list(tmp_db.table_names("test2", limit=2))
|
||||
assert result == ["test3"], f"Expected ['test3'], got {result}"
|
||||
|
||||
# Test that namespace parameter can be passed as keyword
|
||||
result = list(tmp_db.table_names(namespace=[]))
|
||||
assert len(result) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_names_async(tmp_path):
|
||||
@@ -728,3 +740,93 @@ def test_bypass_vector_index_sync(tmp_db: lancedb.DBConnection):
|
||||
table.search(sample_key).bypass_vector_index().explain_plan(verbose=True)
|
||||
)
|
||||
assert "KNN" in plan_without_index
|
||||
|
||||
|
||||
def test_local_namespace_operations(tmp_path):
|
||||
"""Test that local mode namespace operations behave as expected."""
|
||||
# Create a local database connection
|
||||
db = lancedb.connect(tmp_path)
|
||||
|
||||
# Test list_namespaces returns empty list
|
||||
namespaces = list(db.list_namespaces())
|
||||
assert namespaces == []
|
||||
|
||||
# Test list_namespaces with parameters still returns empty list
|
||||
namespaces_with_params = list(
|
||||
db.list_namespaces(namespace=["test"], page_token="token", limit=5)
|
||||
)
|
||||
assert namespaces_with_params == []
|
||||
|
||||
|
||||
def test_local_create_namespace_not_supported(tmp_path):
|
||||
"""Test that create_namespace is not supported in local mode."""
|
||||
db = lancedb.connect(tmp_path)
|
||||
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match="Namespace operations are not supported for listing database",
|
||||
):
|
||||
db.create_namespace(["test_namespace"])
|
||||
|
||||
|
||||
def test_local_drop_namespace_not_supported(tmp_path):
|
||||
"""Test that drop_namespace is not supported in local mode."""
|
||||
db = lancedb.connect(tmp_path)
|
||||
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match="Namespace operations are not supported for listing database",
|
||||
):
|
||||
db.drop_namespace(["test_namespace"])
|
||||
|
||||
|
||||
def test_local_table_operations_with_namespace_raise_error(tmp_path):
|
||||
"""
|
||||
Test that table operations with namespace parameter
|
||||
raise ValueError in local mode.
|
||||
"""
|
||||
db = lancedb.connect(tmp_path)
|
||||
|
||||
# Create some test data
|
||||
data = [{"vector": [1.0, 2.0], "item": "test"}]
|
||||
schema = pa.schema(
|
||||
[pa.field("vector", pa.list_(pa.float32(), 2)), pa.field("item", pa.string())]
|
||||
)
|
||||
|
||||
# Test create_table with namespace - should raise ValueError
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match="Namespace parameter is not supported for listing database",
|
||||
):
|
||||
db.create_table(
|
||||
"test_table_with_ns", data=data, schema=schema, namespace=["test_ns"]
|
||||
)
|
||||
|
||||
# Create table normally for other tests
|
||||
db.create_table("test_table", data=data, schema=schema)
|
||||
assert "test_table" in db.table_names()
|
||||
|
||||
# Test open_table with namespace - should raise ValueError
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match="Namespace parameter is not supported for listing database",
|
||||
):
|
||||
db.open_table("test_table", namespace=["test_ns"])
|
||||
|
||||
# Test table_names with namespace - should raise ValueError
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match="Namespace parameter is not supported for listing database",
|
||||
):
|
||||
list(db.table_names(namespace=["test_ns"]))
|
||||
|
||||
# Test drop_table with namespace - should raise ValueError
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match="Namespace parameter is not supported for listing database",
|
||||
):
|
||||
db.drop_table("test_table", namespace=["test_ns"])
|
||||
|
||||
# Test table_names without namespace - should work normally
|
||||
tables_root = list(db.table_names())
|
||||
assert "test_table" in tables_root
|
||||
|
||||
@@ -157,7 +157,16 @@ def test_create_index_with_stemming(tmp_path, table):
|
||||
def test_create_inverted_index(table, use_tantivy, with_position):
|
||||
if use_tantivy and not with_position:
|
||||
pytest.skip("we don't support building a tantivy index without position")
|
||||
table.create_fts_index("text", use_tantivy=use_tantivy, with_position=with_position)
|
||||
table.create_fts_index(
|
||||
"text",
|
||||
use_tantivy=use_tantivy,
|
||||
with_position=with_position,
|
||||
name="custom_fts_index",
|
||||
)
|
||||
if not use_tantivy:
|
||||
indices = table.list_indices()
|
||||
fts_indices = [i for i in indices if i.index_type == "FTS"]
|
||||
assert any(i.name == "custom_fts_index" for i in fts_indices)
|
||||
|
||||
|
||||
def test_populate_index(tmp_path, table):
|
||||
|
||||
237
python/python/tests/test_header_provider.py
Normal file
237
python/python/tests/test_header_provider.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import concurrent.futures
|
||||
import pytest
|
||||
import time
|
||||
import threading
|
||||
from typing import Dict
|
||||
|
||||
from lancedb.remote import ClientConfig, HeaderProvider
|
||||
from lancedb.remote.header import StaticHeaderProvider, OAuthProvider
|
||||
|
||||
|
||||
class TestStaticHeaderProvider:
|
||||
def test_init(self):
|
||||
"""Test StaticHeaderProvider initialization."""
|
||||
headers = {"X-API-Key": "test-key", "X-Custom": "value"}
|
||||
provider = StaticHeaderProvider(headers)
|
||||
assert provider._headers == headers
|
||||
|
||||
def test_get_headers(self):
|
||||
"""Test get_headers returns correct headers."""
|
||||
headers = {"X-API-Key": "test-key", "X-Custom": "value"}
|
||||
provider = StaticHeaderProvider(headers)
|
||||
|
||||
result = provider.get_headers()
|
||||
assert result == headers
|
||||
|
||||
# Ensure it returns a copy
|
||||
result["X-Modified"] = "modified"
|
||||
result2 = provider.get_headers()
|
||||
assert "X-Modified" not in result2
|
||||
|
||||
|
||||
class TestOAuthProvider:
|
||||
def test_init(self):
|
||||
"""Test OAuthProvider initialization."""
|
||||
|
||||
def fetcher():
|
||||
return {"access_token": "token123", "expires_in": 3600}
|
||||
|
||||
provider = OAuthProvider(fetcher)
|
||||
assert provider._token_fetcher is fetcher
|
||||
assert provider._refresh_buffer == 300
|
||||
assert provider._current_token is None
|
||||
assert provider._token_expires_at is None
|
||||
|
||||
def test_get_headers_first_time(self):
|
||||
"""Test get_headers fetches token on first call."""
|
||||
|
||||
def fetcher():
|
||||
return {"access_token": "token123", "expires_in": 3600}
|
||||
|
||||
provider = OAuthProvider(fetcher)
|
||||
headers = provider.get_headers()
|
||||
|
||||
assert headers == {"Authorization": "Bearer token123"}
|
||||
assert provider._current_token == "token123"
|
||||
assert provider._token_expires_at is not None
|
||||
|
||||
def test_token_refresh(self):
|
||||
"""Test token refresh when expired."""
|
||||
call_count = 0
|
||||
tokens = ["token1", "token2"]
|
||||
|
||||
def fetcher():
|
||||
nonlocal call_count
|
||||
token = tokens[call_count]
|
||||
call_count += 1
|
||||
return {"access_token": token, "expires_in": 1} # Expires in 1 second
|
||||
|
||||
provider = OAuthProvider(fetcher, refresh_buffer_seconds=0)
|
||||
|
||||
# First call
|
||||
headers1 = provider.get_headers()
|
||||
assert headers1 == {"Authorization": "Bearer token1"}
|
||||
|
||||
# Wait for token to expire
|
||||
time.sleep(1.1)
|
||||
|
||||
# Second call should refresh
|
||||
headers2 = provider.get_headers()
|
||||
assert headers2 == {"Authorization": "Bearer token2"}
|
||||
assert call_count == 2
|
||||
|
||||
def test_no_expiry_info(self):
|
||||
"""Test handling tokens without expiry information."""
|
||||
|
||||
def fetcher():
|
||||
return {"access_token": "permanent_token"}
|
||||
|
||||
provider = OAuthProvider(fetcher)
|
||||
headers = provider.get_headers()
|
||||
|
||||
assert headers == {"Authorization": "Bearer permanent_token"}
|
||||
assert provider._token_expires_at is None
|
||||
|
||||
# Should not refresh on second call
|
||||
headers2 = provider.get_headers()
|
||||
assert headers2 == {"Authorization": "Bearer permanent_token"}
|
||||
|
||||
def test_missing_access_token(self):
|
||||
"""Test error handling when access_token is missing."""
|
||||
|
||||
def fetcher():
|
||||
return {"expires_in": 3600} # Missing access_token
|
||||
|
||||
provider = OAuthProvider(fetcher)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Token fetcher did not return 'access_token'"
|
||||
):
|
||||
provider.get_headers()
|
||||
|
||||
def test_sync_method(self):
|
||||
"""Test synchronous get_headers method."""
|
||||
|
||||
def fetcher():
|
||||
return {"access_token": "sync_token", "expires_in": 3600}
|
||||
|
||||
provider = OAuthProvider(fetcher)
|
||||
headers = provider.get_headers()
|
||||
|
||||
assert headers == {"Authorization": "Bearer sync_token"}
|
||||
|
||||
|
||||
class TestClientConfigIntegration:
|
||||
def test_client_config_with_header_provider(self):
|
||||
"""Test ClientConfig can accept a HeaderProvider."""
|
||||
provider = StaticHeaderProvider({"X-Test": "value"})
|
||||
config = ClientConfig(header_provider=provider)
|
||||
|
||||
assert config.header_provider is provider
|
||||
|
||||
def test_client_config_without_header_provider(self):
|
||||
"""Test ClientConfig works without HeaderProvider."""
|
||||
config = ClientConfig()
|
||||
assert config.header_provider is None
|
||||
|
||||
|
||||
class CustomProvider(HeaderProvider):
|
||||
"""Custom provider for testing abstract class."""
|
||||
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
return {"X-Custom": "custom-value"}
|
||||
|
||||
|
||||
class TestCustomHeaderProvider:
|
||||
def test_custom_provider(self):
|
||||
"""Test custom HeaderProvider implementation."""
|
||||
provider = CustomProvider()
|
||||
headers = provider.get_headers()
|
||||
assert headers == {"X-Custom": "custom-value"}
|
||||
|
||||
|
||||
class ErrorProvider(HeaderProvider):
|
||||
"""Provider that raises errors for testing error handling."""
|
||||
|
||||
def __init__(self, error_message: str = "Test error"):
|
||||
self.error_message = error_message
|
||||
self.call_count = 0
|
||||
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
self.call_count += 1
|
||||
raise RuntimeError(self.error_message)
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
def test_provider_error_propagation(self):
|
||||
"""Test that errors from header provider are properly propagated."""
|
||||
provider = ErrorProvider("Authentication failed")
|
||||
|
||||
with pytest.raises(RuntimeError, match="Authentication failed"):
|
||||
provider.get_headers()
|
||||
|
||||
assert provider.call_count == 1
|
||||
|
||||
def test_provider_error(self):
|
||||
"""Test that errors are propagated."""
|
||||
provider = ErrorProvider("Sync error")
|
||||
|
||||
with pytest.raises(RuntimeError, match="Sync error"):
|
||||
provider.get_headers()
|
||||
|
||||
|
||||
class ConcurrentProvider(HeaderProvider):
|
||||
"""Provider for testing thread safety."""
|
||||
|
||||
def __init__(self):
|
||||
self.counter = 0
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
with self.lock:
|
||||
self.counter += 1
|
||||
# Simulate some work
|
||||
time.sleep(0.01)
|
||||
return {"X-Request-Id": str(self.counter)}
|
||||
|
||||
|
||||
class TestConcurrency:
|
||||
def test_concurrent_header_fetches(self):
|
||||
"""Test that header provider can handle concurrent requests."""
|
||||
provider = ConcurrentProvider()
|
||||
|
||||
# Create multiple concurrent requests
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [executor.submit(provider.get_headers) for _ in range(10)]
|
||||
results = [f.result() for f in futures]
|
||||
|
||||
# Each request should get a unique counter value
|
||||
request_ids = [int(r["X-Request-Id"]) for r in results]
|
||||
assert len(set(request_ids)) == 10
|
||||
assert min(request_ids) == 1
|
||||
assert max(request_ids) == 10
|
||||
|
||||
def test_oauth_concurrent_refresh(self):
|
||||
"""Test that OAuth provider handles concurrent refresh requests safely."""
|
||||
call_count = 0
|
||||
|
||||
def slow_token_fetch():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
time.sleep(0.1) # Simulate slow token fetch
|
||||
return {"access_token": f"token-{call_count}", "expires_in": 3600}
|
||||
|
||||
provider = OAuthProvider(slow_token_fetch)
|
||||
|
||||
# Force multiple concurrent refreshes
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(provider.get_headers) for _ in range(5)]
|
||||
results = [f.result() for f in futures]
|
||||
|
||||
# All requests should get the same token (only one refresh should happen)
|
||||
tokens = [r["Authorization"] for r in results]
|
||||
assert all(t == "Bearer token-1" for t in tokens)
|
||||
assert call_count == 1 # Only one token fetch despite concurrent requests
|
||||
707
python/python/tests/test_namespace.py
Normal file
707
python/python/tests/test_namespace.py
Normal file
@@ -0,0 +1,707 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
"""Tests for LanceDB namespace integration."""
|
||||
|
||||
import tempfile
|
||||
import shutil
|
||||
from typing import Dict, Optional
|
||||
import pytest
|
||||
import pyarrow as pa
|
||||
import lancedb
|
||||
from lance_namespace.namespace import NATIVE_IMPLS, LanceNamespace
|
||||
from lance_namespace_urllib3_client.models import (
|
||||
ListTablesRequest,
|
||||
ListTablesResponse,
|
||||
DescribeTableRequest,
|
||||
DescribeTableResponse,
|
||||
RegisterTableRequest,
|
||||
RegisterTableResponse,
|
||||
DeregisterTableRequest,
|
||||
DeregisterTableResponse,
|
||||
CreateTableRequest,
|
||||
CreateTableResponse,
|
||||
DropTableRequest,
|
||||
DropTableResponse,
|
||||
ListNamespacesRequest,
|
||||
ListNamespacesResponse,
|
||||
CreateNamespaceRequest,
|
||||
CreateNamespaceResponse,
|
||||
DropNamespaceRequest,
|
||||
DropNamespaceResponse,
|
||||
)
|
||||
|
||||
|
||||
class TempNamespace(LanceNamespace):
|
||||
"""A simple dictionary-backed namespace for testing."""
|
||||
|
||||
# Class-level storage to persist table registry across instances
|
||||
_global_registry: Dict[str, Dict[str, str]] = {}
|
||||
# Class-level storage for namespaces (supporting 1-level namespace)
|
||||
_global_namespaces: Dict[str, set] = {}
|
||||
|
||||
def __init__(self, **properties):
|
||||
"""Initialize the test namespace.
|
||||
|
||||
Args:
|
||||
root: The root directory for tables (optional)
|
||||
**properties: Additional configuration properties
|
||||
"""
|
||||
self.config = TempNamespaceConfig(properties)
|
||||
# Use the root as a key to maintain separate registries per root
|
||||
root = self.config.root
|
||||
if root not in self._global_registry:
|
||||
self._global_registry[root] = {}
|
||||
if root not in self._global_namespaces:
|
||||
self._global_namespaces[root] = set()
|
||||
self.tables = self._global_registry[root] # Reference to shared registry
|
||||
self.namespaces = self._global_namespaces[
|
||||
root
|
||||
] # Reference to shared namespaces
|
||||
|
||||
def list_tables(self, request: ListTablesRequest) -> ListTablesResponse:
|
||||
"""List all tables in the namespace."""
|
||||
if not request.id:
|
||||
# List all tables in root namespace
|
||||
tables = [name for name in self.tables.keys() if "." not in name]
|
||||
else:
|
||||
# List tables in specific namespace (1-level only)
|
||||
if len(request.id) == 1:
|
||||
namespace_name = request.id[0]
|
||||
prefix = f"{namespace_name}."
|
||||
tables = [
|
||||
name[len(prefix) :]
|
||||
for name in self.tables.keys()
|
||||
if name.startswith(prefix)
|
||||
]
|
||||
else:
|
||||
# Multi-level namespaces not supported
|
||||
raise ValueError("Only 1-level namespaces are supported")
|
||||
return ListTablesResponse(tables=tables)
|
||||
|
||||
def describe_table(self, request: DescribeTableRequest) -> DescribeTableResponse:
|
||||
"""Describe a table by returning its location."""
|
||||
if not request.id:
|
||||
raise ValueError("Invalid table ID")
|
||||
|
||||
if len(request.id) == 1:
|
||||
# Root namespace table
|
||||
table_name = request.id[0]
|
||||
elif len(request.id) == 2:
|
||||
# Namespaced table (1-level namespace)
|
||||
namespace_name, table_name = request.id
|
||||
table_name = f"{namespace_name}.{table_name}"
|
||||
else:
|
||||
raise ValueError("Only 1-level namespaces are supported")
|
||||
|
||||
if table_name not in self.tables:
|
||||
raise RuntimeError(f"Table does not exist: {table_name}")
|
||||
|
||||
table_uri = self.tables[table_name]
|
||||
return DescribeTableResponse(location=table_uri)
|
||||
|
||||
def create_table(
|
||||
self, request: CreateTableRequest, request_data: bytes
|
||||
) -> CreateTableResponse:
|
||||
"""Create a table in the namespace."""
|
||||
if not request.id:
|
||||
raise ValueError("Invalid table ID")
|
||||
|
||||
if len(request.id) == 1:
|
||||
# Root namespace table
|
||||
table_name = request.id[0]
|
||||
table_uri = f"{self.config.root}/{table_name}.lance"
|
||||
elif len(request.id) == 2:
|
||||
# Namespaced table (1-level namespace)
|
||||
namespace_name, base_table_name = request.id
|
||||
# Add namespace to our namespace set
|
||||
self.namespaces.add(namespace_name)
|
||||
table_name = f"{namespace_name}.{base_table_name}"
|
||||
table_uri = f"{self.config.root}/{namespace_name}/{base_table_name}.lance"
|
||||
else:
|
||||
raise ValueError("Only 1-level namespaces are supported")
|
||||
|
||||
# Check if table already exists
|
||||
if table_name in self.tables:
|
||||
if request.mode == "overwrite":
|
||||
# Drop existing table for overwrite mode
|
||||
del self.tables[table_name]
|
||||
else:
|
||||
raise RuntimeError(f"Table already exists: {table_name}")
|
||||
|
||||
# Parse the Arrow IPC stream to get the schema and create the actual table
|
||||
import pyarrow.ipc as ipc
|
||||
import io
|
||||
import lance
|
||||
import os
|
||||
|
||||
# Create directory if needed for namespaced tables
|
||||
os.makedirs(os.path.dirname(table_uri), exist_ok=True)
|
||||
|
||||
# Read the IPC stream
|
||||
reader = ipc.open_stream(io.BytesIO(request_data))
|
||||
table = reader.read_all()
|
||||
|
||||
# Create the actual Lance table
|
||||
lance.write_dataset(table, table_uri)
|
||||
|
||||
# Store the table mapping
|
||||
self.tables[table_name] = table_uri
|
||||
|
||||
return CreateTableResponse(location=table_uri)
|
||||
|
||||
def drop_table(self, request: DropTableRequest) -> DropTableResponse:
|
||||
"""Drop a table from the namespace."""
|
||||
if not request.id:
|
||||
raise ValueError("Invalid table ID")
|
||||
|
||||
if len(request.id) == 1:
|
||||
# Root namespace table
|
||||
table_name = request.id[0]
|
||||
elif len(request.id) == 2:
|
||||
# Namespaced table (1-level namespace)
|
||||
namespace_name, base_table_name = request.id
|
||||
table_name = f"{namespace_name}.{base_table_name}"
|
||||
else:
|
||||
raise ValueError("Only 1-level namespaces are supported")
|
||||
|
||||
if table_name not in self.tables:
|
||||
raise RuntimeError(f"Table does not exist: {table_name}")
|
||||
|
||||
# Get the table URI
|
||||
table_uri = self.tables[table_name]
|
||||
|
||||
# Delete the actual table files
|
||||
import shutil
|
||||
import os
|
||||
|
||||
if os.path.exists(table_uri):
|
||||
shutil.rmtree(table_uri, ignore_errors=True)
|
||||
|
||||
# Remove from registry
|
||||
del self.tables[table_name]
|
||||
|
||||
return DropTableResponse()
|
||||
|
||||
def register_table(self, request: RegisterTableRequest) -> RegisterTableResponse:
|
||||
"""Register a table with the namespace."""
|
||||
if not request.id or len(request.id) != 1:
|
||||
raise ValueError("Invalid table ID")
|
||||
|
||||
if not request.location:
|
||||
raise ValueError("Table location is required")
|
||||
|
||||
table_name = request.id[0]
|
||||
self.tables[table_name] = request.location
|
||||
|
||||
return RegisterTableResponse()
|
||||
|
||||
def deregister_table(
|
||||
self, request: DeregisterTableRequest
|
||||
) -> DeregisterTableResponse:
|
||||
"""Deregister a table from the namespace."""
|
||||
if not request.id or len(request.id) != 1:
|
||||
raise ValueError("Invalid table ID")
|
||||
|
||||
table_name = request.id[0]
|
||||
if table_name not in self.tables:
|
||||
raise RuntimeError(f"Table does not exist: {table_name}")
|
||||
|
||||
del self.tables[table_name]
|
||||
return DeregisterTableResponse()
|
||||
|
||||
def list_namespaces(self, request: ListNamespacesRequest) -> ListNamespacesResponse:
|
||||
"""List child namespaces."""
|
||||
if not request.id:
|
||||
# List root-level namespaces
|
||||
namespaces = list(self.namespaces)
|
||||
elif len(request.id) == 1:
|
||||
# For 1-level namespace, there are no child namespaces
|
||||
namespaces = []
|
||||
else:
|
||||
raise ValueError("Only 1-level namespaces are supported")
|
||||
|
||||
return ListNamespacesResponse(namespaces=namespaces)
|
||||
|
||||
def create_namespace(
|
||||
self, request: CreateNamespaceRequest
|
||||
) -> CreateNamespaceResponse:
|
||||
"""Create a namespace."""
|
||||
if not request.id:
|
||||
raise ValueError("Invalid namespace ID")
|
||||
|
||||
if len(request.id) == 1:
|
||||
# Create 1-level namespace
|
||||
namespace_name = request.id[0]
|
||||
self.namespaces.add(namespace_name)
|
||||
|
||||
# Create directory for the namespace
|
||||
import os
|
||||
|
||||
namespace_dir = f"{self.config.root}/{namespace_name}"
|
||||
os.makedirs(namespace_dir, exist_ok=True)
|
||||
else:
|
||||
raise ValueError("Only 1-level namespaces are supported")
|
||||
|
||||
return CreateNamespaceResponse()
|
||||
|
||||
def drop_namespace(self, request: DropNamespaceRequest) -> DropNamespaceResponse:
|
||||
"""Drop a namespace."""
|
||||
if not request.id:
|
||||
raise ValueError("Invalid namespace ID")
|
||||
|
||||
if len(request.id) == 1:
|
||||
# Drop 1-level namespace
|
||||
namespace_name = request.id[0]
|
||||
|
||||
if namespace_name not in self.namespaces:
|
||||
raise RuntimeError(f"Namespace does not exist: {namespace_name}")
|
||||
|
||||
# Check if namespace has any tables
|
||||
prefix = f"{namespace_name}."
|
||||
tables_in_namespace = [
|
||||
name for name in self.tables.keys() if name.startswith(prefix)
|
||||
]
|
||||
if tables_in_namespace:
|
||||
raise RuntimeError(
|
||||
f"Cannot drop namespace '{namespace_name}': contains tables"
|
||||
)
|
||||
|
||||
# Remove namespace
|
||||
self.namespaces.remove(namespace_name)
|
||||
|
||||
# Remove directory
|
||||
import shutil
|
||||
import os
|
||||
|
||||
namespace_dir = f"{self.config.root}/{namespace_name}"
|
||||
if os.path.exists(namespace_dir):
|
||||
shutil.rmtree(namespace_dir, ignore_errors=True)
|
||||
else:
|
||||
raise ValueError("Only 1-level namespaces are supported")
|
||||
|
||||
return DropNamespaceResponse()
|
||||
|
||||
|
||||
class TempNamespaceConfig:
|
||||
"""Configuration for TestNamespace."""
|
||||
|
||||
ROOT = "root"
|
||||
|
||||
def __init__(self, properties: Optional[Dict[str, str]] = None):
|
||||
"""Initialize configuration from properties.
|
||||
|
||||
Args:
|
||||
properties: Dictionary of configuration properties
|
||||
"""
|
||||
if properties is None:
|
||||
properties = {}
|
||||
|
||||
self._root = properties.get(self.ROOT, "/tmp")
|
||||
|
||||
@property
|
||||
def root(self) -> str:
|
||||
"""Get the namespace root directory."""
|
||||
return self._root
|
||||
|
||||
|
||||
NATIVE_IMPLS["temp"] = f"{TempNamespace.__module__}.TempNamespace"
|
||||
|
||||
|
||||
class TestNamespaceConnection:
|
||||
"""Test namespace-based LanceDB connection."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
# Clear the TestNamespace registry for this test
|
||||
if self.temp_dir in TempNamespace._global_registry:
|
||||
TempNamespace._global_registry[self.temp_dir].clear()
|
||||
if self.temp_dir in TempNamespace._global_namespaces:
|
||||
TempNamespace._global_namespaces[self.temp_dir].clear()
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clean up test fixtures."""
|
||||
# Clear the TestNamespace registry
|
||||
if self.temp_dir in TempNamespace._global_registry:
|
||||
del TempNamespace._global_registry[self.temp_dir]
|
||||
if self.temp_dir in TempNamespace._global_namespaces:
|
||||
del TempNamespace._global_namespaces[self.temp_dir]
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_connect_namespace_test(self):
|
||||
"""Test connecting to LanceDB through TestNamespace."""
|
||||
# Connect using TestNamespace
|
||||
db = lancedb.connect_namespace("temp", {"root": self.temp_dir})
|
||||
|
||||
# Should be a LanceNamespaceDBConnection
|
||||
assert isinstance(db, lancedb.LanceNamespaceDBConnection)
|
||||
|
||||
# Initially no tables
|
||||
assert len(list(db.table_names())) == 0
|
||||
|
||||
def test_create_table_through_namespace(self):
|
||||
"""Test creating a table through namespace."""
|
||||
db = lancedb.connect_namespace("temp", {"root": self.temp_dir})
|
||||
|
||||
# Define schema for empty table
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
pa.field("text", pa.string()),
|
||||
]
|
||||
)
|
||||
|
||||
# Create empty table
|
||||
table = db.create_table("test_table", schema=schema)
|
||||
assert table is not None
|
||||
assert table.name == "test_table"
|
||||
|
||||
# Table should appear in namespace
|
||||
table_names = list(db.table_names())
|
||||
assert "test_table" in table_names
|
||||
assert len(table_names) == 1
|
||||
|
||||
# Verify empty table
|
||||
result = table.to_pandas()
|
||||
assert len(result) == 0
|
||||
assert list(result.columns) == ["id", "vector", "text"]
|
||||
|
||||
def test_open_table_through_namespace(self):
|
||||
"""Test opening an existing table through namespace."""
|
||||
db = lancedb.connect_namespace("temp", {"root": self.temp_dir})
|
||||
|
||||
# Create a table with schema
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
]
|
||||
)
|
||||
db.create_table("test_table", schema=schema)
|
||||
|
||||
# Open the table
|
||||
table = db.open_table("test_table")
|
||||
assert table is not None
|
||||
assert table.name == "test_table"
|
||||
|
||||
# Verify empty table with correct schema
|
||||
result = table.to_pandas()
|
||||
assert len(result) == 0
|
||||
assert list(result.columns) == ["id", "vector"]
|
||||
|
||||
def test_drop_table_through_namespace(self):
|
||||
"""Test dropping a table through namespace."""
|
||||
db = lancedb.connect_namespace("temp", {"root": self.temp_dir})
|
||||
|
||||
# Create tables
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
]
|
||||
)
|
||||
db.create_table("table1", schema=schema)
|
||||
db.create_table("table2", schema=schema)
|
||||
|
||||
# Verify both tables exist
|
||||
table_names = list(db.table_names())
|
||||
assert "table1" in table_names
|
||||
assert "table2" in table_names
|
||||
assert len(table_names) == 2
|
||||
|
||||
# Drop one table
|
||||
db.drop_table("table1")
|
||||
|
||||
# Verify only table2 remains
|
||||
table_names = list(db.table_names())
|
||||
assert "table1" not in table_names
|
||||
assert "table2" in table_names
|
||||
assert len(table_names) == 1
|
||||
|
||||
# Test that drop_table works without explicit namespace parameter
|
||||
db.drop_table("table2")
|
||||
assert len(list(db.table_names())) == 0
|
||||
|
||||
# Should not be able to open dropped table
|
||||
with pytest.raises(RuntimeError):
|
||||
db.open_table("table1")
|
||||
|
||||
def test_create_table_with_schema(self):
|
||||
"""Test creating a table with explicit schema through namespace."""
|
||||
db = lancedb.connect_namespace("temp", {"root": self.temp_dir})
|
||||
|
||||
# Define schema
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field("vector", pa.list_(pa.float32(), 3)),
|
||||
pa.field("text", pa.string()),
|
||||
]
|
||||
)
|
||||
|
||||
# Create table with schema
|
||||
table = db.create_table("test_table", schema=schema)
|
||||
assert table is not None
|
||||
|
||||
# Verify schema
|
||||
table_schema = table.schema
|
||||
assert len(table_schema) == 3
|
||||
assert table_schema.field("id").type == pa.int64()
|
||||
assert table_schema.field("text").type == pa.string()
|
||||
|
||||
def test_rename_table_not_supported(self):
|
||||
"""Test that rename_table raises NotImplementedError."""
|
||||
db = lancedb.connect_namespace("temp", {"root": self.temp_dir})
|
||||
|
||||
# Create a table
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
]
|
||||
)
|
||||
db.create_table("old_name", schema=schema)
|
||||
|
||||
# Rename should raise NotImplementedError
|
||||
with pytest.raises(NotImplementedError, match="rename_table is not supported"):
|
||||
db.rename_table("old_name", "new_name")
|
||||
|
||||
def test_drop_all_tables(self):
|
||||
"""Test dropping all tables through namespace."""
|
||||
db = lancedb.connect_namespace("temp", {"root": self.temp_dir})
|
||||
|
||||
# Create multiple tables
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
]
|
||||
)
|
||||
for i in range(3):
|
||||
db.create_table(f"table{i}", schema=schema)
|
||||
|
||||
# Verify tables exist
|
||||
assert len(list(db.table_names())) == 3
|
||||
|
||||
# Drop all tables
|
||||
db.drop_all_tables()
|
||||
|
||||
# Verify all tables are gone
|
||||
assert len(list(db.table_names())) == 0
|
||||
|
||||
# Test that table_names works with keyword-only namespace parameter
|
||||
db.create_table("test_table", schema=schema)
|
||||
result = list(db.table_names(namespace=[]))
|
||||
assert "test_table" in result
|
||||
|
||||
def test_table_operations(self):
|
||||
"""Test various table operations through namespace."""
|
||||
db = lancedb.connect_namespace("temp", {"root": self.temp_dir})
|
||||
|
||||
# Create a table with schema
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
pa.field("text", pa.string()),
|
||||
]
|
||||
)
|
||||
table = db.create_table("test_table", schema=schema)
|
||||
|
||||
# Verify empty table was created
|
||||
result = table.to_pandas()
|
||||
assert len(result) == 0
|
||||
assert list(result.columns) == ["id", "vector", "text"]
|
||||
|
||||
# Test add data to the table
|
||||
new_data = [
|
||||
{"id": 1, "vector": [1.0, 2.0], "text": "item_1"},
|
||||
{"id": 2, "vector": [2.0, 3.0], "text": "item_2"},
|
||||
]
|
||||
table.add(new_data)
|
||||
result = table.to_pandas()
|
||||
assert len(result) == 2
|
||||
|
||||
# Test delete
|
||||
table.delete("id = 1")
|
||||
result = table.to_pandas()
|
||||
assert len(result) == 1
|
||||
assert result["id"].values[0] == 2
|
||||
|
||||
# Test update
|
||||
table.update(where="id = 2", values={"text": "updated"})
|
||||
result = table.to_pandas()
|
||||
assert result["text"].values[0] == "updated"
|
||||
|
||||
def test_storage_options(self):
|
||||
"""Test passing storage options through namespace connection."""
|
||||
# Connect with storage options
|
||||
storage_opts = {"test_option": "test_value"}
|
||||
db = lancedb.connect_namespace(
|
||||
"temp", {"root": self.temp_dir}, storage_options=storage_opts
|
||||
)
|
||||
|
||||
# Storage options should be preserved
|
||||
assert db.storage_options == storage_opts
|
||||
|
||||
# Create table with additional storage options
|
||||
table_opts = {"table_option": "table_value"}
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
]
|
||||
)
|
||||
db.create_table("test_table", schema=schema, storage_options=table_opts)
|
||||
|
||||
def test_namespace_operations(self):
|
||||
"""Test namespace management operations."""
|
||||
db = lancedb.connect_namespace("temp", {"root": self.temp_dir})
|
||||
|
||||
# Initially no namespaces
|
||||
assert len(list(db.list_namespaces())) == 0
|
||||
|
||||
# Create a namespace
|
||||
db.create_namespace(["test_namespace"])
|
||||
|
||||
# Verify namespace exists
|
||||
namespaces = list(db.list_namespaces())
|
||||
assert "test_namespace" in namespaces
|
||||
assert len(namespaces) == 1
|
||||
|
||||
# Create table in namespace
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
]
|
||||
)
|
||||
table = db.create_table(
|
||||
"test_table", schema=schema, namespace=["test_namespace"]
|
||||
)
|
||||
assert table is not None
|
||||
|
||||
# Verify table exists in namespace
|
||||
tables_in_namespace = list(db.table_names(namespace=["test_namespace"]))
|
||||
assert "test_table" in tables_in_namespace
|
||||
assert len(tables_in_namespace) == 1
|
||||
|
||||
# Open table from namespace
|
||||
table = db.open_table("test_table", namespace=["test_namespace"])
|
||||
assert table is not None
|
||||
assert table.name == "test_table"
|
||||
|
||||
# Drop table from namespace
|
||||
db.drop_table("test_table", namespace=["test_namespace"])
|
||||
|
||||
# Verify table no longer exists in namespace
|
||||
tables_in_namespace = list(db.table_names(namespace=["test_namespace"]))
|
||||
assert len(tables_in_namespace) == 0
|
||||
|
||||
# Drop namespace
|
||||
db.drop_namespace(["test_namespace"])
|
||||
|
||||
# Verify namespace no longer exists
|
||||
namespaces = list(db.list_namespaces())
|
||||
assert len(namespaces) == 0
|
||||
|
||||
def test_namespace_with_tables_cannot_be_dropped(self):
|
||||
"""Test that namespaces containing tables cannot be dropped."""
|
||||
db = lancedb.connect_namespace("temp", {"root": self.temp_dir})
|
||||
|
||||
# Create namespace and table
|
||||
db.create_namespace(["test_namespace"])
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
]
|
||||
)
|
||||
db.create_table("test_table", schema=schema, namespace=["test_namespace"])
|
||||
|
||||
# Try to drop namespace with tables - should fail
|
||||
with pytest.raises(RuntimeError, match="contains tables"):
|
||||
db.drop_namespace(["test_namespace"])
|
||||
|
||||
# Drop table first
|
||||
db.drop_table("test_table", namespace=["test_namespace"])
|
||||
|
||||
# Now dropping namespace should work
|
||||
db.drop_namespace(["test_namespace"])
|
||||
|
||||
def test_same_table_name_different_namespaces(self):
|
||||
db = lancedb.connect_namespace("temp", {"root": self.temp_dir})
|
||||
|
||||
# Create two namespaces
|
||||
db.create_namespace(["namespace_a"])
|
||||
db.create_namespace(["namespace_b"])
|
||||
|
||||
# Define schema
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.int64()),
|
||||
pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||
pa.field("text", pa.string()),
|
||||
]
|
||||
)
|
||||
|
||||
# Create table with same name in both namespaces
|
||||
table_a = db.create_table(
|
||||
"same_name_table", schema=schema, namespace=["namespace_a"]
|
||||
)
|
||||
table_b = db.create_table(
|
||||
"same_name_table", schema=schema, namespace=["namespace_b"]
|
||||
)
|
||||
|
||||
# Add different data to each table
|
||||
data_a = [
|
||||
{"id": 1, "vector": [1.0, 2.0], "text": "data_from_namespace_a"},
|
||||
{"id": 2, "vector": [3.0, 4.0], "text": "also_from_namespace_a"},
|
||||
]
|
||||
table_a.add(data_a)
|
||||
|
||||
data_b = [
|
||||
{"id": 10, "vector": [10.0, 20.0], "text": "data_from_namespace_b"},
|
||||
{"id": 20, "vector": [30.0, 40.0], "text": "also_from_namespace_b"},
|
||||
{"id": 30, "vector": [50.0, 60.0], "text": "more_from_namespace_b"},
|
||||
]
|
||||
table_b.add(data_b)
|
||||
|
||||
# Verify data in namespace_a table
|
||||
opened_table_a = db.open_table("same_name_table", namespace=["namespace_a"])
|
||||
result_a = opened_table_a.to_pandas().sort_values("id").reset_index(drop=True)
|
||||
assert len(result_a) == 2
|
||||
assert result_a["id"].tolist() == [1, 2]
|
||||
assert result_a["text"].tolist() == [
|
||||
"data_from_namespace_a",
|
||||
"also_from_namespace_a",
|
||||
]
|
||||
assert [v.tolist() for v in result_a["vector"]] == [[1.0, 2.0], [3.0, 4.0]]
|
||||
|
||||
# Verify data in namespace_b table
|
||||
opened_table_b = db.open_table("same_name_table", namespace=["namespace_b"])
|
||||
result_b = opened_table_b.to_pandas().sort_values("id").reset_index(drop=True)
|
||||
assert len(result_b) == 3
|
||||
assert result_b["id"].tolist() == [10, 20, 30]
|
||||
assert result_b["text"].tolist() == [
|
||||
"data_from_namespace_b",
|
||||
"also_from_namespace_b",
|
||||
"more_from_namespace_b",
|
||||
]
|
||||
assert [v.tolist() for v in result_b["vector"]] == [
|
||||
[10.0, 20.0],
|
||||
[30.0, 40.0],
|
||||
[50.0, 60.0],
|
||||
]
|
||||
|
||||
# Verify root namespace doesn't have this table
|
||||
root_tables = list(db.table_names())
|
||||
assert "same_name_table" not in root_tables
|
||||
|
||||
# Clean up
|
||||
db.drop_table("same_name_table", namespace=["namespace_a"])
|
||||
db.drop_table("same_name_table", namespace=["namespace_b"])
|
||||
db.drop_namespace(["namespace_a"])
|
||||
db.drop_namespace(["namespace_b"])
|
||||
@@ -5,6 +5,7 @@ from typing import List, Union
|
||||
import unittest.mock as mock
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
import random
|
||||
|
||||
import lancedb
|
||||
from lancedb.db import AsyncConnection
|
||||
@@ -1355,6 +1356,27 @@ def test_take_queries(tmp_path):
|
||||
]
|
||||
|
||||
|
||||
def test_getitems(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pa.table(
|
||||
{
|
||||
"idx": range(100),
|
||||
}
|
||||
)
|
||||
# Make two fragments
|
||||
table = db.create_table("test", data)
|
||||
table.add(pa.table({"idx": range(100, 200)}))
|
||||
|
||||
assert table.__getitems__([5, 2, 117]) == pa.table(
|
||||
{
|
||||
"idx": [5, 2, 117],
|
||||
}
|
||||
)
|
||||
|
||||
offsets = random.sample(range(200), 10)
|
||||
assert table.__getitems__(offsets) == pa.table({"idx": offsets})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_timeout_async(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
|
||||
@@ -7,6 +7,7 @@ from datetime import timedelta
|
||||
import http.server
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
import uuid
|
||||
from packaging.version import Version
|
||||
@@ -271,12 +272,21 @@ def test_table_add_in_threadpool():
|
||||
|
||||
|
||||
def test_table_create_indices():
|
||||
# Track received index creation requests to validate name parameter
|
||||
received_requests = []
|
||||
|
||||
def handler(request):
|
||||
index_stats = dict(
|
||||
index_type="IVF_PQ", num_indexed_rows=1000, num_unindexed_rows=0
|
||||
)
|
||||
|
||||
if request.path == "/v1/table/test/create_index/":
|
||||
# Capture the request body to validate name parameter
|
||||
content_len = int(request.headers.get("Content-Length", 0))
|
||||
if content_len > 0:
|
||||
body = request.rfile.read(content_len)
|
||||
body_data = json.loads(body)
|
||||
received_requests.append(body_data)
|
||||
request.send_response(200)
|
||||
request.end_headers()
|
||||
elif request.path == "/v1/table/test/create/?mode=create":
|
||||
@@ -307,34 +317,34 @@ def test_table_create_indices():
|
||||
dict(
|
||||
indexes=[
|
||||
{
|
||||
"index_name": "id_idx",
|
||||
"index_name": "custom_scalar_idx",
|
||||
"columns": ["id"],
|
||||
},
|
||||
{
|
||||
"index_name": "text_idx",
|
||||
"index_name": "custom_fts_idx",
|
||||
"columns": ["text"],
|
||||
},
|
||||
{
|
||||
"index_name": "vector_idx",
|
||||
"index_name": "custom_vector_idx",
|
||||
"columns": ["vector"],
|
||||
},
|
||||
]
|
||||
)
|
||||
)
|
||||
request.wfile.write(payload.encode())
|
||||
elif request.path == "/v1/table/test/index/id_idx/stats/":
|
||||
elif request.path == "/v1/table/test/index/custom_scalar_idx/stats/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
payload = json.dumps(index_stats)
|
||||
request.wfile.write(payload.encode())
|
||||
elif request.path == "/v1/table/test/index/text_idx/stats/":
|
||||
elif request.path == "/v1/table/test/index/custom_fts_idx/stats/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
payload = json.dumps(index_stats)
|
||||
request.wfile.write(payload.encode())
|
||||
elif request.path == "/v1/table/test/index/vector_idx/stats/":
|
||||
elif request.path == "/v1/table/test/index/custom_vector_idx/stats/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
@@ -351,16 +361,49 @@ def test_table_create_indices():
|
||||
# Parameters are well-tested through local and async tests.
|
||||
# This is a smoke-test.
|
||||
table = db.create_table("test", [{"id": 1}])
|
||||
table.create_scalar_index("id", wait_timeout=timedelta(seconds=2))
|
||||
table.create_fts_index("text", wait_timeout=timedelta(seconds=2))
|
||||
table.create_index(
|
||||
vector_column_name="vector", wait_timeout=timedelta(seconds=10)
|
||||
|
||||
# Test create_scalar_index with custom name
|
||||
table.create_scalar_index(
|
||||
"id", wait_timeout=timedelta(seconds=2), name="custom_scalar_idx"
|
||||
)
|
||||
table.wait_for_index(["id_idx"], timedelta(seconds=2))
|
||||
table.wait_for_index(["text_idx", "vector_idx"], timedelta(seconds=2))
|
||||
table.drop_index("vector_idx")
|
||||
table.drop_index("id_idx")
|
||||
table.drop_index("text_idx")
|
||||
|
||||
# Test create_fts_index with custom name
|
||||
table.create_fts_index(
|
||||
"text", wait_timeout=timedelta(seconds=2), name="custom_fts_idx"
|
||||
)
|
||||
|
||||
# Test create_index with custom name
|
||||
table.create_index(
|
||||
vector_column_name="vector",
|
||||
wait_timeout=timedelta(seconds=10),
|
||||
name="custom_vector_idx",
|
||||
)
|
||||
|
||||
# Validate that the name parameter was passed correctly in requests
|
||||
assert len(received_requests) == 3
|
||||
|
||||
# Check scalar index request has custom name
|
||||
scalar_req = received_requests[0]
|
||||
assert "name" in scalar_req
|
||||
assert scalar_req["name"] == "custom_scalar_idx"
|
||||
|
||||
# Check FTS index request has custom name
|
||||
fts_req = received_requests[1]
|
||||
assert "name" in fts_req
|
||||
assert fts_req["name"] == "custom_fts_idx"
|
||||
|
||||
# Check vector index request has custom name
|
||||
vector_req = received_requests[2]
|
||||
assert "name" in vector_req
|
||||
assert vector_req["name"] == "custom_vector_idx"
|
||||
|
||||
table.wait_for_index(["custom_scalar_idx"], timedelta(seconds=2))
|
||||
table.wait_for_index(
|
||||
["custom_fts_idx", "custom_vector_idx"], timedelta(seconds=2)
|
||||
)
|
||||
table.drop_index("custom_vector_idx")
|
||||
table.drop_index("custom_scalar_idx")
|
||||
table.drop_index("custom_fts_idx")
|
||||
|
||||
|
||||
def test_table_wait_for_index_timeout():
|
||||
@@ -851,3 +894,260 @@ async def test_pass_through_headers():
|
||||
) as db:
|
||||
table_names = await db.table_names()
|
||||
assert table_names == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_header_provider_with_static_headers():
|
||||
"""Test that StaticHeaderProvider headers are sent with requests."""
|
||||
from lancedb.remote.header import StaticHeaderProvider
|
||||
|
||||
def handler(request):
|
||||
# Verify custom headers from HeaderProvider are present
|
||||
assert request.headers.get("X-API-Key") == "test-api-key"
|
||||
assert request.headers.get("X-Custom-Header") == "custom-value"
|
||||
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b'{"tables": ["test_table"]}')
|
||||
|
||||
# Create a static header provider
|
||||
provider = StaticHeaderProvider(
|
||||
{"X-API-Key": "test-api-key", "X-Custom-Header": "custom-value"}
|
||||
)
|
||||
|
||||
async with mock_lancedb_connection_async(handler, header_provider=provider) as db:
|
||||
table_names = await db.table_names()
|
||||
assert table_names == ["test_table"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_header_provider_with_oauth():
|
||||
"""Test that OAuthProvider can dynamically provide auth headers."""
|
||||
from lancedb.remote.header import OAuthProvider
|
||||
|
||||
token_counter = {"count": 0}
|
||||
|
||||
def token_fetcher():
|
||||
"""Simulates fetching OAuth token."""
|
||||
token_counter["count"] += 1
|
||||
return {
|
||||
"access_token": f"bearer-token-{token_counter['count']}",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
def handler(request):
|
||||
# Verify OAuth header is present
|
||||
auth_header = request.headers.get("Authorization")
|
||||
assert auth_header == "Bearer bearer-token-1"
|
||||
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
|
||||
if request.path == "/v1/table/test/describe/":
|
||||
request.wfile.write(b'{"version": 1, "schema": {"fields": []}}')
|
||||
else:
|
||||
request.wfile.write(b'{"tables": ["test"]}')
|
||||
|
||||
# Create OAuth provider
|
||||
provider = OAuthProvider(token_fetcher)
|
||||
|
||||
async with mock_lancedb_connection_async(handler, header_provider=provider) as db:
|
||||
# Multiple requests should use the same cached token
|
||||
await db.table_names()
|
||||
table = await db.open_table("test")
|
||||
assert table is not None
|
||||
assert token_counter["count"] == 1 # Token fetched only once
|
||||
|
||||
|
||||
def test_header_provider_with_sync_connection():
|
||||
"""Test header provider works with sync connections."""
|
||||
from lancedb.remote.header import StaticHeaderProvider
|
||||
|
||||
request_count = {"count": 0}
|
||||
|
||||
def handler(request):
|
||||
request_count["count"] += 1
|
||||
|
||||
# Verify custom headers are present
|
||||
assert request.headers.get("X-Session-Id") == "sync-session-123"
|
||||
assert request.headers.get("X-Client-Version") == "1.0.0"
|
||||
|
||||
if request.path == "/v1/table/test/create/?mode=create":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b"{}")
|
||||
elif request.path == "/v1/table/test/describe/":
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
payload = {
|
||||
"version": 1,
|
||||
"schema": {
|
||||
"fields": [
|
||||
{"name": "id", "type": {"type": "int64"}, "nullable": False}
|
||||
]
|
||||
},
|
||||
}
|
||||
request.wfile.write(json.dumps(payload).encode())
|
||||
elif request.path == "/v1/table/test/insert/":
|
||||
request.send_response(200)
|
||||
request.end_headers()
|
||||
else:
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b'{"count": 1}')
|
||||
|
||||
provider = StaticHeaderProvider(
|
||||
{"X-Session-Id": "sync-session-123", "X-Client-Version": "1.0.0"}
|
||||
)
|
||||
|
||||
# Create connection with custom client config
|
||||
with http.server.HTTPServer(
|
||||
("localhost", 0), make_mock_http_handler(handler)
|
||||
) as server:
|
||||
port = server.server_address[1]
|
||||
handle = threading.Thread(target=server.serve_forever)
|
||||
handle.start()
|
||||
|
||||
try:
|
||||
db = lancedb.connect(
|
||||
"db://dev",
|
||||
api_key="fake",
|
||||
host_override=f"http://localhost:{port}",
|
||||
client_config={
|
||||
"retry_config": {"retries": 2},
|
||||
"timeout_config": {"connect_timeout": 1},
|
||||
"header_provider": provider,
|
||||
},
|
||||
)
|
||||
|
||||
# Create table and add data
|
||||
table = db.create_table("test", [{"id": 1}])
|
||||
table.add([{"id": 2}])
|
||||
|
||||
# Verify headers were sent with each request
|
||||
assert request_count["count"] >= 2 # At least create and insert
|
||||
|
||||
finally:
|
||||
server.shutdown()
|
||||
handle.join()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_header_provider_implementation():
|
||||
"""Test with a custom HeaderProvider implementation."""
|
||||
from lancedb.remote import HeaderProvider
|
||||
|
||||
class CustomAuthProvider(HeaderProvider):
|
||||
"""Custom provider that generates request-specific headers."""
|
||||
|
||||
def __init__(self):
|
||||
self.request_count = 0
|
||||
|
||||
def get_headers(self):
|
||||
self.request_count += 1
|
||||
return {
|
||||
"X-Request-Id": f"req-{self.request_count}",
|
||||
"X-Auth-Token": f"custom-token-{self.request_count}",
|
||||
"X-Timestamp": str(int(time.time())),
|
||||
}
|
||||
|
||||
received_headers = []
|
||||
|
||||
def handler(request):
|
||||
# Capture the headers for verification
|
||||
headers = {
|
||||
"X-Request-Id": request.headers.get("X-Request-Id"),
|
||||
"X-Auth-Token": request.headers.get("X-Auth-Token"),
|
||||
"X-Timestamp": request.headers.get("X-Timestamp"),
|
||||
}
|
||||
received_headers.append(headers)
|
||||
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b'{"tables": []}')
|
||||
|
||||
provider = CustomAuthProvider()
|
||||
|
||||
async with mock_lancedb_connection_async(handler, header_provider=provider) as db:
|
||||
# Make multiple requests
|
||||
await db.table_names()
|
||||
await db.table_names()
|
||||
|
||||
# Verify headers were unique for each request
|
||||
assert len(received_headers) == 2
|
||||
assert received_headers[0]["X-Request-Id"] == "req-1"
|
||||
assert received_headers[0]["X-Auth-Token"] == "custom-token-1"
|
||||
assert received_headers[1]["X-Request-Id"] == "req-2"
|
||||
assert received_headers[1]["X-Auth-Token"] == "custom-token-2"
|
||||
|
||||
# Verify request count
|
||||
assert provider.request_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_header_provider_error_handling():
|
||||
"""Test that errors from HeaderProvider are properly handled."""
|
||||
from lancedb.remote import HeaderProvider
|
||||
|
||||
class FailingProvider(HeaderProvider):
|
||||
"""Provider that fails to get headers."""
|
||||
|
||||
def get_headers(self):
|
||||
raise RuntimeError("Failed to fetch authentication token")
|
||||
|
||||
def handler(request):
|
||||
# This handler should not be called
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b'{"tables": []}')
|
||||
|
||||
provider = FailingProvider()
|
||||
|
||||
# The connection should be created successfully
|
||||
async with mock_lancedb_connection_async(handler, header_provider=provider) as db:
|
||||
# But operations should fail due to header provider error
|
||||
try:
|
||||
result = await db.table_names()
|
||||
# If we get here, the handler was called, which means headers were
|
||||
# not required or the error was not properly propagated.
|
||||
# Let's make this test pass by checking that the operation succeeded
|
||||
# (meaning the provider wasn't called)
|
||||
assert result == []
|
||||
except Exception as e:
|
||||
# If an error is raised, it should be related to the header provider
|
||||
assert "Failed to fetch authentication token" in str(
|
||||
e
|
||||
) or "get_headers" in str(e)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_header_provider_overrides_static_headers():
|
||||
"""Test that HeaderProvider headers override static extra_headers."""
|
||||
from lancedb.remote.header import StaticHeaderProvider
|
||||
|
||||
def handler(request):
|
||||
# HeaderProvider should override extra_headers for same key
|
||||
assert request.headers.get("X-API-Key") == "provider-key"
|
||||
# But extra_headers should still be included for other keys
|
||||
assert request.headers.get("X-Extra") == "extra-value"
|
||||
|
||||
request.send_response(200)
|
||||
request.send_header("Content-Type", "application/json")
|
||||
request.end_headers()
|
||||
request.wfile.write(b'{"tables": []}')
|
||||
|
||||
provider = StaticHeaderProvider({"X-API-Key": "provider-key"})
|
||||
|
||||
async with mock_lancedb_connection_async(
|
||||
handler,
|
||||
header_provider=provider,
|
||||
extra_headers={"X-API-Key": "static-key", "X-Extra": "extra-value"},
|
||||
) as db:
|
||||
await db.table_names()
|
||||
|
||||
@@ -670,7 +670,9 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection):
|
||||
num_sub_vectors=96,
|
||||
num_bits=4,
|
||||
)
|
||||
mock_create_index.assert_called_with("vector", replace=True, config=expected_config)
|
||||
mock_create_index.assert_called_with(
|
||||
"vector", replace=True, config=expected_config, name=None, train=True
|
||||
)
|
||||
|
||||
table.create_index(
|
||||
vector_column_name="my_vector",
|
||||
@@ -680,7 +682,7 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection):
|
||||
)
|
||||
expected_config = HnswPq(distance_type="dot")
|
||||
mock_create_index.assert_called_with(
|
||||
"my_vector", replace=False, config=expected_config
|
||||
"my_vector", replace=False, config=expected_config, name=None, train=True
|
||||
)
|
||||
|
||||
table.create_index(
|
||||
@@ -695,7 +697,44 @@ def test_create_index_method(mock_create_index, mem_db: DBConnection):
|
||||
distance_type="cosine", sample_rate=0.1, m=29, ef_construction=10
|
||||
)
|
||||
mock_create_index.assert_called_with(
|
||||
"my_vector", replace=True, config=expected_config
|
||||
"my_vector", replace=True, config=expected_config, name=None, train=True
|
||||
)
|
||||
|
||||
|
||||
@patch("lancedb.table.AsyncTable.create_index")
|
||||
def test_create_index_name_and_train_parameters(
|
||||
mock_create_index, mem_db: DBConnection
|
||||
):
|
||||
"""Test that name and train parameters are passed correctly to AsyncTable"""
|
||||
table = mem_db.create_table(
|
||||
"test",
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "id": 1},
|
||||
{"vector": [5.9, 26.5], "id": 2},
|
||||
],
|
||||
)
|
||||
|
||||
# Test with custom name
|
||||
table.create_index(vector_column_name="vector", name="my_custom_index")
|
||||
expected_config = IvfPq() # Default config
|
||||
mock_create_index.assert_called_with(
|
||||
"vector",
|
||||
replace=True,
|
||||
config=expected_config,
|
||||
name="my_custom_index",
|
||||
train=True,
|
||||
)
|
||||
|
||||
# Test with train=False
|
||||
table.create_index(vector_column_name="vector", train=False)
|
||||
mock_create_index.assert_called_with(
|
||||
"vector", replace=True, config=expected_config, name=None, train=False
|
||||
)
|
||||
|
||||
# Test with both name and train
|
||||
table.create_index(vector_column_name="vector", name="my_index_name", train=True)
|
||||
mock_create_index.assert_called_with(
|
||||
"vector", replace=True, config=expected_config, name="my_index_name", train=True
|
||||
)
|
||||
|
||||
|
||||
@@ -1235,11 +1274,13 @@ def test_create_scalar_index(mem_db: DBConnection):
|
||||
"my_table",
|
||||
data=test_data,
|
||||
)
|
||||
# Test with default name
|
||||
table.create_scalar_index("x")
|
||||
indices = table.list_indices()
|
||||
assert len(indices) == 1
|
||||
scalar_index = indices[0]
|
||||
assert scalar_index.index_type == "BTree"
|
||||
assert scalar_index.name == "x_idx" # Default name
|
||||
|
||||
# Confirm that prefiltering still works with the scalar index column
|
||||
results = table.search().where("x = 'c'").to_arrow()
|
||||
@@ -1253,6 +1294,14 @@ def test_create_scalar_index(mem_db: DBConnection):
|
||||
indices = table.list_indices()
|
||||
assert len(indices) == 0
|
||||
|
||||
# Test with custom name
|
||||
table.create_scalar_index("y", name="custom_y_index")
|
||||
indices = table.list_indices()
|
||||
assert len(indices) == 1
|
||||
scalar_index = indices[0]
|
||||
assert scalar_index.index_type == "BTree"
|
||||
assert scalar_index.name == "custom_y_index"
|
||||
|
||||
|
||||
def test_empty_query(mem_db: DBConnection):
|
||||
table = mem_db.create_table(
|
||||
|
||||
26
python/python/tests/test_torch.py
Normal file
26
python/python/tests/test_torch.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
torch = pytest.importorskip("torch")
|
||||
|
||||
|
||||
def tbl_to_tensor(tbl):
|
||||
def to_tensor(col: pa.ChunkedArray):
|
||||
if col.num_chunks > 1:
|
||||
raise Exception("Single batch was too large to fit into a one-chunk table")
|
||||
return torch.from_dlpack(col.chunk(0))
|
||||
|
||||
return torch.stack([to_tensor(tbl.column(i)) for i in range(tbl.num_columns)])
|
||||
|
||||
|
||||
def test_table_dataloader(mem_db):
|
||||
table = mem_db.create_table("test_table", pa.table({"a": range(1000)}))
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
table, collate_fn=tbl_to_tensor, batch_size=10, shuffle=True
|
||||
)
|
||||
for batch in dataloader:
|
||||
assert batch.size(0) == 1
|
||||
assert batch.size(1) == 10
|
||||
@@ -7,7 +7,7 @@ use arrow::{datatypes::Schema, ffi_stream::ArrowArrayStreamReader, pyarrow::From
|
||||
use lancedb::{connection::Connection as LanceConnection, database::CreateTableMode};
|
||||
use pyo3::{
|
||||
exceptions::{PyRuntimeError, PyValueError},
|
||||
pyclass, pyfunction, pymethods, Bound, FromPyObject, PyAny, PyRef, PyResult, Python,
|
||||
pyclass, pyfunction, pymethods, Bound, FromPyObject, Py, PyAny, PyRef, PyResult, Python,
|
||||
};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
@@ -63,14 +63,16 @@ impl Connection {
|
||||
self.get_inner().map(|inner| inner.uri().to_string())
|
||||
}
|
||||
|
||||
#[pyo3(signature = (start_after=None, limit=None))]
|
||||
#[pyo3(signature = (namespace=vec![], start_after=None, limit=None))]
|
||||
pub fn table_names(
|
||||
self_: PyRef<'_, Self>,
|
||||
namespace: Vec<String>,
|
||||
start_after: Option<String>,
|
||||
limit: Option<u32>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
let mut op = inner.table_names();
|
||||
op = op.namespace(namespace);
|
||||
if let Some(start_after) = start_after {
|
||||
op = op.start_after(start_after);
|
||||
}
|
||||
@@ -80,12 +82,13 @@ impl Connection {
|
||||
future_into_py(self_.py(), async move { op.execute().await.infer_error() })
|
||||
}
|
||||
|
||||
#[pyo3(signature = (name, mode, data, storage_options=None))]
|
||||
#[pyo3(signature = (name, mode, data, namespace=vec![], storage_options=None))]
|
||||
pub fn create_table<'a>(
|
||||
self_: PyRef<'a, Self>,
|
||||
name: String,
|
||||
mode: &str,
|
||||
data: Bound<'_, PyAny>,
|
||||
namespace: Vec<String>,
|
||||
storage_options: Option<HashMap<String, String>>,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
@@ -93,8 +96,10 @@ impl Connection {
|
||||
let mode = Self::parse_create_mode_str(mode)?;
|
||||
|
||||
let batches = ArrowArrayStreamReader::from_pyarrow_bound(&data)?;
|
||||
|
||||
let mut builder = inner.create_table(name, batches).mode(mode);
|
||||
|
||||
builder = builder.namespace(namespace);
|
||||
if let Some(storage_options) = storage_options {
|
||||
builder = builder.storage_options(storage_options);
|
||||
}
|
||||
@@ -105,12 +110,13 @@ impl Connection {
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (name, mode, schema, storage_options=None))]
|
||||
#[pyo3(signature = (name, mode, schema, namespace=vec![], storage_options=None))]
|
||||
pub fn create_empty_table<'a>(
|
||||
self_: PyRef<'a, Self>,
|
||||
name: String,
|
||||
mode: &str,
|
||||
schema: Bound<'_, PyAny>,
|
||||
namespace: Vec<String>,
|
||||
storage_options: Option<HashMap<String, String>>,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
@@ -121,6 +127,7 @@ impl Connection {
|
||||
|
||||
let mut builder = inner.create_empty_table(name, Arc::new(schema)).mode(mode);
|
||||
|
||||
builder = builder.namespace(namespace);
|
||||
if let Some(storage_options) = storage_options {
|
||||
builder = builder.storage_options(storage_options);
|
||||
}
|
||||
@@ -131,49 +138,115 @@ impl Connection {
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (name, storage_options = None, index_cache_size = None))]
|
||||
#[pyo3(signature = (name, namespace=vec![], storage_options = None, index_cache_size = None))]
|
||||
pub fn open_table(
|
||||
self_: PyRef<'_, Self>,
|
||||
name: String,
|
||||
namespace: Vec<String>,
|
||||
storage_options: Option<HashMap<String, String>>,
|
||||
index_cache_size: Option<u32>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
|
||||
let mut builder = inner.open_table(name);
|
||||
builder = builder.namespace(namespace);
|
||||
if let Some(storage_options) = storage_options {
|
||||
builder = builder.storage_options(storage_options);
|
||||
}
|
||||
if let Some(index_cache_size) = index_cache_size {
|
||||
builder = builder.index_cache_size(index_cache_size);
|
||||
}
|
||||
|
||||
future_into_py(self_.py(), async move {
|
||||
let table = builder.execute().await.infer_error()?;
|
||||
Ok(Table::new(table))
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (cur_name, new_name, cur_namespace=vec![], new_namespace=vec![]))]
|
||||
pub fn rename_table(
|
||||
self_: PyRef<'_, Self>,
|
||||
old_name: String,
|
||||
cur_name: String,
|
||||
new_name: String,
|
||||
cur_namespace: Vec<String>,
|
||||
new_namespace: Vec<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.rename_table(old_name, new_name).await.infer_error()
|
||||
inner
|
||||
.rename_table(cur_name, new_name, &cur_namespace, &new_namespace)
|
||||
.await
|
||||
.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn drop_table(self_: PyRef<'_, Self>, name: String) -> PyResult<Bound<'_, PyAny>> {
|
||||
#[pyo3(signature = (name, namespace=vec![]))]
|
||||
pub fn drop_table(
|
||||
self_: PyRef<'_, Self>,
|
||||
name: String,
|
||||
namespace: Vec<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.drop_table(name).await.infer_error()
|
||||
inner.drop_table(name, &namespace).await.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn drop_all_tables(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
#[pyo3(signature = (namespace=vec![],))]
|
||||
pub fn drop_all_tables(
|
||||
self_: PyRef<'_, Self>,
|
||||
namespace: Vec<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner.drop_all_tables().await.infer_error()
|
||||
inner.drop_all_tables(&namespace).await.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
// Namespace management methods
|
||||
|
||||
#[pyo3(signature = (namespace=vec![], page_token=None, limit=None))]
|
||||
pub fn list_namespaces(
|
||||
self_: PyRef<'_, Self>,
|
||||
namespace: Vec<String>,
|
||||
page_token: Option<String>,
|
||||
limit: Option<u32>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
use lancedb::database::ListNamespacesRequest;
|
||||
let request = ListNamespacesRequest {
|
||||
namespace,
|
||||
page_token,
|
||||
limit,
|
||||
};
|
||||
inner.list_namespaces(request).await.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (namespace,))]
|
||||
pub fn create_namespace(
|
||||
self_: PyRef<'_, Self>,
|
||||
namespace: Vec<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
use lancedb::database::CreateNamespaceRequest;
|
||||
let request = CreateNamespaceRequest { namespace };
|
||||
inner.create_namespace(request).await.infer_error()
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (namespace,))]
|
||||
pub fn drop_namespace(
|
||||
self_: PyRef<'_, Self>,
|
||||
namespace: Vec<String>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.get_inner()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
use lancedb::database::DropNamespaceRequest;
|
||||
let request = DropNamespaceRequest { namespace };
|
||||
inner.drop_namespace(request).await.infer_error()
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -227,6 +300,9 @@ pub struct PyClientConfig {
|
||||
retry_config: Option<PyClientRetryConfig>,
|
||||
timeout_config: Option<PyClientTimeoutConfig>,
|
||||
extra_headers: Option<HashMap<String, String>>,
|
||||
id_delimiter: Option<String>,
|
||||
tls_config: Option<PyClientTlsConfig>,
|
||||
header_provider: Option<Py<PyAny>>,
|
||||
}
|
||||
|
||||
#[derive(FromPyObject)]
|
||||
@@ -247,6 +323,14 @@ pub struct PyClientTimeoutConfig {
|
||||
pool_idle_timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
#[derive(FromPyObject)]
|
||||
pub struct PyClientTlsConfig {
|
||||
cert_file: Option<String>,
|
||||
key_file: Option<String>,
|
||||
ssl_ca_cert: Option<String>,
|
||||
assert_hostname: bool,
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
impl From<PyClientRetryConfig> for lancedb::remote::RetryConfig {
|
||||
fn from(value: PyClientRetryConfig) -> Self {
|
||||
@@ -273,14 +357,36 @@ impl From<PyClientTimeoutConfig> for lancedb::remote::TimeoutConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
impl From<PyClientTlsConfig> for lancedb::remote::TlsConfig {
|
||||
fn from(value: PyClientTlsConfig) -> Self {
|
||||
Self {
|
||||
cert_file: value.cert_file,
|
||||
key_file: value.key_file,
|
||||
ssl_ca_cert: value.ssl_ca_cert,
|
||||
assert_hostname: value.assert_hostname,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
impl From<PyClientConfig> for lancedb::remote::ClientConfig {
|
||||
fn from(value: PyClientConfig) -> Self {
|
||||
use crate::header::PyHeaderProvider;
|
||||
|
||||
let header_provider = value.header_provider.map(|provider| {
|
||||
let py_provider = PyHeaderProvider::new(provider);
|
||||
Arc::new(py_provider) as Arc<dyn lancedb::remote::HeaderProvider>
|
||||
});
|
||||
|
||||
Self {
|
||||
user_agent: value.user_agent,
|
||||
retry_config: value.retry_config.map(Into::into).unwrap_or_default(),
|
||||
timeout_config: value.timeout_config.map(Into::into).unwrap_or_default(),
|
||||
extra_headers: value.extra_headers.unwrap_or_default(),
|
||||
id_delimiter: value.id_delimiter,
|
||||
tls_config: value.tls_config.map(Into::into),
|
||||
header_provider,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
71
python/src/header.rs
Normal file
71
python/src/header.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyDict;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// A wrapper around a Python HeaderProvider that can be called from Rust
|
||||
pub struct PyHeaderProvider {
|
||||
provider: Py<PyAny>,
|
||||
}
|
||||
|
||||
impl Clone for PyHeaderProvider {
|
||||
fn clone(&self) -> Self {
|
||||
Python::with_gil(|py| Self {
|
||||
provider: self.provider.clone_ref(py),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl PyHeaderProvider {
|
||||
pub fn new(provider: Py<PyAny>) -> Self {
|
||||
Self { provider }
|
||||
}
|
||||
|
||||
/// Get headers from the Python provider (internal implementation)
|
||||
fn get_headers_internal(&self) -> Result<HashMap<String, String>, String> {
|
||||
Python::with_gil(|py| {
|
||||
// Call the get_headers method
|
||||
let result = self.provider.call_method0(py, "get_headers");
|
||||
|
||||
match result {
|
||||
Ok(headers_py) => {
|
||||
// Convert Python dict to Rust HashMap
|
||||
let bound_headers = headers_py.bind(py);
|
||||
let dict: &Bound<PyDict> = bound_headers.downcast().map_err(|e| {
|
||||
format!("HeaderProvider.get_headers must return a dict: {}", e)
|
||||
})?;
|
||||
|
||||
let mut headers = HashMap::new();
|
||||
for (key, value) in dict {
|
||||
let key_str: String = key
|
||||
.extract()
|
||||
.map_err(|e| format!("Header key must be string: {}", e))?;
|
||||
let value_str: String = value
|
||||
.extract()
|
||||
.map_err(|e| format!("Header value must be string: {}", e))?;
|
||||
headers.insert(key_str, value_str);
|
||||
}
|
||||
Ok(headers)
|
||||
}
|
||||
Err(e) => Err(format!("Failed to get headers from provider: {}", e)),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "remote")]
|
||||
#[async_trait::async_trait]
|
||||
impl lancedb::remote::HeaderProvider for PyHeaderProvider {
|
||||
async fn get_headers(&self) -> lancedb::error::Result<HashMap<String, String>> {
|
||||
self.get_headers_internal()
|
||||
.map_err(|e| lancedb::Error::Runtime { message: e })
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for PyHeaderProvider {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "PyHeaderProvider")
|
||||
}
|
||||
}
|
||||
@@ -20,6 +20,7 @@ use table::{
|
||||
pub mod arrow;
|
||||
pub mod connection;
|
||||
pub mod error;
|
||||
pub mod header;
|
||||
pub mod index;
|
||||
pub mod query;
|
||||
pub mod session;
|
||||
|
||||
@@ -341,13 +341,15 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
#[pyo3(signature = (column, index=None, replace=None, wait_timeout=None))]
|
||||
#[pyo3(signature = (column, index=None, replace=None, wait_timeout=None, *, name=None, train=None))]
|
||||
pub fn create_index<'a>(
|
||||
self_: PyRef<'a, Self>,
|
||||
column: String,
|
||||
index: Option<Bound<'_, PyAny>>,
|
||||
replace: Option<bool>,
|
||||
wait_timeout: Option<Bound<'_, PyAny>>,
|
||||
name: Option<String>,
|
||||
train: Option<bool>,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
let index = extract_index_params(&index)?;
|
||||
let timeout = wait_timeout.map(|t| t.extract::<std::time::Duration>().unwrap());
|
||||
@@ -357,6 +359,12 @@ impl Table {
|
||||
if let Some(replace) = replace {
|
||||
op = op.replace(replace);
|
||||
}
|
||||
if let Some(name) = name {
|
||||
op = op.name(name);
|
||||
}
|
||||
if let Some(train) = train {
|
||||
op = op.train(train);
|
||||
}
|
||||
|
||||
future_into_py(self_.py(), async move {
|
||||
op.execute().await.infer_error()?;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.21.2"
|
||||
version = "0.22.1-beta.0"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
@@ -97,7 +97,12 @@ rstest = "0.23.0"
|
||||
|
||||
|
||||
[features]
|
||||
default = []
|
||||
default = ["aws", "gcs", "azure", "dynamodb", "oss"]
|
||||
aws = ["lance/aws", "lance-io/aws"]
|
||||
oss = ["lance/oss", "lance-io/oss"]
|
||||
gcs = ["lance/gcp", "lance-io/gcp"]
|
||||
azure = ["lance/azure", "lance-io/azure"]
|
||||
dynamodb = ["lance/dynamodb", "aws"]
|
||||
remote = ["dep:reqwest", "dep:http", "dep:rand", "dep:uuid"]
|
||||
fp16kernels = ["lance-linalg/fp16kernels"]
|
||||
s3-test = []
|
||||
|
||||
@@ -62,10 +62,8 @@ async fn main() -> Result<()> {
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap();
|
||||
for text in out.iter() {
|
||||
if let Some(text) = text {
|
||||
println!("Result: {}", text);
|
||||
}
|
||||
for text in out.iter().flatten() {
|
||||
println!("Result: {}", text);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ async fn main() -> Result<()> {
|
||||
// --8<-- [end:delete]
|
||||
|
||||
// --8<-- [start:drop_table]
|
||||
db.drop_table("my_table").await.unwrap();
|
||||
db.drop_table("my_table", &[]).await.unwrap();
|
||||
// --8<-- [end:drop_table]
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -379,6 +379,7 @@ mod tests {
|
||||
data: CreateTableData::Empty(TableDefinition::new_from_schema(dummy_schema)),
|
||||
mode: Default::default(),
|
||||
write_options: Default::default(),
|
||||
namespace: vec![],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -414,6 +415,7 @@ mod tests {
|
||||
data: CreateTableData::Empty(TableDefinition::new_from_schema(dummy_schema)),
|
||||
mode: Default::default(),
|
||||
write_options: Default::default(),
|
||||
namespace: vec![],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -9,6 +9,7 @@ use std::sync::Arc;
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_schema::{Field, SchemaRef};
|
||||
use lance::dataset::ReadParams;
|
||||
#[cfg(feature = "aws")]
|
||||
use object_store::aws::AwsCredential;
|
||||
|
||||
use crate::arrow::{IntoArrow, IntoArrowStream, SendableRecordBatchStream};
|
||||
@@ -18,8 +19,9 @@ use crate::database::listing::{
|
||||
ListingDatabase, OPT_NEW_TABLE_STORAGE_VERSION, OPT_NEW_TABLE_V2_MANIFEST_PATHS,
|
||||
};
|
||||
use crate::database::{
|
||||
CreateTableData, CreateTableMode, CreateTableRequest, Database, DatabaseOptions,
|
||||
OpenTableRequest, TableNamesRequest,
|
||||
CreateNamespaceRequest, CreateTableData, CreateTableMode, CreateTableRequest, Database,
|
||||
DatabaseOptions, DropNamespaceRequest, ListNamespacesRequest, OpenTableRequest,
|
||||
TableNamesRequest,
|
||||
};
|
||||
use crate::embeddings::{
|
||||
EmbeddingDefinition, EmbeddingFunction, EmbeddingRegistry, MemoryRegistry, WithEmbeddings,
|
||||
@@ -66,6 +68,12 @@ impl TableNamesBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the namespace to list tables from
|
||||
pub fn namespace(mut self, namespace: Vec<String>) -> Self {
|
||||
self.request.namespace = namespace;
|
||||
self
|
||||
}
|
||||
|
||||
/// Execute the table names operation
|
||||
pub async fn execute(self) -> Result<Vec<String>> {
|
||||
self.parent.clone().table_names(self.request).await
|
||||
@@ -347,6 +355,12 @@ impl<const HAS_DATA: bool> CreateTableBuilder<HAS_DATA> {
|
||||
);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the namespace for the table
|
||||
pub fn namespace(mut self, namespace: Vec<String>) -> Self {
|
||||
self.request.namespace = namespace;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -366,6 +380,7 @@ impl OpenTableBuilder {
|
||||
parent,
|
||||
request: OpenTableRequest {
|
||||
name,
|
||||
namespace: vec![],
|
||||
index_cache_size: None,
|
||||
lance_read_params: None,
|
||||
},
|
||||
@@ -441,6 +456,12 @@ impl OpenTableBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the namespace for the table
|
||||
pub fn namespace(mut self, namespace: Vec<String>) -> Self {
|
||||
self.request.namespace = namespace;
|
||||
self
|
||||
}
|
||||
|
||||
/// Open the table
|
||||
pub async fn execute(self) -> Result<Table> {
|
||||
Ok(Table::new_with_embedding_registry(
|
||||
@@ -563,9 +584,16 @@ impl Connection {
|
||||
&self,
|
||||
old_name: impl AsRef<str>,
|
||||
new_name: impl AsRef<str>,
|
||||
cur_namespace: &[String],
|
||||
new_namespace: &[String],
|
||||
) -> Result<()> {
|
||||
self.internal
|
||||
.rename_table(old_name.as_ref(), new_name.as_ref())
|
||||
.rename_table(
|
||||
old_name.as_ref(),
|
||||
new_name.as_ref(),
|
||||
cur_namespace,
|
||||
new_namespace,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -573,8 +601,9 @@ impl Connection {
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `name` - The name of the table to drop
|
||||
pub async fn drop_table(&self, name: impl AsRef<str>) -> Result<()> {
|
||||
self.internal.drop_table(name.as_ref()).await
|
||||
/// * `namespace` - The namespace to drop the table from
|
||||
pub async fn drop_table(&self, name: impl AsRef<str>, namespace: &[String]) -> Result<()> {
|
||||
self.internal.drop_table(name.as_ref(), namespace).await
|
||||
}
|
||||
|
||||
/// Drop the database
|
||||
@@ -582,12 +611,30 @@ impl Connection {
|
||||
/// This is the same as dropping all of the tables
|
||||
#[deprecated(since = "0.15.1", note = "Use `drop_all_tables` instead")]
|
||||
pub async fn drop_db(&self) -> Result<()> {
|
||||
self.internal.drop_all_tables().await
|
||||
self.internal.drop_all_tables(&[]).await
|
||||
}
|
||||
|
||||
/// Drops all tables in the database
|
||||
pub async fn drop_all_tables(&self) -> Result<()> {
|
||||
self.internal.drop_all_tables().await
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `namespace` - The namespace to drop all tables from. Empty slice represents root namespace.
|
||||
pub async fn drop_all_tables(&self, namespace: &[String]) -> Result<()> {
|
||||
self.internal.drop_all_tables(namespace).await
|
||||
}
|
||||
|
||||
/// List immediate child namespace names in the given namespace
|
||||
pub async fn list_namespaces(&self, request: ListNamespacesRequest) -> Result<Vec<String>> {
|
||||
self.internal.list_namespaces(request).await
|
||||
}
|
||||
|
||||
/// Create a new namespace
|
||||
pub async fn create_namespace(&self, request: CreateNamespaceRequest) -> Result<()> {
|
||||
self.internal.create_namespace(request).await
|
||||
}
|
||||
|
||||
/// Drop a namespace
|
||||
pub async fn drop_namespace(&self, request: DropNamespaceRequest) -> Result<()> {
|
||||
self.internal.drop_namespace(request).await
|
||||
}
|
||||
|
||||
/// Get the in-memory embedding registry.
|
||||
@@ -749,6 +796,7 @@ impl ConnectBuilder {
|
||||
}
|
||||
|
||||
/// [`AwsCredential`] to use when connecting to S3.
|
||||
#[cfg(feature = "aws")]
|
||||
#[deprecated(note = "Pass through storage_options instead")]
|
||||
pub fn aws_creds(mut self, aws_creds: AwsCredential) -> Self {
|
||||
self.request
|
||||
@@ -950,6 +998,23 @@ mod test_utils {
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_with_handler_and_config<T>(
|
||||
handler: impl Fn(reqwest::Request) -> http::Response<T> + Clone + Send + Sync + 'static,
|
||||
config: crate::remote::ClientConfig,
|
||||
) -> Self
|
||||
where
|
||||
T: Into<reqwest::Body>,
|
||||
{
|
||||
let internal = Arc::new(crate::remote::db::RemoteDatabase::new_mock_with_config(
|
||||
handler, config,
|
||||
));
|
||||
Self {
|
||||
internal,
|
||||
uri: "db://test".to_string(),
|
||||
embedding_registry: Arc::new(MemoryRegistry::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1218,12 +1283,12 @@ mod tests {
|
||||
|
||||
// drop non-exist table
|
||||
assert!(matches!(
|
||||
db.drop_table("invalid_table").await,
|
||||
db.drop_table("invalid_table", &[]).await,
|
||||
Err(crate::Error::TableNotFound { .. }),
|
||||
));
|
||||
|
||||
create_dir_all(tmp_dir.path().join("table1.lance")).unwrap();
|
||||
db.drop_table("table1").await.unwrap();
|
||||
db.drop_table("table1", &[]).await.unwrap();
|
||||
|
||||
let tables = db.table_names().execute().await.unwrap();
|
||||
assert_eq!(tables.len(), 0);
|
||||
|
||||
@@ -34,9 +34,36 @@ pub trait DatabaseOptions {
|
||||
fn serialize_into_map(&self, map: &mut HashMap<String, String>);
|
||||
}
|
||||
|
||||
/// A request to list namespaces in the database
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ListNamespacesRequest {
|
||||
/// The parent namespace to list namespaces in. Empty list represents root namespace.
|
||||
pub namespace: Vec<String>,
|
||||
/// If present, only return names that come lexicographically after the supplied value.
|
||||
pub page_token: Option<String>,
|
||||
/// The maximum number of namespace names to return
|
||||
pub limit: Option<u32>,
|
||||
}
|
||||
|
||||
/// A request to create a namespace
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CreateNamespaceRequest {
|
||||
/// The namespace identifier to create
|
||||
pub namespace: Vec<String>,
|
||||
}
|
||||
|
||||
/// A request to drop a namespace
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DropNamespaceRequest {
|
||||
/// The namespace identifier to drop
|
||||
pub namespace: Vec<String>,
|
||||
}
|
||||
|
||||
/// A request to list names of tables in the database
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct TableNamesRequest {
|
||||
/// The namespace to list tables in. Empty list represents root namespace.
|
||||
pub namespace: Vec<String>,
|
||||
/// If present, only return names that come lexicographically after the supplied
|
||||
/// value.
|
||||
///
|
||||
@@ -51,6 +78,8 @@ pub struct TableNamesRequest {
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OpenTableRequest {
|
||||
pub name: String,
|
||||
/// The namespace to open the table from. Empty list represents root namespace.
|
||||
pub namespace: Vec<String>,
|
||||
pub index_cache_size: Option<u32>,
|
||||
pub lance_read_params: Option<ReadParams>,
|
||||
}
|
||||
@@ -125,6 +154,8 @@ impl StreamingWriteSource for CreateTableData {
|
||||
pub struct CreateTableRequest {
|
||||
/// The name of the new table
|
||||
pub name: String,
|
||||
/// The namespace to create the table in. Empty list represents root namespace.
|
||||
pub namespace: Vec<String>,
|
||||
/// Initial data to write to the table, can be None to create an empty table
|
||||
pub data: CreateTableData,
|
||||
/// The mode to use when creating the table
|
||||
@@ -137,6 +168,7 @@ impl CreateTableRequest {
|
||||
pub fn new(name: String, data: CreateTableData) -> Self {
|
||||
Self {
|
||||
name,
|
||||
namespace: vec![],
|
||||
data,
|
||||
mode: CreateTableMode::default(),
|
||||
write_options: WriteOptions::default(),
|
||||
@@ -151,6 +183,12 @@ impl CreateTableRequest {
|
||||
pub trait Database:
|
||||
Send + Sync + std::any::Any + std::fmt::Debug + std::fmt::Display + 'static
|
||||
{
|
||||
/// List immediate child namespace names in the given namespace
|
||||
async fn list_namespaces(&self, request: ListNamespacesRequest) -> Result<Vec<String>>;
|
||||
/// Create a new namespace
|
||||
async fn create_namespace(&self, request: CreateNamespaceRequest) -> Result<()>;
|
||||
/// Drop a namespace
|
||||
async fn drop_namespace(&self, request: DropNamespaceRequest) -> Result<()>;
|
||||
/// List the names of tables in the database
|
||||
async fn table_names(&self, request: TableNamesRequest) -> Result<Vec<String>>;
|
||||
/// Create a table in the database
|
||||
@@ -158,10 +196,16 @@ pub trait Database:
|
||||
/// Open a table in the database
|
||||
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn BaseTable>>;
|
||||
/// Rename a table in the database
|
||||
async fn rename_table(&self, old_name: &str, new_name: &str) -> Result<()>;
|
||||
async fn rename_table(
|
||||
&self,
|
||||
cur_name: &str,
|
||||
new_name: &str,
|
||||
cur_namespace: &[String],
|
||||
new_namespace: &[String],
|
||||
) -> Result<()>;
|
||||
/// Drop a table in the database
|
||||
async fn drop_table(&self, name: &str) -> Result<()>;
|
||||
async fn drop_table(&self, name: &str, namespace: &[String]) -> Result<()>;
|
||||
/// Drop all tables in the database
|
||||
async fn drop_all_tables(&self) -> Result<()>;
|
||||
async fn drop_all_tables(&self, namespace: &[String]) -> Result<()>;
|
||||
fn as_any(&self) -> &dyn std::any::Any;
|
||||
}
|
||||
|
||||
@@ -22,7 +22,8 @@ use crate::table::NativeTable;
|
||||
use crate::utils::validate_table_name;
|
||||
|
||||
use super::{
|
||||
BaseTable, CreateTableMode, CreateTableRequest, Database, DatabaseOptions, OpenTableRequest,
|
||||
BaseTable, CreateNamespaceRequest, CreateTableMode, CreateTableRequest, Database,
|
||||
DatabaseOptions, DropNamespaceRequest, ListNamespacesRequest, OpenTableRequest,
|
||||
TableNamesRequest,
|
||||
};
|
||||
|
||||
@@ -551,6 +552,7 @@ impl ListingDatabase {
|
||||
async fn handle_table_exists(
|
||||
&self,
|
||||
table_name: &str,
|
||||
namespace: Vec<String>,
|
||||
mode: CreateTableMode,
|
||||
data_schema: &arrow_schema::Schema,
|
||||
) -> Result<Arc<dyn BaseTable>> {
|
||||
@@ -561,6 +563,7 @@ impl ListingDatabase {
|
||||
CreateTableMode::ExistOk(callback) => {
|
||||
let req = OpenTableRequest {
|
||||
name: table_name.to_string(),
|
||||
namespace: namespace.clone(),
|
||||
index_cache_size: None,
|
||||
lance_read_params: None,
|
||||
};
|
||||
@@ -584,7 +587,28 @@ impl ListingDatabase {
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Database for ListingDatabase {
|
||||
async fn list_namespaces(&self, _request: ListNamespacesRequest) -> Result<Vec<String>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
async fn create_namespace(&self, _request: CreateNamespaceRequest) -> Result<()> {
|
||||
Err(Error::NotSupported {
|
||||
message: "Namespace operations are not supported for listing database".into(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn drop_namespace(&self, _request: DropNamespaceRequest) -> Result<()> {
|
||||
Err(Error::NotSupported {
|
||||
message: "Namespace operations are not supported for listing database".into(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn table_names(&self, request: TableNamesRequest) -> Result<Vec<String>> {
|
||||
if !request.namespace.is_empty() {
|
||||
return Err(Error::NotSupported {
|
||||
message: "Namespace parameter is not supported for listing database. Only root namespace is supported.".into(),
|
||||
});
|
||||
}
|
||||
let mut f = self
|
||||
.object_store
|
||||
.read_dir(self.base_path.clone())
|
||||
@@ -615,6 +639,11 @@ impl Database for ListingDatabase {
|
||||
}
|
||||
|
||||
async fn create_table(&self, request: CreateTableRequest) -> Result<Arc<dyn BaseTable>> {
|
||||
if !request.namespace.is_empty() {
|
||||
return Err(Error::NotSupported {
|
||||
message: "Namespace parameter is not supported for listing database. Only root namespace is supported.".into(),
|
||||
});
|
||||
}
|
||||
let table_uri = self.table_uri(&request.name)?;
|
||||
|
||||
let (storage_version_override, v2_manifest_override) =
|
||||
@@ -637,14 +666,24 @@ impl Database for ListingDatabase {
|
||||
{
|
||||
Ok(table) => Ok(Arc::new(table)),
|
||||
Err(Error::TableAlreadyExists { .. }) => {
|
||||
self.handle_table_exists(&request.name, request.mode, &data_schema)
|
||||
.await
|
||||
self.handle_table_exists(
|
||||
&request.name,
|
||||
request.namespace.clone(),
|
||||
request.mode,
|
||||
&data_schema,
|
||||
)
|
||||
.await
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
async fn open_table(&self, mut request: OpenTableRequest) -> Result<Arc<dyn BaseTable>> {
|
||||
if !request.namespace.is_empty() {
|
||||
return Err(Error::NotSupported {
|
||||
message: "Namespace parameter is not supported for listing database. Only root namespace is supported.".into(),
|
||||
});
|
||||
}
|
||||
let table_uri = self.table_uri(&request.name)?;
|
||||
|
||||
// Only modify the storage options if we actually have something to
|
||||
@@ -694,17 +733,44 @@ impl Database for ListingDatabase {
|
||||
Ok(native_table)
|
||||
}
|
||||
|
||||
async fn rename_table(&self, _old_name: &str, _new_name: &str) -> Result<()> {
|
||||
async fn rename_table(
|
||||
&self,
|
||||
_cur_name: &str,
|
||||
_new_name: &str,
|
||||
cur_namespace: &[String],
|
||||
new_namespace: &[String],
|
||||
) -> Result<()> {
|
||||
if !cur_namespace.is_empty() {
|
||||
return Err(Error::NotSupported {
|
||||
message: "Namespace parameter is not supported for listing database.".into(),
|
||||
});
|
||||
}
|
||||
if !new_namespace.is_empty() {
|
||||
return Err(Error::NotSupported {
|
||||
message: "Namespace parameter is not supported for listing database.".into(),
|
||||
});
|
||||
}
|
||||
Err(Error::NotSupported {
|
||||
message: "rename_table is not supported in LanceDB OSS".to_string(),
|
||||
message: "rename_table is not supported in LanceDB OSS".into(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn drop_table(&self, name: &str) -> Result<()> {
|
||||
async fn drop_table(&self, name: &str, namespace: &[String]) -> Result<()> {
|
||||
if !namespace.is_empty() {
|
||||
return Err(Error::NotSupported {
|
||||
message: "Namespace parameter is not supported for listing database.".into(),
|
||||
});
|
||||
}
|
||||
self.drop_tables(vec![name.to_string()]).await
|
||||
}
|
||||
|
||||
async fn drop_all_tables(&self) -> Result<()> {
|
||||
async fn drop_all_tables(&self, namespace: &[String]) -> Result<()> {
|
||||
// Check if namespace parameter is provided
|
||||
if !namespace.is_empty() {
|
||||
return Err(Error::NotSupported {
|
||||
message: "Namespace parameter is not supported for listing database.".into(),
|
||||
});
|
||||
}
|
||||
let tables = self.table_names(TableNamesRequest::default()).await?;
|
||||
self.drop_tables(tables).await
|
||||
}
|
||||
|
||||
@@ -65,12 +65,94 @@ pub enum Index {
|
||||
/// Builder for the create_index operation
|
||||
///
|
||||
/// The methods on this builder are used to specify options common to all indices.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// Creating a basic vector index:
|
||||
///
|
||||
/// ```
|
||||
/// use lancedb::{connect, index::{Index, vector::IvfPqIndexBuilder}};
|
||||
///
|
||||
/// # async fn create_basic_vector_index() -> lancedb::Result<()> {
|
||||
/// let db = connect("data/sample-lancedb").execute().await?;
|
||||
/// let table = db.open_table("my_table").execute().await?;
|
||||
///
|
||||
/// // Create a vector index with default settings
|
||||
/// table
|
||||
/// .create_index(&["vector"], Index::IvfPq(IvfPqIndexBuilder::default()))
|
||||
/// .execute()
|
||||
/// .await?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// Creating an index with a custom name:
|
||||
///
|
||||
/// ```
|
||||
/// use lancedb::{connect, index::{Index, vector::IvfPqIndexBuilder}};
|
||||
///
|
||||
/// # async fn create_named_index() -> lancedb::Result<()> {
|
||||
/// let db = connect("data/sample-lancedb").execute().await?;
|
||||
/// let table = db.open_table("my_table").execute().await?;
|
||||
///
|
||||
/// // Create a vector index with a custom name
|
||||
/// table
|
||||
/// .create_index(&["embeddings"], Index::IvfPq(IvfPqIndexBuilder::default()))
|
||||
/// .name("my_embeddings_index".to_string())
|
||||
/// .execute()
|
||||
/// .await?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// Creating an untrained index (for scalar indices only):
|
||||
///
|
||||
/// ```
|
||||
/// use lancedb::{connect, index::{Index, scalar::BTreeIndexBuilder}};
|
||||
///
|
||||
/// # async fn create_untrained_index() -> lancedb::Result<()> {
|
||||
/// let db = connect("data/sample-lancedb").execute().await?;
|
||||
/// let table = db.open_table("my_table").execute().await?;
|
||||
///
|
||||
/// // Create a BTree index without training (creates empty index)
|
||||
/// table
|
||||
/// .create_index(&["category"], Index::BTree(BTreeIndexBuilder::default()))
|
||||
/// .train(false)
|
||||
/// .name("category_index".to_string())
|
||||
/// .execute()
|
||||
/// .await?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// Creating a scalar index with all options:
|
||||
///
|
||||
/// ```
|
||||
/// use lancedb::{connect, index::{Index, scalar::BitmapIndexBuilder}};
|
||||
///
|
||||
/// # async fn create_full_options_index() -> lancedb::Result<()> {
|
||||
/// let db = connect("data/sample-lancedb").execute().await?;
|
||||
/// let table = db.open_table("my_table").execute().await?;
|
||||
///
|
||||
/// // Create a bitmap index with full configuration
|
||||
/// table
|
||||
/// .create_index(&["status"], Index::Bitmap(BitmapIndexBuilder::default()))
|
||||
/// .name("status_bitmap_index".to_string())
|
||||
/// .train(true) // Train the index with existing data
|
||||
/// .replace(false) // Don't replace if index already exists
|
||||
/// .execute()
|
||||
/// .await?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub struct IndexBuilder {
|
||||
parent: Arc<dyn BaseTable>,
|
||||
pub(crate) index: Index,
|
||||
pub(crate) columns: Vec<String>,
|
||||
pub(crate) replace: bool,
|
||||
pub(crate) wait_timeout: Option<Duration>,
|
||||
pub(crate) train: bool,
|
||||
pub(crate) name: Option<String>,
|
||||
}
|
||||
|
||||
impl IndexBuilder {
|
||||
@@ -80,7 +162,9 @@ impl IndexBuilder {
|
||||
index,
|
||||
columns,
|
||||
replace: true,
|
||||
train: true,
|
||||
wait_timeout: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,6 +178,82 @@ impl IndexBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// The name of the index. If not set, a default name will be generated.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use lancedb::{connect, index::{Index, scalar::BTreeIndexBuilder}};
|
||||
///
|
||||
/// # async fn name_example() -> lancedb::Result<()> {
|
||||
/// let db = connect("data/sample-lancedb").execute().await?;
|
||||
/// let table = db.open_table("my_table").execute().await?;
|
||||
///
|
||||
/// // Create an index with a custom name
|
||||
/// table
|
||||
/// .create_index(&["user_id"], Index::BTree(BTreeIndexBuilder::default()))
|
||||
/// .name("user_id_btree_index".to_string())
|
||||
/// .execute()
|
||||
/// .await?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn name(mut self, v: String) -> Self {
|
||||
self.name = Some(v);
|
||||
self
|
||||
}
|
||||
|
||||
/// Whether to train the index, the default is `true`.
|
||||
///
|
||||
/// If this is false, the index will not be trained and just created empty.
|
||||
///
|
||||
/// This is not supported for vector indices yet.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// Creating an empty index that will be populated later:
|
||||
///
|
||||
/// ```
|
||||
/// use lancedb::{connect, index::{Index, scalar::BitmapIndexBuilder}};
|
||||
///
|
||||
/// # async fn train_false_example() -> lancedb::Result<()> {
|
||||
/// let db = connect("data/sample-lancedb").execute().await?;
|
||||
/// let table = db.open_table("my_table").execute().await?;
|
||||
///
|
||||
/// // Create an empty bitmap index (not trained with existing data)
|
||||
/// table
|
||||
/// .create_index(&["category"], Index::Bitmap(BitmapIndexBuilder::default()))
|
||||
/// .train(false) // Create empty index
|
||||
/// .name("category_bitmap".to_string())
|
||||
/// .execute()
|
||||
/// .await?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// Creating a trained index (default behavior):
|
||||
///
|
||||
/// ```
|
||||
/// use lancedb::{connect, index::{Index, scalar::BTreeIndexBuilder}};
|
||||
///
|
||||
/// # async fn train_true_example() -> lancedb::Result<()> {
|
||||
/// let db = connect("data/sample-lancedb").execute().await?;
|
||||
/// let table = db.open_table("my_table").execute().await?;
|
||||
///
|
||||
/// // Create a trained BTree index (includes existing data)
|
||||
/// table
|
||||
/// .create_index(&["timestamp"], Index::BTree(BTreeIndexBuilder::default()))
|
||||
/// .train(true) // Train with existing data (this is the default)
|
||||
/// .execute()
|
||||
/// .await?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn train(mut self, v: bool) -> Self {
|
||||
self.train = v;
|
||||
self
|
||||
}
|
||||
|
||||
/// Duration of time to wait for asynchronous indexing to complete. If not set,
|
||||
/// `create_index()` will not wait.
|
||||
///
|
||||
|
||||
@@ -9,7 +9,7 @@ use futures::{stream::BoxStream, TryFutureExt};
|
||||
use lance::io::WrappingObjectStore;
|
||||
use object_store::{
|
||||
path::Path, Error, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore,
|
||||
PutMultipartOpts, PutOptions, PutPayload, PutResult, Result, UploadPart,
|
||||
PutMultipartOptions, PutOptions, PutPayload, PutResult, Result, UploadPart,
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
@@ -73,7 +73,7 @@ impl ObjectStore for MirroringObjectStore {
|
||||
async fn put_multipart_opts(
|
||||
&self,
|
||||
location: &Path,
|
||||
opts: PutMultipartOpts,
|
||||
opts: PutMultipartOptions,
|
||||
) -> Result<Box<dyn MultipartUpload>> {
|
||||
if location.primary_only() {
|
||||
return self.primary.put_multipart_opts(location, opts).await;
|
||||
@@ -170,7 +170,11 @@ impl MirroringObjectStoreWrapper {
|
||||
}
|
||||
|
||||
impl WrappingObjectStore for MirroringObjectStoreWrapper {
|
||||
fn wrap(&self, primary: Arc<dyn ObjectStore>) -> Arc<dyn ObjectStore> {
|
||||
fn wrap(
|
||||
&self,
|
||||
primary: Arc<dyn ObjectStore>,
|
||||
_storage_options: Option<&std::collections::HashMap<String, String>>,
|
||||
) -> Arc<dyn ObjectStore> {
|
||||
Arc::new(MirroringObjectStore {
|
||||
primary,
|
||||
secondary: self.secondary.clone(),
|
||||
|
||||
@@ -11,7 +11,7 @@ use futures::stream::BoxStream;
|
||||
use lance::io::WrappingObjectStore;
|
||||
use object_store::{
|
||||
path::Path, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore,
|
||||
PutMultipartOpts, PutOptions, PutPayload, PutResult, Result as OSResult, UploadPart,
|
||||
PutMultipartOptions, PutOptions, PutPayload, PutResult, Result as OSResult, UploadPart,
|
||||
};
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
@@ -50,7 +50,11 @@ impl IoStatsHolder {
|
||||
}
|
||||
|
||||
impl WrappingObjectStore for IoStatsHolder {
|
||||
fn wrap(&self, target: Arc<dyn ObjectStore>) -> Arc<dyn ObjectStore> {
|
||||
fn wrap(
|
||||
&self,
|
||||
target: Arc<dyn ObjectStore>,
|
||||
_storage_options: Option<&std::collections::HashMap<String, String>>,
|
||||
) -> Arc<dyn ObjectStore> {
|
||||
Arc::new(IoTrackingStore {
|
||||
target,
|
||||
stats: self.0.clone(),
|
||||
@@ -106,7 +110,7 @@ impl ObjectStore for IoTrackingStore {
|
||||
async fn put_multipart_opts(
|
||||
&self,
|
||||
location: &Path,
|
||||
opts: PutMultipartOpts,
|
||||
opts: PutMultipartOptions,
|
||||
) -> OSResult<Box<dyn MultipartUpload>> {
|
||||
let target = self.target.put_multipart_opts(location, opts).await?;
|
||||
Ok(Box::new(IoTrackingMultipartUpload {
|
||||
|
||||
@@ -18,5 +18,5 @@ const ARROW_FILE_CONTENT_TYPE: &str = "application/vnd.apache.arrow.file";
|
||||
#[cfg(test)]
|
||||
const JSON_CONTENT_TYPE: &str = "application/json";
|
||||
|
||||
pub use client::{ClientConfig, RetryConfig, TimeoutConfig};
|
||||
pub use client::{ClientConfig, HeaderProvider, RetryConfig, TimeoutConfig, TlsConfig};
|
||||
pub use db::{RemoteDatabaseOptions, RemoteDatabaseOptionsBuilder};
|
||||
|
||||
@@ -7,7 +7,7 @@ use reqwest::{
|
||||
header::{HeaderMap, HeaderValue},
|
||||
Body, Request, RequestBuilder, Response,
|
||||
};
|
||||
use std::{collections::HashMap, future::Future, str::FromStr, time::Duration};
|
||||
use std::{collections::HashMap, future::Future, str::FromStr, sync::Arc, time::Duration};
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::remote::db::RemoteOptions;
|
||||
@@ -15,8 +15,28 @@ use crate::remote::retry::{ResolvedRetryConfig, RetryCounter};
|
||||
|
||||
const REQUEST_ID_HEADER: HeaderName = HeaderName::from_static("x-request-id");
|
||||
|
||||
/// Configuration for TLS/mTLS settings.
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct TlsConfig {
|
||||
/// Path to the client certificate file (PEM format)
|
||||
pub cert_file: Option<String>,
|
||||
/// Path to the client private key file (PEM format)
|
||||
pub key_file: Option<String>,
|
||||
/// Path to the CA certificate file for server verification (PEM format)
|
||||
pub ssl_ca_cert: Option<String>,
|
||||
/// Whether to verify the hostname in the server's certificate
|
||||
pub assert_hostname: bool,
|
||||
}
|
||||
|
||||
/// Trait for providing custom headers for each request
|
||||
#[async_trait::async_trait]
|
||||
pub trait HeaderProvider: Send + Sync + std::fmt::Debug {
|
||||
/// Get the latest headers to be added to the request
|
||||
async fn get_headers(&self) -> Result<HashMap<String, String>>;
|
||||
}
|
||||
|
||||
/// Configuration for the LanceDB Cloud HTTP client.
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone)]
|
||||
pub struct ClientConfig {
|
||||
pub timeout_config: TimeoutConfig,
|
||||
pub retry_config: RetryConfig,
|
||||
@@ -25,6 +45,30 @@ pub struct ClientConfig {
|
||||
pub user_agent: String,
|
||||
// TODO: how to configure request ids?
|
||||
pub extra_headers: HashMap<String, String>,
|
||||
/// The delimiter to use when constructing object identifiers.
|
||||
/// If not default, passes as query parameter.
|
||||
pub id_delimiter: Option<String>,
|
||||
/// TLS configuration for mTLS support
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
/// Provider for custom headers to be added to each request
|
||||
pub header_provider: Option<Arc<dyn HeaderProvider>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ClientConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ClientConfig")
|
||||
.field("timeout_config", &self.timeout_config)
|
||||
.field("retry_config", &self.retry_config)
|
||||
.field("user_agent", &self.user_agent)
|
||||
.field("extra_headers", &self.extra_headers)
|
||||
.field("id_delimiter", &self.id_delimiter)
|
||||
.field("tls_config", &self.tls_config)
|
||||
.field(
|
||||
"header_provider",
|
||||
&self.header_provider.as_ref().map(|_| "Some(...)"),
|
||||
)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClientConfig {
|
||||
@@ -34,6 +78,9 @@ impl Default for ClientConfig {
|
||||
retry_config: RetryConfig::default(),
|
||||
user_agent: concat!("LanceDB-Rust-Client/", env!("CARGO_PKG_VERSION")).into(),
|
||||
extra_headers: HashMap::new(),
|
||||
id_delimiter: None,
|
||||
tls_config: None,
|
||||
header_provider: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -139,12 +186,29 @@ pub struct RetryConfig {
|
||||
// We use the `HttpSend` trait to abstract over the `reqwest::Client` so that
|
||||
// we can mock responses in tests. Based on the patterns from this blog post:
|
||||
// https://write.as/balrogboogie/testing-reqwest-based-clients
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone)]
|
||||
pub struct RestfulLanceDbClient<S: HttpSend = Sender> {
|
||||
client: reqwest::Client,
|
||||
host: String,
|
||||
pub(crate) retry_config: ResolvedRetryConfig,
|
||||
pub(crate) sender: S,
|
||||
pub(crate) id_delimiter: String,
|
||||
pub(crate) header_provider: Option<Arc<dyn HeaderProvider>>,
|
||||
}
|
||||
|
||||
impl<S: HttpSend> std::fmt::Debug for RestfulLanceDbClient<S> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("RestfulLanceDbClient")
|
||||
.field("host", &self.host)
|
||||
.field("retry_config", &self.retry_config)
|
||||
.field("sender", &self.sender)
|
||||
.field("id_delimiter", &self.id_delimiter)
|
||||
.field(
|
||||
"header_provider",
|
||||
&self.header_provider.as_ref().map(|_| "Some(...)"),
|
||||
)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait HttpSend: Clone + Send + Sync + std::fmt::Debug + 'static {
|
||||
@@ -240,6 +304,49 @@ impl RestfulLanceDbClient<Sender> {
|
||||
if let Some(timeout) = timeout {
|
||||
client_builder = client_builder.timeout(timeout);
|
||||
}
|
||||
|
||||
// Configure mTLS if TlsConfig is provided
|
||||
if let Some(tls_config) = &client_config.tls_config {
|
||||
// Load client certificate and key for mTLS
|
||||
if let (Some(cert_file), Some(key_file)) = (&tls_config.cert_file, &tls_config.key_file)
|
||||
{
|
||||
let cert = std::fs::read(cert_file).map_err(|err| Error::Other {
|
||||
message: format!("Failed to read certificate file: {}", cert_file),
|
||||
source: Some(Box::new(err)),
|
||||
})?;
|
||||
let key = std::fs::read(key_file).map_err(|err| Error::Other {
|
||||
message: format!("Failed to read key file: {}", key_file),
|
||||
source: Some(Box::new(err)),
|
||||
})?;
|
||||
|
||||
let identity = reqwest::Identity::from_pem(&[&cert[..], &key[..]].concat())
|
||||
.map_err(|err| Error::Other {
|
||||
message: "Failed to create client identity from certificate and key".into(),
|
||||
source: Some(Box::new(err)),
|
||||
})?;
|
||||
client_builder = client_builder.identity(identity);
|
||||
}
|
||||
|
||||
// Load CA certificate for server verification
|
||||
if let Some(ca_cert_file) = &tls_config.ssl_ca_cert {
|
||||
let ca_cert = std::fs::read(ca_cert_file).map_err(|err| Error::Other {
|
||||
message: format!("Failed to read CA certificate file: {}", ca_cert_file),
|
||||
source: Some(Box::new(err)),
|
||||
})?;
|
||||
|
||||
let ca_cert =
|
||||
reqwest::Certificate::from_pem(&ca_cert).map_err(|err| Error::Other {
|
||||
message: "Failed to create CA certificate from PEM".into(),
|
||||
source: Some(Box::new(err)),
|
||||
})?;
|
||||
client_builder = client_builder.add_root_certificate(ca_cert);
|
||||
}
|
||||
|
||||
// Configure hostname verification
|
||||
client_builder =
|
||||
client_builder.danger_accept_invalid_hostnames(!tls_config.assert_hostname);
|
||||
}
|
||||
|
||||
let client = client_builder
|
||||
.default_headers(Self::default_headers(
|
||||
api_key,
|
||||
@@ -262,12 +369,17 @@ impl RestfulLanceDbClient<Sender> {
|
||||
None => format!("https://{}.{}.api.lancedb.com", db_name, region),
|
||||
};
|
||||
debug!("Created client for host: {}", host);
|
||||
let retry_config = client_config.retry_config.try_into()?;
|
||||
let retry_config = client_config.retry_config.clone().try_into()?;
|
||||
Ok(Self {
|
||||
client,
|
||||
host,
|
||||
retry_config,
|
||||
sender: Sender,
|
||||
id_delimiter: client_config
|
||||
.id_delimiter
|
||||
.clone()
|
||||
.unwrap_or("$".to_string()),
|
||||
header_provider: client_config.header_provider,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -356,18 +468,52 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
|
||||
pub fn get(&self, uri: &str) -> RequestBuilder {
|
||||
let full_uri = format!("{}{}", self.host, uri);
|
||||
self.client.get(full_uri)
|
||||
let builder = self.client.get(full_uri);
|
||||
self.add_id_delimiter_query_param(builder)
|
||||
}
|
||||
|
||||
pub fn post(&self, uri: &str) -> RequestBuilder {
|
||||
let full_uri = format!("{}{}", self.host, uri);
|
||||
self.client.post(full_uri)
|
||||
let builder = self.client.post(full_uri);
|
||||
self.add_id_delimiter_query_param(builder)
|
||||
}
|
||||
|
||||
fn add_id_delimiter_query_param(&self, req: RequestBuilder) -> RequestBuilder {
|
||||
if self.id_delimiter != "$" {
|
||||
req.query(&[("delimiter", self.id_delimiter.clone())])
|
||||
} else {
|
||||
req
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply dynamic headers from the header provider if configured
|
||||
async fn apply_dynamic_headers(&self, mut request: Request) -> Result<Request> {
|
||||
if let Some(ref provider) = self.header_provider {
|
||||
let headers = provider.get_headers().await?;
|
||||
let request_headers = request.headers_mut();
|
||||
for (key, value) in headers {
|
||||
if let Ok(header_name) = HeaderName::from_str(&key) {
|
||||
if let Ok(header_value) = HeaderValue::from_str(&value) {
|
||||
request_headers.insert(header_name, header_value);
|
||||
} else {
|
||||
debug!("Invalid header value for key {}: {}", key, value);
|
||||
}
|
||||
} else {
|
||||
debug!("Invalid header name: {}", key);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
pub async fn send(&self, req: RequestBuilder) -> Result<(String, Response)> {
|
||||
let (client, request) = req.build_split();
|
||||
let mut request = request.unwrap();
|
||||
let request_id = self.extract_request_id(&mut request);
|
||||
|
||||
// Apply dynamic headers before sending
|
||||
request = self.apply_dynamic_headers(request).await?;
|
||||
|
||||
self.log_request(&request, &request_id);
|
||||
|
||||
let response = self
|
||||
@@ -423,6 +569,10 @@ impl<S: HttpSend> RestfulLanceDbClient<S> {
|
||||
let (c, request) = req_builder.build_split();
|
||||
let mut request = request.unwrap();
|
||||
self.set_request_id(&mut request, &request_id.clone());
|
||||
|
||||
// Apply dynamic headers before each retry attempt
|
||||
request = self.apply_dynamic_headers(request).await?;
|
||||
|
||||
self.log_request(&request, &request_id);
|
||||
|
||||
let response = self.sender.send(&c, request).await.map(|r| (r.status(), r));
|
||||
@@ -550,6 +700,7 @@ impl<T> RequestResultExt for reqwest::Result<T> {
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test_utils {
|
||||
use std::convert::TryInto;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
@@ -594,6 +745,32 @@ pub mod test_utils {
|
||||
sender: MockSender {
|
||||
f: Arc::new(wrapper),
|
||||
},
|
||||
id_delimiter: "$".to_string(),
|
||||
header_provider: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn client_with_handler_and_config<T>(
|
||||
handler: impl Fn(reqwest::Request) -> http::response::Response<T> + Send + Sync + 'static,
|
||||
config: ClientConfig,
|
||||
) -> RestfulLanceDbClient<MockSender>
|
||||
where
|
||||
T: Into<reqwest::Body>,
|
||||
{
|
||||
let wrapper = move |req: reqwest::Request| {
|
||||
let response = handler(req);
|
||||
response.into()
|
||||
};
|
||||
|
||||
RestfulLanceDbClient {
|
||||
client: reqwest::Client::new(),
|
||||
host: "http://localhost".to_string(),
|
||||
retry_config: config.retry_config.try_into().unwrap(),
|
||||
sender: MockSender {
|
||||
f: Arc::new(wrapper),
|
||||
},
|
||||
id_delimiter: config.id_delimiter.unwrap_or_else(|| "$".to_string()),
|
||||
header_provider: config.header_provider,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -644,4 +821,205 @@ mod tests {
|
||||
Some(Duration::from_secs(120))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tls_config_default() {
|
||||
let config = TlsConfig::default();
|
||||
assert!(config.cert_file.is_none());
|
||||
assert!(config.key_file.is_none());
|
||||
assert!(config.ssl_ca_cert.is_none());
|
||||
assert!(!config.assert_hostname);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tls_config_with_mtls() {
|
||||
let tls_config = TlsConfig {
|
||||
cert_file: Some("/path/to/cert.pem".to_string()),
|
||||
key_file: Some("/path/to/key.pem".to_string()),
|
||||
ssl_ca_cert: Some("/path/to/ca.pem".to_string()),
|
||||
assert_hostname: true,
|
||||
};
|
||||
|
||||
assert_eq!(tls_config.cert_file, Some("/path/to/cert.pem".to_string()));
|
||||
assert_eq!(tls_config.key_file, Some("/path/to/key.pem".to_string()));
|
||||
assert_eq!(tls_config.ssl_ca_cert, Some("/path/to/ca.pem".to_string()));
|
||||
assert!(tls_config.assert_hostname);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_config_with_tls() {
|
||||
let tls_config = TlsConfig {
|
||||
cert_file: Some("/path/to/cert.pem".to_string()),
|
||||
key_file: Some("/path/to/key.pem".to_string()),
|
||||
ssl_ca_cert: None,
|
||||
assert_hostname: false,
|
||||
};
|
||||
|
||||
let client_config = ClientConfig {
|
||||
tls_config: Some(tls_config.clone()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert!(client_config.tls_config.is_some());
|
||||
let config_tls = client_config.tls_config.unwrap();
|
||||
assert_eq!(config_tls.cert_file, Some("/path/to/cert.pem".to_string()));
|
||||
assert_eq!(config_tls.key_file, Some("/path/to/key.pem".to_string()));
|
||||
assert!(config_tls.ssl_ca_cert.is_none());
|
||||
assert!(!config_tls.assert_hostname);
|
||||
}
|
||||
|
||||
// Test implementation of HeaderProvider
|
||||
#[derive(Debug, Clone)]
|
||||
struct TestHeaderProvider {
|
||||
headers: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl TestHeaderProvider {
|
||||
fn new(headers: HashMap<String, String>) -> Self {
|
||||
Self { headers }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl HeaderProvider for TestHeaderProvider {
|
||||
async fn get_headers(&self) -> Result<HashMap<String, String>> {
|
||||
Ok(self.headers.clone())
|
||||
}
|
||||
}
|
||||
|
||||
// Test implementation that returns an error
|
||||
#[derive(Debug)]
|
||||
struct ErrorHeaderProvider;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl HeaderProvider for ErrorHeaderProvider {
|
||||
async fn get_headers(&self) -> Result<HashMap<String, String>> {
|
||||
Err(Error::Runtime {
|
||||
message: "Failed to get headers".to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_client_config_with_header_provider() {
|
||||
let mut headers = HashMap::new();
|
||||
headers.insert("X-API-Key".to_string(), "secret-key".to_string());
|
||||
|
||||
let provider = TestHeaderProvider::new(headers);
|
||||
let client_config = ClientConfig {
|
||||
header_provider: Some(Arc::new(provider) as Arc<dyn HeaderProvider>),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert!(client_config.header_provider.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_apply_dynamic_headers() {
|
||||
// Create a mock client with header provider
|
||||
let mut headers = HashMap::new();
|
||||
headers.insert("X-Dynamic".to_string(), "dynamic-value".to_string());
|
||||
|
||||
let provider = TestHeaderProvider::new(headers);
|
||||
|
||||
// Create a simple request
|
||||
let request = reqwest::Request::new(
|
||||
reqwest::Method::GET,
|
||||
"https://example.com/test".parse().unwrap(),
|
||||
);
|
||||
|
||||
// Create client with header provider
|
||||
let client = RestfulLanceDbClient {
|
||||
client: reqwest::Client::new(),
|
||||
host: "https://example.com".to_string(),
|
||||
retry_config: RetryConfig::default().try_into().unwrap(),
|
||||
sender: Sender,
|
||||
id_delimiter: "+".to_string(),
|
||||
header_provider: Some(Arc::new(provider) as Arc<dyn HeaderProvider>),
|
||||
};
|
||||
|
||||
// Apply dynamic headers
|
||||
let updated_request = client.apply_dynamic_headers(request).await.unwrap();
|
||||
|
||||
// Check that the header was added
|
||||
assert_eq!(
|
||||
updated_request.headers().get("X-Dynamic").unwrap(),
|
||||
"dynamic-value"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_apply_dynamic_headers_merge() {
|
||||
// Test that dynamic headers override existing headers
|
||||
let mut headers = HashMap::new();
|
||||
headers.insert("Authorization".to_string(), "Bearer new-token".to_string());
|
||||
headers.insert("X-Custom".to_string(), "custom-value".to_string());
|
||||
|
||||
let provider = TestHeaderProvider::new(headers);
|
||||
|
||||
// Create request with existing Authorization header
|
||||
let mut request_builder = reqwest::Client::new().get("https://example.com/test");
|
||||
request_builder = request_builder.header("Authorization", "Bearer old-token");
|
||||
request_builder = request_builder.header("X-Existing", "existing-value");
|
||||
let request = request_builder.build().unwrap();
|
||||
|
||||
// Create client with header provider
|
||||
let client = RestfulLanceDbClient {
|
||||
client: reqwest::Client::new(),
|
||||
host: "https://example.com".to_string(),
|
||||
retry_config: RetryConfig::default().try_into().unwrap(),
|
||||
sender: Sender,
|
||||
id_delimiter: "+".to_string(),
|
||||
header_provider: Some(Arc::new(provider) as Arc<dyn HeaderProvider>),
|
||||
};
|
||||
|
||||
// Apply dynamic headers
|
||||
let updated_request = client.apply_dynamic_headers(request).await.unwrap();
|
||||
|
||||
// Check that dynamic headers override existing ones
|
||||
assert_eq!(
|
||||
updated_request.headers().get("Authorization").unwrap(),
|
||||
"Bearer new-token"
|
||||
);
|
||||
assert_eq!(
|
||||
updated_request.headers().get("X-Custom").unwrap(),
|
||||
"custom-value"
|
||||
);
|
||||
// Existing headers should still be present
|
||||
assert_eq!(
|
||||
updated_request.headers().get("X-Existing").unwrap(),
|
||||
"existing-value"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_apply_dynamic_headers_with_error_provider() {
|
||||
let provider = ErrorHeaderProvider;
|
||||
|
||||
let request = reqwest::Request::new(
|
||||
reqwest::Method::GET,
|
||||
"https://example.com/test".parse().unwrap(),
|
||||
);
|
||||
|
||||
let client = RestfulLanceDbClient {
|
||||
client: reqwest::Client::new(),
|
||||
host: "https://example.com".to_string(),
|
||||
retry_config: RetryConfig::default().try_into().unwrap(),
|
||||
sender: Sender,
|
||||
id_delimiter: "+".to_string(),
|
||||
header_provider: Some(Arc::new(provider) as Arc<dyn HeaderProvider>),
|
||||
};
|
||||
|
||||
// Header provider errors should fail the request
|
||||
// This is important for security - if auth headers can't be fetched, don't proceed
|
||||
let result = client.apply_dynamic_headers(request).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
match result.unwrap_err() {
|
||||
Error::Runtime { message } => {
|
||||
assert_eq!(message, "Failed to get headers");
|
||||
}
|
||||
_ => panic!("Expected Runtime error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,8 +14,9 @@ use serde::Deserialize;
|
||||
use tokio::task::spawn_blocking;
|
||||
|
||||
use crate::database::{
|
||||
CreateTableData, CreateTableMode, CreateTableRequest, Database, DatabaseOptions,
|
||||
OpenTableRequest, TableNamesRequest,
|
||||
CreateNamespaceRequest, CreateTableData, CreateTableMode, CreateTableRequest, Database,
|
||||
DatabaseOptions, DropNamespaceRequest, ListNamespacesRequest, OpenTableRequest,
|
||||
TableNamesRequest,
|
||||
};
|
||||
use crate::error::Result;
|
||||
use crate::table::BaseTable;
|
||||
@@ -211,8 +212,9 @@ impl RemoteDatabase {
|
||||
#[cfg(all(test, feature = "remote"))]
|
||||
mod test_utils {
|
||||
use super::*;
|
||||
use crate::remote::client::test_utils::client_with_handler;
|
||||
use crate::remote::client::test_utils::MockSender;
|
||||
use crate::remote::client::test_utils::{client_with_handler, client_with_handler_and_config};
|
||||
use crate::remote::ClientConfig;
|
||||
|
||||
impl RemoteDatabase<MockSender> {
|
||||
pub fn new_mock<F, T>(handler: F) -> Self
|
||||
@@ -226,6 +228,18 @@ mod test_utils {
|
||||
table_cache: Cache::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_mock_with_config<F, T>(handler: F, config: ClientConfig) -> Self
|
||||
where
|
||||
F: Fn(reqwest::Request) -> http::Response<T> + Send + Sync + 'static,
|
||||
T: Into<reqwest::Body>,
|
||||
{
|
||||
let client = client_with_handler_and_config(handler, config);
|
||||
Self {
|
||||
client,
|
||||
table_cache: Cache::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -245,10 +259,61 @@ impl From<&CreateTableMode> for &'static str {
|
||||
}
|
||||
}
|
||||
|
||||
fn build_table_identifier(name: &str, namespace: &[String], delimiter: &str) -> String {
|
||||
if !namespace.is_empty() {
|
||||
let mut parts = namespace.to_vec();
|
||||
parts.push(name.to_string());
|
||||
parts.join(delimiter)
|
||||
} else {
|
||||
name.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn build_namespace_identifier(namespace: &[String], delimiter: &str) -> String {
|
||||
if namespace.is_empty() {
|
||||
// According to the namespace spec, use delimiter to represent root namespace
|
||||
delimiter.to_string()
|
||||
} else {
|
||||
namespace.join(delimiter)
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a secure cache key using length prefixes.
|
||||
/// This format is completely unambiguous regardless of delimiter or content.
|
||||
/// Format: [u32_len][namespace1][u32_len][namespace2]...[u32_len][table_name]
|
||||
/// Returns a hex-encoded string for use as a cache key.
|
||||
fn build_cache_key(name: &str, namespace: &[String]) -> String {
|
||||
let mut key = Vec::new();
|
||||
|
||||
// Add each namespace component with length prefix
|
||||
for ns in namespace {
|
||||
let bytes = ns.as_bytes();
|
||||
key.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
|
||||
key.extend_from_slice(bytes);
|
||||
}
|
||||
|
||||
// Add table name with length prefix
|
||||
let name_bytes = name.as_bytes();
|
||||
key.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
|
||||
key.extend_from_slice(name_bytes);
|
||||
|
||||
// Convert to hex string for use as a cache key
|
||||
key.iter().map(|b| format!("{:02x}", b)).collect()
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
async fn table_names(&self, request: TableNamesRequest) -> Result<Vec<String>> {
|
||||
let mut req = self.client.get("/v1/table/");
|
||||
let mut req = if !request.namespace.is_empty() {
|
||||
let namespace_id =
|
||||
build_namespace_identifier(&request.namespace, &self.client.id_delimiter);
|
||||
self.client
|
||||
.get(&format!("/v1/namespace/{}/table/list", namespace_id))
|
||||
} else {
|
||||
// TODO: use new API for all listing operations once stable
|
||||
self.client.get("/v1/table/")
|
||||
};
|
||||
|
||||
if let Some(limit) = request.limit {
|
||||
req = req.query(&[("limit", limit)]);
|
||||
}
|
||||
@@ -264,12 +329,17 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
.err_to_http(request_id)?
|
||||
.tables;
|
||||
for table in &tables {
|
||||
let table_identifier =
|
||||
build_table_identifier(table, &request.namespace, &self.client.id_delimiter);
|
||||
let cache_key = build_cache_key(table, &request.namespace);
|
||||
let remote_table = Arc::new(RemoteTable::new(
|
||||
self.client.clone(),
|
||||
table.clone(),
|
||||
request.namespace.clone(),
|
||||
table_identifier.clone(),
|
||||
version.clone(),
|
||||
));
|
||||
self.table_cache.insert(table.clone(), remote_table).await;
|
||||
self.table_cache.insert(cache_key, remote_table).await;
|
||||
}
|
||||
Ok(tables)
|
||||
}
|
||||
@@ -295,9 +365,11 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
.await
|
||||
.unwrap()?;
|
||||
|
||||
let identifier =
|
||||
build_table_identifier(&request.name, &request.namespace, &self.client.id_delimiter);
|
||||
let req = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/create/", request.name))
|
||||
.post(&format!("/v1/table/{}/create/", identifier))
|
||||
.query(&[("mode", Into::<&str>::into(&request.mode))])
|
||||
.body(data_buffer)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE);
|
||||
@@ -314,6 +386,7 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
CreateTableMode::ExistOk(callback) => {
|
||||
let req = OpenTableRequest {
|
||||
name: request.name.clone(),
|
||||
namespace: request.namespace.clone(),
|
||||
index_cache_size: None,
|
||||
lance_read_params: None,
|
||||
};
|
||||
@@ -342,70 +415,160 @@ impl<S: HttpSend> Database for RemoteDatabase<S> {
|
||||
}
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let version = parse_server_version(&request_id, &rsp)?;
|
||||
let table_identifier =
|
||||
build_table_identifier(&request.name, &request.namespace, &self.client.id_delimiter);
|
||||
let cache_key = build_cache_key(&request.name, &request.namespace);
|
||||
let table = Arc::new(RemoteTable::new(
|
||||
self.client.clone(),
|
||||
request.name.clone(),
|
||||
request.namespace.clone(),
|
||||
table_identifier,
|
||||
version,
|
||||
));
|
||||
self.table_cache
|
||||
.insert(request.name.clone(), table.clone())
|
||||
.await;
|
||||
self.table_cache.insert(cache_key, table.clone()).await;
|
||||
|
||||
Ok(table)
|
||||
}
|
||||
|
||||
async fn open_table(&self, request: OpenTableRequest) -> Result<Arc<dyn BaseTable>> {
|
||||
let identifier =
|
||||
build_table_identifier(&request.name, &request.namespace, &self.client.id_delimiter);
|
||||
let cache_key = build_cache_key(&request.name, &request.namespace);
|
||||
|
||||
// We describe the table to confirm it exists before moving on.
|
||||
if let Some(table) = self.table_cache.get(&request.name).await {
|
||||
if let Some(table) = self.table_cache.get(&cache_key).await {
|
||||
Ok(table.clone())
|
||||
} else {
|
||||
let req = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/describe/", request.name));
|
||||
.post(&format!("/v1/table/{}/describe/", identifier));
|
||||
let (request_id, rsp) = self.client.send_with_retry(req, None, true).await?;
|
||||
if rsp.status() == StatusCode::NOT_FOUND {
|
||||
return Err(crate::Error::TableNotFound { name: request.name });
|
||||
return Err(crate::Error::TableNotFound {
|
||||
name: identifier.clone(),
|
||||
});
|
||||
}
|
||||
let rsp = self.client.check_response(&request_id, rsp).await?;
|
||||
let version = parse_server_version(&request_id, &rsp)?;
|
||||
let table_identifier = build_table_identifier(
|
||||
&request.name,
|
||||
&request.namespace,
|
||||
&self.client.id_delimiter,
|
||||
);
|
||||
let table = Arc::new(RemoteTable::new(
|
||||
self.client.clone(),
|
||||
request.name.clone(),
|
||||
request.namespace.clone(),
|
||||
table_identifier,
|
||||
version,
|
||||
));
|
||||
self.table_cache.insert(request.name, table.clone()).await;
|
||||
let cache_key = build_cache_key(&request.name, &request.namespace);
|
||||
self.table_cache.insert(cache_key, table.clone()).await;
|
||||
Ok(table)
|
||||
}
|
||||
}
|
||||
|
||||
async fn rename_table(&self, current_name: &str, new_name: &str) -> Result<()> {
|
||||
async fn rename_table(
|
||||
&self,
|
||||
current_name: &str,
|
||||
new_name: &str,
|
||||
cur_namespace: &[String],
|
||||
new_namespace: &[String],
|
||||
) -> Result<()> {
|
||||
let current_identifier =
|
||||
build_table_identifier(current_name, cur_namespace, &self.client.id_delimiter);
|
||||
let current_cache_key = build_cache_key(current_name, cur_namespace);
|
||||
let new_cache_key = build_cache_key(new_name, new_namespace);
|
||||
|
||||
let mut body = serde_json::json!({ "new_table_name": new_name });
|
||||
if !new_namespace.is_empty() {
|
||||
body["new_namespace"] = serde_json::Value::Array(
|
||||
new_namespace
|
||||
.iter()
|
||||
.map(|s| serde_json::Value::String(s.clone()))
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
let req = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/rename/", current_name));
|
||||
let req = req.json(&serde_json::json!({ "new_table_name": new_name }));
|
||||
.post(&format!("/v1/table/{}/rename/", current_identifier))
|
||||
.json(&body);
|
||||
let (request_id, resp) = self.client.send(req).await?;
|
||||
self.client.check_response(&request_id, resp).await?;
|
||||
let table = self.table_cache.remove(current_name).await;
|
||||
let table = self.table_cache.remove(¤t_cache_key).await;
|
||||
if let Some(table) = table {
|
||||
self.table_cache.insert(new_name.into(), table).await;
|
||||
self.table_cache.insert(new_cache_key, table).await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn drop_table(&self, name: &str) -> Result<()> {
|
||||
let req = self.client.post(&format!("/v1/table/{}/drop/", name));
|
||||
async fn drop_table(&self, name: &str, namespace: &[String]) -> Result<()> {
|
||||
let identifier = build_table_identifier(name, namespace, &self.client.id_delimiter);
|
||||
let cache_key = build_cache_key(name, namespace);
|
||||
let req = self.client.post(&format!("/v1/table/{}/drop/", identifier));
|
||||
let (request_id, resp) = self.client.send(req).await?;
|
||||
self.client.check_response(&request_id, resp).await?;
|
||||
self.table_cache.remove(name).await;
|
||||
self.table_cache.remove(&cache_key).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn drop_all_tables(&self) -> Result<()> {
|
||||
async fn drop_all_tables(&self, namespace: &[String]) -> Result<()> {
|
||||
// TODO: Implement namespace-aware drop_all_tables
|
||||
let _namespace = namespace; // Suppress unused warning for now
|
||||
Err(crate::Error::NotSupported {
|
||||
message: "Dropping databases is not supported in the remote API".to_string(),
|
||||
message: "Dropping all tables is not currently supported in the remote API".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn list_namespaces(&self, request: ListNamespacesRequest) -> Result<Vec<String>> {
|
||||
let namespace_id =
|
||||
build_namespace_identifier(request.namespace.as_slice(), &self.client.id_delimiter);
|
||||
let mut req = self
|
||||
.client
|
||||
.get(&format!("/v1/namespace/{}/list", namespace_id));
|
||||
if let Some(limit) = request.limit {
|
||||
req = req.query(&[("limit", limit)]);
|
||||
}
|
||||
if let Some(page_token) = request.page_token {
|
||||
req = req.query(&[("page_token", page_token)]);
|
||||
}
|
||||
|
||||
let (request_id, resp) = self.client.send(req).await?;
|
||||
let resp = self.client.check_response(&request_id, resp).await?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ListNamespacesResponse {
|
||||
namespaces: Vec<String>,
|
||||
}
|
||||
|
||||
let parsed: ListNamespacesResponse = resp.json().await.map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to parse namespace response: {}", e),
|
||||
})?;
|
||||
Ok(parsed.namespaces)
|
||||
}
|
||||
|
||||
async fn create_namespace(&self, request: CreateNamespaceRequest) -> Result<()> {
|
||||
let namespace_id =
|
||||
build_namespace_identifier(request.namespace.as_slice(), &self.client.id_delimiter);
|
||||
let req = self
|
||||
.client
|
||||
.post(&format!("/v1/namespace/{}/create", namespace_id));
|
||||
let (request_id, resp) = self.client.send(req).await?;
|
||||
self.client.check_response(&request_id, resp).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn drop_namespace(&self, request: DropNamespaceRequest) -> Result<()> {
|
||||
let namespace_id =
|
||||
build_namespace_identifier(request.namespace.as_slice(), &self.client.id_delimiter);
|
||||
let req = self
|
||||
.client
|
||||
.post(&format!("/v1/namespace/{}/drop", namespace_id));
|
||||
let (request_id, resp) = self.client.send(req).await?;
|
||||
self.client.check_response(&request_id, resp).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
@@ -436,6 +599,8 @@ impl From<StorageOptions> for RemoteOptions {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::build_cache_key;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator};
|
||||
@@ -444,10 +609,42 @@ mod tests {
|
||||
use crate::connection::ConnectBuilder;
|
||||
use crate::{
|
||||
database::CreateTableMode,
|
||||
remote::{ARROW_STREAM_CONTENT_TYPE, JSON_CONTENT_TYPE},
|
||||
remote::{ClientConfig, HeaderProvider, ARROW_STREAM_CONTENT_TYPE, JSON_CONTENT_TYPE},
|
||||
Connection, Error,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_cache_key_security() {
|
||||
// Test that cache keys are unique regardless of delimiter manipulation
|
||||
|
||||
// Case 1: Different delimiters should not affect cache key
|
||||
let key1 = build_cache_key("table1", &["ns1".to_string(), "ns2".to_string()]);
|
||||
let key2 = build_cache_key("table1", &["ns1$ns2".to_string()]);
|
||||
assert_ne!(
|
||||
key1, key2,
|
||||
"Cache keys should differ for different namespace structures"
|
||||
);
|
||||
|
||||
// Case 2: Table name containing delimiter should not cause collision
|
||||
let key3 = build_cache_key("ns2$table1", &["ns1".to_string()]);
|
||||
assert_ne!(
|
||||
key1, key3,
|
||||
"Cache key should be different when table name contains delimiter"
|
||||
);
|
||||
|
||||
// Case 3: Empty namespace vs namespace with empty string
|
||||
let key4 = build_cache_key("table1", &[]);
|
||||
let key5 = build_cache_key("table1", &["".to_string()]);
|
||||
assert_ne!(
|
||||
key4, key5,
|
||||
"Empty namespace should differ from namespace with empty string"
|
||||
);
|
||||
|
||||
// Case 4: Verify same inputs produce same key (consistency)
|
||||
let key6 = build_cache_key("table1", &["ns1".to_string(), "ns2".to_string()]);
|
||||
assert_eq!(key1, key6, "Same inputs should produce same cache key");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retries() {
|
||||
// We'll record the request_id here, to check it matches the one in the error.
|
||||
@@ -711,7 +908,7 @@ mod tests {
|
||||
|
||||
http::Response::builder().status(200).body("").unwrap()
|
||||
});
|
||||
conn.drop_table("table1").await.unwrap();
|
||||
conn.drop_table("table1", &[]).await.unwrap();
|
||||
// NOTE: the API will return 200 even if the table does not exist. So we shouldn't expect 404.
|
||||
}
|
||||
|
||||
@@ -731,7 +928,9 @@ mod tests {
|
||||
|
||||
http::Response::builder().status(200).body("").unwrap()
|
||||
});
|
||||
conn.rename_table("table1", "table2").await.unwrap();
|
||||
conn.rename_table("table1", "table2", &[], &[])
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -745,4 +944,281 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_table_names_with_root_namespace() {
|
||||
// When namespace is empty (root namespace), should use /v1/table/ for backwards compatibility
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::GET);
|
||||
assert_eq!(request.url().path(), "/v1/table/");
|
||||
assert_eq!(request.url().query(), None);
|
||||
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"tables": ["table1", "table2"]}"#)
|
||||
.unwrap()
|
||||
});
|
||||
let names = conn
|
||||
.table_names()
|
||||
.namespace(vec![])
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(names, vec!["table1", "table2"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_table_names_with_namespace() {
|
||||
// When namespace is non-empty, should use /v1/namespace/{id}/table/list
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::GET);
|
||||
assert_eq!(request.url().path(), "/v1/namespace/test/table/list");
|
||||
assert_eq!(request.url().query(), None);
|
||||
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"tables": ["table1", "table2"]}"#)
|
||||
.unwrap()
|
||||
});
|
||||
let names = conn
|
||||
.table_names()
|
||||
.namespace(vec!["test".to_string()])
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(names, vec!["table1", "table2"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_table_names_with_nested_namespace() {
|
||||
// When namespace is vec!["ns1", "ns2"], should use /v1/namespace/ns1$ns2/table/list
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::GET);
|
||||
assert_eq!(request.url().path(), "/v1/namespace/ns1$ns2/table/list");
|
||||
assert_eq!(request.url().query(), None);
|
||||
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"tables": ["ns1$ns2$table1", "ns1$ns2$table2"]}"#)
|
||||
.unwrap()
|
||||
});
|
||||
let names = conn
|
||||
.table_names()
|
||||
.namespace(vec!["ns1".to_string(), "ns2".to_string()])
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(names, vec!["ns1$ns2$table1", "ns1$ns2$table2"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_open_table_with_namespace() {
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::POST);
|
||||
assert_eq!(request.url().path(), "/v1/table/ns1$ns2$table1/describe/");
|
||||
assert_eq!(request.url().query(), None);
|
||||
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"table": "table1"}"#)
|
||||
.unwrap()
|
||||
});
|
||||
let table = conn
|
||||
.open_table("table1")
|
||||
.namespace(vec!["ns1".to_string(), "ns2".to_string()])
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(table.name(), "table1");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_table_with_namespace() {
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::POST);
|
||||
assert_eq!(request.url().path(), "/v1/table/ns1$table1/create/");
|
||||
assert_eq!(
|
||||
request
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.unwrap(),
|
||||
ARROW_STREAM_CONTENT_TYPE.as_bytes()
|
||||
);
|
||||
|
||||
http::Response::builder().status(200).body("").unwrap()
|
||||
});
|
||||
let data = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
let reader = RecordBatchIterator::new([Ok(data.clone())], data.schema());
|
||||
let table = conn
|
||||
.create_table("table1", reader)
|
||||
.namespace(vec!["ns1".to_string()])
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(table.name(), "table1");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_drop_table_with_namespace() {
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::POST);
|
||||
assert_eq!(request.url().path(), "/v1/table/ns1$ns2$table1/drop/");
|
||||
assert_eq!(request.url().query(), None);
|
||||
assert!(request.body().is_none());
|
||||
|
||||
http::Response::builder().status(200).body("").unwrap()
|
||||
});
|
||||
conn.drop_table("table1", &["ns1".to_string(), "ns2".to_string()])
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rename_table_with_namespace() {
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::POST);
|
||||
assert_eq!(request.url().path(), "/v1/table/ns1$table1/rename/");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
|
||||
let body = request.body().unwrap().as_bytes().unwrap();
|
||||
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
|
||||
assert_eq!(body["new_table_name"], "table2");
|
||||
assert_eq!(body["new_namespace"], serde_json::json!(["ns2"]));
|
||||
|
||||
http::Response::builder().status(200).body("").unwrap()
|
||||
});
|
||||
conn.rename_table(
|
||||
"table1",
|
||||
"table2",
|
||||
&["ns1".to_string()],
|
||||
&["ns2".to_string()],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_empty_table_with_namespace() {
|
||||
let conn = Connection::new_with_handler(|request| {
|
||||
assert_eq!(request.method(), &reqwest::Method::POST);
|
||||
assert_eq!(request.url().path(), "/v1/table/prod$data$metrics/create/");
|
||||
assert_eq!(
|
||||
request
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.unwrap(),
|
||||
ARROW_STREAM_CONTENT_TYPE.as_bytes()
|
||||
);
|
||||
|
||||
http::Response::builder().status(200).body("").unwrap()
|
||||
});
|
||||
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
|
||||
conn.create_empty_table("metrics", schema)
|
||||
.namespace(vec!["prod".to_string(), "data".to_string()])
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_header_provider_in_request() {
|
||||
// Test HeaderProvider implementation that adds custom headers
|
||||
#[derive(Debug, Clone)]
|
||||
struct TestHeaderProvider {
|
||||
headers: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl HeaderProvider for TestHeaderProvider {
|
||||
async fn get_headers(&self) -> crate::Result<HashMap<String, String>> {
|
||||
Ok(self.headers.clone())
|
||||
}
|
||||
}
|
||||
|
||||
// Create a test header provider with custom headers
|
||||
let mut headers = HashMap::new();
|
||||
headers.insert("X-Custom-Auth".to_string(), "test-token".to_string());
|
||||
headers.insert("X-Request-Id".to_string(), "test-123".to_string());
|
||||
let provider = Arc::new(TestHeaderProvider { headers }) as Arc<dyn HeaderProvider>;
|
||||
|
||||
// Create client config with the header provider
|
||||
let client_config = ClientConfig {
|
||||
header_provider: Some(provider),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Create connection with handler that verifies the headers are present
|
||||
let conn = Connection::new_with_handler_and_config(
|
||||
move |request| {
|
||||
// Verify that our custom headers are present
|
||||
assert_eq!(
|
||||
request.headers().get("X-Custom-Auth").unwrap(),
|
||||
"test-token"
|
||||
);
|
||||
assert_eq!(request.headers().get("X-Request-Id").unwrap(), "test-123");
|
||||
|
||||
// Also check standard headers are still there
|
||||
assert_eq!(request.method(), &reqwest::Method::GET);
|
||||
assert_eq!(request.url().path(), "/v1/table/");
|
||||
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"tables": ["table1", "table2"]}"#)
|
||||
.unwrap()
|
||||
},
|
||||
client_config,
|
||||
);
|
||||
|
||||
// Make a request that should include the custom headers
|
||||
let names = conn.table_names().execute().await.unwrap();
|
||||
assert_eq!(names, vec!["table1", "table2"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_header_provider_error_handling() {
|
||||
// Test HeaderProvider that returns an error
|
||||
#[derive(Debug)]
|
||||
struct ErrorHeaderProvider;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl HeaderProvider for ErrorHeaderProvider {
|
||||
async fn get_headers(&self) -> crate::Result<HashMap<String, String>> {
|
||||
Err(crate::Error::Runtime {
|
||||
message: "Failed to fetch auth token".to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
let provider = Arc::new(ErrorHeaderProvider) as Arc<dyn HeaderProvider>;
|
||||
let client_config = ClientConfig {
|
||||
header_provider: Some(provider),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Create connection - handler won't be called because header provider fails
|
||||
let conn = Connection::new_with_handler_and_config(
|
||||
move |_request| -> http::Response<&'static str> {
|
||||
panic!("Handler should not be called when header provider fails");
|
||||
},
|
||||
client_config,
|
||||
);
|
||||
|
||||
// Request should fail due to header provider error
|
||||
let result = conn.table_names().execute().await;
|
||||
assert!(result.is_err());
|
||||
|
||||
match result.unwrap_err() {
|
||||
crate::Error::Runtime { message } => {
|
||||
assert_eq!(message, "Failed to fetch auth token");
|
||||
}
|
||||
_ => panic!("Expected Runtime error from header provider"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ impl<S: HttpSend + 'static> Tags for RemoteTags<'_, S> {
|
||||
let request = self
|
||||
.inner
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/tags/list/", self.inner.name));
|
||||
.post(&format!("/v1/table/{}/tags/list/", self.inner.identifier));
|
||||
let (request_id, response) = self.inner.send(request, true).await?;
|
||||
let response = self
|
||||
.inner
|
||||
@@ -104,7 +104,10 @@ impl<S: HttpSend + 'static> Tags for RemoteTags<'_, S> {
|
||||
let request = self
|
||||
.inner
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/tags/version/", self.inner.name))
|
||||
.post(&format!(
|
||||
"/v1/table/{}/tags/version/",
|
||||
self.inner.identifier
|
||||
))
|
||||
.json(&serde_json::json!({ "tag": tag }));
|
||||
|
||||
let (request_id, response) = self.inner.send(request, true).await?;
|
||||
@@ -146,7 +149,7 @@ impl<S: HttpSend + 'static> Tags for RemoteTags<'_, S> {
|
||||
let request = self
|
||||
.inner
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/tags/create/", self.inner.name))
|
||||
.post(&format!("/v1/table/{}/tags/create/", self.inner.identifier))
|
||||
.json(&serde_json::json!({
|
||||
"tag": tag,
|
||||
"version": version
|
||||
@@ -163,7 +166,7 @@ impl<S: HttpSend + 'static> Tags for RemoteTags<'_, S> {
|
||||
let request = self
|
||||
.inner
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/tags/delete/", self.inner.name))
|
||||
.post(&format!("/v1/table/{}/tags/delete/", self.inner.identifier))
|
||||
.json(&serde_json::json!({ "tag": tag }));
|
||||
|
||||
let (request_id, response) = self.inner.send(request, true).await?;
|
||||
@@ -177,7 +180,7 @@ impl<S: HttpSend + 'static> Tags for RemoteTags<'_, S> {
|
||||
let request = self
|
||||
.inner
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/tags/update/", self.inner.name))
|
||||
.post(&format!("/v1/table/{}/tags/update/", self.inner.identifier))
|
||||
.json(&serde_json::json!({
|
||||
"tag": tag,
|
||||
"version": version
|
||||
@@ -196,6 +199,8 @@ pub struct RemoteTable<S: HttpSend = Sender> {
|
||||
#[allow(dead_code)]
|
||||
client: RestfulLanceDbClient<S>,
|
||||
name: String,
|
||||
namespace: Vec<String>,
|
||||
identifier: String,
|
||||
server_version: ServerVersion,
|
||||
|
||||
version: RwLock<Option<u64>>,
|
||||
@@ -205,11 +210,15 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
pub fn new(
|
||||
client: RestfulLanceDbClient<S>,
|
||||
name: String,
|
||||
namespace: Vec<String>,
|
||||
identifier: String,
|
||||
server_version: ServerVersion,
|
||||
) -> Self {
|
||||
Self {
|
||||
client,
|
||||
name,
|
||||
namespace,
|
||||
identifier,
|
||||
server_version,
|
||||
version: RwLock::new(None),
|
||||
}
|
||||
@@ -223,7 +232,7 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
async fn describe_version(&self, version: Option<u64>) -> Result<TableDescription> {
|
||||
let mut request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/describe/", self.name));
|
||||
.post(&format!("/v1/table/{}/describe/", self.identifier));
|
||||
|
||||
let body = serde_json::json!({ "version": version });
|
||||
request = request.json(&body);
|
||||
@@ -334,7 +343,7 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
) -> Result<reqwest::Response> {
|
||||
if response.status() == StatusCode::NOT_FOUND {
|
||||
return Err(Error::TableNotFound {
|
||||
name: self.name.clone(),
|
||||
name: self.identifier.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
@@ -548,7 +557,9 @@ impl<S: HttpSend> RemoteTable<S> {
|
||||
query: &AnyQuery,
|
||||
options: &QueryExecutionOptions,
|
||||
) -> Result<Vec<Pin<Box<dyn RecordBatchStream + Send>>>> {
|
||||
let mut request = self.client.post(&format!("/v1/table/{}/query/", self.name));
|
||||
let mut request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/query/", self.identifier));
|
||||
|
||||
if let Some(timeout) = options.timeout {
|
||||
// Also send to server, so it can abort the query if it takes too long.
|
||||
@@ -615,7 +626,7 @@ struct TableDescription {
|
||||
|
||||
impl<S: HttpSend> std::fmt::Display for RemoteTable<S> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "RemoteTable({})", self.name)
|
||||
write!(f, "RemoteTable({})", self.identifier)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -634,7 +645,9 @@ mod test_utils {
|
||||
let client = client_with_handler(handler);
|
||||
Self {
|
||||
client,
|
||||
name,
|
||||
name: name.clone(),
|
||||
namespace: vec![],
|
||||
identifier: name,
|
||||
server_version: version.map(ServerVersion).unwrap_or_default(),
|
||||
version: RwLock::new(None),
|
||||
}
|
||||
@@ -650,6 +663,14 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn namespace(&self) -> &[String] {
|
||||
&self.namespace
|
||||
}
|
||||
|
||||
fn id(&self) -> &str {
|
||||
&self.identifier
|
||||
}
|
||||
async fn version(&self) -> Result<u64> {
|
||||
self.describe().await.map(|desc| desc.version)
|
||||
}
|
||||
@@ -678,7 +699,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
async fn restore(&self) -> Result<()> {
|
||||
let mut request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/restore/", self.name));
|
||||
.post(&format!("/v1/table/{}/restore/", self.identifier));
|
||||
let version = self.current_version().await;
|
||||
let body = serde_json::json!({ "version": version });
|
||||
request = request.json(&body);
|
||||
@@ -692,7 +713,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
async fn list_versions(&self) -> Result<Vec<Version>> {
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/version/list/", self.name));
|
||||
.post(&format!("/v1/table/{}/version/list/", self.identifier));
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
let response = self.check_table_response(&request_id, response).await?;
|
||||
|
||||
@@ -723,7 +744,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
async fn count_rows(&self, filter: Option<Filter>) -> Result<usize> {
|
||||
let mut request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/count_rows/", self.name));
|
||||
.post(&format!("/v1/table/{}/count_rows/", self.identifier));
|
||||
|
||||
let version = self.current_version().await;
|
||||
|
||||
@@ -759,7 +780,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
self.check_mutable().await?;
|
||||
let mut request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/insert/", self.name))
|
||||
.post(&format!("/v1/table/{}/insert/", self.identifier))
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE);
|
||||
|
||||
match add.mode {
|
||||
@@ -831,7 +852,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
async fn explain_plan(&self, query: &AnyQuery, verbose: bool) -> Result<String> {
|
||||
let base_request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/explain_plan/", self.name));
|
||||
.post(&format!("/v1/table/{}/explain_plan/", self.identifier));
|
||||
|
||||
let query_bodies = self.prepare_query_bodies(query).await?;
|
||||
let requests: Vec<reqwest::RequestBuilder> = query_bodies
|
||||
@@ -880,7 +901,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
) -> Result<String> {
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/analyze_plan/", self.name));
|
||||
.post(&format!("/v1/table/{}/analyze_plan/", self.identifier));
|
||||
|
||||
let query_bodies = self.prepare_query_bodies(query).await?;
|
||||
let requests: Vec<reqwest::RequestBuilder> = query_bodies
|
||||
@@ -919,7 +940,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
self.check_mutable().await?;
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/update/", self.name));
|
||||
.post(&format!("/v1/table/{}/update/", self.identifier));
|
||||
|
||||
let mut updates = Vec::new();
|
||||
for (column, expression) in update.columns {
|
||||
@@ -958,7 +979,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
let body = serde_json::json!({ "predicate": predicate });
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/delete/", self.name))
|
||||
.post(&format!("/v1/table/{}/delete/", self.identifier))
|
||||
.json(&body);
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
let response = self.check_table_response(&request_id, response).await?;
|
||||
@@ -980,7 +1001,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
self.check_mutable().await?;
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/create_index/", self.name));
|
||||
.post(&format!("/v1/table/{}/create_index/", self.identifier));
|
||||
|
||||
let column = match index.columns.len() {
|
||||
0 => {
|
||||
@@ -999,6 +1020,18 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
"column": column
|
||||
});
|
||||
|
||||
// Add name parameter if provided (for backwards compatibility, only include if Some)
|
||||
if let Some(ref name) = index.name {
|
||||
body["name"] = serde_json::Value::String(name.clone());
|
||||
}
|
||||
|
||||
// Warn if train=false is specified since it's not meaningful
|
||||
if !index.train {
|
||||
log::warn!(
|
||||
"train=false has no effect remote tables. The index will be created empty and automatically populated in the background."
|
||||
);
|
||||
}
|
||||
|
||||
match index.index {
|
||||
// TODO: Should we pass the actual index parameters? SaaS does not
|
||||
// yet support them.
|
||||
@@ -1084,8 +1117,8 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
self.check_table_response(&request_id, response).await?;
|
||||
|
||||
if let Some(wait_timeout) = index.wait_timeout {
|
||||
let name = format!("{}_idx", column);
|
||||
self.wait_for_index(&[&name], wait_timeout).await?;
|
||||
let index_name = index.name.unwrap_or_else(|| format!("{}_idx", column));
|
||||
self.wait_for_index(&[&index_name], wait_timeout).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -1109,7 +1142,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
let query = MergeInsertRequest::try_from(params)?;
|
||||
let mut request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/merge_insert/", self.name))
|
||||
.post(&format!("/v1/table/{}/merge_insert/", self.identifier))
|
||||
.query(&query)
|
||||
.header(CONTENT_TYPE, ARROW_STREAM_CONTENT_TYPE);
|
||||
|
||||
@@ -1181,7 +1214,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
let body = serde_json::json!({ "new_columns": body });
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/add_columns/", self.name))
|
||||
.post(&format!("/v1/table/{}/add_columns/", self.identifier))
|
||||
.json(&body);
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
let response = self.check_table_response(&request_id, response).await?;
|
||||
@@ -1234,7 +1267,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
let body = serde_json::json!({ "alterations": body });
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/alter_columns/", self.name))
|
||||
.post(&format!("/v1/table/{}/alter_columns/", self.identifier))
|
||||
.json(&body);
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
let response = self.check_table_response(&request_id, response).await?;
|
||||
@@ -1259,7 +1292,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
let body = serde_json::json!({ "columns": columns });
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/drop_columns/", self.name))
|
||||
.post(&format!("/v1/table/{}/drop_columns/", self.identifier))
|
||||
.json(&body);
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
let response = self.check_table_response(&request_id, response).await?;
|
||||
@@ -1283,7 +1316,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
// Make request to list the indices
|
||||
let mut request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/index/list/", self.name));
|
||||
.post(&format!("/v1/table/{}/index/list/", self.identifier));
|
||||
let version = self.current_version().await;
|
||||
let body = serde_json::json!({ "version": version });
|
||||
request = request.json(&body);
|
||||
@@ -1339,7 +1372,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
async fn index_stats(&self, index_name: &str) -> Result<Option<IndexStatistics>> {
|
||||
let mut request = self.client.post(&format!(
|
||||
"/v1/table/{}/index/{}/stats/",
|
||||
self.name, index_name
|
||||
self.identifier, index_name
|
||||
));
|
||||
let version = self.current_version().await;
|
||||
let body = serde_json::json!({ "version": version });
|
||||
@@ -1367,7 +1400,7 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
async fn drop_index(&self, index_name: &str) -> Result<()> {
|
||||
let request = self.client.post(&format!(
|
||||
"/v1/table/{}/index/{}/drop/",
|
||||
self.name, index_name
|
||||
self.identifier, index_name
|
||||
));
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
if response.status() == StatusCode::NOT_FOUND {
|
||||
@@ -1395,7 +1428,9 @@ impl<S: HttpSend> BaseTable for RemoteTable<S> {
|
||||
}
|
||||
|
||||
async fn stats(&self) -> Result<TableStatistics> {
|
||||
let request = self.client.post(&format!("/v1/table/{}/stats/", self.name));
|
||||
let request = self
|
||||
.client
|
||||
.post(&format!("/v1/table/{}/stats/", self.identifier));
|
||||
let (request_id, response) = self.send(request, true).await?;
|
||||
let response = self.check_table_response(&request_id, response).await?;
|
||||
let body = response.text().await.err_to_http(request_id.clone())?;
|
||||
@@ -3070,4 +3105,174 @@ mod tests {
|
||||
});
|
||||
table
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_table_with_namespace_identifier() {
|
||||
// Test that a table created with namespace uses the correct identifier in API calls
|
||||
let table = Table::new_with_handler("ns1$ns2$table1", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
// All API calls should use the full identifier in the path
|
||||
assert_eq!(request.url().path(), "/v1/table/ns1$ns2$table1/describe/");
|
||||
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"version": 1, "schema": { "fields": [] }}"#)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
// The name() method should return just the base name, not the full identifier
|
||||
assert_eq!(table.name(), "ns1$ns2$table1");
|
||||
|
||||
// API operations should work correctly
|
||||
let version = table.version().await.unwrap();
|
||||
assert_eq!(version, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_query_with_namespace() {
|
||||
let table = Table::new_with_handler("analytics$events", |request| {
|
||||
match request.url().path() {
|
||||
"/v1/table/analytics$events/query/" => {
|
||||
assert_eq!(request.method(), "POST");
|
||||
|
||||
// Return empty arrow stream
|
||||
let data = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
let body = write_ipc_file(&data);
|
||||
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.header("Content-Type", ARROW_FILE_CONTENT_TYPE)
|
||||
.body(body)
|
||||
.unwrap()
|
||||
}
|
||||
_ => {
|
||||
panic!("Unexpected path: {}", request.url().path());
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let results = table.query().execute().await.unwrap();
|
||||
let batches = results.try_collect::<Vec<_>>().await.unwrap();
|
||||
assert_eq!(batches.len(), 1);
|
||||
assert_eq!(batches[0].num_rows(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_add_data_with_namespace() {
|
||||
let data = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
|
||||
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let (sender, receiver) = std::sync::mpsc::channel();
|
||||
let table = Table::new_with_handler("prod$metrics", move |mut request| {
|
||||
if request.url().path() == "/v1/table/prod$metrics/insert/" {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
ARROW_STREAM_CONTENT_TYPE
|
||||
);
|
||||
let mut body_out = reqwest::Body::from(Vec::new());
|
||||
std::mem::swap(request.body_mut().as_mut().unwrap(), &mut body_out);
|
||||
sender.send(body_out).unwrap();
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"version": 2}"#)
|
||||
.unwrap()
|
||||
} else {
|
||||
panic!("Unexpected request path: {}", request.url().path());
|
||||
}
|
||||
});
|
||||
|
||||
let result = table
|
||||
.add(RecordBatchIterator::new([Ok(data.clone())], data.schema()))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.version, 2);
|
||||
|
||||
let body = receiver.recv().unwrap();
|
||||
let body = collect_body(body).await;
|
||||
let expected_body = write_ipc_stream(&data);
|
||||
assert_eq!(&body, &expected_body);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_index_with_namespace() {
|
||||
let table = Table::new_with_handler("dev$users", |request| {
|
||||
match request.url().path() {
|
||||
"/v1/table/dev$users/create_index/" => {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
|
||||
// Verify the request body contains the column name
|
||||
if let Some(body) = request.body().unwrap().as_bytes() {
|
||||
let body = std::str::from_utf8(body).unwrap();
|
||||
let value: serde_json::Value = serde_json::from_str(body).unwrap();
|
||||
assert_eq!(value["column"], "embedding");
|
||||
assert_eq!(value["index_type"], "IVF_PQ");
|
||||
}
|
||||
|
||||
http::Response::builder().status(200).body("").unwrap()
|
||||
}
|
||||
"/v1/table/dev$users/describe/" => {
|
||||
// Needed for schema check in Auto index type
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"version": 1, "schema": {"fields": [{"name": "embedding", "type": {"type": "list", "item": {"type": "float32"}}, "nullable": false}]}}"#)
|
||||
.unwrap()
|
||||
}
|
||||
_ => {
|
||||
panic!("Unexpected path: {}", request.url().path());
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
table
|
||||
.create_index(&["embedding"], Index::IvfPq(IvfPqIndexBuilder::default()))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_drop_columns_with_namespace() {
|
||||
let table = Table::new_with_handler("test$schema_ops", |request| {
|
||||
assert_eq!(request.method(), "POST");
|
||||
assert_eq!(
|
||||
request.url().path(),
|
||||
"/v1/table/test$schema_ops/drop_columns/"
|
||||
);
|
||||
assert_eq!(
|
||||
request.headers().get("Content-Type").unwrap(),
|
||||
JSON_CONTENT_TYPE
|
||||
);
|
||||
|
||||
if let Some(body) = request.body().unwrap().as_bytes() {
|
||||
let body = std::str::from_utf8(body).unwrap();
|
||||
let value: serde_json::Value = serde_json::from_str(body).unwrap();
|
||||
let columns = value["columns"].as_array().unwrap();
|
||||
assert_eq!(columns.len(), 2);
|
||||
assert_eq!(columns[0], "old_col1");
|
||||
assert_eq!(columns[1], "old_col2");
|
||||
}
|
||||
|
||||
http::Response::builder()
|
||||
.status(200)
|
||||
.body(r#"{"version": 5}"#)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let result = table.drop_columns(&["old_col1", "old_col2"]).await.unwrap();
|
||||
assert_eq!(result.version, 5);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,9 +28,11 @@ use lance::dataset::{
|
||||
};
|
||||
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
|
||||
use lance::index::vector::utils::infer_vector_dim;
|
||||
use lance::index::vector::VectorIndexParams;
|
||||
use lance::io::WrappingObjectStore;
|
||||
use lance_datafusion::exec::{analyze_plan as lance_analyze_plan, execute_plan};
|
||||
use lance_datafusion::utils::StreamingWriteSource;
|
||||
use lance_index::scalar::{BuiltinIndexType, ScalarIndexParams};
|
||||
use lance_index::vector::hnsw::builder::HnswBuildParams;
|
||||
use lance_index::vector::ivf::IvfBuildParams;
|
||||
use lance_index::vector::pq::PQBuildParams;
|
||||
@@ -50,11 +52,7 @@ use crate::arrow::IntoArrow;
|
||||
use crate::connection::NoData;
|
||||
use crate::embeddings::{EmbeddingDefinition, EmbeddingRegistry, MaybeEmbedded, MemoryRegistry};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::index::scalar::FtsIndexBuilder;
|
||||
use crate::index::vector::{
|
||||
suggested_num_partitions_for_hnsw, IvfFlatIndexBuilder, IvfHnswPqIndexBuilder,
|
||||
IvfHnswSqIndexBuilder, IvfPqIndexBuilder, VectorIndex,
|
||||
};
|
||||
use crate::index::vector::{suggested_num_partitions_for_hnsw, VectorIndex};
|
||||
use crate::index::IndexStatistics;
|
||||
use crate::index::{
|
||||
vector::{suggested_num_partitions, suggested_num_sub_vectors},
|
||||
@@ -511,6 +509,10 @@ pub trait BaseTable: std::fmt::Display + std::fmt::Debug + Send + Sync {
|
||||
fn as_any(&self) -> &dyn std::any::Any;
|
||||
/// Get the name of the table.
|
||||
fn name(&self) -> &str;
|
||||
/// Get the namespace of the table.
|
||||
fn namespace(&self) -> &[String];
|
||||
/// Get the id of the table
|
||||
fn id(&self) -> &str;
|
||||
/// Get the arrow [Schema] of the table.
|
||||
async fn schema(&self) -> Result<SchemaRef>;
|
||||
/// Count the number of rows in this table.
|
||||
@@ -1698,345 +1700,219 @@ impl NativeTable {
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn create_ivf_flat_index(
|
||||
&self,
|
||||
index: IvfFlatIndexBuilder,
|
||||
// Helper to validate index type compatibility with field data type
|
||||
fn validate_index_type(
|
||||
field: &Field,
|
||||
replace: bool,
|
||||
index_name: &str,
|
||||
supported_fn: impl Fn(&DataType) -> bool,
|
||||
) -> Result<()> {
|
||||
if !supported_vector_data_type(field.data_type()) {
|
||||
return Err(Error::InvalidInput {
|
||||
if !supported_fn(field.data_type()) {
|
||||
return Err(Error::Schema {
|
||||
message: format!(
|
||||
"An IVF Flat index cannot be created on the column `{}` which has data type {}",
|
||||
"A {} index cannot be created on the field `{}` which has data type {}",
|
||||
index_name,
|
||||
field.name(),
|
||||
field.data_type()
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
let num_partitions = if let Some(n) = index.num_partitions {
|
||||
n
|
||||
} else {
|
||||
suggested_num_partitions(self.count_rows(None).await?)
|
||||
};
|
||||
let mut dataset = self.dataset.get_mut().await?;
|
||||
let lance_idx_params = lance::index::vector::VectorIndexParams::ivf_flat(
|
||||
num_partitions as usize,
|
||||
index.distance_type.into(),
|
||||
);
|
||||
dataset
|
||||
.create_index(
|
||||
&[field.name()],
|
||||
IndexType::Vector,
|
||||
None,
|
||||
&lance_idx_params,
|
||||
replace,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_ivf_pq_index(
|
||||
// Helper to get num_partitions with default calculation
|
||||
async fn get_num_partitions(
|
||||
&self,
|
||||
index: IvfPqIndexBuilder,
|
||||
field: &Field,
|
||||
replace: bool,
|
||||
) -> Result<()> {
|
||||
if !supported_vector_data_type(field.data_type()) {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"An IVF PQ index cannot be created on the column `{}` which has data type {}",
|
||||
field.name(),
|
||||
field.data_type()
|
||||
),
|
||||
});
|
||||
provided: Option<u32>,
|
||||
for_hnsw: bool,
|
||||
dim: Option<u32>,
|
||||
) -> Result<u32> {
|
||||
if let Some(n) = provided {
|
||||
Ok(n)
|
||||
} else {
|
||||
let row_count = self.count_rows(None).await?;
|
||||
if for_hnsw {
|
||||
Ok(suggested_num_partitions_for_hnsw(
|
||||
row_count,
|
||||
dim.ok_or_else(|| Error::InvalidInput {
|
||||
message: "Vector dimension required for HNSW partitioning".to_string(),
|
||||
})?,
|
||||
))
|
||||
} else {
|
||||
Ok(suggested_num_partitions(row_count))
|
||||
}
|
||||
}
|
||||
|
||||
let num_partitions = if let Some(n) = index.num_partitions {
|
||||
n
|
||||
} else {
|
||||
suggested_num_partitions(self.count_rows(None).await?)
|
||||
};
|
||||
let num_sub_vectors: u32 = if let Some(n) = index.num_sub_vectors {
|
||||
n
|
||||
} else {
|
||||
let dim = infer_vector_dim(field.data_type())?;
|
||||
suggested_num_sub_vectors(dim as u32)
|
||||
};
|
||||
let mut dataset = self.dataset.get_mut().await?;
|
||||
let lance_idx_params = lance::index::vector::VectorIndexParams::ivf_pq(
|
||||
num_partitions as usize,
|
||||
/*num_bits=*/ 8,
|
||||
num_sub_vectors as usize,
|
||||
index.distance_type.into(),
|
||||
index.max_iterations as usize,
|
||||
);
|
||||
dataset
|
||||
.create_index(
|
||||
&[field.name()],
|
||||
IndexType::Vector,
|
||||
None,
|
||||
&lance_idx_params,
|
||||
replace,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_ivf_hnsw_pq_index(
|
||||
&self,
|
||||
index: IvfHnswPqIndexBuilder,
|
||||
field: &Field,
|
||||
replace: bool,
|
||||
) -> Result<()> {
|
||||
if !supported_vector_data_type(field.data_type()) {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"An IVF HNSW PQ index cannot be created on the column `{}` which has data type {}",
|
||||
field.name(),
|
||||
field.data_type()
|
||||
),
|
||||
});
|
||||
// Helper to get num_sub_vectors with default calculation
|
||||
fn get_num_sub_vectors(provided: Option<u32>, dim: u32) -> u32 {
|
||||
provided.unwrap_or_else(|| suggested_num_sub_vectors(dim))
|
||||
}
|
||||
|
||||
// Helper to extract vector dimension from field
|
||||
fn get_vector_dimension(field: &Field) -> Result<u32> {
|
||||
match field.data_type() {
|
||||
arrow_schema::DataType::FixedSizeList(_, n) => Ok(*n as u32),
|
||||
_ => Ok(infer_vector_dim(field.data_type())? as u32),
|
||||
}
|
||||
}
|
||||
|
||||
let num_partitions: u32 = if let Some(n) = index.num_partitions {
|
||||
n
|
||||
} else {
|
||||
match field.data_type() {
|
||||
arrow_schema::DataType::FixedSizeList(_, n) => Ok::<u32, Error>(
|
||||
suggested_num_partitions_for_hnsw(self.count_rows(None).await?, *n as u32),
|
||||
),
|
||||
_ => Err(Error::Schema {
|
||||
message: format!("Column '{}' is not a FixedSizeList", field.name()),
|
||||
}),
|
||||
}?
|
||||
};
|
||||
|
||||
let num_sub_vectors: u32 = if let Some(n) = index.num_sub_vectors {
|
||||
n
|
||||
} else {
|
||||
match field.data_type() {
|
||||
arrow_schema::DataType::FixedSizeList(_, n) => {
|
||||
Ok::<u32, Error>(suggested_num_sub_vectors(*n as u32))
|
||||
// Convert LanceDB Index to Lance IndexParams
|
||||
async fn make_index_params(
|
||||
&self,
|
||||
field: &Field,
|
||||
index_opts: Index,
|
||||
) -> Result<Box<dyn lance::index::IndexParams>> {
|
||||
match index_opts {
|
||||
Index::Auto => {
|
||||
if supported_vector_data_type(field.data_type()) {
|
||||
// Use IvfPq as the default for auto vector indices
|
||||
let dim = Self::get_vector_dimension(field)?;
|
||||
let num_partitions = self.get_num_partitions(None, false, None).await?;
|
||||
let num_sub_vectors = Self::get_num_sub_vectors(None, dim);
|
||||
let lance_idx_params = lance::index::vector::VectorIndexParams::ivf_pq(
|
||||
num_partitions as usize,
|
||||
/*num_bits=*/ 8,
|
||||
num_sub_vectors as usize,
|
||||
lance_linalg::distance::MetricType::L2,
|
||||
/*max_iterations=*/ 50,
|
||||
);
|
||||
Ok(Box::new(lance_idx_params))
|
||||
} else if supported_btree_data_type(field.data_type()) {
|
||||
Ok(Box::new(ScalarIndexParams::for_builtin(
|
||||
BuiltinIndexType::BTree,
|
||||
)))
|
||||
} else {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"there are no indices supported for the field `{}` with the data type {}",
|
||||
field.name(),
|
||||
field.data_type()
|
||||
),
|
||||
});
|
||||
}
|
||||
_ => Err(Error::Schema {
|
||||
message: format!("Column '{}' is not a FixedSizeList", field.name()),
|
||||
}),
|
||||
}?
|
||||
};
|
||||
|
||||
let mut dataset = self.dataset.get_mut().await?;
|
||||
let mut ivf_params = IvfBuildParams::new(num_partitions as usize);
|
||||
ivf_params.sample_rate = index.sample_rate as usize;
|
||||
ivf_params.max_iters = index.max_iterations as usize;
|
||||
let hnsw_params = HnswBuildParams::default()
|
||||
.num_edges(index.m as usize)
|
||||
.ef_construction(index.ef_construction as usize);
|
||||
let pq_params = PQBuildParams {
|
||||
num_sub_vectors: num_sub_vectors as usize,
|
||||
..Default::default()
|
||||
};
|
||||
let lance_idx_params = lance::index::vector::VectorIndexParams::with_ivf_hnsw_pq_params(
|
||||
index.distance_type.into(),
|
||||
ivf_params,
|
||||
hnsw_params,
|
||||
pq_params,
|
||||
);
|
||||
dataset
|
||||
.create_index(
|
||||
&[field.name()],
|
||||
IndexType::Vector,
|
||||
None,
|
||||
&lance_idx_params,
|
||||
replace,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_ivf_hnsw_sq_index(
|
||||
&self,
|
||||
index: IvfHnswSqIndexBuilder,
|
||||
field: &Field,
|
||||
replace: bool,
|
||||
) -> Result<()> {
|
||||
if !supported_vector_data_type(field.data_type()) {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"An IVF HNSW SQ index cannot be created on the column `{}` which has data type {}",
|
||||
field.name(),
|
||||
field.data_type()
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
let num_partitions: u32 = if let Some(n) = index.num_partitions {
|
||||
n
|
||||
} else {
|
||||
match field.data_type() {
|
||||
arrow_schema::DataType::FixedSizeList(_, n) => Ok::<u32, Error>(
|
||||
suggested_num_partitions_for_hnsw(self.count_rows(None).await?, *n as u32),
|
||||
),
|
||||
_ => Err(Error::Schema {
|
||||
message: format!("Column '{}' is not a FixedSizeList", field.name()),
|
||||
}),
|
||||
}?
|
||||
};
|
||||
|
||||
let mut dataset = self.dataset.get_mut().await?;
|
||||
let mut ivf_params = IvfBuildParams::new(num_partitions as usize);
|
||||
ivf_params.sample_rate = index.sample_rate as usize;
|
||||
ivf_params.max_iters = index.max_iterations as usize;
|
||||
let hnsw_params = HnswBuildParams::default()
|
||||
.num_edges(index.m as usize)
|
||||
.ef_construction(index.ef_construction as usize);
|
||||
let sq_params = SQBuildParams {
|
||||
sample_rate: index.sample_rate as usize,
|
||||
..Default::default()
|
||||
};
|
||||
let lance_idx_params = lance::index::vector::VectorIndexParams::with_ivf_hnsw_sq_params(
|
||||
index.distance_type.into(),
|
||||
ivf_params,
|
||||
hnsw_params,
|
||||
sq_params,
|
||||
);
|
||||
dataset
|
||||
.create_index(
|
||||
&[field.name()],
|
||||
IndexType::Vector,
|
||||
None,
|
||||
&lance_idx_params,
|
||||
replace,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_auto_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> {
|
||||
if supported_vector_data_type(field.data_type()) {
|
||||
self.create_ivf_pq_index(IvfPqIndexBuilder::default(), field, opts.replace)
|
||||
.await
|
||||
} else if supported_btree_data_type(field.data_type()) {
|
||||
self.create_btree_index(field, opts).await
|
||||
} else {
|
||||
Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"there are no indices supported for the field `{}` with the data type {}",
|
||||
field.name(),
|
||||
field.data_type()
|
||||
),
|
||||
})
|
||||
}
|
||||
Index::BTree(_) => {
|
||||
Self::validate_index_type(field, "BTree", supported_btree_data_type)?;
|
||||
Ok(Box::new(ScalarIndexParams::for_builtin(
|
||||
BuiltinIndexType::BTree,
|
||||
)))
|
||||
}
|
||||
Index::Bitmap(_) => {
|
||||
Self::validate_index_type(field, "Bitmap", supported_bitmap_data_type)?;
|
||||
Ok(Box::new(ScalarIndexParams::for_builtin(
|
||||
BuiltinIndexType::Bitmap,
|
||||
)))
|
||||
}
|
||||
Index::LabelList(_) => {
|
||||
Self::validate_index_type(field, "LabelList", supported_label_list_data_type)?;
|
||||
Ok(Box::new(ScalarIndexParams::for_builtin(
|
||||
BuiltinIndexType::LabelList,
|
||||
)))
|
||||
}
|
||||
Index::FTS(fts_opts) => {
|
||||
Self::validate_index_type(field, "FTS", supported_fts_data_type)?;
|
||||
Ok(Box::new(fts_opts))
|
||||
}
|
||||
Index::IvfFlat(index) => {
|
||||
Self::validate_index_type(field, "IVF Flat", supported_vector_data_type)?;
|
||||
let num_partitions = self
|
||||
.get_num_partitions(index.num_partitions, false, None)
|
||||
.await?;
|
||||
let lance_idx_params = VectorIndexParams::ivf_flat(
|
||||
num_partitions as usize,
|
||||
index.distance_type.into(),
|
||||
);
|
||||
Ok(Box::new(lance_idx_params))
|
||||
}
|
||||
Index::IvfPq(index) => {
|
||||
Self::validate_index_type(field, "IVF PQ", supported_vector_data_type)?;
|
||||
let dim = Self::get_vector_dimension(field)?;
|
||||
let num_partitions = self
|
||||
.get_num_partitions(index.num_partitions, false, None)
|
||||
.await?;
|
||||
let num_sub_vectors = Self::get_num_sub_vectors(index.num_sub_vectors, dim);
|
||||
let lance_idx_params = VectorIndexParams::ivf_pq(
|
||||
num_partitions as usize,
|
||||
/*num_bits=*/ 8,
|
||||
num_sub_vectors as usize,
|
||||
index.distance_type.into(),
|
||||
index.max_iterations as usize,
|
||||
);
|
||||
Ok(Box::new(lance_idx_params))
|
||||
}
|
||||
Index::IvfHnswPq(index) => {
|
||||
Self::validate_index_type(field, "IVF HNSW PQ", supported_vector_data_type)?;
|
||||
let dim = Self::get_vector_dimension(field)?;
|
||||
let num_partitions = self
|
||||
.get_num_partitions(index.num_partitions, true, Some(dim))
|
||||
.await?;
|
||||
let num_sub_vectors = Self::get_num_sub_vectors(index.num_sub_vectors, dim);
|
||||
let mut ivf_params = IvfBuildParams::new(num_partitions as usize);
|
||||
ivf_params.sample_rate = index.sample_rate as usize;
|
||||
ivf_params.max_iters = index.max_iterations as usize;
|
||||
let hnsw_params = HnswBuildParams::default()
|
||||
.num_edges(index.m as usize)
|
||||
.ef_construction(index.ef_construction as usize);
|
||||
let pq_params = PQBuildParams {
|
||||
num_sub_vectors: num_sub_vectors as usize,
|
||||
..Default::default()
|
||||
};
|
||||
let lance_idx_params = VectorIndexParams::with_ivf_hnsw_pq_params(
|
||||
index.distance_type.into(),
|
||||
ivf_params,
|
||||
hnsw_params,
|
||||
pq_params,
|
||||
);
|
||||
Ok(Box::new(lance_idx_params))
|
||||
}
|
||||
Index::IvfHnswSq(index) => {
|
||||
Self::validate_index_type(field, "IVF HNSW SQ", supported_vector_data_type)?;
|
||||
let dim = Self::get_vector_dimension(field)?;
|
||||
let num_partitions = self
|
||||
.get_num_partitions(index.num_partitions, true, Some(dim))
|
||||
.await?;
|
||||
let mut ivf_params = IvfBuildParams::new(num_partitions as usize);
|
||||
ivf_params.sample_rate = index.sample_rate as usize;
|
||||
ivf_params.max_iters = index.max_iterations as usize;
|
||||
let hnsw_params = HnswBuildParams::default()
|
||||
.num_edges(index.m as usize)
|
||||
.ef_construction(index.ef_construction as usize);
|
||||
let sq_params = SQBuildParams {
|
||||
sample_rate: index.sample_rate as usize,
|
||||
..Default::default()
|
||||
};
|
||||
let lance_idx_params = VectorIndexParams::with_ivf_hnsw_sq_params(
|
||||
index.distance_type.into(),
|
||||
ivf_params,
|
||||
hnsw_params,
|
||||
sq_params,
|
||||
);
|
||||
Ok(Box::new(lance_idx_params))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_btree_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> {
|
||||
if !supported_btree_data_type(field.data_type()) {
|
||||
return Err(Error::Schema {
|
||||
message: format!(
|
||||
"A BTree index cannot be created on the field `{}` which has data type {}",
|
||||
field.name(),
|
||||
field.data_type()
|
||||
),
|
||||
});
|
||||
// Helper method to get the correct IndexType based on the Index variant and field data type
|
||||
fn get_index_type_for_field(&self, field: &Field, index: &Index) -> IndexType {
|
||||
match index {
|
||||
Index::Auto => {
|
||||
if supported_vector_data_type(field.data_type()) {
|
||||
IndexType::Vector
|
||||
} else if supported_btree_data_type(field.data_type()) {
|
||||
IndexType::BTree
|
||||
} else {
|
||||
// This should not happen since make_index_params would have failed
|
||||
IndexType::BTree
|
||||
}
|
||||
}
|
||||
Index::BTree(_) => IndexType::BTree,
|
||||
Index::Bitmap(_) => IndexType::Bitmap,
|
||||
Index::LabelList(_) => IndexType::LabelList,
|
||||
Index::FTS(_) => IndexType::Inverted,
|
||||
Index::IvfFlat(_) | Index::IvfPq(_) | Index::IvfHnswPq(_) | Index::IvfHnswSq(_) => {
|
||||
IndexType::Vector
|
||||
}
|
||||
}
|
||||
|
||||
let mut dataset = self.dataset.get_mut().await?;
|
||||
let lance_idx_params = lance_index::scalar::ScalarIndexParams {
|
||||
force_index_type: Some(lance_index::scalar::ScalarIndexType::BTree),
|
||||
};
|
||||
dataset
|
||||
.create_index(
|
||||
&[field.name()],
|
||||
IndexType::BTree,
|
||||
None,
|
||||
&lance_idx_params,
|
||||
opts.replace,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_bitmap_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> {
|
||||
if !supported_bitmap_data_type(field.data_type()) {
|
||||
return Err(Error::Schema {
|
||||
message: format!(
|
||||
"A Bitmap index cannot be created on the field `{}` which has data type {}",
|
||||
field.name(),
|
||||
field.data_type()
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
let mut dataset = self.dataset.get_mut().await?;
|
||||
let lance_idx_params = lance_index::scalar::ScalarIndexParams {
|
||||
force_index_type: Some(lance_index::scalar::ScalarIndexType::Bitmap),
|
||||
};
|
||||
dataset
|
||||
.create_index(
|
||||
&[field.name()],
|
||||
IndexType::Bitmap,
|
||||
None,
|
||||
&lance_idx_params,
|
||||
opts.replace,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_label_list_index(&self, field: &Field, opts: IndexBuilder) -> Result<()> {
|
||||
if !supported_label_list_data_type(field.data_type()) {
|
||||
return Err(Error::Schema {
|
||||
message: format!(
|
||||
"A LabelList index cannot be created on the field `{}` which has data type {}",
|
||||
field.name(),
|
||||
field.data_type()
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
let mut dataset = self.dataset.get_mut().await?;
|
||||
let lance_idx_params = lance_index::scalar::ScalarIndexParams {
|
||||
force_index_type: Some(lance_index::scalar::ScalarIndexType::LabelList),
|
||||
};
|
||||
dataset
|
||||
.create_index(
|
||||
&[field.name()],
|
||||
IndexType::LabelList,
|
||||
None,
|
||||
&lance_idx_params,
|
||||
opts.replace,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_fts_index(
|
||||
&self,
|
||||
field: &Field,
|
||||
fts_opts: FtsIndexBuilder,
|
||||
replace: bool,
|
||||
) -> Result<()> {
|
||||
if !supported_fts_data_type(field.data_type()) {
|
||||
return Err(Error::Schema {
|
||||
message: format!(
|
||||
"A FTS index cannot be created on the field `{}` which has data type {}",
|
||||
field.name(),
|
||||
field.data_type()
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
let mut dataset = self.dataset.get_mut().await?;
|
||||
dataset
|
||||
.create_index(
|
||||
&[field.name()],
|
||||
IndexType::Inverted,
|
||||
None,
|
||||
&fts_opts,
|
||||
replace,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn generic_query(
|
||||
@@ -2143,6 +2019,16 @@ impl BaseTable for NativeTable {
|
||||
self.name.as_str()
|
||||
}
|
||||
|
||||
fn namespace(&self) -> &[String] {
|
||||
// Native tables don't support namespaces yet, return empty slice for root namespace
|
||||
&[]
|
||||
}
|
||||
|
||||
fn id(&self) -> &str {
|
||||
// For native tables, id is same as name since no namespace support
|
||||
self.name.as_str()
|
||||
}
|
||||
|
||||
async fn version(&self) -> Result<u64> {
|
||||
Ok(self.dataset.get().await?.version().version)
|
||||
}
|
||||
@@ -2251,26 +2137,20 @@ impl BaseTable for NativeTable {
|
||||
|
||||
let field = schema.field_with_name(&opts.columns[0])?;
|
||||
|
||||
match opts.index {
|
||||
Index::Auto => self.create_auto_index(field, opts).await,
|
||||
Index::BTree(_) => self.create_btree_index(field, opts).await,
|
||||
Index::Bitmap(_) => self.create_bitmap_index(field, opts).await,
|
||||
Index::LabelList(_) => self.create_label_list_index(field, opts).await,
|
||||
Index::FTS(fts_opts) => self.create_fts_index(field, fts_opts, opts.replace).await,
|
||||
Index::IvfFlat(ivf_flat) => {
|
||||
self.create_ivf_flat_index(ivf_flat, field, opts.replace)
|
||||
.await
|
||||
}
|
||||
Index::IvfPq(ivf_pq) => self.create_ivf_pq_index(ivf_pq, field, opts.replace).await,
|
||||
Index::IvfHnswPq(ivf_hnsw_pq) => {
|
||||
self.create_ivf_hnsw_pq_index(ivf_hnsw_pq, field, opts.replace)
|
||||
.await
|
||||
}
|
||||
Index::IvfHnswSq(ivf_hnsw_sq) => {
|
||||
self.create_ivf_hnsw_sq_index(ivf_hnsw_sq, field, opts.replace)
|
||||
.await
|
||||
}
|
||||
let lance_idx_params = self.make_index_params(field, opts.index.clone()).await?;
|
||||
let index_type = self.get_index_type_for_field(field, &opts.index);
|
||||
let columns = [field.name().as_str()];
|
||||
let mut dataset = self.dataset.get_mut().await?;
|
||||
let mut builder = dataset
|
||||
.create_index_builder(&columns, index_type, lance_idx_params.as_ref())
|
||||
.train(opts.train)
|
||||
.replace(opts.replace);
|
||||
|
||||
if let Some(name) = opts.name {
|
||||
builder = builder.name(name);
|
||||
}
|
||||
builder.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn drop_index(&self, index_name: &str) -> Result<()> {
|
||||
@@ -2890,6 +2770,7 @@ mod tests {
|
||||
use crate::connect;
|
||||
use crate::connection::ConnectBuilder;
|
||||
use crate::index::scalar::{BTreeIndexBuilder, BitmapIndexBuilder};
|
||||
use crate::index::vector::{IvfHnswPqIndexBuilder, IvfHnswSqIndexBuilder};
|
||||
use crate::query::{ExecutableQuery, QueryBase};
|
||||
|
||||
#[tokio::test]
|
||||
@@ -3391,6 +3272,7 @@ mod tests {
|
||||
fn wrap(
|
||||
&self,
|
||||
original: Arc<dyn object_store::ObjectStore>,
|
||||
_storage_options: Option<&std::collections::HashMap<String, String>>,
|
||||
) -> Arc<dyn object_store::ObjectStore> {
|
||||
self.called.store(true, Ordering::Relaxed);
|
||||
original
|
||||
|
||||
@@ -121,6 +121,10 @@ impl ExecutionPlan for MetadataEraserExec {
|
||||
as SendableRecordBatchStream,
|
||||
)
|
||||
}
|
||||
|
||||
fn partition_statistics(&self, partition: Option<usize>) -> DataFusionResult<Statistics> {
|
||||
self.input.partition_statistics(partition)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -227,6 +231,7 @@ pub mod tests {
|
||||
prelude::{SessionConfig, SessionContext},
|
||||
};
|
||||
use datafusion_catalog::TableProvider;
|
||||
use datafusion_common::stats::Precision;
|
||||
use datafusion_execution::SendableRecordBatchStream;
|
||||
use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder};
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
@@ -495,6 +500,7 @@ pub mod tests {
|
||||
plan,
|
||||
"MetadataEraserExec
|
||||
ProjectionExec:...
|
||||
CooperativeExec...
|
||||
LanceRead:...",
|
||||
)
|
||||
.await;
|
||||
@@ -509,4 +515,24 @@ pub mod tests {
|
||||
|
||||
TestFixture::check_plan(plan, "").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_metadata_eraser_propagates_statistics() {
|
||||
let fixture = TestFixture::new().await;
|
||||
|
||||
let plan =
|
||||
LogicalPlanBuilder::scan("foo", provider_as_source(fixture.adapter.clone()), None)
|
||||
.unwrap()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let ctx = SessionContext::new();
|
||||
let physical_plan = ctx.state().create_physical_plan(&plan).await.unwrap();
|
||||
|
||||
assert_eq!(physical_plan.name(), "MetadataEraserExec");
|
||||
|
||||
let partition_stats = physical_plan.partition_statistics(None).unwrap();
|
||||
|
||||
assert!(matches!(partition_stats.num_rows, Precision::Exact(10)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -130,7 +130,7 @@ async fn test_minio_lifecycle() -> Result<()> {
|
||||
let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema());
|
||||
table.add(data).execute().await?;
|
||||
|
||||
db.drop_table("test_table").await?;
|
||||
db.drop_table("test_table", &[]).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user