Compare commits

...

3 Commits

Author SHA1 Message Date
LanceDB Robot
c1c19cd133 chore: update lance dependency to v8.0.0-beta.17 (#3552)
Updates the Lance Rust workspace dependencies and Java lance-core
dependency to v8.0.0-beta.17.

No LanceDB compatibility code changes were required; validation passed
with cargo clippy and cargo fmt. Triggering Lance tag:
https://github.com/lance-format/lance/releases/tag/v8.0.0-beta.17
2026-06-17 16:08:09 -07:00
Will Jones
ce5dadd386 fix(ci): allow shell pre-commit hooks in bumpversion configs (#3554)
The "Create release commit" workflow (`make-release-commit.yml`) has
failed on its last two runs; no release tags have been created since
June 4. Since this workflow creates the tag that the cargo/npm/pypi/java
publish workflows trigger off of, all recent releases are effectively
blocked.

The workflow installs `bump-my-version` unpinned. Version `1.4.0` added
a check that refuses to run `pre_commit_hooks` containing shell syntax
(pipes, `&&`, `if`, variable expansion) unless `allow_shell_hooks =
true` is set. Both bumpversion configs use such hooks:

- `python/.bumpversion.toml` — updates `Cargo.lock` after the bump
(fails first)
- `.bumpversion.toml` — runs `mvn versions:set` for the Java packages

The job dies at the version-bump step with:

> Hook '…' contains shell syntax (pipes, redirects, or variable
expansion). Set `allow_shell_hooks = true` in your configuration to
enable shell execution…

This sets `allow_shell_hooks = true` in both configs to restore the
previous behavior.

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-17 15:22:05 -07:00
Armaan Sandhu
1f8ebef3cd fix(rust): return typed errors instead of panicking in Bedrock embedding path (#3512)
Closes #3506

## Problem

The Bedrock embedding compute path
(`rust/lancedb/src/embeddings/bedrock.rs`) panics instead of returning a
typed error in several places:

- `serde_json::to_vec(&request_body).unwrap()`: request serialization.
- `block_in_place(...).unwrap()`: the AWS `invoke_model` send result;
any API error terminates the worker instead of propagating.
- `v.as_f64().unwrap() as f32`: panics on non-numeric values in the
returned embedding array.
- `Handle::current()` + `block_in_place` assume a multi-threaded Tokio
runtime and panic when that assumption does not hold (no runtime, or a
current-thread runtime).

Malformed payloads, non-numeric embedding values, or an incompatible
runtime should surface as typed errors and never panic.

## Fix

- Serialize the request body before the blocking section so a
serialization failure returns `Error::Runtime` via `?`.
- Map the `invoke_model` send error to `Error::Runtime` instead of
`unwrap`.
- Add a `json_array_to_f32` helper that converts the response array to
`Vec<f32>`, returning `Error::Runtime` for a missing/non-array field or
a non-numeric element (used by both the Titan and Cohere paths).
- Add `current_multi_thread_handle()` (`Handle::try_current()` + a
`RuntimeFlavor::CurrentThread` guard) so an absent or incompatible
runtime returns a typed error rather than panicking in `block_in_place`.

Scope note: the sibling `openai.rs` provider uses the same
`block_in_place` + `block_on` bridge, so the bridge pattern itself is
kept; this change only removes the panic paths that are specific to the
Bedrock provider.

## Testing

Added 6 unit tests (no AWS credentials required):

- `json_array_to_f32`: valid numbers, non-array payload, and non-numeric
element.
- `current_multi_thread_handle`: errors with no runtime, errors on a
current-thread runtime, and succeeds on a multi-threaded runtime.

All pass; `cargo fmt` and `cargo clippy` clean. Build/test with
`--features bedrock,lance/protoc`.
2026-06-17 15:06:44 -07:00
6 changed files with 171 additions and 83 deletions

View File

@@ -23,6 +23,8 @@ allow_dirty = true
commit = true
message = "Bump version: {current_version} → {new_version}"
commit_args = ""
# bump-my-version >=1.4.0 rejects pre_commit_hooks containing shell syntax unless opted in.
allow_shell_hooks = true
# Java maven files
pre_commit_hooks = [

85
Cargo.lock generated
View File

@@ -3432,8 +3432,8 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "fsst"
version = "8.0.0-beta.14"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.14#c188de59fcf0976a0a9fef53ae67ae7ae8bcb61a"
version = "8.0.0-beta.17"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
dependencies = [
"arrow-array",
"rand 0.9.4",
@@ -4735,8 +4735,8 @@ checksum = "e037a2e1d8d5fdbd49b16a4ea09d5d6401c1f29eca5ff29d03d3824dba16256a"
[[package]]
name = "lance"
version = "8.0.0-beta.14"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.14#c188de59fcf0976a0a9fef53ae67ae7ae8bcb61a"
version = "8.0.0-beta.17"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
dependencies = [
"arc-swap",
"arrow",
@@ -4810,8 +4810,8 @@ dependencies = [
[[package]]
name = "lance-arrow"
version = "8.0.0-beta.14"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.14#c188de59fcf0976a0a9fef53ae67ae7ae8bcb61a"
version = "8.0.0-beta.17"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4832,7 +4832,7 @@ dependencies = [
[[package]]
name = "lance-arrow-scalar"
version = "58.0.0"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.14#c188de59fcf0976a0a9fef53ae67ae7ae8bcb61a"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4846,7 +4846,7 @@ dependencies = [
[[package]]
name = "lance-arrow-stats"
version = "58.0.0"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.14#c188de59fcf0976a0a9fef53ae67ae7ae8bcb61a"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
dependencies = [
"arrow-array",
"arrow-schema",
@@ -4855,8 +4855,8 @@ dependencies = [
[[package]]
name = "lance-bitpacking"
version = "8.0.0-beta.14"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.14#c188de59fcf0976a0a9fef53ae67ae7ae8bcb61a"
version = "8.0.0-beta.17"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
dependencies = [
"arrayref",
"paste",
@@ -4865,8 +4865,8 @@ dependencies = [
[[package]]
name = "lance-core"
version = "8.0.0-beta.14"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.14#c188de59fcf0976a0a9fef53ae67ae7ae8bcb61a"
version = "8.0.0-beta.17"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4904,8 +4904,8 @@ dependencies = [
[[package]]
name = "lance-datafusion"
version = "8.0.0-beta.14"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.14#c188de59fcf0976a0a9fef53ae67ae7ae8bcb61a"
version = "8.0.0-beta.17"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
dependencies = [
"arrow",
"arrow-array",
@@ -4935,8 +4935,8 @@ dependencies = [
[[package]]
name = "lance-datagen"
version = "8.0.0-beta.14"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.14#c188de59fcf0976a0a9fef53ae67ae7ae8bcb61a"
version = "8.0.0-beta.17"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
dependencies = [
"arrow",
"arrow-array",
@@ -4953,8 +4953,8 @@ dependencies = [
[[package]]
name = "lance-derive"
version = "8.0.0-beta.14"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.14#c188de59fcf0976a0a9fef53ae67ae7ae8bcb61a"
version = "8.0.0-beta.17"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
dependencies = [
"proc-macro2",
"quote",
@@ -4963,8 +4963,8 @@ dependencies = [
[[package]]
name = "lance-encoding"
version = "8.0.0-beta.14"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.14#c188de59fcf0976a0a9fef53ae67ae7ae8bcb61a"
version = "8.0.0-beta.17"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
dependencies = [
"arrow-arith",
"arrow-array",
@@ -4999,8 +4999,8 @@ dependencies = [
[[package]]
name = "lance-file"
version = "8.0.0-beta.14"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.14#c188de59fcf0976a0a9fef53ae67ae7ae8bcb61a"
version = "8.0.0-beta.17"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
dependencies = [
"arrow-arith",
"arrow-array",
@@ -5030,8 +5030,8 @@ dependencies = [
[[package]]
name = "lance-index"
version = "8.0.0-beta.14"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.14#c188de59fcf0976a0a9fef53ae67ae7ae8bcb61a"
version = "8.0.0-beta.17"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
dependencies = [
"arc-swap",
"arrow",
@@ -5083,6 +5083,7 @@ dependencies = [
"rand_distr 0.5.1",
"rangemap",
"rayon",
"regex-syntax",
"roaring",
"serde",
"serde_json",
@@ -5095,8 +5096,8 @@ dependencies = [
[[package]]
name = "lance-io"
version = "8.0.0-beta.14"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.14#c188de59fcf0976a0a9fef53ae67ae7ae8bcb61a"
version = "8.0.0-beta.17"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
dependencies = [
"arrow",
"arrow-arith",
@@ -5137,8 +5138,8 @@ dependencies = [
[[package]]
name = "lance-linalg"
version = "8.0.0-beta.14"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.14#c188de59fcf0976a0a9fef53ae67ae7ae8bcb61a"
version = "8.0.0-beta.17"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -5153,8 +5154,8 @@ dependencies = [
[[package]]
name = "lance-namespace"
version = "8.0.0-beta.14"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.14#c188de59fcf0976a0a9fef53ae67ae7ae8bcb61a"
version = "8.0.0-beta.17"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
dependencies = [
"arrow",
"async-trait",
@@ -5166,8 +5167,8 @@ dependencies = [
[[package]]
name = "lance-namespace-impls"
version = "8.0.0-beta.14"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.14#c188de59fcf0976a0a9fef53ae67ae7ae8bcb61a"
version = "8.0.0-beta.17"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
dependencies = [
"arrow",
"arrow-ipc",
@@ -5207,9 +5208,9 @@ dependencies = [
[[package]]
name = "lance-namespace-reqwest-client"
version = "0.8.5"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d287494559c22838ce34e51ea0fa29dc780d5be8283de5ab33e9395623000c8"
checksum = "ba3f0a235e3ed5f8805205649ccc7d7d0f3df23ce1294242c9265ad488d7f19d"
dependencies = [
"reqwest 0.12.28",
"serde",
@@ -5221,8 +5222,8 @@ dependencies = [
[[package]]
name = "lance-select"
version = "8.0.0-beta.14"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.14#c188de59fcf0976a0a9fef53ae67ae7ae8bcb61a"
version = "8.0.0-beta.17"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -5237,8 +5238,8 @@ dependencies = [
[[package]]
name = "lance-table"
version = "8.0.0-beta.14"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.14#c188de59fcf0976a0a9fef53ae67ae7ae8bcb61a"
version = "8.0.0-beta.17"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
dependencies = [
"arrow",
"arrow-array",
@@ -5277,8 +5278,8 @@ dependencies = [
[[package]]
name = "lance-testing"
version = "8.0.0-beta.14"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.14#c188de59fcf0976a0a9fef53ae67ae7ae8bcb61a"
version = "8.0.0-beta.17"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
dependencies = [
"arrow-array",
"arrow-schema",
@@ -5291,8 +5292,8 @@ dependencies = [
[[package]]
name = "lance-tokenizer"
version = "8.0.0-beta.14"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.14#c188de59fcf0976a0a9fef53ae67ae7ae8bcb61a"
version = "8.0.0-beta.17"
source = "git+https://github.com/lance-format/lance.git?tag=v8.0.0-beta.17#0f2745d10a0fe5b34a1cf214466bbc0c0d13c90c"
dependencies = [
"icu_segmenter",
"jieba-rs",

View File

@@ -13,20 +13,20 @@ categories = ["database-implementations"]
rust-version = "1.91.0"
[workspace.dependencies]
lance = { "version" = "=8.0.0-beta.14", default-features = false, "tag" = "v8.0.0-beta.14", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=8.0.0-beta.14", "tag" = "v8.0.0-beta.14", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=8.0.0-beta.14", "tag" = "v8.0.0-beta.14", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=8.0.0-beta.14", "tag" = "v8.0.0-beta.14", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=8.0.0-beta.14", default-features = false, "tag" = "v8.0.0-beta.14", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=8.0.0-beta.14", "tag" = "v8.0.0-beta.14", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=8.0.0-beta.14", "tag" = "v8.0.0-beta.14", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=8.0.0-beta.14", "tag" = "v8.0.0-beta.14", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=8.0.0-beta.14", default-features = false, "tag" = "v8.0.0-beta.14", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=8.0.0-beta.14", "tag" = "v8.0.0-beta.14", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=8.0.0-beta.14", "tag" = "v8.0.0-beta.14", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=8.0.0-beta.14", "tag" = "v8.0.0-beta.14", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=8.0.0-beta.14", "tag" = "v8.0.0-beta.14", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=8.0.0-beta.14", "tag" = "v8.0.0-beta.14", "git" = "https://github.com/lance-format/lance.git" }
lance = { "version" = "=8.0.0-beta.17", default-features = false, "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=8.0.0-beta.17", "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=8.0.0-beta.17", "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=8.0.0-beta.17", "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=8.0.0-beta.17", default-features = false, "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=8.0.0-beta.17", "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=8.0.0-beta.17", "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=8.0.0-beta.17", "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=8.0.0-beta.17", default-features = false, "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=8.0.0-beta.17", "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=8.0.0-beta.17", "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=8.0.0-beta.17", "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=8.0.0-beta.17", "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=8.0.0-beta.17", "tag" = "v8.0.0-beta.17", "git" = "https://github.com/lance-format/lance.git" }
ahash = "0.8"
# Note that this one does not include pyarrow
arrow = { version = "58.0.0", optional = false }

View File

@@ -28,7 +28,7 @@
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<arrow.version>15.0.0</arrow.version>
<lance-core.version>8.0.0-beta.14</lance-core.version>
<lance-core.version>8.0.0-beta.17</lance-core.version>
<spotless.skip>false</spotless.skip>
<spotless.version>2.30.0</spotless.version>
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>

View File

@@ -23,6 +23,8 @@ allow_dirty = true
commit = true
message = "Bump version: {current_version} → {new_version}"
commit_args = ""
# bump-my-version >=1.4.0 rejects pre_commit_hooks containing shell syntax unless opted in.
allow_shell_hooks = true
# Update Cargo.lock after version bump
pre_commit_hooks = [

View File

@@ -13,7 +13,7 @@ use serde_json::{Value, json};
use super::EmbeddingFunction;
use crate::{Error, Result};
use tokio::runtime::Handle;
use tokio::runtime::{Handle, RuntimeFlavor};
use tokio::task::block_in_place;
#[derive(Debug)]
@@ -148,6 +148,12 @@ impl BedrockEmbeddingFunction {
_ => unreachable!(),
};
// Bedrock's SDK is async but this trait method is synchronous, so we
// bridge with `block_in_place` + `block_on`. That requires a
// multi-threaded Tokio runtime; return a typed error instead of
// panicking when no compatible runtime is available.
let handle = current_multi_thread_handle()?;
for text in texts {
let request_body = match self.model {
BedrockEmbeddingModel::TitanEmbedding => {
@@ -163,24 +169,28 @@ impl BedrockEmbeddingFunction {
}
};
// Serialize before entering the blocking section so a serialization
// failure surfaces as a typed error rather than an `unwrap` panic.
let body = serde_json::to_vec(&request_body).map_err(|e| Error::Runtime {
message: format!("Failed to serialize Bedrock request: {e}"),
})?;
let client = self.client.clone();
let model_id = self.model.model_id().to_string();
let request_body = request_body.clone();
let response = block_in_place(move || {
Handle::current().block_on(async move {
let response = block_in_place(|| {
handle.block_on(async move {
client
.invoke_model()
.model_id(model_id)
.body(aws_sdk_bedrockruntime::primitives::Blob::new(
serde_json::to_vec(&request_body).unwrap(),
))
.body(aws_sdk_bedrockruntime::primitives::Blob::new(body))
.send()
.await
.map_err(Box::new)
.map_err(|e| Error::Runtime {
message: format!("Bedrock invoke_model request failed: {e}"),
})
})
})
.unwrap();
})?;
let response_json: Value =
serde_json::from_slice(response.body.as_ref()).map_err(|e| Error::Runtime {
@@ -188,22 +198,12 @@ impl BedrockEmbeddingFunction {
})?;
let embedding = match self.model {
BedrockEmbeddingModel::TitanEmbedding => response_json["embedding"]
.as_array()
.ok_or_else(|| Error::Runtime {
message: "Missing embedding in response".to_string(),
})?
.iter()
.map(|v| v.as_f64().unwrap() as f32)
.collect::<Vec<f32>>(),
BedrockEmbeddingModel::CohereLarge => response_json["embeddings"][0]
.as_array()
.ok_or_else(|| Error::Runtime {
message: "Missing embeddings in response".to_string(),
})?
.iter()
.map(|v| v.as_f64().unwrap() as f32)
.collect::<Vec<f32>>(),
BedrockEmbeddingModel::TitanEmbedding => {
json_array_to_f32(&response_json["embedding"], "embedding")?
}
BedrockEmbeddingModel::CohereLarge => {
json_array_to_f32(&response_json["embeddings"][0], "embeddings")?
}
};
builder.append_slice(&embedding);
@@ -212,3 +212,86 @@ impl BedrockEmbeddingFunction {
Ok(builder.finish())
}
}
/// Returns a handle to the current multi-threaded Tokio runtime, or a typed
/// [`Error::Runtime`] when called outside a runtime or on the current-thread
/// runtime. This keeps the synchronous-over-async bridge in
/// [`BedrockEmbeddingFunction::compute_inner`] from panicking on runtime
/// configurations that cannot support `block_in_place`.
fn current_multi_thread_handle() -> Result<Handle> {
let handle = Handle::try_current().map_err(|e| Error::Runtime {
message: format!("Bedrock embedding must be called from within a Tokio runtime: {e}"),
})?;
if handle.runtime_flavor() == RuntimeFlavor::CurrentThread {
return Err(Error::Runtime {
message: "Bedrock embedding requires a multi-threaded Tokio runtime; the \
current-thread runtime cannot use `block_in_place`"
.to_string(),
});
}
Ok(handle)
}
/// Converts a JSON value expected to be an array of numbers into `Vec<f32>`.
///
/// Returns a typed [`Error::Runtime`] (rather than panicking) when the value is
/// not an array or contains a non-numeric element, so malformed provider
/// responses degrade gracefully.
fn json_array_to_f32(value: &Value, field: &str) -> Result<Vec<f32>> {
let arr = value.as_array().ok_or_else(|| Error::Runtime {
message: format!("Missing or non-array '{field}' field in Bedrock response"),
})?;
arr.iter()
.map(|v| {
v.as_f64().map(|f| f as f32).ok_or_else(|| Error::Runtime {
message: format!("Non-numeric value in Bedrock '{field}' embedding: {v}"),
})
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn json_array_to_f32_parses_numbers() {
let v = json!([1.0, 2, -3.5]);
let out = json_array_to_f32(&v, "embedding").unwrap();
assert_eq!(out, vec![1.0_f32, 2.0, -3.5]);
}
#[test]
fn json_array_to_f32_rejects_non_array() {
// Missing field indexes to `Value::Null`; a malformed payload should be
// a typed error, not a panic.
let v = json!({"unexpected": "shape"});
let err = json_array_to_f32(&v["embedding"], "embedding").unwrap_err();
assert!(matches!(err, Error::Runtime { .. }), "got {err:?}");
}
#[test]
fn json_array_to_f32_rejects_non_numeric_element() {
let v = json!([1.0, "not-a-number", 3.0]);
let err = json_array_to_f32(&v, "embedding").unwrap_err();
assert!(matches!(err, Error::Runtime { .. }), "got {err:?}");
}
#[test]
fn handle_errors_without_runtime() {
// No Tokio runtime in scope -> typed error instead of a panic.
let err = current_multi_thread_handle().unwrap_err();
assert!(matches!(err, Error::Runtime { .. }), "got {err:?}");
}
#[tokio::test(flavor = "current_thread")]
async fn handle_errors_on_current_thread_runtime() {
let err = current_multi_thread_handle().unwrap_err();
assert!(matches!(err, Error::Runtime { .. }), "got {err:?}");
}
#[tokio::test(flavor = "multi_thread")]
async fn handle_ok_on_multi_thread_runtime() {
current_multi_thread_handle().expect("multi-threaded runtime should be accepted");
}
}