From 1f8ebef3cd3f63b5cc3fcc8e8aaf23c8e8793618 Mon Sep 17 00:00:00 2001 From: Armaan Sandhu <74664101+Ar-maan05@users.noreply.github.com> Date: Thu, 18 Jun 2026 03:36:44 +0530 Subject: [PATCH] fix(rust): return typed errors instead of panicking in Bedrock embedding path (#3512) Closes #3506 ## Problem The Bedrock embedding compute path (`rust/lancedb/src/embeddings/bedrock.rs`) panics instead of returning a typed error in several places: - `serde_json::to_vec(&request_body).unwrap()`: request serialization. - `block_in_place(...).unwrap()`: the AWS `invoke_model` send result; any API error terminates the worker instead of propagating. - `v.as_f64().unwrap() as f32`: panics on non-numeric values in the returned embedding array. - `Handle::current()` + `block_in_place` assume a multi-threaded Tokio runtime and panic when that assumption does not hold (no runtime, or a current-thread runtime). Malformed payloads, non-numeric embedding values, or an incompatible runtime should surface as typed errors and never panic. ## Fix - Serialize the request body before the blocking section so a serialization failure returns `Error::Runtime` via `?`. - Map the `invoke_model` send error to `Error::Runtime` instead of `unwrap`. - Add a `json_array_to_f32` helper that converts the response array to `Vec`, returning `Error::Runtime` for a missing/non-array field or a non-numeric element (used by both the Titan and Cohere paths). - Add `current_multi_thread_handle()` (`Handle::try_current()` + a `RuntimeFlavor::CurrentThread` guard) so an absent or incompatible runtime returns a typed error rather than panicking in `block_in_place`. Scope note: the sibling `openai.rs` provider uses the same `block_in_place` + `block_on` bridge, so the bridge pattern itself is kept; this change only removes the panic paths that are specific to the Bedrock provider. ## Testing Added 6 unit tests (no AWS credentials required): - `json_array_to_f32`: valid numbers, non-array payload, and non-numeric element. - `current_multi_thread_handle`: errors with no runtime, errors on a current-thread runtime, and succeeds on a multi-threaded runtime. All pass; `cargo fmt` and `cargo clippy` clean. Build/test with `--features bedrock,lance/protoc`. --- rust/lancedb/src/embeddings/bedrock.rs | 135 ++++++++++++++++++++----- 1 file changed, 109 insertions(+), 26 deletions(-) diff --git a/rust/lancedb/src/embeddings/bedrock.rs b/rust/lancedb/src/embeddings/bedrock.rs index ca80de9e7..0fe832aa1 100644 --- a/rust/lancedb/src/embeddings/bedrock.rs +++ b/rust/lancedb/src/embeddings/bedrock.rs @@ -13,7 +13,7 @@ use serde_json::{Value, json}; use super::EmbeddingFunction; use crate::{Error, Result}; -use tokio::runtime::Handle; +use tokio::runtime::{Handle, RuntimeFlavor}; use tokio::task::block_in_place; #[derive(Debug)] @@ -148,6 +148,12 @@ impl BedrockEmbeddingFunction { _ => unreachable!(), }; + // Bedrock's SDK is async but this trait method is synchronous, so we + // bridge with `block_in_place` + `block_on`. That requires a + // multi-threaded Tokio runtime; return a typed error instead of + // panicking when no compatible runtime is available. + let handle = current_multi_thread_handle()?; + for text in texts { let request_body = match self.model { BedrockEmbeddingModel::TitanEmbedding => { @@ -163,24 +169,28 @@ impl BedrockEmbeddingFunction { } }; + // Serialize before entering the blocking section so a serialization + // failure surfaces as a typed error rather than an `unwrap` panic. + let body = serde_json::to_vec(&request_body).map_err(|e| Error::Runtime { + message: format!("Failed to serialize Bedrock request: {e}"), + })?; + let client = self.client.clone(); let model_id = self.model.model_id().to_string(); - let request_body = request_body.clone(); - let response = block_in_place(move || { - Handle::current().block_on(async move { + let response = block_in_place(|| { + handle.block_on(async move { client .invoke_model() .model_id(model_id) - .body(aws_sdk_bedrockruntime::primitives::Blob::new( - serde_json::to_vec(&request_body).unwrap(), - )) + .body(aws_sdk_bedrockruntime::primitives::Blob::new(body)) .send() .await - .map_err(Box::new) + .map_err(|e| Error::Runtime { + message: format!("Bedrock invoke_model request failed: {e}"), + }) }) - }) - .unwrap(); + })?; let response_json: Value = serde_json::from_slice(response.body.as_ref()).map_err(|e| Error::Runtime { @@ -188,22 +198,12 @@ impl BedrockEmbeddingFunction { })?; let embedding = match self.model { - BedrockEmbeddingModel::TitanEmbedding => response_json["embedding"] - .as_array() - .ok_or_else(|| Error::Runtime { - message: "Missing embedding in response".to_string(), - })? - .iter() - .map(|v| v.as_f64().unwrap() as f32) - .collect::>(), - BedrockEmbeddingModel::CohereLarge => response_json["embeddings"][0] - .as_array() - .ok_or_else(|| Error::Runtime { - message: "Missing embeddings in response".to_string(), - })? - .iter() - .map(|v| v.as_f64().unwrap() as f32) - .collect::>(), + BedrockEmbeddingModel::TitanEmbedding => { + json_array_to_f32(&response_json["embedding"], "embedding")? + } + BedrockEmbeddingModel::CohereLarge => { + json_array_to_f32(&response_json["embeddings"][0], "embeddings")? + } }; builder.append_slice(&embedding); @@ -212,3 +212,86 @@ impl BedrockEmbeddingFunction { Ok(builder.finish()) } } + +/// Returns a handle to the current multi-threaded Tokio runtime, or a typed +/// [`Error::Runtime`] when called outside a runtime or on the current-thread +/// runtime. This keeps the synchronous-over-async bridge in +/// [`BedrockEmbeddingFunction::compute_inner`] from panicking on runtime +/// configurations that cannot support `block_in_place`. +fn current_multi_thread_handle() -> Result { + let handle = Handle::try_current().map_err(|e| Error::Runtime { + message: format!("Bedrock embedding must be called from within a Tokio runtime: {e}"), + })?; + if handle.runtime_flavor() == RuntimeFlavor::CurrentThread { + return Err(Error::Runtime { + message: "Bedrock embedding requires a multi-threaded Tokio runtime; the \ + current-thread runtime cannot use `block_in_place`" + .to_string(), + }); + } + Ok(handle) +} + +/// Converts a JSON value expected to be an array of numbers into `Vec`. +/// +/// Returns a typed [`Error::Runtime`] (rather than panicking) when the value is +/// not an array or contains a non-numeric element, so malformed provider +/// responses degrade gracefully. +fn json_array_to_f32(value: &Value, field: &str) -> Result> { + let arr = value.as_array().ok_or_else(|| Error::Runtime { + message: format!("Missing or non-array '{field}' field in Bedrock response"), + })?; + arr.iter() + .map(|v| { + v.as_f64().map(|f| f as f32).ok_or_else(|| Error::Runtime { + message: format!("Non-numeric value in Bedrock '{field}' embedding: {v}"), + }) + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn json_array_to_f32_parses_numbers() { + let v = json!([1.0, 2, -3.5]); + let out = json_array_to_f32(&v, "embedding").unwrap(); + assert_eq!(out, vec![1.0_f32, 2.0, -3.5]); + } + + #[test] + fn json_array_to_f32_rejects_non_array() { + // Missing field indexes to `Value::Null`; a malformed payload should be + // a typed error, not a panic. + let v = json!({"unexpected": "shape"}); + let err = json_array_to_f32(&v["embedding"], "embedding").unwrap_err(); + assert!(matches!(err, Error::Runtime { .. }), "got {err:?}"); + } + + #[test] + fn json_array_to_f32_rejects_non_numeric_element() { + let v = json!([1.0, "not-a-number", 3.0]); + let err = json_array_to_f32(&v, "embedding").unwrap_err(); + assert!(matches!(err, Error::Runtime { .. }), "got {err:?}"); + } + + #[test] + fn handle_errors_without_runtime() { + // No Tokio runtime in scope -> typed error instead of a panic. + let err = current_multi_thread_handle().unwrap_err(); + assert!(matches!(err, Error::Runtime { .. }), "got {err:?}"); + } + + #[tokio::test(flavor = "current_thread")] + async fn handle_errors_on_current_thread_runtime() { + let err = current_multi_thread_handle().unwrap_err(); + assert!(matches!(err, Error::Runtime { .. }), "got {err:?}"); + } + + #[tokio::test(flavor = "multi_thread")] + async fn handle_ok_on_multi_thread_runtime() { + current_multi_thread_handle().expect("multi-threaded runtime should be accepted"); + } +}