mirror of
https://github.com/lancedb/lancedb.git
synced 2026-07-03 19:10:41 +00:00
fix(rust): return typed errors instead of panicking in Bedrock embedding path (#3512)
Closes #3506 ## Problem The Bedrock embedding compute path (`rust/lancedb/src/embeddings/bedrock.rs`) panics instead of returning a typed error in several places: - `serde_json::to_vec(&request_body).unwrap()`: request serialization. - `block_in_place(...).unwrap()`: the AWS `invoke_model` send result; any API error terminates the worker instead of propagating. - `v.as_f64().unwrap() as f32`: panics on non-numeric values in the returned embedding array. - `Handle::current()` + `block_in_place` assume a multi-threaded Tokio runtime and panic when that assumption does not hold (no runtime, or a current-thread runtime). Malformed payloads, non-numeric embedding values, or an incompatible runtime should surface as typed errors and never panic. ## Fix - Serialize the request body before the blocking section so a serialization failure returns `Error::Runtime` via `?`. - Map the `invoke_model` send error to `Error::Runtime` instead of `unwrap`. - Add a `json_array_to_f32` helper that converts the response array to `Vec<f32>`, returning `Error::Runtime` for a missing/non-array field or a non-numeric element (used by both the Titan and Cohere paths). - Add `current_multi_thread_handle()` (`Handle::try_current()` + a `RuntimeFlavor::CurrentThread` guard) so an absent or incompatible runtime returns a typed error rather than panicking in `block_in_place`. Scope note: the sibling `openai.rs` provider uses the same `block_in_place` + `block_on` bridge, so the bridge pattern itself is kept; this change only removes the panic paths that are specific to the Bedrock provider. ## Testing Added 6 unit tests (no AWS credentials required): - `json_array_to_f32`: valid numbers, non-array payload, and non-numeric element. - `current_multi_thread_handle`: errors with no runtime, errors on a current-thread runtime, and succeeds on a multi-threaded runtime. All pass; `cargo fmt` and `cargo clippy` clean. Build/test with `--features bedrock,lance/protoc`.
This commit is contained in:
@@ -13,7 +13,7 @@ use serde_json::{Value, json};
|
||||
use super::EmbeddingFunction;
|
||||
use crate::{Error, Result};
|
||||
|
||||
use tokio::runtime::Handle;
|
||||
use tokio::runtime::{Handle, RuntimeFlavor};
|
||||
use tokio::task::block_in_place;
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -148,6 +148,12 @@ impl BedrockEmbeddingFunction {
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
// Bedrock's SDK is async but this trait method is synchronous, so we
|
||||
// bridge with `block_in_place` + `block_on`. That requires a
|
||||
// multi-threaded Tokio runtime; return a typed error instead of
|
||||
// panicking when no compatible runtime is available.
|
||||
let handle = current_multi_thread_handle()?;
|
||||
|
||||
for text in texts {
|
||||
let request_body = match self.model {
|
||||
BedrockEmbeddingModel::TitanEmbedding => {
|
||||
@@ -163,24 +169,28 @@ impl BedrockEmbeddingFunction {
|
||||
}
|
||||
};
|
||||
|
||||
// Serialize before entering the blocking section so a serialization
|
||||
// failure surfaces as a typed error rather than an `unwrap` panic.
|
||||
let body = serde_json::to_vec(&request_body).map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to serialize Bedrock request: {e}"),
|
||||
})?;
|
||||
|
||||
let client = self.client.clone();
|
||||
let model_id = self.model.model_id().to_string();
|
||||
let request_body = request_body.clone();
|
||||
|
||||
let response = block_in_place(move || {
|
||||
Handle::current().block_on(async move {
|
||||
let response = block_in_place(|| {
|
||||
handle.block_on(async move {
|
||||
client
|
||||
.invoke_model()
|
||||
.model_id(model_id)
|
||||
.body(aws_sdk_bedrockruntime::primitives::Blob::new(
|
||||
serde_json::to_vec(&request_body).unwrap(),
|
||||
))
|
||||
.body(aws_sdk_bedrockruntime::primitives::Blob::new(body))
|
||||
.send()
|
||||
.await
|
||||
.map_err(Box::new)
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("Bedrock invoke_model request failed: {e}"),
|
||||
})
|
||||
})
|
||||
})
|
||||
.unwrap();
|
||||
})?;
|
||||
|
||||
let response_json: Value =
|
||||
serde_json::from_slice(response.body.as_ref()).map_err(|e| Error::Runtime {
|
||||
@@ -188,22 +198,12 @@ impl BedrockEmbeddingFunction {
|
||||
})?;
|
||||
|
||||
let embedding = match self.model {
|
||||
BedrockEmbeddingModel::TitanEmbedding => response_json["embedding"]
|
||||
.as_array()
|
||||
.ok_or_else(|| Error::Runtime {
|
||||
message: "Missing embedding in response".to_string(),
|
||||
})?
|
||||
.iter()
|
||||
.map(|v| v.as_f64().unwrap() as f32)
|
||||
.collect::<Vec<f32>>(),
|
||||
BedrockEmbeddingModel::CohereLarge => response_json["embeddings"][0]
|
||||
.as_array()
|
||||
.ok_or_else(|| Error::Runtime {
|
||||
message: "Missing embeddings in response".to_string(),
|
||||
})?
|
||||
.iter()
|
||||
.map(|v| v.as_f64().unwrap() as f32)
|
||||
.collect::<Vec<f32>>(),
|
||||
BedrockEmbeddingModel::TitanEmbedding => {
|
||||
json_array_to_f32(&response_json["embedding"], "embedding")?
|
||||
}
|
||||
BedrockEmbeddingModel::CohereLarge => {
|
||||
json_array_to_f32(&response_json["embeddings"][0], "embeddings")?
|
||||
}
|
||||
};
|
||||
|
||||
builder.append_slice(&embedding);
|
||||
@@ -212,3 +212,86 @@ impl BedrockEmbeddingFunction {
|
||||
Ok(builder.finish())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a handle to the current multi-threaded Tokio runtime, or a typed
|
||||
/// [`Error::Runtime`] when called outside a runtime or on the current-thread
|
||||
/// runtime. This keeps the synchronous-over-async bridge in
|
||||
/// [`BedrockEmbeddingFunction::compute_inner`] from panicking on runtime
|
||||
/// configurations that cannot support `block_in_place`.
|
||||
fn current_multi_thread_handle() -> Result<Handle> {
|
||||
let handle = Handle::try_current().map_err(|e| Error::Runtime {
|
||||
message: format!("Bedrock embedding must be called from within a Tokio runtime: {e}"),
|
||||
})?;
|
||||
if handle.runtime_flavor() == RuntimeFlavor::CurrentThread {
|
||||
return Err(Error::Runtime {
|
||||
message: "Bedrock embedding requires a multi-threaded Tokio runtime; the \
|
||||
current-thread runtime cannot use `block_in_place`"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
/// Converts a JSON value expected to be an array of numbers into `Vec<f32>`.
|
||||
///
|
||||
/// Returns a typed [`Error::Runtime`] (rather than panicking) when the value is
|
||||
/// not an array or contains a non-numeric element, so malformed provider
|
||||
/// responses degrade gracefully.
|
||||
fn json_array_to_f32(value: &Value, field: &str) -> Result<Vec<f32>> {
|
||||
let arr = value.as_array().ok_or_else(|| Error::Runtime {
|
||||
message: format!("Missing or non-array '{field}' field in Bedrock response"),
|
||||
})?;
|
||||
arr.iter()
|
||||
.map(|v| {
|
||||
v.as_f64().map(|f| f as f32).ok_or_else(|| Error::Runtime {
|
||||
message: format!("Non-numeric value in Bedrock '{field}' embedding: {v}"),
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn json_array_to_f32_parses_numbers() {
|
||||
let v = json!([1.0, 2, -3.5]);
|
||||
let out = json_array_to_f32(&v, "embedding").unwrap();
|
||||
assert_eq!(out, vec![1.0_f32, 2.0, -3.5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn json_array_to_f32_rejects_non_array() {
|
||||
// Missing field indexes to `Value::Null`; a malformed payload should be
|
||||
// a typed error, not a panic.
|
||||
let v = json!({"unexpected": "shape"});
|
||||
let err = json_array_to_f32(&v["embedding"], "embedding").unwrap_err();
|
||||
assert!(matches!(err, Error::Runtime { .. }), "got {err:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn json_array_to_f32_rejects_non_numeric_element() {
|
||||
let v = json!([1.0, "not-a-number", 3.0]);
|
||||
let err = json_array_to_f32(&v, "embedding").unwrap_err();
|
||||
assert!(matches!(err, Error::Runtime { .. }), "got {err:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handle_errors_without_runtime() {
|
||||
// No Tokio runtime in scope -> typed error instead of a panic.
|
||||
let err = current_multi_thread_handle().unwrap_err();
|
||||
assert!(matches!(err, Error::Runtime { .. }), "got {err:?}");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn handle_errors_on_current_thread_runtime() {
|
||||
let err = current_multi_thread_handle().unwrap_err();
|
||||
assert!(matches!(err, Error::Runtime { .. }), "got {err:?}");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn handle_ok_on_multi_thread_runtime() {
|
||||
current_multi_thread_handle().expect("multi-threaded runtime should be accepted");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user