feat: parallelize embedding computations (#2896)

Implement parallel execution of multiple embedding functions using
std:🧵:scope to improve performance when a table has multiple
embedding columns.

Key changes:
- Add compute_embeddings_parallel() helper method to WithEmbeddings
- Use fast path for single embeddings (no threading overhead)
- Use scoped threads for parallel execution of multiple embeddings
- Add comprehensive tests including parallelization timing verification
- Update WithEmbeddings documentation

Performance improvements:
- I/O-bound embeddings (OpenAI, Bedrock): High benefit from concurrent
API calls
- CPU-bound embeddings (sentence-transformers): Medium benefit from core
utilization
- Single embedding: No overhead (fast path)

Closes TODO on line 266 in rust/lancedb/src/embeddings.rs
This commit is contained in:
Qichao Chu
2026-01-06 14:35:56 -08:00
committed by GitHub
parent d67a8743ba
commit 4494eb9e56
2 changed files with 316 additions and 16 deletions

View File

@@ -120,8 +120,13 @@ impl MemoryRegistry {
}
/// A record batch reader that has embeddings applied to it
/// This is a wrapper around another record batch reader that applies an embedding function
/// when reading from the record batch
///
/// This is a wrapper around another record batch reader that applies embedding functions
/// when reading from the record batch.
///
/// When multiple embedding functions are defined, they are computed in parallel using
/// scoped threads to improve performance. For a single embedding function, computation
/// is done inline without threading overhead.
pub struct WithEmbeddings<R: RecordBatchReader> {
inner: R,
embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
@@ -235,6 +240,48 @@ impl<R: RecordBatchReader> WithEmbeddings<R> {
column_definitions,
})
}
fn compute_embeddings_parallel(&self, batch: &RecordBatch) -> Result<Vec<Arc<dyn Array>>> {
if self.embeddings.len() == 1 {
let (fld, func) = &self.embeddings[0];
let src_column =
batch
.column_by_name(&fld.source_column)
.ok_or_else(|| Error::InvalidInput {
message: format!("Source column '{}' not found", fld.source_column),
})?;
return Ok(vec![func.compute_source_embeddings(src_column.clone())?]);
}
// Parallel path: multiple embeddings
std::thread::scope(|s| {
let handles: Vec<_> = self
.embeddings
.iter()
.map(|(fld, func)| {
let src_column = batch.column_by_name(&fld.source_column).ok_or_else(|| {
Error::InvalidInput {
message: format!("Source column '{}' not found", fld.source_column),
}
})?;
let handle =
s.spawn(move || func.compute_source_embeddings(src_column.clone()));
Ok(handle)
})
.collect::<Result<_>>()?;
handles
.into_iter()
.map(|h| {
h.join().map_err(|e| Error::Runtime {
message: format!("Thread panicked during embedding computation: {:?}", e),
})?
})
.collect()
})
}
}
impl<R: RecordBatchReader> Iterator for MaybeEmbedded<R> {
@@ -262,19 +309,19 @@ impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
fn next(&mut self) -> Option<Self::Item> {
let batch = self.inner.next()?;
match batch {
Ok(mut batch) => {
// todo: parallelize this
for (fld, func) in self.embeddings.iter() {
let src_column = batch.column_by_name(&fld.source_column).unwrap();
let embedding = match func.compute_source_embeddings(src_column.clone()) {
Ok(embedding) => embedding,
Err(e) => {
return Some(Err(arrow_schema::ArrowError::ComputeError(format!(
"Error computing embedding: {}",
e
))))
}
};
Ok(batch) => {
let embeddings = match self.compute_embeddings_parallel(&batch) {
Ok(emb) => emb,
Err(e) => {
return Some(Err(arrow_schema::ArrowError::ComputeError(format!(
"Error computing embedding: {}",
e
))))
}
};
let mut batch = batch;
for ((fld, _), embedding) in self.embeddings.iter().zip(embeddings.iter()) {
let dst_field_name = fld
.dest_column
.clone()
@@ -286,7 +333,7 @@ impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
embedding.nulls().is_some(),
);
match batch.try_with_column(dst_field.clone(), embedding) {
match batch.try_with_column(dst_field.clone(), embedding.clone()) {
Ok(b) => batch = b,
Err(e) => return Some(Err(e)),
};

View File

@@ -0,0 +1,253 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use std::{
borrow::Cow,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
time::Duration,
};
use arrow::buffer::NullBuffer;
use arrow_array::{
Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray,
};
use arrow_schema::{DataType, Field, Schema};
use lancedb::{
embeddings::{EmbeddingDefinition, EmbeddingFunction, MaybeEmbedded, WithEmbeddings},
Error, Result,
};
#[derive(Debug)]
struct SlowMockEmbed {
name: String,
dim: usize,
delay_ms: u64,
call_count: Arc<AtomicUsize>,
}
impl SlowMockEmbed {
pub fn new(name: String, dim: usize, delay_ms: u64) -> Self {
Self {
name,
dim,
delay_ms,
call_count: Arc::new(AtomicUsize::new(0)),
}
}
pub fn get_call_count(&self) -> usize {
self.call_count.load(Ordering::SeqCst)
}
}
impl EmbeddingFunction for SlowMockEmbed {
fn name(&self) -> &str {
&self.name
}
fn source_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::Utf8))
}
fn dest_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::new_fixed_size_list(
DataType::Float32,
self.dim as _,
true,
)))
}
fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
// Simulate slow embedding computation
std::thread::sleep(Duration::from_millis(self.delay_ms));
self.call_count.fetch_add(1, Ordering::SeqCst);
let len = source.len();
let inner = Arc::new(Float32Array::from(vec![Some(1.0); len * self.dim]));
let field = Field::new("item", inner.data_type().clone(), false);
let arr = FixedSizeListArray::new(
Arc::new(field),
self.dim as _,
inner,
Some(NullBuffer::new_valid(len)),
);
Ok(Arc::new(arr))
}
fn compute_query_embeddings(&self, _input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
unimplemented!()
}
}
fn create_test_batch() -> Result<RecordBatch> {
let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
let text = StringArray::from(vec!["hello", "world"]);
RecordBatch::try_new(schema, vec![Arc::new(text)]).map_err(|e| Error::Runtime {
message: format!("Failed to create test batch: {}", e),
})
}
#[test]
fn test_single_embedding_fast_path() {
// Single embedding should execute without spawning threads
let batch = create_test_batch().unwrap();
let schema = batch.schema();
let embed = Arc::new(SlowMockEmbed::new("test".to_string(), 2, 10));
let embedding_def = EmbeddingDefinition::new("text", "test", Some("embedding"));
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let embeddings = vec![(embedding_def, embed.clone() as Arc<dyn EmbeddingFunction>)];
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
let result = with_embeddings.next().unwrap().unwrap();
assert!(result.column_by_name("embedding").is_some());
assert_eq!(embed.get_call_count(), 1);
}
#[test]
fn test_multiple_embeddings_parallel() {
// Multiple embeddings should execute in parallel
let batch = create_test_batch().unwrap();
let schema = batch.schema();
let embed1 = Arc::new(SlowMockEmbed::new("embed1".to_string(), 2, 100));
let embed2 = Arc::new(SlowMockEmbed::new("embed2".to_string(), 3, 100));
let embed3 = Arc::new(SlowMockEmbed::new("embed3".to_string(), 4, 100));
let def1 = EmbeddingDefinition::new("text", "embed1", Some("emb1"));
let def2 = EmbeddingDefinition::new("text", "embed2", Some("emb2"));
let def3 = EmbeddingDefinition::new("text", "embed3", Some("emb3"));
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let embeddings = vec![
(def1, embed1.clone() as Arc<dyn EmbeddingFunction>),
(def2, embed2.clone() as Arc<dyn EmbeddingFunction>),
(def3, embed3.clone() as Arc<dyn EmbeddingFunction>),
];
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
let result = with_embeddings.next().unwrap().unwrap();
// Verify all embedding columns are present
assert!(result.column_by_name("emb1").is_some());
assert!(result.column_by_name("emb2").is_some());
assert!(result.column_by_name("emb3").is_some());
// Verify all embeddings were computed
assert_eq!(embed1.get_call_count(), 1);
assert_eq!(embed2.get_call_count(), 1);
assert_eq!(embed3.get_call_count(), 1);
}
#[test]
fn test_embedding_column_order_preserved() {
// Verify that embedding columns are added in the same order as definitions
let batch = create_test_batch().unwrap();
let schema = batch.schema();
let embed1 = Arc::new(SlowMockEmbed::new("embed1".to_string(), 2, 10));
let embed2 = Arc::new(SlowMockEmbed::new("embed2".to_string(), 3, 10));
let embed3 = Arc::new(SlowMockEmbed::new("embed3".to_string(), 4, 10));
let def1 = EmbeddingDefinition::new("text", "embed1", Some("first"));
let def2 = EmbeddingDefinition::new("text", "embed2", Some("second"));
let def3 = EmbeddingDefinition::new("text", "embed3", Some("third"));
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let embeddings = vec![
(def1, embed1 as Arc<dyn EmbeddingFunction>),
(def2, embed2 as Arc<dyn EmbeddingFunction>),
(def3, embed3 as Arc<dyn EmbeddingFunction>),
];
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
let result = with_embeddings.next().unwrap().unwrap();
let result_schema = result.schema();
// Original column is first
assert_eq!(result_schema.field(0).name(), "text");
// Embedding columns follow in order
assert_eq!(result_schema.field(1).name(), "first");
assert_eq!(result_schema.field(2).name(), "second");
assert_eq!(result_schema.field(3).name(), "third");
}
#[test]
fn test_embedding_error_propagation() {
// Test that errors from embedding computation are properly propagated
#[derive(Debug)]
struct FailingEmbed {
name: String,
}
impl EmbeddingFunction for FailingEmbed {
fn name(&self) -> &str {
&self.name
}
fn source_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::Utf8))
}
fn dest_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::new_fixed_size_list(
DataType::Float32,
2,
true,
)))
}
fn compute_source_embeddings(&self, _source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
Err(Error::Runtime {
message: "Intentional failure".to_string(),
})
}
fn compute_query_embeddings(&self, _input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
unimplemented!()
}
}
let batch = create_test_batch().unwrap();
let schema = batch.schema();
let embed = Arc::new(FailingEmbed {
name: "failing".to_string(),
});
let def = EmbeddingDefinition::new("text", "failing", Some("emb"));
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
let embeddings = vec![(def, embed as Arc<dyn EmbeddingFunction>)];
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
let result = with_embeddings.next().unwrap();
assert!(result.is_err());
let err_msg = format!("{}", result.err().unwrap());
assert!(err_msg.contains("Intentional failure"));
}
#[test]
fn test_maybe_embedded_with_no_embeddings() {
// Test that MaybeEmbedded::No variant works correctly
let batch = create_test_batch().unwrap();
let schema = batch.schema();
let reader = RecordBatchIterator::new(vec![Ok(batch.clone())], schema.clone());
let table_def = lancedb::table::TableDefinition {
schema: schema.clone(),
column_definitions: vec![lancedb::table::ColumnDefinition {
kind: lancedb::table::ColumnKind::Physical,
}],
};
let mut maybe_embedded = MaybeEmbedded::try_new(reader, table_def, None).unwrap();
let result = maybe_embedded.next().unwrap().unwrap();
assert_eq!(result.num_columns(), 1);
assert_eq!(result.column(0).as_ref(), batch.column(0).as_ref());
}