fix(rust): return typed errors instead of panicking in Bedrock embedding path (#3512)

Closes #3506

## Problem

The Bedrock embedding compute path
(`rust/lancedb/src/embeddings/bedrock.rs`) panics instead of returning a
typed error in several places:

- `serde_json::to_vec(&request_body).unwrap()`: request serialization.
- `block_in_place(...).unwrap()`: the AWS `invoke_model` send result;
any API error terminates the worker instead of propagating.
- `v.as_f64().unwrap() as f32`: panics on non-numeric values in the
returned embedding array.
- `Handle::current()` + `block_in_place` assume a multi-threaded Tokio
runtime and panic when that assumption does not hold (no runtime, or a
current-thread runtime).

Malformed payloads, non-numeric embedding values, or an incompatible
runtime should surface as typed errors and never panic.

## Fix

- Serialize the request body before the blocking section so a
serialization failure returns `Error::Runtime` via `?`.
- Map the `invoke_model` send error to `Error::Runtime` instead of
`unwrap`.
- Add a `json_array_to_f32` helper that converts the response array to
`Vec<f32>`, returning `Error::Runtime` for a missing/non-array field or
a non-numeric element (used by both the Titan and Cohere paths).
- Add `current_multi_thread_handle()` (`Handle::try_current()` + a
`RuntimeFlavor::CurrentThread` guard) so an absent or incompatible
runtime returns a typed error rather than panicking in `block_in_place`.

Scope note: the sibling `openai.rs` provider uses the same
`block_in_place` + `block_on` bridge, so the bridge pattern itself is
kept; this change only removes the panic paths that are specific to the
Bedrock provider.

## Testing

Added 6 unit tests (no AWS credentials required):

- `json_array_to_f32`: valid numbers, non-array payload, and non-numeric
element.
- `current_multi_thread_handle`: errors with no runtime, errors on a
current-thread runtime, and succeeds on a multi-threaded runtime.

All pass; `cargo fmt` and `cargo clippy` clean. Build/test with
`--features bedrock,lance/protoc`.
This commit is contained in:
Armaan Sandhu
2026-06-18 03:36:44 +05:30
committed by GitHub
parent 217fd8491d
commit 1f8ebef3cd

View File

@@ -13,7 +13,7 @@ use serde_json::{Value, json};
use super::EmbeddingFunction;
use crate::{Error, Result};
use tokio::runtime::Handle;
use tokio::runtime::{Handle, RuntimeFlavor};
use tokio::task::block_in_place;
#[derive(Debug)]
@@ -148,6 +148,12 @@ impl BedrockEmbeddingFunction {
_ => unreachable!(),
};
// Bedrock's SDK is async but this trait method is synchronous, so we
// bridge with `block_in_place` + `block_on`. That requires a
// multi-threaded Tokio runtime; return a typed error instead of
// panicking when no compatible runtime is available.
let handle = current_multi_thread_handle()?;
for text in texts {
let request_body = match self.model {
BedrockEmbeddingModel::TitanEmbedding => {
@@ -163,24 +169,28 @@ impl BedrockEmbeddingFunction {
}
};
// Serialize before entering the blocking section so a serialization
// failure surfaces as a typed error rather than an `unwrap` panic.
let body = serde_json::to_vec(&request_body).map_err(|e| Error::Runtime {
message: format!("Failed to serialize Bedrock request: {e}"),
})?;
let client = self.client.clone();
let model_id = self.model.model_id().to_string();
let request_body = request_body.clone();
let response = block_in_place(move || {
Handle::current().block_on(async move {
let response = block_in_place(|| {
handle.block_on(async move {
client
.invoke_model()
.model_id(model_id)
.body(aws_sdk_bedrockruntime::primitives::Blob::new(
serde_json::to_vec(&request_body).unwrap(),
))
.body(aws_sdk_bedrockruntime::primitives::Blob::new(body))
.send()
.await
.map_err(Box::new)
.map_err(|e| Error::Runtime {
message: format!("Bedrock invoke_model request failed: {e}"),
})
})
})
.unwrap();
})?;
let response_json: Value =
serde_json::from_slice(response.body.as_ref()).map_err(|e| Error::Runtime {
@@ -188,22 +198,12 @@ impl BedrockEmbeddingFunction {
})?;
let embedding = match self.model {
BedrockEmbeddingModel::TitanEmbedding => response_json["embedding"]
.as_array()
.ok_or_else(|| Error::Runtime {
message: "Missing embedding in response".to_string(),
})?
.iter()
.map(|v| v.as_f64().unwrap() as f32)
.collect::<Vec<f32>>(),
BedrockEmbeddingModel::CohereLarge => response_json["embeddings"][0]
.as_array()
.ok_or_else(|| Error::Runtime {
message: "Missing embeddings in response".to_string(),
})?
.iter()
.map(|v| v.as_f64().unwrap() as f32)
.collect::<Vec<f32>>(),
BedrockEmbeddingModel::TitanEmbedding => {
json_array_to_f32(&response_json["embedding"], "embedding")?
}
BedrockEmbeddingModel::CohereLarge => {
json_array_to_f32(&response_json["embeddings"][0], "embeddings")?
}
};
builder.append_slice(&embedding);
@@ -212,3 +212,86 @@ impl BedrockEmbeddingFunction {
Ok(builder.finish())
}
}
/// Returns a handle to the current multi-threaded Tokio runtime, or a typed
/// [`Error::Runtime`] when called outside a runtime or on the current-thread
/// runtime. This keeps the synchronous-over-async bridge in
/// [`BedrockEmbeddingFunction::compute_inner`] from panicking on runtime
/// configurations that cannot support `block_in_place`.
fn current_multi_thread_handle() -> Result<Handle> {
let handle = Handle::try_current().map_err(|e| Error::Runtime {
message: format!("Bedrock embedding must be called from within a Tokio runtime: {e}"),
})?;
if handle.runtime_flavor() == RuntimeFlavor::CurrentThread {
return Err(Error::Runtime {
message: "Bedrock embedding requires a multi-threaded Tokio runtime; the \
current-thread runtime cannot use `block_in_place`"
.to_string(),
});
}
Ok(handle)
}
/// Converts a JSON value expected to be an array of numbers into `Vec<f32>`.
///
/// Returns a typed [`Error::Runtime`] (rather than panicking) when the value is
/// not an array or contains a non-numeric element, so malformed provider
/// responses degrade gracefully.
fn json_array_to_f32(value: &Value, field: &str) -> Result<Vec<f32>> {
let arr = value.as_array().ok_or_else(|| Error::Runtime {
message: format!("Missing or non-array '{field}' field in Bedrock response"),
})?;
arr.iter()
.map(|v| {
v.as_f64().map(|f| f as f32).ok_or_else(|| Error::Runtime {
message: format!("Non-numeric value in Bedrock '{field}' embedding: {v}"),
})
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn json_array_to_f32_parses_numbers() {
let v = json!([1.0, 2, -3.5]);
let out = json_array_to_f32(&v, "embedding").unwrap();
assert_eq!(out, vec![1.0_f32, 2.0, -3.5]);
}
#[test]
fn json_array_to_f32_rejects_non_array() {
// Missing field indexes to `Value::Null`; a malformed payload should be
// a typed error, not a panic.
let v = json!({"unexpected": "shape"});
let err = json_array_to_f32(&v["embedding"], "embedding").unwrap_err();
assert!(matches!(err, Error::Runtime { .. }), "got {err:?}");
}
#[test]
fn json_array_to_f32_rejects_non_numeric_element() {
let v = json!([1.0, "not-a-number", 3.0]);
let err = json_array_to_f32(&v, "embedding").unwrap_err();
assert!(matches!(err, Error::Runtime { .. }), "got {err:?}");
}
#[test]
fn handle_errors_without_runtime() {
// No Tokio runtime in scope -> typed error instead of a panic.
let err = current_multi_thread_handle().unwrap_err();
assert!(matches!(err, Error::Runtime { .. }), "got {err:?}");
}
#[tokio::test(flavor = "current_thread")]
async fn handle_errors_on_current_thread_runtime() {
let err = current_multi_thread_handle().unwrap_err();
assert!(matches!(err, Error::Runtime { .. }), "got {err:?}");
}
#[tokio::test(flavor = "multi_thread")]
async fn handle_ok_on_multi_thread_runtime() {
current_multi_thread_handle().expect("multi-threaded runtime should be accepted");
}
}