diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index c5b4e874..83a4eacb 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -53,7 +53,10 @@ jobs: run: cargo clippy --all --all-features -- -D warnings linux: timeout-minutes: 30 - runs-on: ubuntu-22.04 + # To build all features, we need more disk space than is available + # on the GitHub-provided runner. This is mostly due to the the + # sentence-transformers feature. + runs-on: warp-ubuntu-latest-x64-4x defaults: run: shell: bash @@ -131,4 +134,3 @@ jobs: $env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT cargo build cargo test - \ No newline at end of file diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index c3f8c929..4f588f32 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -46,6 +46,11 @@ serde_with = { version = "3.8.1" } reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true } polars-arrow = { version = ">=0.37,<0.40.0", optional = true } polars = { version = ">=0.37,<0.40.0", optional = true } +hf-hub = {version = "0.3.2", optional = true} +candle-core = { version = "0.6.0", optional = true } +candle-transformers = { version = "0.6.0", optional = true } +candle-nn = { version = "0.6.0", optional = true } +tokenizers = { version = "0.19.1", optional = true} [dev-dependencies] tempfile = "3.5.0" @@ -68,8 +73,12 @@ fp16kernels = ["lance-linalg/fp16kernels"] s3-test = [] openai = ["dep:async-openai", "dep:reqwest"] polars = ["dep:polars-arrow", "dep:polars"] - +sentence-transformers = ["dep:hf-hub", "dep:candle-core", "dep:candle-transformers", "dep:candle-nn", "dep:tokenizers"] [[example]] name = "openai" required-features = ["openai"] + +[[example]] +name = "sentence_transformers" +required-features = ["sentence-transformers"] diff --git a/rust/lancedb/examples/sentence_transformers.rs b/rust/lancedb/examples/sentence_transformers.rs new file mode 100644 index 00000000..13b270c4 --- /dev/null +++ b/rust/lancedb/examples/sentence_transformers.rs @@ -0,0 +1,92 @@ +use std::{iter::once, sync::Arc}; + +use arrow_array::{RecordBatch, RecordBatchIterator, StringArray}; +use arrow_schema::{DataType, Field, Schema}; +use futures::StreamExt; +use lancedb::{ + arrow::IntoArrow, + connect, + embeddings::{ + sentence_transformers::SentenceTransformersEmbeddings, EmbeddingDefinition, + EmbeddingFunction, + }, + query::{ExecutableQuery, QueryBase}, + Result, +}; + +#[tokio::main] +async fn main() -> Result<()> { + let tempdir = tempfile::tempdir().unwrap(); + let tempdir = tempdir.path().to_str().unwrap(); + let embedding = SentenceTransformersEmbeddings::builder().build()?; + let embedding = Arc::new(embedding); + let db = connect(tempdir).execute().await?; + db.embedding_registry() + .register("sentence-transformers", embedding.clone())?; + + let table = db + .create_table("vectors", make_data()) + .add_embedding(EmbeddingDefinition::new( + "facts", + "sentence-transformers", + Some("embeddings"), + ))? + .execute() + .await?; + + let query = Arc::new(StringArray::from_iter_values(once( + "How many bones are in the human body?", + ))); + let query_vector = embedding.compute_query_embeddings(query)?; + let mut results = table + .vector_search(query_vector)? + .limit(1) + .execute() + .await?; + + let rb = results.next().await.unwrap()?; + let out = rb + .column_by_name("facts") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let text = out.iter().next().unwrap().unwrap(); + println!("Answer: {}", text); + Ok(()) +} + +fn make_data() -> impl IntoArrow { + let schema = Schema::new(vec![Field::new("facts", DataType::Utf8, false)]); + + let facts = StringArray::from_iter_values(vec![ + "Albert Einstein was a theoretical physicist.", + "The capital of France is Paris.", + "The Great Wall of China is one of the Seven Wonders of the World.", + "Python is a popular programming language.", + "Mount Everest is the highest mountain in the world.", + "Leonardo da Vinci painted the Mona Lisa.", + "Shakespeare wrote Hamlet.", + "The human body has 206 bones.", + "The speed of light is approximately 299,792 kilometers per second.", + "Water boils at 100 degrees Celsius.", + "The Earth orbits the Sun.", + "The Pyramids of Giza are located in Egypt.", + "Coffee is one of the most popular beverages in the world.", + "Tokyo is the capital city of Japan.", + "Photosynthesis is the process by which plants make their food.", + "The Pacific Ocean is the largest ocean on Earth.", + "Mozart was a prolific composer of classical music.", + "The Internet is a global network of computers.", + "Basketball is a sport played with a ball and a hoop.", + "The first computer virus was created in 1983.", + "Artificial neural networks are inspired by the human brain.", + "Deep learning is a subset of machine learning.", + "IBM's Watson won Jeopardy! in 2011.", + "The first computer programmer was Ada Lovelace.", + "The first chatbot was ELIZA, created in the 1960s.", + ]); + let schema = Arc::new(schema); + let rb = RecordBatch::try_new(schema.clone(), vec![Arc::new(facts)]).unwrap(); + Box::new(RecordBatchIterator::new(vec![Ok(rb)], schema)) +} diff --git a/rust/lancedb/src/embeddings.rs b/rust/lancedb/src/embeddings.rs index 614f2ec3..5093e931 100644 --- a/rust/lancedb/src/embeddings.rs +++ b/rust/lancedb/src/embeddings.rs @@ -14,6 +14,9 @@ #[cfg(feature = "openai")] pub mod openai; +#[cfg(feature = "sentence-transformers")] +pub mod sentence_transformers; + use lance::arrow::RecordBatchExt; use std::{ borrow::Cow, diff --git a/rust/lancedb/src/embeddings/openai.rs b/rust/lancedb/src/embeddings/openai.rs index 88528919..fea84260 100644 --- a/rust/lancedb/src/embeddings/openai.rs +++ b/rust/lancedb/src/embeddings/openai.rs @@ -176,6 +176,7 @@ impl EmbeddingFunction for OpenAIEmbeddingFunction { Ok(Arc::new(arr)) } } + impl OpenAIEmbeddingFunction { fn compute_inner(&self, source: Arc) -> Result { // OpenAI only supports non-nullable string arrays diff --git a/rust/lancedb/src/embeddings/sentence_transformers.rs b/rust/lancedb/src/embeddings/sentence_transformers.rs new file mode 100644 index 00000000..7a21130d --- /dev/null +++ b/rust/lancedb/src/embeddings/sentence_transformers.rs @@ -0,0 +1,470 @@ +use std::{borrow::Cow, sync::Arc}; + +use super::EmbeddingFunction; +use arrow::{ + array::{AsArray, PrimitiveBuilder}, + datatypes::{ + ArrowPrimitiveType, Float16Type, Float32Type, Float64Type, Int64Type, UInt32Type, UInt8Type, + }, +}; +use arrow_array::{Array, FixedSizeListArray, PrimitiveArray}; +use arrow_data::ArrayData; +use arrow_schema::DataType; +use candle_core::{CpuStorage, Device, Layout, Storage, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::models::bert::{BertModel, DTYPE}; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::{tokenizer::Tokenizer, PaddingParams}; + +/// Compute embeddings using huggingface sentence-transformers. +pub struct SentenceTransformersEmbeddingsBuilder { + /// The sentence-transformers model to use. + /// Defaults to 'all-MiniLM-L6-v2' + model: Option, + /// The device to use for computation. + /// Defaults to 'cpu' + device: Option, + /// Defaults to true + normalize: bool, + n_dims: Option, + revision: Option, + /// path to configuration file. + /// Defaults to `config.json` + config_path: Option, + /// path to tokenizer file. + /// Defaults to `tokenizer.json` + tokenizer_path: Option, + /// path to model file. + /// Defaults to `model.safetensors` + model_path: Option, + /// Padding parameters for the tokenizer. + padding: Option, +} + +pub struct SentenceTransformersEmbeddings { + model: BertModel, + tokenizer: Tokenizer, + device: Device, + n_dims: Option, +} + +impl std::fmt::Debug for SentenceTransformersEmbeddings { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SentenceTransformersEmbeddings") + .field("tokenizer", &self.tokenizer) + .field("device", &self.device) + .field("n_dims", &self.n_dims) + .finish() + } +} + +impl Default for SentenceTransformersEmbeddingsBuilder { + fn default() -> Self { + Self::new() + } +} + +impl SentenceTransformersEmbeddingsBuilder { + pub fn new() -> Self { + Self { + model: None, + device: None, + normalize: true, + n_dims: None, + revision: None, + config_path: None, + tokenizer_path: None, + model_path: None, + padding: None, + } + } + + pub fn model>(mut self, name: S) -> Self { + self.model = Some(name.into()); + self + } + + pub fn device>(mut self, device: D) -> Self { + self.device = Some(device.into()); + self + } + + pub fn normalize(mut self, normalize: bool) -> Self { + self.normalize = normalize; + self + } + + /// If you know the number of dimensions of the embeddings, you can set it here. + /// This will avoid a call to the model to determine the number of dimensions. + pub fn ndims(mut self, n_dims: usize) -> Self { + self.n_dims = Some(n_dims); + self + } + + /// If you want to use a specific revision of the model, you can set it here. + pub fn revision>(mut self, revision: S) -> Self { + self.revision = Some(revision.into()); + self + } + + /// Set the path to the configuration file. + /// Defaults to `config.json` + /// + /// Note: this is the path inside the huggingface repo, **NOT the path on disk**. + pub fn config_path>(mut self, config: S) -> Self { + self.config_path = Some(config.into()); + self + } + + /// Set the path to the tokenizer file. + /// Defaults to `tokenizer.json` + /// + /// Note: this is the path inside the huggingface repo, **NOT the path on disk**. + pub fn tokenizer_path>(mut self, tokenizer: S) -> Self { + self.tokenizer_path = Some(tokenizer.into()); + self + } + + /// Set the path inside the huggingface repo to the model file. + /// Defaults to `model.safetensors` + /// + /// Note: this is the path inside the huggingface repo, **NOT the path on disk**. + /// + /// Note: we currently only support a single model file. + pub fn model_path>(mut self, model: S) -> Self { + self.model_path = Some(model.into()); + self + } + + pub fn build(mut self) -> crate::Result { + let model_id = self.model.as_deref().unwrap_or("all-MiniLM-L6-v2"); + let model_id = format!("sentence-transformers/{}", model_id); + let config = self.config_path.as_deref().unwrap_or("config.json"); + let tokenizer = self.tokenizer_path.as_deref().unwrap_or("tokenizer.json"); + let model_path = self.model_path.as_deref().unwrap_or("model.safetensors"); + let device = self.device.unwrap_or(Device::Cpu); + + let repo = if let Some(revision) = self.revision { + Repo::with_revision(model_id, RepoType::Model, revision.to_string()) + } else { + Repo::new(model_id, RepoType::Model) + }; + + let (config_filename, tokenizer_filename, weights_filename) = { + let api = Api::new()?; + let api = api.repo(repo); + let config = api.get(config)?; + let tokenizer = api.get(tokenizer)?; + let weights = api.get(model_path)?; + + (config, tokenizer, weights) + }; + + let config = std::fs::read_to_string(config_filename) + .map_err(|e| crate::Error::Runtime { + message: format!("Error reading config file: {}", e), + }) + .and_then(|s| { + serde_json::from_str(&s).map_err(|e| crate::Error::Runtime { + message: format!("Error deserializing config file: {}", e), + }) + })?; + let mut tokenizer = + Tokenizer::from_file(tokenizer_filename).map_err(|e| crate::Error::Runtime { + message: format!("Error loading tokenizer: {}", e), + })?; + if self.padding.is_some() { + tokenizer.with_padding(self.padding.take()); + } + + let vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }; + let model = BertModel::load(vb, &config)?; + Ok(SentenceTransformersEmbeddings { + model, + tokenizer, + device, + n_dims: self.n_dims, + }) + } +} + +impl SentenceTransformersEmbeddings { + pub fn builder() -> SentenceTransformersEmbeddingsBuilder { + SentenceTransformersEmbeddingsBuilder::new() + } + + fn ndims(&self) -> crate::Result { + if let Some(n_dims) = self.n_dims { + Ok(n_dims) + } else { + Ok(self.compute_ndims_and_dtype()?.0) + } + } + + fn compute_ndims_and_dtype(&self) -> crate::Result<(usize, DataType)> { + let token = self.tokenizer.encode("hello", true).unwrap(); + let token = token.get_ids().to_vec(); + let input_ids = Tensor::new(vec![token], &self.device)?; + + let token_type_ids = input_ids.zeros_like()?; + + let embeddings = self + .model + .forward(&input_ids, &token_type_ids) + // TODO: it'd be nice to support other devices + .and_then(|output| output.to_device(&Device::Cpu))?; + + let (_, _, n_dims) = embeddings.dims3().unwrap(); + let (storage, _) = embeddings.storage_and_layout(); + let dtype = match &*storage { + Storage::Cpu(CpuStorage::U8(_)) => DataType::UInt8, + Storage::Cpu(CpuStorage::U32(_)) => DataType::UInt32, + Storage::Cpu(CpuStorage::I64(_)) => DataType::Int64, + Storage::Cpu(CpuStorage::F16(_)) => DataType::Float16, + Storage::Cpu(CpuStorage::F32(_)) => DataType::Float32, + Storage::Cpu(CpuStorage::F64(_)) => DataType::Float64, + Storage::Cpu(CpuStorage::BF16(_)) => { + return Err(crate::Error::Runtime { + message: "unsupported data type".to_string(), + }) + } + _ => unreachable!("we already moved the tensor to the CPU device"), + }; + Ok((n_dims, dtype)) + } + + fn compute_inner(&self, source: Arc) -> crate::Result<(Arc, DataType)> { + if source.is_nullable() { + return Err(crate::Error::InvalidInput { + message: "Expected non-nullable data type".to_string(), + }); + } + if !matches!(source.data_type(), DataType::Utf8 | DataType::LargeUtf8) { + return Err(crate::Error::InvalidInput { + message: "Expected Utf8 data type".to_string(), + }); + } + let check_nulls = |source: &dyn Array| { + if source.null_count() > 0 { + return Err(crate::Error::Runtime { + message: "null values not supported".to_string(), + }); + } + Ok(()) + }; + let tokens = match source.data_type() { + DataType::Utf8 => { + check_nulls(&*source)?; + source + .as_string::() + // TODO: should we do this in parallel? (e.g. using rayon) + .into_iter() + .map(|v| { + let value = v.unwrap(); + let token = self.tokenizer.encode(value, true).map_err(|e| { + crate::Error::Runtime { + message: format!("failed to encode value: {}", e), + } + })?; + let token = token.get_ids().to_vec(); + Ok(Tensor::new(token.as_slice(), &self.device)?) + }) + .collect::>>()? + } + DataType::LargeUtf8 => { + check_nulls(&*source)?; + + source + .as_string::() + // TODO: should we do this in parallel? (e.g. using rayon) + .into_iter() + .map(|v| { + let value = v.unwrap(); + let token = self.tokenizer.encode(value, true).map_err(|e| { + crate::Error::Runtime { + message: format!("failed to encode value: {}", e), + } + })?; + + let token = token.get_ids().to_vec(); + Ok(Tensor::new(token.as_slice(), &self.device)?) + }) + .collect::>>()? + } + DataType::Utf8View => { + return Err(crate::Error::Runtime { + message: "Utf8View not yet implemented".to_string(), + }) + } + _ => { + return Err(crate::Error::Runtime { + message: "invalid type".to_string(), + }) + } + }; + + let embeddings = Tensor::stack(&tokens, 0) + .and_then(|tokens| { + let token_type_ids = tokens.zeros_like()?; + self.model.forward(&tokens, &token_type_ids) + }) + // TODO: it'd be nice to support other devices + .and_then(|tokens| tokens.to_device(&Device::Cpu)) + .map_err(|e| crate::Error::Runtime { + message: format!("failed to compute embeddings: {}", e), + })?; + let (_, n_tokens, _) = embeddings.dims3().map_err(|e| crate::Error::Runtime { + message: format!("failed to get embeddings dimensions: {}", e), + })?; + + let embeddings = (embeddings.sum(1).unwrap() / (n_tokens as f64)).map_err(|e| { + crate::Error::Runtime { + message: format!("failed to compute mean embeddings: {}", e), + } + })?; + let dims = embeddings.shape().dims().len(); + let (arr, dtype): (Arc, DataType) = match dims { + 2 => { + let (d1, d2) = embeddings.dims2().map_err(|e| crate::Error::Runtime { + message: format!("failed to get embeddings dimensions: {}", e), + })?; + let (storage, layout) = embeddings.storage_and_layout(); + match &*storage { + Storage::Cpu(CpuStorage::U8(data)) => { + let data: &[u8] = data.as_slice(); + let arr = from_cpu_storage::(data, layout, &embeddings, d1, d2); + + (Arc::new(arr), DataType::UInt8) + } + Storage::Cpu(CpuStorage::U32(data)) => ( + Arc::new(from_cpu_storage::( + data, + layout, + &embeddings, + d1, + d2, + )), + DataType::UInt32, + ), + Storage::Cpu(CpuStorage::I64(data)) => ( + Arc::new(from_cpu_storage::( + data, + layout, + &embeddings, + d1, + d2, + )), + DataType::Int64, + ), + Storage::Cpu(CpuStorage::F16(data)) => ( + Arc::new(from_cpu_storage::( + data, + layout, + &embeddings, + d1, + d2, + )), + DataType::Float16, + ), + Storage::Cpu(CpuStorage::F32(data)) => ( + Arc::new(from_cpu_storage::( + data, + layout, + &embeddings, + d1, + d2, + )), + DataType::Float32, + ), + Storage::Cpu(CpuStorage::F64(data)) => ( + Arc::new(from_cpu_storage::( + data, + layout, + &embeddings, + d1, + d2, + )), + DataType::Float64, + ), + Storage::Cpu(CpuStorage::BF16(_)) => { + panic!("Unsupported storage type: BF16") + } + _ => unreachable!("Only CPU storage currently supported"), + } + } + n_dims => todo!("Only 2 dimensions supported, got {}", n_dims), + }; + Ok((arr, dtype)) + } +} + +impl EmbeddingFunction for SentenceTransformersEmbeddings { + fn name(&self) -> &str { + "sentence-transformers" + } + + fn source_type(&self) -> crate::Result> { + Ok(Cow::Owned(DataType::Utf8)) + } + + fn dest_type(&self) -> crate::Result> { + let (n_dims, dtype) = self.compute_ndims_and_dtype()?; + Ok(Cow::Owned(DataType::new_fixed_size_list( + dtype, + n_dims as i32, + false, + ))) + } + + fn compute_source_embeddings(&self, source: Arc) -> crate::Result> { + let len = source.len(); + let n_dims = self.ndims()?; + let (inner, dtype) = self.compute_inner(source)?; + + let fsl = DataType::new_fixed_size_list(dtype, n_dims as i32, false); + + // We can't use the FixedSizeListBuilder here because it always adds a null bitmap + // and we want to explicitly work with non-nullable arrays. + let array_data = ArrayData::builder(fsl) + .len(len) + .add_child_data(inner.into_data()) + .build()?; + + Ok(Arc::new(FixedSizeListArray::from(array_data))) + } + + fn compute_query_embeddings(&self, input: Arc) -> crate::Result> { + let (arr, _) = self.compute_inner(input)?; + Ok(arr) + } +} + +fn from_cpu_storage( + buffer: &[T::Native], + layout: &Layout, + embeddings: &Tensor, + dim1: usize, + dim2: usize, +) -> PrimitiveArray { + let mut builder = PrimitiveBuilder::::with_capacity(dim1 * dim2); + + match layout.contiguous_offsets() { + Some((o1, o2)) => { + let data = &buffer[o1..o2]; + builder.append_slice(data); + builder.finish() + } + None => { + let mut src_index = embeddings.strided_index(); + + for _idx_row in 0..dim1 { + let row = (0..dim2) + .map(|_| buffer[src_index.next().unwrap()]) + .collect::>(); + builder.append_slice(&row); + } + builder.finish() + } + } +} diff --git a/rust/lancedb/src/error.rs b/rust/lancedb/src/error.rs index 1e0d1511..77f2373a 100644 --- a/rust/lancedb/src/error.rs +++ b/rust/lancedb/src/error.rs @@ -125,3 +125,22 @@ impl From for Error { } } } + +#[cfg(feature = "sentence-transformers")] +impl From for Error { + fn from(source: hf_hub::api::sync::ApiError) -> Self { + Self::Other { + message: "Error in Sentence Transformers integration.".to_string(), + source: Some(Box::new(source)), + } + } +} +#[cfg(feature = "sentence-transformers")] +impl From for Error { + fn from(source: candle_core::Error) -> Self { + Self::Other { + message: "Error in 'candle_core'.".to_string(), + source: Some(Box::new(source)), + } + } +}