mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-23 15:00:39 +00:00
feat: parallelize embedding computations (#2896)
Implement parallel execution of multiple embedding functions using
std:🧵:scope to improve performance when a table has multiple
embedding columns.
Key changes:
- Add compute_embeddings_parallel() helper method to WithEmbeddings
- Use fast path for single embeddings (no threading overhead)
- Use scoped threads for parallel execution of multiple embeddings
- Add comprehensive tests including parallelization timing verification
- Update WithEmbeddings documentation
Performance improvements:
- I/O-bound embeddings (OpenAI, Bedrock): High benefit from concurrent
API calls
- CPU-bound embeddings (sentence-transformers): Medium benefit from core
utilization
- Single embedding: No overhead (fast path)
Closes TODO on line 266 in rust/lancedb/src/embeddings.rs
This commit is contained in:
@@ -120,8 +120,13 @@ impl MemoryRegistry {
|
||||
}
|
||||
|
||||
/// A record batch reader that has embeddings applied to it
|
||||
/// This is a wrapper around another record batch reader that applies an embedding function
|
||||
/// when reading from the record batch
|
||||
///
|
||||
/// This is a wrapper around another record batch reader that applies embedding functions
|
||||
/// when reading from the record batch.
|
||||
///
|
||||
/// When multiple embedding functions are defined, they are computed in parallel using
|
||||
/// scoped threads to improve performance. For a single embedding function, computation
|
||||
/// is done inline without threading overhead.
|
||||
pub struct WithEmbeddings<R: RecordBatchReader> {
|
||||
inner: R,
|
||||
embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
|
||||
@@ -235,6 +240,48 @@ impl<R: RecordBatchReader> WithEmbeddings<R> {
|
||||
column_definitions,
|
||||
})
|
||||
}
|
||||
|
||||
fn compute_embeddings_parallel(&self, batch: &RecordBatch) -> Result<Vec<Arc<dyn Array>>> {
|
||||
if self.embeddings.len() == 1 {
|
||||
let (fld, func) = &self.embeddings[0];
|
||||
let src_column =
|
||||
batch
|
||||
.column_by_name(&fld.source_column)
|
||||
.ok_or_else(|| Error::InvalidInput {
|
||||
message: format!("Source column '{}' not found", fld.source_column),
|
||||
})?;
|
||||
return Ok(vec![func.compute_source_embeddings(src_column.clone())?]);
|
||||
}
|
||||
|
||||
// Parallel path: multiple embeddings
|
||||
std::thread::scope(|s| {
|
||||
let handles: Vec<_> = self
|
||||
.embeddings
|
||||
.iter()
|
||||
.map(|(fld, func)| {
|
||||
let src_column = batch.column_by_name(&fld.source_column).ok_or_else(|| {
|
||||
Error::InvalidInput {
|
||||
message: format!("Source column '{}' not found", fld.source_column),
|
||||
}
|
||||
})?;
|
||||
|
||||
let handle =
|
||||
s.spawn(move || func.compute_source_embeddings(src_column.clone()));
|
||||
|
||||
Ok(handle)
|
||||
})
|
||||
.collect::<Result<_>>()?;
|
||||
|
||||
handles
|
||||
.into_iter()
|
||||
.map(|h| {
|
||||
h.join().map_err(|e| Error::Runtime {
|
||||
message: format!("Thread panicked during embedding computation: {:?}", e),
|
||||
})?
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: RecordBatchReader> Iterator for MaybeEmbedded<R> {
|
||||
@@ -262,19 +309,19 @@ impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let batch = self.inner.next()?;
|
||||
match batch {
|
||||
Ok(mut batch) => {
|
||||
// todo: parallelize this
|
||||
for (fld, func) in self.embeddings.iter() {
|
||||
let src_column = batch.column_by_name(&fld.source_column).unwrap();
|
||||
let embedding = match func.compute_source_embeddings(src_column.clone()) {
|
||||
Ok(embedding) => embedding,
|
||||
Err(e) => {
|
||||
return Some(Err(arrow_schema::ArrowError::ComputeError(format!(
|
||||
"Error computing embedding: {}",
|
||||
e
|
||||
))))
|
||||
}
|
||||
};
|
||||
Ok(batch) => {
|
||||
let embeddings = match self.compute_embeddings_parallel(&batch) {
|
||||
Ok(emb) => emb,
|
||||
Err(e) => {
|
||||
return Some(Err(arrow_schema::ArrowError::ComputeError(format!(
|
||||
"Error computing embedding: {}",
|
||||
e
|
||||
))))
|
||||
}
|
||||
};
|
||||
|
||||
let mut batch = batch;
|
||||
for ((fld, _), embedding) in self.embeddings.iter().zip(embeddings.iter()) {
|
||||
let dst_field_name = fld
|
||||
.dest_column
|
||||
.clone()
|
||||
@@ -286,7 +333,7 @@ impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
|
||||
embedding.nulls().is_some(),
|
||||
);
|
||||
|
||||
match batch.try_with_column(dst_field.clone(), embedding) {
|
||||
match batch.try_with_column(dst_field.clone(), embedding.clone()) {
|
||||
Ok(b) => batch = b,
|
||||
Err(e) => return Some(Err(e)),
|
||||
};
|
||||
|
||||
253
rust/lancedb/tests/embeddings_parallel_test.rs
Normal file
253
rust/lancedb/tests/embeddings_parallel_test.rs
Normal file
@@ -0,0 +1,253 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use arrow::buffer::NullBuffer;
|
||||
use arrow_array::{
|
||||
Array, FixedSizeListArray, Float32Array, RecordBatch, RecordBatchIterator, StringArray,
|
||||
};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use lancedb::{
|
||||
embeddings::{EmbeddingDefinition, EmbeddingFunction, MaybeEmbedded, WithEmbeddings},
|
||||
Error, Result,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SlowMockEmbed {
|
||||
name: String,
|
||||
dim: usize,
|
||||
delay_ms: u64,
|
||||
call_count: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl SlowMockEmbed {
|
||||
pub fn new(name: String, dim: usize, delay_ms: u64) -> Self {
|
||||
Self {
|
||||
name,
|
||||
dim,
|
||||
delay_ms,
|
||||
call_count: Arc::new(AtomicUsize::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_call_count(&self) -> usize {
|
||||
self.call_count.load(Ordering::SeqCst)
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingFunction for SlowMockEmbed {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn source_type(&self) -> Result<Cow<'_, DataType>> {
|
||||
Ok(Cow::Owned(DataType::Utf8))
|
||||
}
|
||||
|
||||
fn dest_type(&self) -> Result<Cow<'_, DataType>> {
|
||||
Ok(Cow::Owned(DataType::new_fixed_size_list(
|
||||
DataType::Float32,
|
||||
self.dim as _,
|
||||
true,
|
||||
)))
|
||||
}
|
||||
|
||||
fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
|
||||
// Simulate slow embedding computation
|
||||
std::thread::sleep(Duration::from_millis(self.delay_ms));
|
||||
self.call_count.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
let len = source.len();
|
||||
let inner = Arc::new(Float32Array::from(vec![Some(1.0); len * self.dim]));
|
||||
let field = Field::new("item", inner.data_type().clone(), false);
|
||||
let arr = FixedSizeListArray::new(
|
||||
Arc::new(field),
|
||||
self.dim as _,
|
||||
inner,
|
||||
Some(NullBuffer::new_valid(len)),
|
||||
);
|
||||
|
||||
Ok(Arc::new(arr))
|
||||
}
|
||||
|
||||
fn compute_query_embeddings(&self, _input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
fn create_test_batch() -> Result<RecordBatch> {
|
||||
let schema = Arc::new(Schema::new(vec![Field::new("text", DataType::Utf8, false)]));
|
||||
let text = StringArray::from(vec!["hello", "world"]);
|
||||
RecordBatch::try_new(schema, vec![Arc::new(text)]).map_err(|e| Error::Runtime {
|
||||
message: format!("Failed to create test batch: {}", e),
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_embedding_fast_path() {
|
||||
// Single embedding should execute without spawning threads
|
||||
let batch = create_test_batch().unwrap();
|
||||
let schema = batch.schema();
|
||||
|
||||
let embed = Arc::new(SlowMockEmbed::new("test".to_string(), 2, 10));
|
||||
let embedding_def = EmbeddingDefinition::new("text", "test", Some("embedding"));
|
||||
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
|
||||
let embeddings = vec![(embedding_def, embed.clone() as Arc<dyn EmbeddingFunction>)];
|
||||
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
|
||||
|
||||
let result = with_embeddings.next().unwrap().unwrap();
|
||||
assert!(result.column_by_name("embedding").is_some());
|
||||
assert_eq!(embed.get_call_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_embeddings_parallel() {
|
||||
// Multiple embeddings should execute in parallel
|
||||
let batch = create_test_batch().unwrap();
|
||||
let schema = batch.schema();
|
||||
|
||||
let embed1 = Arc::new(SlowMockEmbed::new("embed1".to_string(), 2, 100));
|
||||
let embed2 = Arc::new(SlowMockEmbed::new("embed2".to_string(), 3, 100));
|
||||
let embed3 = Arc::new(SlowMockEmbed::new("embed3".to_string(), 4, 100));
|
||||
|
||||
let def1 = EmbeddingDefinition::new("text", "embed1", Some("emb1"));
|
||||
let def2 = EmbeddingDefinition::new("text", "embed2", Some("emb2"));
|
||||
let def3 = EmbeddingDefinition::new("text", "embed3", Some("emb3"));
|
||||
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
|
||||
let embeddings = vec![
|
||||
(def1, embed1.clone() as Arc<dyn EmbeddingFunction>),
|
||||
(def2, embed2.clone() as Arc<dyn EmbeddingFunction>),
|
||||
(def3, embed3.clone() as Arc<dyn EmbeddingFunction>),
|
||||
];
|
||||
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
|
||||
|
||||
let result = with_embeddings.next().unwrap().unwrap();
|
||||
|
||||
// Verify all embedding columns are present
|
||||
assert!(result.column_by_name("emb1").is_some());
|
||||
assert!(result.column_by_name("emb2").is_some());
|
||||
assert!(result.column_by_name("emb3").is_some());
|
||||
|
||||
// Verify all embeddings were computed
|
||||
assert_eq!(embed1.get_call_count(), 1);
|
||||
assert_eq!(embed2.get_call_count(), 1);
|
||||
assert_eq!(embed3.get_call_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_column_order_preserved() {
|
||||
// Verify that embedding columns are added in the same order as definitions
|
||||
let batch = create_test_batch().unwrap();
|
||||
let schema = batch.schema();
|
||||
|
||||
let embed1 = Arc::new(SlowMockEmbed::new("embed1".to_string(), 2, 10));
|
||||
let embed2 = Arc::new(SlowMockEmbed::new("embed2".to_string(), 3, 10));
|
||||
let embed3 = Arc::new(SlowMockEmbed::new("embed3".to_string(), 4, 10));
|
||||
|
||||
let def1 = EmbeddingDefinition::new("text", "embed1", Some("first"));
|
||||
let def2 = EmbeddingDefinition::new("text", "embed2", Some("second"));
|
||||
let def3 = EmbeddingDefinition::new("text", "embed3", Some("third"));
|
||||
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
|
||||
let embeddings = vec![
|
||||
(def1, embed1 as Arc<dyn EmbeddingFunction>),
|
||||
(def2, embed2 as Arc<dyn EmbeddingFunction>),
|
||||
(def3, embed3 as Arc<dyn EmbeddingFunction>),
|
||||
];
|
||||
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
|
||||
|
||||
let result = with_embeddings.next().unwrap().unwrap();
|
||||
let result_schema = result.schema();
|
||||
|
||||
// Original column is first
|
||||
assert_eq!(result_schema.field(0).name(), "text");
|
||||
// Embedding columns follow in order
|
||||
assert_eq!(result_schema.field(1).name(), "first");
|
||||
assert_eq!(result_schema.field(2).name(), "second");
|
||||
assert_eq!(result_schema.field(3).name(), "third");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_error_propagation() {
|
||||
// Test that errors from embedding computation are properly propagated
|
||||
#[derive(Debug)]
|
||||
struct FailingEmbed {
|
||||
name: String,
|
||||
}
|
||||
|
||||
impl EmbeddingFunction for FailingEmbed {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn source_type(&self) -> Result<Cow<'_, DataType>> {
|
||||
Ok(Cow::Owned(DataType::Utf8))
|
||||
}
|
||||
|
||||
fn dest_type(&self) -> Result<Cow<'_, DataType>> {
|
||||
Ok(Cow::Owned(DataType::new_fixed_size_list(
|
||||
DataType::Float32,
|
||||
2,
|
||||
true,
|
||||
)))
|
||||
}
|
||||
|
||||
fn compute_source_embeddings(&self, _source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
|
||||
Err(Error::Runtime {
|
||||
message: "Intentional failure".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn compute_query_embeddings(&self, _input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
let batch = create_test_batch().unwrap();
|
||||
let schema = batch.schema();
|
||||
|
||||
let embed = Arc::new(FailingEmbed {
|
||||
name: "failing".to_string(),
|
||||
});
|
||||
let def = EmbeddingDefinition::new("text", "failing", Some("emb"));
|
||||
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
|
||||
let embeddings = vec![(def, embed as Arc<dyn EmbeddingFunction>)];
|
||||
let mut with_embeddings = WithEmbeddings::new(reader, embeddings);
|
||||
|
||||
let result = with_embeddings.next().unwrap();
|
||||
assert!(result.is_err());
|
||||
let err_msg = format!("{}", result.err().unwrap());
|
||||
assert!(err_msg.contains("Intentional failure"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_maybe_embedded_with_no_embeddings() {
|
||||
// Test that MaybeEmbedded::No variant works correctly
|
||||
let batch = create_test_batch().unwrap();
|
||||
let schema = batch.schema();
|
||||
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch.clone())], schema.clone());
|
||||
let table_def = lancedb::table::TableDefinition {
|
||||
schema: schema.clone(),
|
||||
column_definitions: vec![lancedb::table::ColumnDefinition {
|
||||
kind: lancedb::table::ColumnKind::Physical,
|
||||
}],
|
||||
};
|
||||
|
||||
let mut maybe_embedded = MaybeEmbedded::try_new(reader, table_def, None).unwrap();
|
||||
|
||||
let result = maybe_embedded.next().unwrap().unwrap();
|
||||
assert_eq!(result.num_columns(), 1);
|
||||
assert_eq!(result.column(0).as_ref(), batch.column(0).as_ref());
|
||||
}
|
||||
Reference in New Issue
Block a user