feat: add native OAuth/OIDC authentication support

Add OAuthConfig and OAuthHeaderProvider to the Rust core with support
for five OAuth flows: ClientCredentials, AuthorizationCodePKCE,
DeviceCode, AzureManagedIdentity, and WorkloadIdentity. Token
acquisition and auto-refresh happen entirely in Rust.

Python and TypeScript expose OAuthConfig as a plain config object that
maps to the Rust header provider via FFI — no dynamic callbacks cross
the language boundary.

ConnectBuilder gains an oauth_config() method that replaces the API key
requirement when OAuth is configured.
This commit is contained in:
Jack Ye
2026-05-12 12:53:19 -07:00
parent 650f173236
commit b4f2300f80
58 changed files with 1859 additions and 449 deletions

View File

@@ -44,7 +44,7 @@ jobs:
lfs: true
- uses: actions/setup-node@v4
with:
node-version: 20
node-version: 22
cache: 'npm'
cache-dependency-path: nodejs/package-lock.json
- uses: actions-rust-lang/setup-rust-toolchain@v1
@@ -71,7 +71,7 @@ jobs:
timeout-minutes: 30
strategy:
matrix:
node-version: [ "18", "20" ]
node-version: [ "20", "22" ]
runs-on: "ubuntu-22.04"
defaults:
run:
@@ -83,11 +83,9 @@ jobs:
fetch-depth: 0
lfs: true
- uses: actions/setup-node@v4
name: Setup Node.js 20 for build
name: Setup Node.js 22 for build
with:
# @napi-rs/cli v3 requires Node >= 20.12 (via @inquirer/prompts@8).
# Build always on Node 20; tests run on the matrix version below.
node-version: 20
node-version: 22
cache: 'npm'
cache-dependency-path: nodejs/package-lock.json
- uses: Swatinem/rust-cache@v2
@@ -150,7 +148,7 @@ jobs:
lfs: true
- uses: actions/setup-node@v4
with:
node-version: 20
node-version: 22
cache: 'npm'
cache-dependency-path: nodejs/package-lock.json
- uses: Swatinem/rust-cache@v2

80
Cargo.lock generated
View File

@@ -108,7 +108,7 @@ version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc"
dependencies = [
"windows-sys 0.60.2",
"windows-sys 0.61.2",
]
[[package]]
@@ -119,7 +119,7 @@ checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d"
dependencies = [
"anstyle",
"once_cell_polyfill",
"windows-sys 0.60.2",
"windows-sys 0.61.2",
]
[[package]]
@@ -1648,7 +1648,7 @@ version = "3.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "faf9468729b8cbcea668e36183cb69d317348c2e08e994829fb56ebfdfbaac34"
dependencies = [
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -2826,7 +2826,7 @@ dependencies = [
"libc",
"option-ext",
"redox_users",
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -3087,7 +3087,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [
"libc",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -3996,7 +3996,7 @@ dependencies = [
"libc",
"percent-encoding",
"pin-project-lite",
"socket2 0.5.10",
"socket2 0.6.3",
"system-configuration",
"tokio",
"tower-service",
@@ -4233,6 +4233,25 @@ dependencies = [
"serde",
]
[[package]]
name = "is-docker"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "928bae27f42bc99b60d9ac7334e3a21d10ad8f1835a4e12ec3ec0464765ed1b3"
dependencies = [
"once_cell",
]
[[package]]
name = "is-wsl"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "173609498df190136aa7dea1a91db051746d339e18476eed5ca40521f02d7aa5"
dependencies = [
"is-docker",
"once_cell",
]
[[package]]
name = "is_terminal_polyfill"
version = "1.70.2"
@@ -4323,7 +4342,7 @@ dependencies = [
"portable-atomic-util",
"serde_core",
"wasm-bindgen",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -5025,6 +5044,7 @@ dependencies = [
"aws-sdk-kms",
"aws-sdk-s3",
"aws-smithy-runtime",
"base64 0.22.1",
"bytes",
"candle-core",
"candle-nn",
@@ -5063,6 +5083,7 @@ dependencies = [
"moka",
"num-traits",
"object_store",
"open",
"pin-project",
"polars",
"polars-arrow",
@@ -5075,12 +5096,14 @@ dependencies = [
"serde",
"serde_json",
"serde_with",
"sha2",
"snafu 0.8.9",
"tempfile",
"test-log",
"tokenizers",
"tokio",
"url",
"urlencoding",
"uuid",
"walkdir",
]
@@ -5804,7 +5827,7 @@ version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -6029,6 +6052,17 @@ dependencies = [
"pkg-config",
]
[[package]]
name = "open"
version = "5.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2fbaa89d2ddc8473c78a3adf69eea8cffa28c483b8e02a971ef31527cd0fc92c"
dependencies = [
"is-wsl",
"libc",
"pathdiff",
]
[[package]]
name = "opendal"
version = "0.56.0"
@@ -6365,6 +6399,12 @@ dependencies = [
"stfu8",
]
[[package]]
name = "pathdiff"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3"
[[package]]
name = "pbkdf2"
version = "0.12.2"
@@ -7046,8 +7086,8 @@ version = "0.14.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7"
dependencies = [
"heck 0.4.1",
"itertools 0.11.0",
"heck 0.5.0",
"itertools 0.14.0",
"log",
"multimap",
"petgraph",
@@ -7066,7 +7106,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b"
dependencies = [
"anyhow",
"itertools 0.11.0",
"itertools 0.14.0",
"proc-macro2",
"quote",
"syn 2.0.117",
@@ -7231,7 +7271,7 @@ dependencies = [
"quinn-udp",
"rustc-hash",
"rustls 0.23.37",
"socket2 0.5.10",
"socket2 0.6.3",
"thiserror 2.0.18",
"tokio",
"tracing",
@@ -7269,7 +7309,7 @@ dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2 0.5.10",
"socket2 0.6.3",
"tracing",
"windows-sys 0.60.2",
]
@@ -7970,7 +8010,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -8041,7 +8081,7 @@ dependencies = [
"security-framework",
"security-framework-sys",
"webpki-root-certs",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -8549,7 +8589,7 @@ version = "0.8.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1c97747dbf44bb1ca44a561ece23508e99cb592e862f22222dcf42f51d1e451"
dependencies = [
"heck 0.4.1",
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.117",
@@ -8561,7 +8601,7 @@ version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54254b8531cafa275c5e096f62d48c81435d1015405a91198ddb11e967301d40"
dependencies = [
"heck 0.4.1",
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.117",
@@ -8584,7 +8624,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e"
dependencies = [
"libc",
"windows-sys 0.60.2",
"windows-sys 0.61.2",
]
[[package]]
@@ -8964,7 +9004,7 @@ dependencies = [
"getrandom 0.4.2",
"once_cell",
"rustix",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
@@ -9893,7 +9933,7 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
dependencies = [
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]

View File

@@ -25,8 +25,7 @@ new BooleanQuery(queries): BooleanQuery
Creates an instance of BooleanQuery.
#### Parameters
* **queries**: [[`Occur`](../enumerations/Occur.md), [`FullTextQuery`](../interfaces/FullTextQuery.md)][]
* **queries**: [[`Occur`](../enumerations/Occur.md), [`FullTextQuery`](../interfaces/FullTextQuery.md)][]
An array of (Occur, FullTextQuery objects) to combine.
Occur specifies whether the query must match, or should match.

View File

@@ -31,18 +31,14 @@ but penalizes those that match the negative query.
the penalty is controlled by the `negativeBoost` parameter.
#### Parameters
* **positive**: [`FullTextQuery`](../interfaces/FullTextQuery.md)
* **positive**: [`FullTextQuery`](../interfaces/FullTextQuery.md)
The positive query that boosts the relevance score.
* **negative**: [`FullTextQuery`](../interfaces/FullTextQuery.md)
* **negative**: [`FullTextQuery`](../interfaces/FullTextQuery.md)
The negative query that reduces the relevance score.
* **options?**
* **options?**
Optional parameters for the boost query.
- `negativeBoost`: The boost factor for the negative query (default is 0.0).
* **options.negativeBoost?**: `number`
* **options.negativeBoost?**: `number`
#### Returns

View File

@@ -42,26 +42,19 @@ both the source and cloned tables to evolve independently while initially
sharing the same data, deletion, and index files.
#### Parameters
* **targetTableName**: `string`
* **targetTableName**: `string`
The name of the target table to create.
* **sourceUri**: `string`
* **sourceUri**: `string`
The URI of the source table to clone from.
* **options?**
* **options?**
Clone options.
* **options.isShallow?**: `boolean`
* **options.isShallow?**: `boolean`
Whether to perform a shallow clone (defaults to true).
* **options.sourceTag?**: `string`
* **options.sourceTag?**: `string`
The tag of the source table to clone.
* **options.sourceVersion?**: `number`
* **options.sourceVersion?**: `number`
The version of the source table to clone.
* **options.targetNamespacePath?**: `string`[]
* **options.targetNamespacePath?**: `string`[]
The namespace path for the target table (defaults to root namespace).
#### Returns
@@ -102,14 +95,11 @@ abstract createEmptyTable(
Creates a new empty Table
##### Parameters
* **name**: `string`
* **name**: `string`
The name of the table.
* **schema**: [`SchemaLike`](../type-aliases/SchemaLike.md)
* **schema**: [`SchemaLike`](../type-aliases/SchemaLike.md)
The schema of the table
* **options?**: `Partial`<[`CreateTableOptions`](../interfaces/CreateTableOptions.md)>
* **options?**: `Partial`<[`CreateTableOptions`](../interfaces/CreateTableOptions.md)>
Additional options (backwards compatibility)
##### Returns
@@ -129,17 +119,13 @@ abstract createEmptyTable(
Creates a new empty Table
##### Parameters
* **name**: `string`
* **name**: `string`
The name of the table.
* **schema**: [`SchemaLike`](../type-aliases/SchemaLike.md)
* **schema**: [`SchemaLike`](../type-aliases/SchemaLike.md)
The schema of the table
* **namespacePath?**: `string`[]
* **namespacePath?**: `string`[]
The namespace path to create the table in (defaults to root namespace)
* **options?**: `Partial`<[`CreateTableOptions`](../interfaces/CreateTableOptions.md)>
* **options?**: `Partial`<[`CreateTableOptions`](../interfaces/CreateTableOptions.md)>
Additional options
##### Returns
@@ -159,11 +145,9 @@ abstract createTable(options, namespacePath?): Promise<Table>
Creates a new Table and initialize it with new data.
##### Parameters
* **options**: `object` & `Partial`&lt;[`CreateTableOptions`](../interfaces/CreateTableOptions.md)&gt;
* **options**: `object` & `Partial`&lt;[`CreateTableOptions`](../interfaces/CreateTableOptions.md)&gt;
The options object.
* **namespacePath?**: `string`[]
* **namespacePath?**: `string`[]
The namespace path to create the table in (defaults to root namespace)
##### Returns
@@ -182,15 +166,12 @@ abstract createTable(
Creates a new Table and initialize it with new data.
##### Parameters
* **name**: `string`
* **name**: `string`
The name of the table.
* **data**: [`TableLike`](../type-aliases/TableLike.md) \| `Record`&lt;`string`, `unknown`&gt;[]
* **data**: [`TableLike`](../type-aliases/TableLike.md) \| `Record`&lt;`string`, `unknown`&gt;[]
Non-empty Array of Records
to be inserted into the table
* **options?**: `Partial`&lt;[`CreateTableOptions`](../interfaces/CreateTableOptions.md)&gt;
* **options?**: `Partial`&lt;[`CreateTableOptions`](../interfaces/CreateTableOptions.md)&gt;
Additional options (backwards compatibility)
##### Returns
@@ -210,18 +191,14 @@ abstract createTable(
Creates a new Table and initialize it with new data.
##### Parameters
* **name**: `string`
* **name**: `string`
The name of the table.
* **data**: [`TableLike`](../type-aliases/TableLike.md) \| `Record`&lt;`string`, `unknown`&gt;[]
* **data**: [`TableLike`](../type-aliases/TableLike.md) \| `Record`&lt;`string`, `unknown`&gt;[]
Non-empty Array of Records
to be inserted into the table
* **namespacePath?**: `string`[]
* **namespacePath?**: `string`[]
The namespace path to create the table in (defaults to root namespace)
* **options?**: `Partial`&lt;[`CreateTableOptions`](../interfaces/CreateTableOptions.md)&gt;
* **options?**: `Partial`&lt;[`CreateTableOptions`](../interfaces/CreateTableOptions.md)&gt;
Additional options
##### Returns
@@ -253,8 +230,7 @@ abstract dropAllTables(namespacePath?): Promise<void>
Drop all tables in the database.
#### Parameters
* **namespacePath?**: `string`[]
* **namespacePath?**: `string`[]
The namespace path to drop tables from (defaults to root namespace).
#### Returns
@@ -272,11 +248,9 @@ abstract dropTable(name, namespacePath?): Promise<void>
Drop an existing table.
#### Parameters
* **name**: `string`
* **name**: `string`
The name of the table to drop.
* **namespacePath?**: `string`[]
* **namespacePath?**: `string`[]
The namespace path of the table (defaults to root namespace).
#### Returns
@@ -311,14 +285,11 @@ abstract openTable(
Open a table in the database.
#### Parameters
* **name**: `string`
* **name**: `string`
The name of the table
* **namespacePath?**: `string`[]
* **namespacePath?**: `string`[]
The namespace path of the table (defaults to root namespace)
* **options?**: `Partial`&lt;[`OpenTableOptions`](../interfaces/OpenTableOptions.md)&gt;
* **options?**: `Partial`&lt;[`OpenTableOptions`](../interfaces/OpenTableOptions.md)&gt;
Additional options
#### Returns
@@ -340,8 +311,7 @@ List all the table names in this database.
Tables will be returned in lexicographical order.
##### Parameters
* **options?**: `Partial`&lt;[`TableNamesOptions`](../interfaces/TableNamesOptions.md)&gt;
* **options?**: `Partial`&lt;[`TableNamesOptions`](../interfaces/TableNamesOptions.md)&gt;
options to control the
paging / start point (backwards compatibility)
@@ -360,11 +330,9 @@ List all the table names in this database.
Tables will be returned in lexicographical order.
##### Parameters
* **namespacePath?**: `string`[]
* **namespacePath?**: `string`[]
The namespace path to list tables from (defaults to root namespace)
* **options?**: `Partial`&lt;[`TableNamesOptions`](../interfaces/TableNamesOptions.md)&gt;
* **options?**: `Partial`&lt;[`TableNamesOptions`](../interfaces/TableNamesOptions.md)&gt;
options to control the
paging / start point

View File

@@ -73,8 +73,7 @@ The results of a full text search are ordered by relevance measured by BM25.
You can combine filters with full text search.
#### Parameters
* **options?**: `Partial`&lt;[`FtsOptions`](../interfaces/FtsOptions.md)&gt;
* **options?**: `Partial`&lt;[`FtsOptions`](../interfaces/FtsOptions.md)&gt;
#### Returns
@@ -95,8 +94,7 @@ It is a variant of the HNSW algorithm that uses product quantization to compress
the vectors.
#### Parameters
* **options?**: `Partial`&lt;[`HnswPqOptions`](../interfaces/HnswPqOptions.md)&gt;
* **options?**: `Partial`&lt;[`HnswPqOptions`](../interfaces/HnswPqOptions.md)&gt;
#### Returns
@@ -117,8 +115,7 @@ It is a variant of the HNSW algorithm that uses scalar quantization to compress
the vectors.
#### Parameters
* **options?**: `Partial`&lt;[`HnswSqOptions`](../interfaces/HnswSqOptions.md)&gt;
* **options?**: `Partial`&lt;[`HnswSqOptions`](../interfaces/HnswSqOptions.md)&gt;
#### Returns
@@ -148,8 +145,7 @@ Note that training an IVF FLAT index on a large dataset is a slow operation and
currently is also a memory intensive operation.
#### Parameters
* **options?**: `Partial`&lt;[`IvfFlatOptions`](../interfaces/IvfFlatOptions.md)&gt;
* **options?**: `Partial`&lt;[`IvfFlatOptions`](../interfaces/IvfFlatOptions.md)&gt;
#### Returns
@@ -185,8 +181,7 @@ Note that training an IVF PQ index on a large dataset is a slow operation and
currently is also a memory intensive operation.
#### Parameters
* **options?**: `Partial`&lt;[`IvfPqOptions`](../interfaces/IvfPqOptions.md)&gt;
* **options?**: `Partial`&lt;[`IvfPqOptions`](../interfaces/IvfPqOptions.md)&gt;
#### Returns
@@ -216,8 +211,7 @@ Note that training an IVF RQ index on a large dataset is a slow operation and
currently is also a memory intensive operation.
#### Parameters
* **options?**: `Partial`&lt;[`IvfRqOptions`](../interfaces/IvfRqOptions.md)&gt;
* **options?**: `Partial`&lt;[`IvfRqOptions`](../interfaces/IvfRqOptions.md)&gt;
#### Returns

View File

@@ -17,8 +17,7 @@ new MakeArrowTableOptions(values?): MakeArrowTableOptions
```
#### Parameters
* **values?**: `Partial`&lt;[`MakeArrowTableOptions`](MakeArrowTableOptions.md)&gt;
* **values?**: `Partial`&lt;[`MakeArrowTableOptions`](MakeArrowTableOptions.md)&gt;
#### Returns

View File

@@ -28,30 +28,22 @@ new MatchQuery(
Creates an instance of MatchQuery.
#### Parameters
* **query**: `string`
* **query**: `string`
The text query to search for.
* **column**: `string`
* **column**: `string`
The name of the column to search within.
* **options?**
* **options?**
Optional parameters for the match query.
- `boost`: The boost factor for the query (default is 1.0).
- `fuzziness`: The fuzziness level for the query (default is 0).
- `maxExpansions`: The maximum number of terms to consider for fuzzy matching (default is 50).
- `operator`: The logical operator to use for combining terms in the query (default is "OR").
- `prefixLength`: The number of beginning characters being unchanged for fuzzy matching.
* **options.boost?**: `number`
* **options.fuzziness?**: `number`
* **options.maxExpansions?**: `number`
* **options.operator?**: [`Operator`](../enumerations/Operator.md)
* **options.prefixLength?**: `number`
* **options.boost?**: `number`
* **options.fuzziness?**: `number`
* **options.maxExpansions?**: `number`
* **options.operator?**: [`Operator`](../enumerations/Operator.md)
* **options.prefixLength?**: `number`
#### Returns

View File

@@ -19,10 +19,8 @@ new MergeInsertBuilder(native, schema): MergeInsertBuilder
Construct a MergeInsertBuilder. __Internal use only.__
#### Parameters
* **native**: `NativeMergeInsertBuilder`
* **schema**: `Schema`&lt;`any`&gt; \| `Promise`&lt;`Schema`&lt;`any`&gt;&gt;
* **native**: `NativeMergeInsertBuilder`
* **schema**: `Schema`&lt;`any`&gt; \| `Promise`&lt;`Schema`&lt;`any`&gt;&gt;
#### Returns
@@ -39,10 +37,8 @@ execute(data, execOptions?): Promise<MergeResult>
Executes the merge insert operation
#### Parameters
* **data**: [`Data`](../type-aliases/Data.md)
* **execOptions?**: `Partial`&lt;[`WriteExecutionOptions`](../interfaces/WriteExecutionOptions.md)&gt;
* **data**: [`Data`](../type-aliases/Data.md)
* **execOptions?**: `Partial`&lt;[`WriteExecutionOptions`](../interfaces/WriteExecutionOptions.md)&gt;
#### Returns
@@ -66,8 +62,7 @@ table scan even if an index exists. This can be useful for benchmarking or when
the query optimizer chooses a suboptimal path.
#### Parameters
* **useIndex**: `boolean`
* **useIndex**: `boolean`
Whether to use indices for the merge operation. Defaults to `true`.
#### Returns
@@ -104,10 +99,8 @@ table (new data).
For example, "target.last_update < source.last_update"
#### Parameters
* **options?**
* **options.where?**: `string`
* **options?**
* **options.where?**: `string`
#### Returns
@@ -126,10 +119,8 @@ deleted. An optional condition can be provided to limit what
data is deleted.
#### Parameters
* **options?**
* **options.where?**: `string`
* **options?**
* **options.where?**: `string`
An optional condition to limit what data is deleted
#### Returns

View File

@@ -28,21 +28,16 @@ new MultiMatchQuery(
Creates an instance of MultiMatchQuery.
#### Parameters
* **query**: `string`
* **query**: `string`
The text query to search for across multiple columns.
* **columns**: `string`[]
* **columns**: `string`[]
An array of column names to search within.
* **options?**
* **options?**
Optional parameters for the multi-match query.
- `boosts`: An array of boost factors for each column (default is 1.0 for all).
- `operator`: The logical operator to use for combining terms in the query (default is "OR").
* **options.boosts?**: `number`[]
* **options.operator?**: [`Operator`](../enumerations/Operator.md)
* **options.boosts?**: `number`[]
* **options.operator?**: [`Operator`](../enumerations/Operator.md)
#### Returns

View File

@@ -21,8 +21,7 @@ new NativeJsHeaderProvider(getHeadersCallback): NativeJsHeaderProvider
Create a new JsHeaderProvider from a JavaScript callback
#### Parameters
* **getHeadersCallback**
* **getHeadersCallback**
#### Returns

View File

@@ -51,11 +51,9 @@ new OAuthHeaderProvider(tokenFetcher, refreshBufferSeconds): OAuthHeaderProvider
Initialize the OAuth provider.
#### Parameters
* **tokenFetcher**
* **tokenFetcher**
Function to fetch new tokens. Should return object with 'accessToken' and optionally 'expiresIn'.
* **refreshBufferSeconds**: `number` = `300`
* **refreshBufferSeconds**: `number` = `300`
Seconds before expiry to refresh token. Default 300 (5 minutes).
#### Returns

View File

@@ -46,8 +46,7 @@ filter(filter): PermutationBuilder
Configure filtering for the permutation.
#### Parameters
* **filter**: `string`
* **filter**: `string`
SQL filter expression
#### Returns
@@ -73,11 +72,9 @@ persist(connection, tableName): PermutationBuilder
Configure the permutation to be persisted.
#### Parameters
* **connection**: [`Connection`](Connection.md)
* **connection**: [`Connection`](Connection.md)
The connection to persist the permutation to
* **tableName**: `string`
* **tableName**: `string`
The name of the table to create
#### Returns
@@ -103,8 +100,7 @@ shuffle(options): PermutationBuilder
Configure shuffling for the permutation.
#### Parameters
* **options**: [`ShuffleOptions`](../interfaces/ShuffleOptions.md)
* **options**: [`ShuffleOptions`](../interfaces/ShuffleOptions.md)
Configuration for shuffling
#### Returns
@@ -134,8 +130,7 @@ splitCalculated(options): PermutationBuilder
Configure calculated splits for the permutation.
#### Parameters
* **options**: [`SplitCalculatedOptions`](../interfaces/SplitCalculatedOptions.md)
* **options**: [`SplitCalculatedOptions`](../interfaces/SplitCalculatedOptions.md)
Configuration for calculated splitting
#### Returns
@@ -161,8 +156,7 @@ splitHash(options): PermutationBuilder
Configure hash-based splits for the permutation.
#### Parameters
* **options**: [`SplitHashOptions`](../interfaces/SplitHashOptions.md)
* **options**: [`SplitHashOptions`](../interfaces/SplitHashOptions.md)
Configuration for hash-based splitting
#### Returns
@@ -192,8 +186,7 @@ splitRandom(options): PermutationBuilder
Configure random splits for the permutation.
#### Parameters
* **options**: [`SplitRandomOptions`](../interfaces/SplitRandomOptions.md)
* **options**: [`SplitRandomOptions`](../interfaces/SplitRandomOptions.md)
Configuration for random splitting
#### Returns
@@ -226,8 +219,7 @@ splitSequential(options): PermutationBuilder
Configure sequential splits for the permutation.
#### Parameters
* **options**: [`SplitSequentialOptions`](../interfaces/SplitSequentialOptions.md)
* **options**: [`SplitSequentialOptions`](../interfaces/SplitSequentialOptions.md)
Configuration for sequential splitting
#### Returns

View File

@@ -28,18 +28,14 @@ new PhraseQuery(
Creates an instance of `PhraseQuery`.
#### Parameters
* **query**: `string`
* **query**: `string`
The phrase to search for in the specified column.
* **column**: `string`
* **column**: `string`
The name of the column to search within.
* **options?**
* **options?**
Optional parameters for the phrase query.
- `slop`: The maximum number of intervening unmatched positions allowed between words in the phrase (default is 0).
* **options.slop?**: `number`
* **options.slop?**: `number`
#### Returns

View File

@@ -86,8 +86,7 @@ protected execute(options?): AsyncGenerator<RecordBatch<any>, void, unknown>
Execute the query and return the results as an
#### Parameters
* **options?**: `Partial`&lt;[`QueryExecutionOptions`](../interfaces/QueryExecutionOptions.md)&gt;
* **options?**: `Partial`&lt;[`QueryExecutionOptions`](../interfaces/QueryExecutionOptions.md)&gt;
#### Returns
@@ -120,8 +119,7 @@ explainPlan(verbose): Promise<string>
Generates an explanation of the query execution plan.
#### Parameters
* **verbose**: `boolean` = `false`
* **verbose**: `boolean` = `false`
If true, provides a more detailed explanation. Defaults to false.
#### Returns
@@ -177,8 +175,7 @@ filter(predicate): this
A filter statement to be applied to this query.
#### Parameters
* **predicate**: `string`
* **predicate**: `string`
#### Returns
@@ -205,10 +202,8 @@ fullTextSearch(query, options?): this
```
#### Parameters
* **query**: `string` \| [`FullTextQuery`](../interfaces/FullTextQuery.md)
* **options?**: `Partial`&lt;[`FullTextSearchOptions`](../interfaces/FullTextSearchOptions.md)&gt;
* **query**: `string` \| [`FullTextQuery`](../interfaces/FullTextQuery.md)
* **options?**: `Partial`&lt;[`FullTextSearchOptions`](../interfaces/FullTextSearchOptions.md)&gt;
#### Returns
@@ -232,8 +227,7 @@ By default, a plain search has no limit. If this method is not
called then every valid row from the table will be returned.
#### Parameters
* **limit**: `number`
* **limit**: `number`
#### Returns
@@ -268,8 +262,7 @@ fixed size list of floats) then the column does not need to be specified.
If there is more than one vector column you must use
#### Parameters
* **vector**: [`IntoVector`](../type-aliases/IntoVector.md)
* **vector**: [`IntoVector`](../type-aliases/IntoVector.md)
#### Returns
@@ -308,10 +301,8 @@ nearestToText(query, columns?): Query
```
#### Parameters
* **query**: `string` \| [`FullTextQuery`](../interfaces/FullTextQuery.md)
* **columns?**: `string`[]
* **query**: `string` \| [`FullTextQuery`](../interfaces/FullTextQuery.md)
* **columns?**: `string`[]
#### Returns
@@ -330,8 +321,7 @@ Set the number of rows to skip before returning results.
This is useful for pagination.
#### Parameters
* **offset**: `number`
* **offset**: `number`
#### Returns
@@ -393,8 +383,7 @@ For example, an SQL query might state `SELECT a + b AS combined, c`. The equiva
input to this method would be:
#### Parameters
* **columns**: `string` \| `string`[] \| `Record`&lt;`string`, `string`&gt; \| `Map`&lt;`string`, `string`&gt;
* **columns**: `string` \| `string`[] \| `Record`&lt;`string`, `string`&gt; \| `Map`&lt;`string`, `string`&gt;
#### Returns
@@ -428,8 +417,7 @@ toArray(options?): Promise<any[]>
Collect the results as an array of objects.
#### Parameters
* **options?**: `Partial`&lt;[`QueryExecutionOptions`](../interfaces/QueryExecutionOptions.md)&gt;
* **options?**: `Partial`&lt;[`QueryExecutionOptions`](../interfaces/QueryExecutionOptions.md)&gt;
#### Returns
@@ -450,8 +438,7 @@ toArrow(options?): Promise<Table<any>>
Collect the results as an Arrow
#### Parameters
* **options?**: `Partial`&lt;[`QueryExecutionOptions`](../interfaces/QueryExecutionOptions.md)&gt;
* **options?**: `Partial`&lt;[`QueryExecutionOptions`](../interfaces/QueryExecutionOptions.md)&gt;
#### Returns
@@ -478,8 +465,7 @@ A filter statement to be applied to this query.
The filter should be supplied as an SQL query string. For example:
#### Parameters
* **predicate**: `string`
* **predicate**: `string`
#### Returns

View File

@@ -87,8 +87,7 @@ protected execute(options?): AsyncGenerator<RecordBatch<any>, void, unknown>
Execute the query and return the results as an
#### Parameters
* **options?**: `Partial`&lt;[`QueryExecutionOptions`](../interfaces/QueryExecutionOptions.md)&gt;
* **options?**: `Partial`&lt;[`QueryExecutionOptions`](../interfaces/QueryExecutionOptions.md)&gt;
#### Returns
@@ -117,8 +116,7 @@ explainPlan(verbose): Promise<string>
Generates an explanation of the query execution plan.
#### Parameters
* **verbose**: `boolean` = `false`
* **verbose**: `boolean` = `false`
If true, provides a more detailed explanation. Defaults to false.
#### Returns
@@ -186,8 +184,7 @@ For example, an SQL query might state `SELECT a + b AS combined, c`. The equiva
input to this method would be:
#### Parameters
* **columns**: `string` \| `string`[] \| `Record`&lt;`string`, `string`&gt; \| `Map`&lt;`string`, `string`&gt;
* **columns**: `string` \| `string`[] \| `Record`&lt;`string`, `string`&gt; \| `Map`&lt;`string`, `string`&gt;
#### Returns
@@ -217,8 +214,7 @@ toArray(options?): Promise<any[]>
Collect the results as an array of objects.
#### Parameters
* **options?**: `Partial`&lt;[`QueryExecutionOptions`](../interfaces/QueryExecutionOptions.md)&gt;
* **options?**: `Partial`&lt;[`QueryExecutionOptions`](../interfaces/QueryExecutionOptions.md)&gt;
#### Returns
@@ -235,8 +231,7 @@ toArrow(options?): Promise<Table<any>>
Collect the results as an Arrow
#### Parameters
* **options?**: `Partial`&lt;[`QueryExecutionOptions`](../interfaces/QueryExecutionOptions.md)&gt;
* **options?**: `Partial`&lt;[`QueryExecutionOptions`](../interfaces/QueryExecutionOptions.md)&gt;
#### Returns

View File

@@ -33,10 +33,8 @@ Create a new session with custom cache sizes.
Defaults to 1GB if not specified.
#### Parameters
* **indexCacheSizeBytes?**: `null` \| `bigint`
* **metadataCacheSizeBytes?**: `null` \| `bigint`
* **indexCacheSizeBytes?**: `null` \| `bigint`
* **metadataCacheSizeBytes?**: `null` \| `bigint`
#### Returns

View File

@@ -37,8 +37,7 @@ new StaticHeaderProvider(headers): StaticHeaderProvider
Initialize with static headers.
#### Parameters
* **headers**: `Record`&lt;`string`, `string`&gt;
* **headers**: `Record`&lt;`string`, `string`&gt;
Headers to return for every request.
#### Returns

View File

@@ -46,11 +46,9 @@ abstract add(data, options?): Promise<AddResult>
Insert records into this Table.
#### Parameters
* **data**: [`Data`](../type-aliases/Data.md)
* **data**: [`Data`](../type-aliases/Data.md)
Records to be inserted into the Table
* **options?**: `Partial`&lt;[`AddDataOptions`](../interfaces/AddDataOptions.md)&gt;
* **options?**: `Partial`&lt;[`AddDataOptions`](../interfaces/AddDataOptions.md)&gt;
#### Returns
@@ -70,8 +68,7 @@ abstract addColumns(newColumnTransforms): Promise<AddColumnsResult>
Add new columns with defined values.
#### Parameters
* **newColumnTransforms**: `Field`&lt;`any`&gt; \| `Field`&lt;`any`&gt;[] \| `Schema`&lt;`any`&gt; \| [`AddColumnsSql`](../interfaces/AddColumnsSql.md)[]
* **newColumnTransforms**: `Field`&lt;`any`&gt; \| `Field`&lt;`any`&gt;[] \| `Schema`&lt;`any`&gt; \| [`AddColumnsSql`](../interfaces/AddColumnsSql.md)[]
Either:
- An array of objects with column names and SQL expressions to calculate values
- A single Arrow Field defining one column with its data type (column will be initialized with null values)
@@ -96,8 +93,7 @@ abstract alterColumns(columnAlterations): Promise<AlterColumnsResult>
Alter the name or nullability of columns.
#### Parameters
* **columnAlterations**: [`ColumnAlteration`](../interfaces/ColumnAlteration.md)[]
* **columnAlterations**: [`ColumnAlteration`](../interfaces/ColumnAlteration.md)[]
One or more alterations to
apply to columns.
@@ -126,8 +122,7 @@ Calling this method will set the table into time-travel mode. If you
wish to return to standard mode, call `checkoutLatest`.
#### Parameters
* **version**: `string` \| `number`
* **version**: `string` \| `number`
The version to checkout, could be version number or tag
#### Returns
@@ -196,8 +191,7 @@ abstract countRows(filter?): Promise<number>
Count the total number of rows in the dataset.
#### Parameters
* **filter?**: `string`
* **filter?**: `string`
#### Returns
@@ -222,10 +216,8 @@ We currently don't support custom named indexes.
The index name will always be `${column}_idx`.
#### Parameters
* **column**: `string`
* **options?**: `Partial`&lt;[`IndexOptions`](../interfaces/IndexOptions.md)&gt;
* **column**: `string`
* **options?**: `Partial`&lt;[`IndexOptions`](../interfaces/IndexOptions.md)&gt;
#### Returns
@@ -268,8 +260,7 @@ abstract delete(predicate): Promise<DeleteResult>
Delete the rows that satisfy the predicate.
#### Parameters
* **predicate**: `string`
* **predicate**: `string`
#### Returns
@@ -308,8 +299,7 @@ call ``compact_files`` to rewrite the data without the removed columns and
then call ``cleanup_files`` to remove the old files.
#### Parameters
* **columnNames**: `string`[]
* **columnNames**: `string`[]
The names of the columns to drop. These can
be nested column references (e.g. "a.b.c") or top-level column names
(e.g. "a").
@@ -332,8 +322,7 @@ abstract dropIndex(name): Promise<void>
Drop an index from the table.
#### Parameters
* **name**: `string`
* **name**: `string`
The name of the index.
This does not delete the index from disk, it just removes it from the table.
To delete the index, run [Table#optimize](Table.md#optimize) after dropping the index.
@@ -354,8 +343,7 @@ abstract indexStats(name): Promise<undefined | IndexStatistics>
List all the stats of a specified index
#### Parameters
* **name**: `string`
* **name**: `string`
The name of the index.
#### Returns
@@ -460,8 +448,7 @@ abstract mergeInsert(on): MergeInsertBuilder
```
#### Parameters
* **on**: `string` \| `string`[]
* **on**: `string` \| `string`[]
#### Returns
@@ -492,8 +479,7 @@ Modeled after ``VACUUM`` in PostgreSQL.
modification operations.
#### Parameters
* **options?**: `Partial`&lt;[`OptimizeOptions`](../interfaces/OptimizeOptions.md)&gt;
* **options?**: `Partial`&lt;[`OptimizeOptions`](../interfaces/OptimizeOptions.md)&gt;
#### Returns
@@ -510,8 +496,7 @@ abstract prewarmIndex(name): Promise<void>
Prewarm an index in the table.
#### Parameters
* **name**: `string`
* **name**: `string`
The name of the index.
This will load the index into memory. This may reduce the cold-start time for
future queries. If the index does not fit in the cache then this call may be
@@ -643,14 +628,11 @@ Create a search query to find the nearest neighbors
of the given query
#### Parameters
* **query**: `string` \| [`IntoVector`](../type-aliases/IntoVector.md) \| [`MultiVector`](../type-aliases/MultiVector.md) \| [`FullTextQuery`](../interfaces/FullTextQuery.md)
* **query**: `string` \| [`IntoVector`](../type-aliases/IntoVector.md) \| [`MultiVector`](../type-aliases/MultiVector.md) \| [`FullTextQuery`](../interfaces/FullTextQuery.md)
the query, a vector or string
* **queryType?**: `string`
* **queryType?**: `string`
the type of the query, "vector", "fts", or "auto"
* **ftsColumns?**: `string` \| `string`[]
* **ftsColumns?**: `string` \| `string`[]
the columns to search in for full text search
for now, only one column can be searched at a time.
when "auto" is used, if the query is a string and an embedding function is defined, it will be treated as a vector query
@@ -715,8 +697,7 @@ abstract takeOffsets(offsets): TakeQuery
Create a query that returns a subset of the rows in the table.
#### Parameters
* **offsets**: `number`[]
* **offsets**: `number`[]
The offsets of the rows to return.
#### Returns
@@ -736,8 +717,7 @@ abstract takeRowIds(rowIds): TakeQuery
Create a query that returns a subset of the rows in the table.
#### Parameters
* **rowIds**: readonly (`number` \| `bigint`)[]
* **rowIds**: readonly (`number` \| `bigint`)[]
The row ids of the rows to return.
Row ids returned by `withRowId()` are `bigint`, so `bigint[]` is supported.
For convenience / backwards compatibility, `number[]` is also accepted (for
@@ -776,8 +756,7 @@ abstract update(opts): Promise<UpdateResult>
Update existing records in the Table
##### Parameters
* **opts**: `object` & `Partial`&lt;[`UpdateOptions`](../interfaces/UpdateOptions.md)&gt;
* **opts**: `object` & `Partial`&lt;[`UpdateOptions`](../interfaces/UpdateOptions.md)&gt;
##### Returns
@@ -801,8 +780,7 @@ abstract update(opts): Promise<UpdateResult>
Update existing records in the Table
##### Parameters
* **opts**: `object` & `Partial`&lt;[`UpdateOptions`](../interfaces/UpdateOptions.md)&gt;
* **opts**: `object` & `Partial`&lt;[`UpdateOptions`](../interfaces/UpdateOptions.md)&gt;
##### Returns
@@ -839,12 +817,10 @@ better performance with a single [`merge_insert`] call instead of
repeatedly calilng this method.
##### Parameters
* **updates**: `Record`&lt;`string`, `string`&gt; \| `Map`&lt;`string`, `string`&gt;
* **updates**: `Record`&lt;`string`, `string`&gt; \| `Map`&lt;`string`, `string`&gt;
the
columns to update
* **options?**: `Partial`&lt;[`UpdateOptions`](../interfaces/UpdateOptions.md)&gt;
* **options?**: `Partial`&lt;[`UpdateOptions`](../interfaces/UpdateOptions.md)&gt;
additional options to control
the update behavior
@@ -875,8 +851,7 @@ is the same thing as calling `nearestTo` on the builder returned
by `query`.
#### Parameters
* **vector**: [`IntoVector`](../type-aliases/IntoVector.md) \| [`MultiVector`](../type-aliases/MultiVector.md)
* **vector**: [`IntoVector`](../type-aliases/IntoVector.md) \| [`MultiVector`](../type-aliases/MultiVector.md)
#### Returns
@@ -911,11 +886,9 @@ abstract waitForIndex(indexNames, timeoutSeconds): Promise<void>
Waits for asynchronous indexing to complete on the table.
#### Parameters
* **indexNames**: `string`[]
* **indexNames**: `string`[]
The name of the indices to wait for
* **timeoutSeconds**: `number`
* **timeoutSeconds**: `number`
The number of seconds to wait before timing out
This will raise an error if the indices are not created and fully indexed within the timeout.

View File

@@ -27,10 +27,8 @@ create(tag, version): Promise<void>
```
#### Parameters
* **tag**: `string`
* **version**: `number`
* **tag**: `string`
* **version**: `number`
#### Returns
@@ -45,8 +43,7 @@ delete(tag): Promise<void>
```
#### Parameters
* **tag**: `string`
* **tag**: `string`
#### Returns
@@ -61,8 +58,7 @@ getVersion(tag): Promise<number>
```
#### Parameters
* **tag**: `string`
* **tag**: `string`
#### Returns
@@ -89,10 +85,8 @@ update(tag, version): Promise<void>
```
#### Parameters
* **tag**: `string`
* **version**: `number`
* **tag**: `string`
* **version**: `number`
#### Returns

View File

@@ -82,8 +82,7 @@ protected execute(options?): AsyncGenerator<RecordBatch<any>, void, unknown>
Execute the query and return the results as an
#### Parameters
* **options?**: `Partial`&lt;[`QueryExecutionOptions`](../interfaces/QueryExecutionOptions.md)&gt;
* **options?**: `Partial`&lt;[`QueryExecutionOptions`](../interfaces/QueryExecutionOptions.md)&gt;
#### Returns
@@ -116,8 +115,7 @@ explainPlan(verbose): Promise<string>
Generates an explanation of the query execution plan.
#### Parameters
* **verbose**: `boolean` = `false`
* **verbose**: `boolean` = `false`
If true, provides a more detailed explanation. Defaults to false.
#### Returns
@@ -193,8 +191,7 @@ For example, an SQL query might state `SELECT a + b AS combined, c`. The equiva
input to this method would be:
#### Parameters
* **columns**: `string` \| `string`[] \| `Record`&lt;`string`, `string`&gt; \| `Map`&lt;`string`, `string`&gt;
* **columns**: `string` \| `string`[] \| `Record`&lt;`string`, `string`&gt; \| `Map`&lt;`string`, `string`&gt;
#### Returns
@@ -228,8 +225,7 @@ toArray(options?): Promise<any[]>
Collect the results as an array of objects.
#### Parameters
* **options?**: `Partial`&lt;[`QueryExecutionOptions`](../interfaces/QueryExecutionOptions.md)&gt;
* **options?**: `Partial`&lt;[`QueryExecutionOptions`](../interfaces/QueryExecutionOptions.md)&gt;
#### Returns
@@ -250,8 +246,7 @@ toArrow(options?): Promise<Table<any>>
Collect the results as an Arrow
#### Parameters
* **options?**: `Partial`&lt;[`QueryExecutionOptions`](../interfaces/QueryExecutionOptions.md)&gt;
* **options?**: `Partial`&lt;[`QueryExecutionOptions`](../interfaces/QueryExecutionOptions.md)&gt;
#### Returns

View File

@@ -15,8 +15,7 @@ new VectorColumnOptions(values?): VectorColumnOptions
```
#### Parameters
* **values?**: `Partial`&lt;[`VectorColumnOptions`](VectorColumnOptions.md)&gt;
* **values?**: `Partial`&lt;[`VectorColumnOptions`](VectorColumnOptions.md)&gt;
#### Returns

View File

@@ -39,8 +39,7 @@ addQueryVector(vector): VectorQuery
```
#### Parameters
* **vector**: [`IntoVector`](../type-aliases/IntoVector.md)
* **vector**: [`IntoVector`](../type-aliases/IntoVector.md)
#### Returns
@@ -127,8 +126,7 @@ This controls which column is compared to the query vector supplied in
the call to
#### Parameters
* **column**: `string`
* **column**: `string`
#### Returns
@@ -150,10 +148,8 @@ distanceRange(lowerBound?, upperBound?): VectorQuery
```
#### Parameters
* **lowerBound?**: `number`
* **upperBound?**: `number`
* **lowerBound?**: `number`
* **upperBound?**: `number`
#### Returns
@@ -174,8 +170,7 @@ to some kind of distance metric. This parameter controls which distance metric
use. See
#### Parameters
* **distanceType**: `"l2"` \| `"cosine"` \| `"dot"`
* **distanceType**: `"l2"` \| `"cosine"` \| `"dot"`
#### Returns
@@ -209,8 +204,7 @@ Increasing this value will increase the recall of your query but will
also increase the latency of your query. The default value is 1.5*limit.
#### Parameters
* **ef**: `number`
* **ef**: `number`
#### Returns
@@ -227,8 +221,7 @@ protected execute(options?): AsyncGenerator<RecordBatch<any>, void, unknown>
Execute the query and return the results as an
#### Parameters
* **options?**: `Partial`&lt;[`QueryExecutionOptions`](../interfaces/QueryExecutionOptions.md)&gt;
* **options?**: `Partial`&lt;[`QueryExecutionOptions`](../interfaces/QueryExecutionOptions.md)&gt;
#### Returns
@@ -261,8 +254,7 @@ explainPlan(verbose): Promise<string>
Generates an explanation of the query execution plan.
#### Parameters
* **verbose**: `boolean` = `false`
* **verbose**: `boolean` = `false`
If true, provides a more detailed explanation. Defaults to false.
#### Returns
@@ -318,8 +310,7 @@ filter(predicate): this
A filter statement to be applied to this query.
#### Parameters
* **predicate**: `string`
* **predicate**: `string`
#### Returns
@@ -346,10 +337,8 @@ fullTextSearch(query, options?): this
```
#### Parameters
* **query**: `string` \| [`FullTextQuery`](../interfaces/FullTextQuery.md)
* **options?**: `Partial`&lt;[`FullTextSearchOptions`](../interfaces/FullTextSearchOptions.md)&gt;
* **query**: `string` \| [`FullTextQuery`](../interfaces/FullTextQuery.md)
* **options?**: `Partial`&lt;[`FullTextSearchOptions`](../interfaces/FullTextSearchOptions.md)&gt;
#### Returns
@@ -373,8 +362,7 @@ By default, a plain search has no limit. If this method is not
called then every valid row from the table will be returned.
#### Parameters
* **limit**: `number`
* **limit**: `number`
#### Returns
@@ -401,8 +389,7 @@ a narrow filter to allow these queries to spend more time searching and avoid
potential false negatives.
#### Parameters
* **maximumNprobes**: `number`
* **maximumNprobes**: `number`
#### Returns
@@ -424,8 +411,7 @@ filter. See `nprobes` for more details. Higher values will increase recall
but will also increase latency.
#### Parameters
* **minimumNprobes**: `number`
* **minimumNprobes**: `number`
#### Returns
@@ -465,8 +451,7 @@ you can use `minimumNprobes` and `maximumNprobes`. This method sets both
the minimum and maximum to the same value.
#### Parameters
* **nprobes**: `number`
* **nprobes**: `number`
#### Returns
@@ -485,8 +470,7 @@ Set the number of rows to skip before returning results.
This is useful for pagination.
#### Parameters
* **offset**: `number`
* **offset**: `number`
#### Returns
@@ -590,8 +574,7 @@ and the quantized result vectors. This can be considerably different than the t
distance between the query vector and the actual uncompressed vector.
#### Parameters
* **refineFactor**: `number`
* **refineFactor**: `number`
#### Returns
@@ -606,8 +589,7 @@ rerank(reranker): VectorQuery
```
#### Parameters
* **reranker**: [`Reranker`](../namespaces/rerankers/interfaces/Reranker.md)
* **reranker**: [`Reranker`](../namespaces/rerankers/interfaces/Reranker.md)
#### Returns
@@ -642,8 +624,7 @@ For example, an SQL query might state `SELECT a + b AS combined, c`. The equiva
input to this method would be:
#### Parameters
* **columns**: `string` \| `string`[] \| `Record`&lt;`string`, `string`&gt; \| `Map`&lt;`string`, `string`&gt;
* **columns**: `string` \| `string`[] \| `Record`&lt;`string`, `string`&gt; \| `Map`&lt;`string`, `string`&gt;
#### Returns
@@ -677,8 +658,7 @@ toArray(options?): Promise<any[]>
Collect the results as an array of objects.
#### Parameters
* **options?**: `Partial`&lt;[`QueryExecutionOptions`](../interfaces/QueryExecutionOptions.md)&gt;
* **options?**: `Partial`&lt;[`QueryExecutionOptions`](../interfaces/QueryExecutionOptions.md)&gt;
#### Returns
@@ -699,8 +679,7 @@ toArrow(options?): Promise<Table<any>>
Collect the results as an Arrow
#### Parameters
* **options?**: `Partial`&lt;[`QueryExecutionOptions`](../interfaces/QueryExecutionOptions.md)&gt;
* **options?**: `Partial`&lt;[`QueryExecutionOptions`](../interfaces/QueryExecutionOptions.md)&gt;
#### Returns
@@ -727,8 +706,7 @@ A filter statement to be applied to this query.
The filter should be supplied as an SQL query string. For example:
#### Parameters
* **predicate**: `string`
* **predicate**: `string`
#### Returns

View File

@@ -0,0 +1,59 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / OAuthFlowType
# Enumeration: OAuthFlowType
OAuth authentication flow types.
## Enumeration Members
### AuthorizationCodePKCE
```ts
AuthorizationCodePKCE: "authorization_code_pkce";
```
Authorization Code with PKCE (interactive browser-based auth).
***
### AzureManagedIdentity
```ts
AzureManagedIdentity: "azure_managed_identity";
```
Azure Managed Identity via IMDS.
***
### ClientCredentials
```ts
ClientCredentials: "client_credentials";
```
Client Credentials grant (service-to-service / M2M).
***
### DeviceCode
```ts
DeviceCode: "device_code";
```
Device Code grant (CLI / headless environments).
***
### WorkloadIdentity
```ts
WorkloadIdentity: "workload_identity";
```
Workload Identity Federation (K8s, GitHub Actions).

View File

@@ -11,8 +11,7 @@ function RecordBatchIterator(promisedInner): AsyncGenerator<RecordBatch<any>, vo
```
## Parameters
* **promisedInner**: `Promise`&lt;`RecordBatchIterator`&gt;
* **promisedInner**: `Promise`&lt;`RecordBatchIterator`&gt;
## Returns

View File

@@ -25,17 +25,13 @@ Accepted formats:
- `db://host:port` - remote database (LanceDB cloud)
### Parameters
* **uri**: `string`
* **uri**: `string`
The uri of the database. If the database uri starts
with `db://` then it connects to a remote database.
* **options?**: `Partial`&lt;[`ConnectionOptions`](../interfaces/ConnectionOptions.md)&gt;
* **options?**: `Partial`&lt;[`ConnectionOptions`](../interfaces/ConnectionOptions.md)&gt;
The options to use when connecting to the database
* **session?**: [`Session`](../classes/Session.md)
* **headerProvider?**: [`HeaderProvider`](../classes/HeaderProvider.md) \| () => `Record`&lt;`string`, `string`&gt; \| () => `Promise`&lt;`Record`&lt;`string`, `string`&gt;&gt;
* **session?**: [`Session`](../classes/Session.md)
* **headerProvider?**: [`HeaderProvider`](../classes/HeaderProvider.md) \| () => `Record`&lt;`string`, `string`&gt; \| () => `Promise`&lt;`Record`&lt;`string`, `string`&gt;&gt;
### Returns
@@ -85,8 +81,7 @@ Accepted formats:
- `db://host:port` - remote database (LanceDB cloud)
### Parameters
* **options**: `Partial`&lt;[`ConnectionOptions`](../interfaces/ConnectionOptions.md)&gt; & `object`
* **options**: `Partial`&lt;[`ConnectionOptions`](../interfaces/ConnectionOptions.md)&gt; & `object`
The options to use when connecting to the database
### Returns

View File

@@ -46,12 +46,9 @@ rules are as follows:
- Array<any> => List
## Parameters
* **data**: `Record`&lt;`string`, `unknown`&gt;[]
* **options?**: `Partial`&lt;[`MakeArrowTableOptions`](../classes/MakeArrowTableOptions.md)&gt;
* **metadata?**: `Map`&lt;`string`, `string`&gt;
* **data**: `Record`&lt;`string`, `unknown`&gt;[]
* **options?**: `Partial`&lt;[`MakeArrowTableOptions`](../classes/MakeArrowTableOptions.md)&gt;
* **metadata?**: `Map`&lt;`string`, `string`&gt;
## Returns

View File

@@ -11,8 +11,7 @@ function packBits(data): number[]
```
## Parameters
* **data**: `number`[]
* **data**: `number`[]
## Returns

View File

@@ -13,8 +13,7 @@ function permutationBuilder(table): PermutationBuilder
Create a permutation builder for the given table.
## Parameters
* **table**: [`Table`](../classes/Table.md)
* **table**: [`Table`](../classes/Table.md)
The source table to create a permutation from
## Returns

View File

@@ -12,6 +12,7 @@
## Enumerations
- [FullTextQueryType](enumerations/FullTextQueryType.md)
- [OAuthFlowType](enumerations/OAuthFlowType.md)
- [Occur](enumerations/Occur.md)
- [Operator](enumerations/Operator.md)
@@ -70,6 +71,8 @@
- [IvfPqOptions](interfaces/IvfPqOptions.md)
- [IvfRqOptions](interfaces/IvfRqOptions.md)
- [MergeResult](interfaces/MergeResult.md)
- [NativeOAuthConfig](interfaces/NativeOAuthConfig.md)
- [OAuthConfig](interfaces/OAuthConfig.md)
- [OpenTableOptions](interfaces/OpenTableOptions.md)
- [OptimizeOptions](interfaces/OptimizeOptions.md)
- [OptimizeStats](interfaces/OptimizeStats.md)

View File

@@ -64,6 +64,18 @@ client used by manifest-enabled native connections.
***
### oauthConfig?
```ts
optional oauthConfig: NativeOAuthConfig;
```
(For LanceDB cloud only): OAuth configuration for IdP-based
authentication (e.g., Azure Entra ID). When set, token acquisition
and refresh are handled entirely in Rust.
***
### readConsistencyInterval?
```ts

View File

@@ -0,0 +1,112 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / NativeOAuthConfig
# Interface: NativeOAuthConfig
OAuth configuration for LanceDB authentication.
All token acquisition and refresh is handled in the Rust layer.
## Properties
### callbackPort?
```ts
optional callbackPort: number;
```
Port for local callback server (authorization_code_pkce, default: 8400).
***
### clientId
```ts
clientId: string;
```
Application / Client ID.
***
### clientSecret?
```ts
optional clientSecret: string;
```
Client secret (required for client_credentials).
***
### flow?
```ts
optional flow: string;
```
Authentication flow: "client_credentials", "authorization_code_pkce",
"device_code", "azure_managed_identity", "workload_identity"
***
### issuerUrl
```ts
issuerUrl: string;
```
OIDC issuer URL or OAuth authority URL.
For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
***
### managedIdentityClientId?
```ts
optional managedIdentityClientId: string;
```
Client ID for user-assigned managed identity (azure_managed_identity).
***
### redirectUri?
```ts
optional redirectUri: string;
```
Redirect URI (authorization_code_pkce flow).
***
### refreshBufferSecs?
```ts
optional refreshBufferSecs: number;
```
Seconds before expiry to trigger proactive refresh (default: 300).
***
### scopes
```ts
scopes: string[];
```
OAuth scopes to request. For Azure: `["api://{app_id}/.default"]`
***
### tokenFile?
```ts
optional tokenFile: string;
```
Path to federated token file (workload_identity).

View File

@@ -0,0 +1,134 @@
[**@lancedb/lancedb**](../README.md) • **Docs**
***
[@lancedb/lancedb](../globals.md) / OAuthConfig
# Interface: OAuthConfig
OAuth configuration for LanceDB authentication.
All token acquisition and refresh is handled in the Rust layer.
This config is passed through to Rust via napi-rs.
## Examples
```typescript
const config: OAuthConfig = {
issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
clientId: "app-id",
clientSecret: "secret",
scopes: ["api://lancedb-api/.default"],
};
```
```typescript
const config: OAuthConfig = {
issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
clientId: "app-id",
scopes: ["api://lancedb-api/.default"],
flow: OAuthFlowType.AzureManagedIdentity,
};
```
## Properties
### callbackPort?
```ts
optional callbackPort: number;
```
Port for local callback server (AuthorizationCodePKCE, default: 8400).
***
### clientId
```ts
clientId: string;
```
Application / Client ID.
***
### clientSecret?
```ts
optional clientSecret: string;
```
Client secret (required for ClientCredentials).
***
### flow?
```ts
optional flow: OAuthFlowType;
```
Authentication flow (default: ClientCredentials).
***
### issuerUrl
```ts
issuerUrl: string;
```
OIDC issuer URL or OAuth authority URL.
For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
***
### managedIdentityClientId?
```ts
optional managedIdentityClientId: string;
```
Client ID for user-assigned managed identity (AzureManagedIdentity).
***
### redirectUri?
```ts
optional redirectUri: string;
```
Redirect URI (AuthorizationCodePKCE flow).
***
### refreshBufferSecs?
```ts
optional refreshBufferSecs: number;
```
Seconds before expiry to trigger proactive refresh (default: 300).
***
### scopes
```ts
scopes: string[];
```
OAuth scopes to request.
For Azure: `["api://{app_id}/.default"]`
***
### tokenFile?
```ts
optional tokenFile: string;
```
Path to federated token file (WorkloadIdentity).

View File

@@ -58,8 +58,7 @@ computeQueryEmbeddings(data): Promise<number[] | Uint8Array | Float32Array | Flo
Compute the embeddings for a single query
#### Parameters
* **data**: `T`
* **data**: `T`
#### Returns
@@ -76,8 +75,7 @@ abstract computeSourceEmbeddings(data): Promise<number[][] | Float32Array[] | Fl
Creates a vector representation for the given values.
#### Parameters
* **data**: `T`[]
* **data**: `T`[]
#### Returns
@@ -155,8 +153,7 @@ protected resolveVariables(config): Partial<M>
Apply variables to the config.
#### Parameters
* **config**: `Partial`&lt;`M`&gt;
* **config**: `Partial`&lt;`M`&gt;
#### Returns
@@ -173,8 +170,7 @@ sourceField(optionsOrDatatype): [DataType<Type, any>, Map<string, EmbeddingFunct
sourceField is used in combination with `LanceSchema` to provide a declarative data model
#### Parameters
* **optionsOrDatatype**: `DataType`&lt;`Type`, `any`&gt; \| `Partial`&lt;[`FieldOptions`](../interfaces/FieldOptions.md)&lt;`DataType`&lt;`Type`, `any`&gt;&gt;&gt;
* **optionsOrDatatype**: `DataType`&lt;`Type`, `any`&gt; \| `Partial`&lt;[`FieldOptions`](../interfaces/FieldOptions.md)&lt;`DataType`&lt;`Type`, `any`&gt;&gt;&gt;
The options for the field or the datatype
#### Returns
@@ -211,8 +207,7 @@ vectorField(optionsOrDatatype?): [DataType<Type, any>, Map<string, EmbeddingFunc
vectorField is used in combination with `LanceSchema` to provide a declarative data model
#### Parameters
* **optionsOrDatatype?**: `DataType`&lt;`Type`, `any`&gt; \| `Partial`&lt;[`FieldOptions`](../interfaces/FieldOptions.md)&lt;`DataType`&lt;`Type`, `any`&gt;&gt;&gt;
* **optionsOrDatatype?**: `DataType`&lt;`Type`, `any`&gt; \| `Partial`&lt;[`FieldOptions`](../interfaces/FieldOptions.md)&lt;`DataType`&lt;`Type`, `any`&gt;&gt;&gt;
The options for the field
#### Returns

View File

@@ -32,8 +32,7 @@ functionToMetadata(conf): Record<string, any>
```
#### Parameters
* **conf**: [`EmbeddingFunctionConfig`](../interfaces/EmbeddingFunctionConfig.md)
* **conf**: [`EmbeddingFunctionConfig`](../interfaces/EmbeddingFunctionConfig.md)
#### Returns
@@ -54,8 +53,7 @@ Fetch an embedding function by name
**T** *extends* [`EmbeddingFunction`](EmbeddingFunction.md)&lt;`unknown`, [`FunctionOptions`](../interfaces/FunctionOptions.md)&gt;
#### Parameters
* **name**: `string`
* **name**: `string`
The name of the function
#### Returns
@@ -71,8 +69,7 @@ getTableMetadata(functions): Map<string, string>
```
#### Parameters
* **functions**: [`EmbeddingFunctionConfig`](../interfaces/EmbeddingFunctionConfig.md)[]
* **functions**: [`EmbeddingFunctionConfig`](../interfaces/EmbeddingFunctionConfig.md)[]
#### Returns
@@ -89,8 +86,7 @@ getVar(name): undefined | string
Get a variable.
#### Parameters
* **name**: `string`
* **name**: `string`
#### Returns
@@ -129,18 +125,15 @@ Register an embedding function
**T** *extends* [`EmbeddingFunctionConstructor`](../interfaces/EmbeddingFunctionConstructor.md)&lt;[`EmbeddingFunction`](EmbeddingFunction.md)&lt;`any`, [`FunctionOptions`](../interfaces/FunctionOptions.md)&gt;&gt; = [`EmbeddingFunctionConstructor`](../interfaces/EmbeddingFunctionConstructor.md)&lt;[`EmbeddingFunction`](EmbeddingFunction.md)&lt;`any`, [`FunctionOptions`](../interfaces/FunctionOptions.md)&gt;&gt;
#### Parameters
* **this**: [`EmbeddingFunctionRegistry`](EmbeddingFunctionRegistry.md)
* **alias?**: `string`
* **this**: [`EmbeddingFunctionRegistry`](EmbeddingFunctionRegistry.md)
* **alias?**: `string`
#### Returns
`Function`
##### Parameters
* **ctor**: `T`
* **ctor**: `T`
##### Returns
@@ -161,8 +154,7 @@ reset(this): void
reset the registry to the initial state
#### Parameters
* **this**: [`EmbeddingFunctionRegistry`](EmbeddingFunctionRegistry.md)
* **this**: [`EmbeddingFunctionRegistry`](EmbeddingFunctionRegistry.md)
#### Returns
@@ -187,10 +179,8 @@ whether to use a GPU for inference.
The name must not contain colons. The default value can contain colons.
#### Parameters
* **name**: `string`
* **value**: `string`
* **name**: `string`
* **value**: `string`
#### Returns

View File

@@ -43,8 +43,7 @@ computeQueryEmbeddings(data): Promise<number[] | Uint8Array | Float32Array | Flo
Compute the embeddings for a single query
#### Parameters
* **data**: `string`
* **data**: `string`
#### Returns
@@ -65,8 +64,7 @@ computeSourceEmbeddings(data): Promise<number[][] | Float32Array[] | Float64Arra
Creates a vector representation for the given values.
#### Parameters
* **data**: `string`[]
* **data**: `string`[]
#### Returns
@@ -103,10 +101,8 @@ abstract generateEmbeddings(texts, ...args): Promise<number[][] | Float32Array[]
```
#### Parameters
* **texts**: `string`[]
* ...**args**: `any`[]
* **texts**: `string`[]
* ...**args**: `any`[]
#### Returns
@@ -182,8 +178,7 @@ protected resolveVariables(config): Partial<M>
Apply variables to the config.
#### Parameters
* **config**: `Partial`&lt;`M`&gt;
* **config**: `Partial`&lt;`M`&gt;
#### Returns
@@ -245,8 +240,7 @@ vectorField(optionsOrDatatype?): [DataType<Type, any>, Map<string, EmbeddingFunc
vectorField is used in combination with `LanceSchema` to provide a declarative data model
#### Parameters
* **optionsOrDatatype?**: `DataType`&lt;`Type`, `any`&gt; \| `Partial`&lt;[`FieldOptions`](../interfaces/FieldOptions.md)&lt;`DataType`&lt;`Type`, `any`&gt;&gt;&gt;
* **optionsOrDatatype?**: `DataType`&lt;`Type`, `any`&gt; \| `Partial`&lt;[`FieldOptions`](../interfaces/FieldOptions.md)&lt;`DataType`&lt;`Type`, `any`&gt;&gt;&gt;
The options for the field
#### Returns

View File

@@ -13,8 +13,7 @@ function LanceSchema(fields): Schema
Create a schema with embedding functions.
## Parameters
* **fields**: `Record`&lt;`string`, `object` \| [`object`, `Map`&lt;`string`, [`EmbeddingFunction`](../classes/EmbeddingFunction.md)&lt;`any`, [`FunctionOptions`](../interfaces/FunctionOptions.md)&gt;&gt;]&gt;
* **fields**: `Record`&lt;`string`, `object` \| [`object`, `Map`&lt;`string`, [`EmbeddingFunction`](../classes/EmbeddingFunction.md)&lt;`any`, [`FunctionOptions`](../interfaces/FunctionOptions.md)&gt;&gt;]&gt;
## Returns

View File

@@ -11,16 +11,14 @@ function register(name?): (ctor) => any
```
## Parameters
* **name?**: `string`
* **name?**: `string`
## Returns
`Function`
### Parameters
* **ctor**: [`EmbeddingFunctionConstructor`](../interfaces/EmbeddingFunctionConstructor.md)&lt;[`EmbeddingFunction`](../classes/EmbeddingFunction.md)&lt;`any`, [`FunctionOptions`](../interfaces/FunctionOptions.md)&gt;&gt;
* **ctor**: [`EmbeddingFunctionConstructor`](../interfaces/EmbeddingFunctionConstructor.md)&lt;[`EmbeddingFunction`](../classes/EmbeddingFunction.md)&lt;`any`, [`FunctionOptions`](../interfaces/FunctionOptions.md)&gt;&gt;
### Returns

View File

@@ -19,8 +19,7 @@ new EmbeddingFunctionConstructor(modelOptions?): T
```
#### Parameters
* **modelOptions?**: `T`\[`"TOptions"`\]
* **modelOptions?**: `T`\[`"TOptions"`\]
#### Returns

View File

@@ -19,8 +19,7 @@ create(options?): CreateReturnType<T>
```
#### Parameters
* **options?**: `T`\[`"TOptions"`\]
* **options?**: `T`\[`"TOptions"`\]
#### Returns

View File

@@ -20,12 +20,9 @@ rerankHybrid(
```
#### Parameters
* **query**: `string`
* **vecResults**: `RecordBatch`&lt;`any`&gt;
* **ftsResults**: `RecordBatch`&lt;`any`&gt;
* **query**: `string`
* **vecResults**: `RecordBatch`&lt;`any`&gt;
* **ftsResults**: `RecordBatch`&lt;`any`&gt;
#### Returns
@@ -40,8 +37,7 @@ static create(k): Promise<RRFReranker>
```
#### Parameters
* **k**: `number` = `60`
* **k**: `number` = `60`
#### Returns

View File

@@ -18,12 +18,9 @@ rerankHybrid(
```
#### Parameters
* **query**: `string`
* **vecResults**: `RecordBatch`&lt;`any`&gt;
* **ftsResults**: `RecordBatch`&lt;`any`&gt;
* **query**: `string`
* **vecResults**: `RecordBatch`&lt;`any`&gt;
* **ftsResults**: `RecordBatch`&lt;`any`&gt;
#### Returns

View File

@@ -48,6 +48,7 @@ export {
SplitHashOptions,
SplitSequentialOptions,
ShuffleOptions,
OAuthConfig as NativeOAuthConfig,
} from "./native.js";
export {
@@ -113,6 +114,8 @@ export {
TokenResponse,
} from "./header";
export { OAuthConfig, OAuthFlowType } from "./oauth";
export { MergeInsertBuilder, WriteExecutionOptions } from "./merge";
export * as embedding from "./embedding";

82
nodejs/lancedb/oauth.ts Normal file
View File

@@ -0,0 +1,82 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
/**
* OAuth authentication flow types.
*/
export enum OAuthFlowType {
/** Client Credentials grant (service-to-service / M2M). */
ClientCredentials = "client_credentials",
/** Authorization Code with PKCE (interactive browser-based auth). */
AuthorizationCodePKCE = "authorization_code_pkce",
/** Device Code grant (CLI / headless environments). */
DeviceCode = "device_code",
/** Azure Managed Identity via IMDS. */
AzureManagedIdentity = "azure_managed_identity",
/** Workload Identity Federation (K8s, GitHub Actions). */
WorkloadIdentity = "workload_identity",
}
/**
* OAuth configuration for LanceDB authentication.
*
* All token acquisition and refresh is handled in the Rust layer.
* This config is passed through to Rust via napi-rs.
*
* @example Client Credentials (service-to-service):
* ```typescript
* const config: OAuthConfig = {
* issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
* clientId: "app-id",
* clientSecret: "secret",
* scopes: ["api://lancedb-api/.default"],
* };
* ```
*
* @example Azure Managed Identity:
* ```typescript
* const config: OAuthConfig = {
* issuerUrl: "https://login.microsoftonline.com/{tenant}/v2.0",
* clientId: "app-id",
* scopes: ["api://lancedb-api/.default"],
* flow: OAuthFlowType.AzureManagedIdentity,
* };
* ```
*/
export interface OAuthConfig {
/**
* OIDC issuer URL or OAuth authority URL.
* For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
*/
issuerUrl: string;
/** Application / Client ID. */
clientId: string;
/**
* OAuth scopes to request.
* For Azure: `["api://{app_id}/.default"]`
*/
scopes: string[];
/** Authentication flow (default: ClientCredentials). */
flow?: OAuthFlowType;
/** Client secret (required for ClientCredentials). */
clientSecret?: string;
/** Redirect URI (AuthorizationCodePKCE flow). */
redirectUri?: string;
/** Port for local callback server (AuthorizationCodePKCE, default: 8400). */
callbackPort?: number;
/** Client ID for user-assigned managed identity (AzureManagedIdentity). */
managedIdentityClientId?: string;
/** Path to federated token file (WorkloadIdentity). */
tokenFile?: string;
/** Seconds before expiry to trigger proactive refresh (default: 300). */
refreshBufferSecs?: number;
}

View File

@@ -85,6 +85,11 @@ impl Connection {
builder = builder.client_config(rust_config);
if let Some(oauth_config) = options.oauth_config {
let config: lancedb::remote::oauth::OAuthConfig = oauth_config.into();
builder = builder.oauth_config(config);
}
if let Some(api_key) = options.api_key {
builder = builder.api_key(&api_key);
}

View File

@@ -60,6 +60,10 @@ pub struct ConnectionOptions {
/// (For LanceDB cloud only): the host to use for LanceDB cloud. Used
/// for testing purposes.
pub host_override: Option<String>,
/// (For LanceDB cloud only): OAuth configuration for IdP-based
/// authentication (e.g., Azure Entra ID). When set, token acquisition
/// and refresh are handled entirely in Rust.
pub oauth_config: Option<remote::OAuthConfig>,
}
#[napi(object)]

View File

@@ -140,6 +140,67 @@ impl From<TlsConfig> for lancedb::remote::TlsConfig {
}
}
/// OAuth configuration for LanceDB authentication.
/// All token acquisition and refresh is handled in the Rust layer.
#[napi(object)]
#[derive(Debug, Clone)]
pub struct OAuthConfig {
/// OIDC issuer URL or OAuth authority URL.
/// For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
pub issuer_url: String,
/// Application / Client ID.
pub client_id: String,
/// OAuth scopes to request. For Azure: `["api://{app_id}/.default"]`
pub scopes: Vec<String>,
/// Authentication flow: "client_credentials", "authorization_code_pkce",
/// "device_code", "azure_managed_identity", "workload_identity"
pub flow: Option<String>,
/// Client secret (required for client_credentials).
pub client_secret: Option<String>,
/// Redirect URI (authorization_code_pkce flow).
pub redirect_uri: Option<String>,
/// Port for local callback server (authorization_code_pkce, default: 8400).
pub callback_port: Option<u16>,
/// Client ID for user-assigned managed identity (azure_managed_identity).
pub managed_identity_client_id: Option<String>,
/// Path to federated token file (workload_identity).
pub token_file: Option<String>,
/// Seconds before expiry to trigger proactive refresh (default: 300).
pub refresh_buffer_secs: Option<u32>,
}
impl From<OAuthConfig> for lancedb::remote::oauth::OAuthConfig {
fn from(config: OAuthConfig) -> Self {
use lancedb::remote::oauth::OAuthFlow;
let flow = match config.flow.as_deref().unwrap_or("client_credentials") {
"authorization_code_pkce" => OAuthFlow::AuthorizationCodePKCE {
redirect_uri: config.redirect_uri,
callback_port: config.callback_port,
},
"device_code" => OAuthFlow::DeviceCode,
"azure_managed_identity" => OAuthFlow::AzureManagedIdentity {
client_id: config.managed_identity_client_id,
},
"workload_identity" => OAuthFlow::WorkloadIdentity {
token_file: config
.token_file
.expect("tokenFile is required for workload_identity flow"),
},
other => panic!("Unknown OAuth flow type: {other}"),
};
Self {
issuer_url: config.issuer_url,
client_id: config.client_id,
client_secret: config.client_secret,
scopes: config.scopes,
flow,
refresh_buffer_secs: config.refresh_buffer_secs.map(|v| v as u64),
}
}
}
impl From<ClientConfig> for lancedb::remote::ClientConfig {
fn from(config: ClientConfig) -> Self {
Self {

View File

@@ -320,6 +320,7 @@ async def connect_async(
session: Optional[Session] = None,
manifest_enabled: bool = False,
namespace_client_properties: Optional[Dict[str, str]] = None,
oauth_config=None,
) -> AsyncConnection:
"""Connect to a LanceDB database.
@@ -410,6 +411,7 @@ async def connect_async(
session,
manifest_enabled,
namespace_client_properties,
oauth_config,
)
)

View File

@@ -247,6 +247,7 @@ async def connect(
session: Optional[Session],
manifest_enabled: bool = False,
namespace_client_properties: Optional[Dict[str, str]] = None,
oauth_config: Optional[Any] = None,
) -> Connection: ...
class RecordBatchStream:

View File

@@ -9,6 +9,7 @@ from typing import List, Optional
from lancedb import __version__
from .header import HeaderProvider
from .oauth import OAuthConfig, OAuthFlowType
__all__ = [
"TimeoutConfig",
@@ -16,6 +17,8 @@ __all__ = [
"TlsConfig",
"ClientConfig",
"HeaderProvider",
"OAuthConfig",
"OAuthFlowType",
]

View File

@@ -0,0 +1,90 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional
class OAuthFlowType(str, Enum):
"""OAuth authentication flow types."""
CLIENT_CREDENTIALS = "client_credentials"
"""Client Credentials grant (service-to-service / M2M)."""
AUTHORIZATION_CODE_PKCE = "authorization_code_pkce"
"""Authorization Code with PKCE (interactive browser-based auth)."""
DEVICE_CODE = "device_code"
"""Device Code grant (CLI / headless environments)."""
AZURE_MANAGED_IDENTITY = "azure_managed_identity"
"""Azure Managed Identity via IMDS."""
WORKLOAD_IDENTITY = "workload_identity"
"""Workload Identity Federation (K8s, GitHub Actions)."""
@dataclass
class OAuthConfig:
"""OAuth configuration for LanceDB authentication.
All token acquisition and refresh is handled in the Rust layer.
This config is passed through to Rust via PyO3.
Parameters
----------
issuer_url : str
OIDC issuer URL or OAuth authority URL.
For Azure: ``https://login.microsoftonline.com/{tenant_id}/v2.0``
client_id : str
Application / Client ID.
scopes : List[str]
OAuth scopes to request.
For Azure: ``["api://{app_id}/.default"]``
flow : OAuthFlowType
Authentication flow to use. Default: CLIENT_CREDENTIALS.
client_secret : Optional[str]
Client secret (required for CLIENT_CREDENTIALS).
redirect_uri : Optional[str]
Redirect URI for AUTHORIZATION_CODE_PKCE flow.
callback_port : Optional[int]
Port for local HTTP callback server (AUTHORIZATION_CODE_PKCE, default: 8400).
managed_identity_client_id : Optional[str]
Client ID for user-assigned managed identity (AZURE_MANAGED_IDENTITY).
token_file : Optional[str]
Path to federated token file (WORKLOAD_IDENTITY).
refresh_buffer_secs : Optional[int]
Seconds before expiry to trigger proactive refresh (default: 300).
Examples
--------
Client Credentials (service-to-service):
>>> config = OAuthConfig(
... issuer_url="https://login.microsoftonline.com/{tenant}/v2.0",
... client_id="app-id",
... client_secret="secret",
... scopes=["api://lancedb-api/.default"],
... )
Azure Managed Identity:
>>> config = OAuthConfig(
... issuer_url="https://login.microsoftonline.com/{tenant}/v2.0",
... client_id="app-id",
... scopes=["api://lancedb-api/.default"],
... flow=OAuthFlowType.AZURE_MANAGED_IDENTITY,
... )
"""
issuer_url: str
client_id: str
scopes: List[str]
flow: OAuthFlowType = OAuthFlowType.CLIENT_CREDENTIALS
client_secret: Optional[str] = None
redirect_uri: Optional[str] = None
callback_port: Optional[int] = None
managed_identity_client_id: Optional[str] = None
token_file: Optional[str] = None
refresh_buffer_secs: Optional[int] = None

View File

@@ -524,7 +524,7 @@ impl Connection {
}
#[pyfunction]
#[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None, manifest_enabled=false, namespace_client_properties=None))]
#[pyo3(signature = (uri, api_key=None, region=None, host_override=None, read_consistency_interval=None, client_config=None, storage_options=None, session=None, manifest_enabled=false, namespace_client_properties=None, oauth_config=None))]
#[allow(clippy::too_many_arguments)]
pub fn connect(
py: Python<'_>,
@@ -538,6 +538,7 @@ pub fn connect(
session: Option<crate::session::Session>,
manifest_enabled: bool,
namespace_client_properties: Option<HashMap<String, String>>,
oauth_config: Option<crate::oauth::PyOAuthConfig>,
) -> PyResult<Bound<'_, PyAny>> {
future_into_py(py, async move {
let mut builder = lancedb::connect(&uri);
@@ -567,6 +568,10 @@ pub fn connect(
if let Some(client_config) = client_config {
builder = builder.client_config(client_config.into());
}
if let Some(oauth_config) = oauth_config {
let config: lancedb::remote::oauth::OAuthConfig = oauth_config.into();
builder = builder.oauth_config(config);
}
if let Some(session) = session {
builder = builder.session(session.inner.clone());
}

View File

@@ -26,6 +26,7 @@ pub mod expr;
pub mod header;
pub mod index;
pub mod namespace;
pub mod oauth;
pub mod permutation;
pub mod query;
pub mod runtime;

53
python/src/oauth.rs Normal file
View File

@@ -0,0 +1,53 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use pyo3::FromPyObject;
use lancedb::remote::oauth::{OAuthConfig, OAuthFlow};
/// Python-side OAuth configuration, extracted via FromPyObject.
/// Maps to `lancedb.remote.oauth.OAuthConfig` Python dataclass.
#[derive(FromPyObject)]
pub struct PyOAuthConfig {
pub issuer_url: String,
pub client_id: String,
pub scopes: Vec<String>,
pub flow: String,
pub client_secret: Option<String>,
pub redirect_uri: Option<String>,
pub callback_port: Option<u16>,
pub managed_identity_client_id: Option<String>,
pub token_file: Option<String>,
pub refresh_buffer_secs: Option<u64>,
}
impl From<PyOAuthConfig> for OAuthConfig {
fn from(py: PyOAuthConfig) -> Self {
let flow = match py.flow.as_str() {
"client_credentials" => OAuthFlow::ClientCredentials,
"authorization_code_pkce" => OAuthFlow::AuthorizationCodePKCE {
redirect_uri: py.redirect_uri,
callback_port: py.callback_port,
},
"device_code" => OAuthFlow::DeviceCode,
"azure_managed_identity" => OAuthFlow::AzureManagedIdentity {
client_id: py.managed_identity_client_id,
},
"workload_identity" => OAuthFlow::WorkloadIdentity {
token_file: py
.token_file
.expect("token_file is required for workload_identity flow"),
},
other => panic!("Unknown OAuth flow type: {other}"),
};
OAuthConfig {
issuer_url: py.issuer_url,
client_id: py.client_id,
client_secret: py.client_secret,
scopes: py.scopes,
flow,
refresh_buffer_secs: py.refresh_buffer_secs,
}
}
}

View File

@@ -75,6 +75,11 @@ reqwest = { version = "0.12.0", default-features = false, features = [
"stream",
], optional = true }
http = { version = "1", optional = true } # Matching what is in reqwest
# OAuth dependencies (used by remote feature)
sha2 = { version = "0.10", optional = true }
base64 = { version = "0.22", optional = true }
urlencoding = { version = "2", optional = true }
open = { version = "5", optional = true }
uuid = { version = "1.7.0", features = ["v4"] }
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
polars = { version = ">=0.37,<0.40.0", optional = true }
@@ -128,7 +133,7 @@ huggingface = [
"lance-namespace-impls/dir-huggingface",
]
dynamodb = ["lance/dynamodb", "aws"]
remote = ["dep:reqwest", "dep:http", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"]
remote = ["dep:reqwest", "dep:http", "dep:sha2", "dep:base64", "dep:urlencoding", "dep:open", "lance-namespace-impls/rest", "lance-namespace-impls/rest-adapter"]
fp16kernels = ["lance-linalg/fp16kernels"]
s3-test = []
bedrock = ["dep:aws-sdk-bedrockruntime"]

View File

@@ -622,6 +622,8 @@ pub struct ConnectRequest {
pub struct ConnectBuilder {
request: ConnectRequest,
embedding_registry: Option<Arc<dyn EmbeddingRegistry>>,
#[cfg(feature = "remote")]
oauth_config: Option<crate::remote::oauth::OAuthConfig>,
}
#[cfg(feature = "remote")]
@@ -643,6 +645,8 @@ impl ConnectBuilder {
session: None,
},
embedding_registry: None,
#[cfg(feature = "remote")]
oauth_config: None,
}
}
@@ -731,6 +735,19 @@ impl ConnectBuilder {
self
}
/// Configure OAuth authentication for LanceDB Cloud/Enterprise.
///
/// This creates an [`OAuthHeaderProvider`](crate::remote::OAuthHeaderProvider)
/// from the given config and sets it as the header provider, replacing any
/// previously configured header provider or API key.
///
/// Token acquisition and refresh are handled entirely in Rust.
#[cfg(feature = "remote")]
pub fn oauth_config(mut self, config: crate::remote::oauth::OAuthConfig) -> Self {
self.oauth_config = Some(config);
self
}
/// Provide a custom [`EmbeddingRegistry`] to use for this connection.
pub fn embedding_registry(mut self, registry: Arc<dyn EmbeddingRegistry>) -> Self {
self.embedding_registry = Some(registry);
@@ -874,9 +891,29 @@ impl ConnectBuilder {
let region = options.region.ok_or_else(|| Error::InvalidInput {
message: "A region is required when connecting to LanceDb Cloud".to_string(),
})?;
let api_key = options.api_key.ok_or_else(|| Error::InvalidInput {
message: "An api_key is required when connecting to LanceDb Cloud".to_string(),
})?;
// When OAuth is configured, api_key is not required
let api_key = match (&self.oauth_config, &options.api_key) {
(Some(_), None) => String::new(),
(Some(_), Some(key)) => key.clone(),
(None, Some(key)) => key.clone(),
(None, None) => {
return Err(Error::InvalidInput {
message:
"An api_key or oauth_config is required when connecting to LanceDb Cloud"
.to_string(),
});
}
};
let mut client_config = self.request.client_config;
// Apply OAuth header provider if configured
if let Some(oauth_config) = self.oauth_config {
let provider = crate::remote::oauth::OAuthHeaderProvider::new(oauth_config)?;
client_config.header_provider =
Some(Arc::new(provider) as Arc<dyn crate::remote::client::HeaderProvider>);
}
let storage_options = StorageOptions(options.storage_options.clone());
let internal = Arc::new(crate::remote::db::RemoteDatabase::try_new(
@@ -884,7 +921,7 @@ impl ConnectBuilder {
&api_key,
&region,
options.host_override,
self.request.client_config,
client_config,
storage_options.into(),
)?);
Ok(Connection {

View File

@@ -8,6 +8,7 @@
pub(crate) mod client;
pub(crate) mod db;
pub mod oauth;
mod retry;
pub(crate) mod table;
pub(crate) mod util;
@@ -20,3 +21,4 @@ const JSON_CONTENT_TYPE: &str = "application/json";
pub use client::{ClientConfig, HeaderProvider, RetryConfig, TimeoutConfig, TlsConfig};
pub use db::{RemoteDatabaseOptions, RemoteDatabaseOptionsBuilder};
pub use oauth::{OAuthConfig, OAuthFlow, OAuthHeaderProvider};

View File

@@ -0,0 +1,906 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use log::{debug, info, warn};
use reqwest::Client;
use serde::Deserialize;
use tokio::sync::RwLock;
use crate::error::{Error, Result};
use crate::remote::client::HeaderProvider;
const DEFAULT_REFRESH_BUFFER_SECS: u64 = 300;
const DEFAULT_CALLBACK_PORT: u16 = 8400;
const AZURE_IMDS_ENDPOINT: &str = "http://169.254.169.254/metadata/identity/oauth2/token";
const AZURE_IMDS_API_VERSION: &str = "2018-02-01";
/// OAuth authentication flow configuration.
#[derive(Debug, Clone)]
pub enum OAuthFlow {
/// Client Credentials grant (service-to-service / M2M).
/// Requires `client_secret` in [`OAuthConfig`].
ClientCredentials,
/// Authorization Code with PKCE (interactive browser-based auth).
AuthorizationCodePKCE {
/// Redirect URI (default: `http://localhost:{callback_port}/callback`)
redirect_uri: Option<String>,
/// Port for the local HTTP callback server (default: 8400)
callback_port: Option<u16>,
},
/// Device Code grant (CLI / headless environments).
/// Displays a verification URL and user code for out-of-band authentication.
DeviceCode,
/// Azure Managed Identity via IMDS.
/// Works on Azure VMs, AKS, App Service, and Azure Functions.
AzureManagedIdentity {
/// Client ID for user-assigned managed identity.
/// Omit for system-assigned managed identity.
client_id: Option<String>,
},
/// Workload Identity Federation.
/// Exchanges a platform token (K8s service account, GitHub OIDC) for an IdP token.
WorkloadIdentity {
/// Path to the federated token file
/// (e.g. `AZURE_FEDERATED_TOKEN_FILE`).
token_file: String,
},
}
/// OAuth configuration for LanceDB authentication.
///
/// All token acquisition and refresh is handled in the Rust layer.
/// Python and TypeScript bindings expose this as a plain config object.
#[derive(Debug, Clone)]
pub struct OAuthConfig {
/// OIDC issuer URL or OAuth authority URL.
/// For Azure: `https://login.microsoftonline.com/{tenant_id}/v2.0`
pub issuer_url: String,
/// Application / Client ID.
pub client_id: String,
/// Client secret (required for `ClientCredentials`, optional for others).
pub client_secret: Option<String>,
/// OAuth scopes to request.
/// For Azure: `["api://{app_id}/.default"]`
pub scopes: Vec<String>,
/// Authentication flow to use.
pub flow: OAuthFlow,
/// Seconds before token expiry to trigger proactive refresh (default: 300).
pub refresh_buffer_secs: Option<u64>,
}
// -- OIDC Discovery --
#[derive(Debug, Deserialize)]
struct OidcDiscovery {
token_endpoint: String,
authorization_endpoint: Option<String>,
device_authorization_endpoint: Option<String>,
}
// -- Token Response --
#[derive(Debug, Deserialize)]
struct TokenResponse {
access_token: String,
#[serde(default)]
refresh_token: Option<String>,
/// Token lifetime in seconds.
/// Some providers (Azure IMDS) return this as a string, so we accept both.
#[serde(default, deserialize_with = "deserialize_optional_u64_or_string")]
expires_in: Option<u64>,
#[serde(default)]
#[allow(dead_code)]
token_type: Option<String>,
}
fn deserialize_optional_u64_or_string<'de, D>(
deserializer: D,
) -> std::result::Result<Option<u64>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de;
struct U64OrString;
impl<'de> de::Visitor<'de> for U64OrString {
type Value = Option<u64>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a u64, a numeric string, or null")
}
fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Self::Value, E> {
Ok(Some(v))
}
fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Self::Value, E> {
Ok(Some(v as u64))
}
fn visit_str<E: de::Error>(self, v: &str) -> std::result::Result<Self::Value, E> {
v.parse::<u64>().map(Some).map_err(de::Error::custom)
}
fn visit_none<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
Ok(None)
}
fn visit_unit<E: de::Error>(self) -> std::result::Result<Self::Value, E> {
Ok(None)
}
}
deserializer.deserialize_any(U64OrString)
}
// -- Device Code Response --
#[derive(Debug, Deserialize)]
struct DeviceCodeResponse {
device_code: String,
user_code: String,
verification_uri: String,
#[serde(default)]
verification_uri_complete: Option<String>,
expires_in: u64,
interval: Option<u64>,
}
// -- Internal Token State --
struct TokenState {
access_token: Option<String>,
refresh_token: Option<String>,
expires_at: Option<Instant>,
}
impl TokenState {
fn new() -> Self {
Self {
access_token: None,
refresh_token: None,
expires_at: None,
}
}
fn is_expired(&self, buffer: Duration) -> bool {
match (self.access_token.as_ref(), self.expires_at) {
(Some(_), Some(expires_at)) => Instant::now() + buffer >= expires_at,
(None, _) => true,
(Some(_), None) => false, // no expiry info, assume valid
}
}
fn update(&mut self, resp: &TokenResponse) {
self.access_token = Some(resp.access_token.clone());
if resp.refresh_token.is_some() {
self.refresh_token = resp.refresh_token.clone();
}
self.expires_at = resp
.expires_in
.map(|secs| Instant::now() + Duration::from_secs(secs));
}
}
/// OAuth header provider that manages the full token lifecycle.
///
/// Implements [`HeaderProvider`] to inject `Authorization: Bearer <token>`
/// headers into every LanceDB request, with automatic token refresh.
pub struct OAuthHeaderProvider {
config: OAuthConfig,
http_client: Client,
token_state: Arc<RwLock<TokenState>>,
/// Cached OIDC discovery document
discovery: Arc<RwLock<Option<OidcDiscovery>>>,
refresh_buffer: Duration,
}
impl std::fmt::Debug for OAuthHeaderProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OAuthHeaderProvider")
.field("issuer_url", &self.config.issuer_url)
.field("client_id", &self.config.client_id)
.field("flow", &self.config.flow)
.finish()
}
}
impl OAuthHeaderProvider {
/// Create a new OAuth header provider from configuration.
pub fn new(config: OAuthConfig) -> Result<Self> {
// Validate config upfront
if matches!(config.flow, OAuthFlow::ClientCredentials) && config.client_secret.is_none() {
return Err(Error::InvalidInput {
message: "client_secret is required for ClientCredentials flow".to_string(),
});
}
if config.scopes.is_empty() {
return Err(Error::InvalidInput {
message: "At least one OAuth scope is required".to_string(),
});
}
let http_client = Client::builder()
.timeout(Duration::from_secs(30))
.build()
.map_err(|e| Error::Runtime {
message: format!("Failed to create HTTP client for OAuth: {e}"),
})?;
let refresh_buffer = Duration::from_secs(
config
.refresh_buffer_secs
.unwrap_or(DEFAULT_REFRESH_BUFFER_SECS),
);
Ok(Self {
config,
http_client,
token_state: Arc::new(RwLock::new(TokenState::new())),
discovery: Arc::new(RwLock::new(None)),
refresh_buffer,
})
}
/// Get a valid access token, refreshing if necessary.
async fn get_valid_token(&self) -> Result<String> {
// Fast path: check if current token is still valid
{
let state = self.token_state.read().await;
if !state.is_expired(self.refresh_buffer)
&& let Some(ref token) = state.access_token
{
return Ok(token.clone());
}
}
// Slow path: acquire or refresh token
let mut state = self.token_state.write().await;
// Double-check after acquiring write lock
if !state.is_expired(self.refresh_buffer)
&& let Some(ref token) = state.access_token
{
return Ok(token.clone());
}
let uses_refresh_token = !matches!(
self.config.flow,
OAuthFlow::ClientCredentials
| OAuthFlow::AzureManagedIdentity { .. }
| OAuthFlow::WorkloadIdentity { .. }
);
let resp = if let Some(ref refresh_token) = state.refresh_token
&& uses_refresh_token
{
debug!("Refreshing OAuth token using refresh_token");
self.refresh_with_token(refresh_token).await?
} else {
debug!("Acquiring new OAuth token via {:?} flow", self.config.flow);
self.acquire_token().await?
};
state.update(&resp);
Ok(resp.access_token)
}
/// Acquire a new token using the configured flow.
async fn acquire_token(&self) -> Result<TokenResponse> {
match &self.config.flow {
OAuthFlow::ClientCredentials => self.acquire_client_credentials().await,
OAuthFlow::AuthorizationCodePKCE {
redirect_uri,
callback_port,
} => {
self.acquire_authorization_code_pkce(
redirect_uri.as_deref(),
callback_port.unwrap_or(DEFAULT_CALLBACK_PORT),
)
.await
}
OAuthFlow::DeviceCode => self.acquire_device_code().await,
OAuthFlow::AzureManagedIdentity { client_id } => {
self.acquire_managed_identity(client_id.as_deref()).await
}
OAuthFlow::WorkloadIdentity { token_file } => {
self.acquire_workload_identity(token_file).await
}
}
}
// -- OIDC Discovery --
async fn get_discovery(&self) -> Result<OidcDiscovery> {
{
let cached = self.discovery.read().await;
if let Some(ref disc) = *cached {
return Ok(OidcDiscovery {
token_endpoint: disc.token_endpoint.clone(),
authorization_endpoint: disc.authorization_endpoint.clone(),
device_authorization_endpoint: disc.device_authorization_endpoint.clone(),
});
}
}
let mut cache = self.discovery.write().await;
// Double-check
if let Some(ref disc) = *cache {
return Ok(OidcDiscovery {
token_endpoint: disc.token_endpoint.clone(),
authorization_endpoint: disc.authorization_endpoint.clone(),
device_authorization_endpoint: disc.device_authorization_endpoint.clone(),
});
}
let discovery_url = format!(
"{}/.well-known/openid-configuration",
self.config.issuer_url.trim_end_matches('/')
);
debug!("Fetching OIDC discovery from {}", discovery_url);
let resp = self
.http_client
.get(&discovery_url)
.send()
.await
.map_err(|e| Error::Runtime {
message: format!("Failed to fetch OIDC discovery document: {e}"),
})?;
if !resp.status().is_success() {
return Err(Error::Runtime {
message: format!(
"OIDC discovery failed with status {}: {}",
resp.status(),
resp.text().await.unwrap_or_default()
),
});
}
let disc: OidcDiscovery = resp.json().await.map_err(|e| Error::Runtime {
message: format!("Failed to parse OIDC discovery document: {e}"),
})?;
let result = OidcDiscovery {
token_endpoint: disc.token_endpoint.clone(),
authorization_endpoint: disc.authorization_endpoint.clone(),
device_authorization_endpoint: disc.device_authorization_endpoint.clone(),
};
*cache = Some(disc);
Ok(result)
}
fn get_token_endpoint_from_issuer(&self) -> String {
// Derive Azure v2.0 token endpoint from issuer URL
// issuer: https://login.microsoftonline.com/{tenant}/v2.0
// token: https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token
let base = self.config.issuer_url.trim_end_matches("/v2.0");
format!("{base}/oauth2/v2.0/token")
}
async fn get_token_endpoint(&self) -> Result<String> {
match self.get_discovery().await {
Ok(disc) => Ok(disc.token_endpoint),
Err(e) => {
warn!("OIDC discovery failed, falling back to derived endpoint: {e}");
Ok(self.get_token_endpoint_from_issuer())
}
}
}
fn scopes_string(&self) -> String {
self.config.scopes.join(" ")
}
// -- Client Credentials Flow --
async fn acquire_client_credentials(&self) -> Result<TokenResponse> {
let client_secret = self
.config
.client_secret
.as_ref()
.ok_or(Error::InvalidInput {
message: "client_secret is required for ClientCredentials flow".to_string(),
})?;
let token_endpoint = self.get_token_endpoint().await?;
let params = [
("grant_type", "client_credentials"),
("client_id", &self.config.client_id),
("client_secret", client_secret),
("scope", &self.scopes_string()),
];
self.post_token_request(&token_endpoint, &params).await
}
// -- Authorization Code + PKCE Flow --
async fn acquire_authorization_code_pkce(
&self,
redirect_uri: Option<&str>,
callback_port: u16,
) -> Result<TokenResponse> {
use rand::Rng;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener;
let discovery = self.get_discovery().await?;
let auth_endpoint = discovery.authorization_endpoint.ok_or(Error::Runtime {
message: "OIDC discovery did not provide authorization_endpoint".to_string(),
})?;
let default_redirect = format!("http://localhost:{callback_port}/callback");
let redirect = redirect_uri.unwrap_or(&default_redirect);
// Generate PKCE code verifier and challenge (S256)
const PKCE_CHARSET: &[u8] =
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
let code_verifier: String = {
let mut rng = rand::rng();
(0..128)
.map(|_| {
let idx = rng.random_range(0..PKCE_CHARSET.len());
PKCE_CHARSET[idx] as char
})
.collect()
};
let code_challenge = {
use sha2::{Digest, Sha256};
let hash = Sha256::digest(code_verifier.as_bytes());
base64_url_encode(&hash)
};
let state: String = {
let mut rng = rand::rng();
(0..32)
.map(|_| {
let idx = rng.random_range(0..16u8);
b"0123456789abcdef"[idx as usize] as char
})
.collect()
};
// Build authorization URL
let auth_url = format!(
"{auth_endpoint}?response_type=code&client_id={}&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256&state={state}",
urlencoding::encode(&self.config.client_id),
urlencoding::encode(redirect),
urlencoding::encode(&self.scopes_string()),
urlencoding::encode(&code_challenge),
);
info!("Opening browser for OAuth login...");
info!("If the browser doesn't open, visit: {auth_url}");
// Try to open browser
let _ = open::that(&auth_url);
// Start local callback server
let listener = TcpListener::bind(format!("127.0.0.1:{callback_port}"))
.await
.map_err(|e| Error::Runtime {
message: format!("Failed to bind callback server on port {callback_port}: {e}"),
})?;
info!("Waiting for OAuth callback on port {callback_port}...");
let (mut stream, _) = listener.accept().await.map_err(|e| Error::Runtime {
message: format!("Failed to accept callback connection: {e}"),
})?;
// Read the HTTP request
let mut buf = vec![0u8; 4096];
let n = tokio::io::AsyncReadExt::read(&mut stream, &mut buf)
.await
.map_err(|e| Error::Runtime {
message: format!("Failed to read callback request: {e}"),
})?;
let request_str = String::from_utf8_lossy(&buf[..n]);
// Extract authorization code from query params
let code = extract_query_param(&request_str, "code").ok_or(Error::Runtime {
message: "No authorization code in callback".to_string(),
})?;
let returned_state = extract_query_param(&request_str, "state");
if returned_state.as_deref() != Some(&state) {
return Err(Error::Runtime {
message: "OAuth state mismatch — possible CSRF attack".to_string(),
});
}
// Send success response to browser
let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n<html><body><h2>Authentication successful!</h2><p>You can close this window.</p></body></html>";
let _ = stream.write_all(response.as_bytes()).await;
// Exchange code for tokens
let token_endpoint = self.get_token_endpoint().await?;
let mut params = vec![
("grant_type", "authorization_code"),
("client_id", self.config.client_id.as_str()),
("code", &code),
("redirect_uri", redirect),
("code_verifier", &code_verifier),
];
if let Some(ref secret) = self.config.client_secret {
params.push(("client_secret", secret));
}
self.post_token_request(&token_endpoint, &params).await
}
// -- Device Code Flow --
async fn acquire_device_code(&self) -> Result<TokenResponse> {
let discovery = self.get_discovery().await?;
let device_endpoint = discovery
.device_authorization_endpoint
.ok_or(Error::Runtime {
message: "OIDC discovery did not provide device_authorization_endpoint".to_string(),
})?;
let params = [
("client_id", self.config.client_id.as_str()),
("scope", &self.scopes_string()),
];
let resp = self
.http_client
.post(&device_endpoint)
.form(&params)
.send()
.await
.map_err(|e| Error::Runtime {
message: format!("Device code request failed: {e}"),
})?;
if !resp.status().is_success() {
return Err(Error::Runtime {
message: format!(
"Device code request failed with status {}: {}",
resp.status(),
resp.text().await.unwrap_or_default()
),
});
}
let device_resp: DeviceCodeResponse = resp.json().await.map_err(|e| Error::Runtime {
message: format!("Failed to parse device code response: {e}"),
})?;
// Display instructions to user
info!(
"To sign in, visit {} and enter code: {}",
device_resp.verification_uri, device_resp.user_code
);
if let Some(ref uri) = device_resp.verification_uri_complete {
info!("Or visit: {uri}");
}
// Poll token endpoint
let token_endpoint = self.get_token_endpoint().await?;
let poll_interval = Duration::from_secs(device_resp.interval.unwrap_or(5));
let deadline = Instant::now() + Duration::from_secs(device_resp.expires_in);
loop {
if Instant::now() >= deadline {
return Err(Error::Runtime {
message: "Device code flow timed out waiting for user authentication"
.to_string(),
});
}
tokio::time::sleep(poll_interval).await;
let poll_params = [
("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
("client_id", self.config.client_id.as_str()),
("device_code", &device_resp.device_code),
];
let poll_resp = self
.http_client
.post(&token_endpoint)
.form(&poll_params)
.send()
.await
.map_err(|e| Error::Runtime {
message: format!("Device code poll failed: {e}"),
})?;
if poll_resp.status().is_success() {
return poll_resp.json().await.map_err(|e| Error::Runtime {
message: format!("Failed to parse token response: {e}"),
});
}
// Check for pending / slow_down errors
let body = poll_resp.text().await.unwrap_or_default();
if body.contains("authorization_pending") {
continue;
}
if body.contains("slow_down") {
tokio::time::sleep(Duration::from_secs(5)).await;
continue;
}
return Err(Error::Runtime {
message: format!("Device code poll failed: {body}"),
});
}
}
// -- Azure Managed Identity Flow --
async fn acquire_managed_identity(&self, mi_client_id: Option<&str>) -> Result<TokenResponse> {
let resource = self.scopes_string().replace("/.default", "");
let mut url = format!(
"{AZURE_IMDS_ENDPOINT}?api-version={AZURE_IMDS_API_VERSION}&resource={}",
urlencoding::encode(&resource),
);
if let Some(cid) = mi_client_id {
url.push_str(&format!("&client_id={}", urlencoding::encode(cid)));
}
let resp = self
.http_client
.get(&url)
.header("Metadata", "true")
.send()
.await
.map_err(|e| Error::Runtime {
message: format!("Azure IMDS request failed: {e}"),
})?;
if !resp.status().is_success() {
return Err(Error::Runtime {
message: format!(
"Azure IMDS returned status {}: {}",
resp.status(),
resp.text().await.unwrap_or_default()
),
});
}
resp.json().await.map_err(|e| Error::Runtime {
message: format!("Failed to parse IMDS token response: {e}"),
})
}
// -- Workload Identity Federation Flow --
async fn acquire_workload_identity(&self, token_file: &str) -> Result<TokenResponse> {
let federated_token =
tokio::fs::read_to_string(token_file)
.await
.map_err(|e| Error::Runtime {
message: format!("Failed to read federated token file '{token_file}': {e}"),
})?;
let token_endpoint = self.get_token_endpoint().await?;
let params = [
("grant_type", "client_credentials"),
("client_id", self.config.client_id.as_str()),
(
"client_assertion_type",
"urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
),
("client_assertion", federated_token.trim()),
("scope", &self.scopes_string()),
];
self.post_token_request(&token_endpoint, &params).await
}
// -- Refresh Token Flow --
async fn refresh_with_token(&self, refresh_token: &str) -> Result<TokenResponse> {
let token_endpoint = self.get_token_endpoint().await?;
let mut params = vec![
("grant_type", "refresh_token"),
("client_id", self.config.client_id.as_str()),
("refresh_token", refresh_token),
];
if let Some(ref secret) = self.config.client_secret {
params.push(("client_secret", secret.as_str()));
}
self.post_token_request(&token_endpoint, &params).await
}
// -- Shared Helpers --
async fn post_token_request(
&self,
endpoint: &str,
params: &[(&str, &str)],
) -> Result<TokenResponse> {
let resp = self
.http_client
.post(endpoint)
.form(params)
.send()
.await
.map_err(|e| Error::Runtime {
message: format!("Token request to {endpoint} failed: {e}"),
})?;
if !resp.status().is_success() {
return Err(Error::Runtime {
message: format!(
"Token request failed with status {}: {}",
resp.status(),
resp.text().await.unwrap_or_default()
),
});
}
resp.json().await.map_err(|e| Error::Runtime {
message: format!("Failed to parse token response: {e}"),
})
}
}
#[async_trait::async_trait]
impl HeaderProvider for OAuthHeaderProvider {
async fn get_headers(&self) -> Result<HashMap<String, String>> {
let token = self.get_valid_token().await?;
Ok(HashMap::from([(
"authorization".to_string(),
format!("Bearer {token}"),
)]))
}
}
// -- Utility functions --
fn base64_url_encode(input: &[u8]) -> String {
use base64::Engine;
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(input)
}
/// Extract a query parameter value from a raw HTTP GET request line.
fn extract_query_param(request: &str, param: &str) -> Option<String> {
let first_line = request.lines().next()?;
let path = first_line.split_whitespace().nth(1)?;
let query = path.split('?').nth(1)?;
for pair in query.split('&') {
let mut kv = pair.splitn(2, '=');
if let (Some(key), Some(value)) = (kv.next(), kv.next())
&& key == param
{
return Some(urlencoding::decode(value).ok()?.into_owned());
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_query_param() {
let request = "GET /callback?code=abc123&state=xyz HTTP/1.1\r\nHost: localhost\r\n";
assert_eq!(
extract_query_param(request, "code"),
Some("abc123".to_string())
);
assert_eq!(
extract_query_param(request, "state"),
Some("xyz".to_string())
);
assert_eq!(extract_query_param(request, "missing"), None);
}
#[test]
fn test_extract_query_param_encoded() {
let request = "GET /callback?code=abc%20123&state=x%26y HTTP/1.1\r\n";
assert_eq!(
extract_query_param(request, "code"),
Some("abc 123".to_string())
);
assert_eq!(
extract_query_param(request, "state"),
Some("x&y".to_string())
);
}
#[test]
fn test_token_state_expiry() {
let mut state = TokenState::new();
assert!(state.is_expired(Duration::from_secs(0)));
state.access_token = Some("tok".to_string());
state.expires_at = Some(Instant::now() + Duration::from_secs(600));
assert!(!state.is_expired(Duration::from_secs(300)));
assert!(state.is_expired(Duration::from_secs(601)));
}
#[test]
fn test_base64_url_encode() {
let input = b"hello world";
let encoded = base64_url_encode(input);
assert!(!encoded.contains('+'));
assert!(!encoded.contains('/'));
assert!(!encoded.contains('='));
}
#[test]
fn test_scopes_string() {
let config = OAuthConfig {
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
client_id: "app-id".to_string(),
client_secret: Some("secret".to_string()),
scopes: vec!["scope1".to_string(), "scope2".to_string()],
flow: OAuthFlow::ClientCredentials,
refresh_buffer_secs: None,
};
let provider = OAuthHeaderProvider::new(config).unwrap();
assert_eq!(provider.scopes_string(), "scope1 scope2");
}
#[test]
fn test_token_endpoint_derivation() {
let config = OAuthConfig {
issuer_url: "https://login.microsoftonline.com/my-tenant/v2.0".to_string(),
client_id: "id".to_string(),
client_secret: None,
scopes: vec!["api://test/.default".to_string()],
flow: OAuthFlow::DeviceCode,
refresh_buffer_secs: None,
};
let provider = OAuthHeaderProvider::new(config).unwrap();
assert_eq!(
provider.get_token_endpoint_from_issuer(),
"https://login.microsoftonline.com/my-tenant/oauth2/v2.0/token"
);
}
#[test]
fn test_client_credentials_requires_secret() {
let config = OAuthConfig {
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
client_id: "app-id".to_string(),
client_secret: None,
scopes: vec!["scope".to_string()],
flow: OAuthFlow::ClientCredentials,
refresh_buffer_secs: None,
};
assert!(OAuthHeaderProvider::new(config).is_err());
}
#[test]
fn test_empty_scopes_rejected() {
let config = OAuthConfig {
issuer_url: "https://login.microsoftonline.com/tenant/v2.0".to_string(),
client_id: "app-id".to_string(),
client_secret: None,
scopes: vec![],
flow: OAuthFlow::DeviceCode,
refresh_buffer_secs: None,
};
assert!(OAuthHeaderProvider::new(config).is_err());
}
}