feat(rust): huggingface sentence-transformers (#1447)

Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
Cory Grinstead
2024-07-22 13:47:57 -05:00
committed by GitHub
parent c9c61eb060
commit 391fa26175
7 changed files with 599 additions and 3 deletions

View File

@@ -53,7 +53,10 @@ jobs:
run: cargo clippy --all --all-features -- -D warnings
linux:
timeout-minutes: 30
runs-on: ubuntu-22.04
# To build all features, we need more disk space than is available
# on the GitHub-provided runner. This is mostly due to the the
# sentence-transformers feature.
runs-on: warp-ubuntu-latest-x64-4x
defaults:
run:
shell: bash
@@ -131,4 +134,3 @@ jobs:
$env:VCPKG_ROOT = $env:VCPKG_INSTALLATION_ROOT
cargo build
cargo test

View File

@@ -46,6 +46,11 @@ serde_with = { version = "3.8.1" }
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
polars = { version = ">=0.37,<0.40.0", optional = true }
hf-hub = {version = "0.3.2", optional = true}
candle-core = { version = "0.6.0", optional = true }
candle-transformers = { version = "0.6.0", optional = true }
candle-nn = { version = "0.6.0", optional = true }
tokenizers = { version = "0.19.1", optional = true}
[dev-dependencies]
tempfile = "3.5.0"
@@ -68,8 +73,12 @@ fp16kernels = ["lance-linalg/fp16kernels"]
s3-test = []
openai = ["dep:async-openai", "dep:reqwest"]
polars = ["dep:polars-arrow", "dep:polars"]
sentence-transformers = ["dep:hf-hub", "dep:candle-core", "dep:candle-transformers", "dep:candle-nn", "dep:tokenizers"]
[[example]]
name = "openai"
required-features = ["openai"]
[[example]]
name = "sentence_transformers"
required-features = ["sentence-transformers"]

View File

@@ -0,0 +1,92 @@
use std::{iter::once, sync::Arc};
use arrow_array::{RecordBatch, RecordBatchIterator, StringArray};
use arrow_schema::{DataType, Field, Schema};
use futures::StreamExt;
use lancedb::{
arrow::IntoArrow,
connect,
embeddings::{
sentence_transformers::SentenceTransformersEmbeddings, EmbeddingDefinition,
EmbeddingFunction,
},
query::{ExecutableQuery, QueryBase},
Result,
};
#[tokio::main]
async fn main() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();
let tempdir = tempdir.path().to_str().unwrap();
let embedding = SentenceTransformersEmbeddings::builder().build()?;
let embedding = Arc::new(embedding);
let db = connect(tempdir).execute().await?;
db.embedding_registry()
.register("sentence-transformers", embedding.clone())?;
let table = db
.create_table("vectors", make_data())
.add_embedding(EmbeddingDefinition::new(
"facts",
"sentence-transformers",
Some("embeddings"),
))?
.execute()
.await?;
let query = Arc::new(StringArray::from_iter_values(once(
"How many bones are in the human body?",
)));
let query_vector = embedding.compute_query_embeddings(query)?;
let mut results = table
.vector_search(query_vector)?
.limit(1)
.execute()
.await?;
let rb = results.next().await.unwrap()?;
let out = rb
.column_by_name("facts")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let text = out.iter().next().unwrap().unwrap();
println!("Answer: {}", text);
Ok(())
}
fn make_data() -> impl IntoArrow {
let schema = Schema::new(vec![Field::new("facts", DataType::Utf8, false)]);
let facts = StringArray::from_iter_values(vec![
"Albert Einstein was a theoretical physicist.",
"The capital of France is Paris.",
"The Great Wall of China is one of the Seven Wonders of the World.",
"Python is a popular programming language.",
"Mount Everest is the highest mountain in the world.",
"Leonardo da Vinci painted the Mona Lisa.",
"Shakespeare wrote Hamlet.",
"The human body has 206 bones.",
"The speed of light is approximately 299,792 kilometers per second.",
"Water boils at 100 degrees Celsius.",
"The Earth orbits the Sun.",
"The Pyramids of Giza are located in Egypt.",
"Coffee is one of the most popular beverages in the world.",
"Tokyo is the capital city of Japan.",
"Photosynthesis is the process by which plants make their food.",
"The Pacific Ocean is the largest ocean on Earth.",
"Mozart was a prolific composer of classical music.",
"The Internet is a global network of computers.",
"Basketball is a sport played with a ball and a hoop.",
"The first computer virus was created in 1983.",
"Artificial neural networks are inspired by the human brain.",
"Deep learning is a subset of machine learning.",
"IBM's Watson won Jeopardy! in 2011.",
"The first computer programmer was Ada Lovelace.",
"The first chatbot was ELIZA, created in the 1960s.",
]);
let schema = Arc::new(schema);
let rb = RecordBatch::try_new(schema.clone(), vec![Arc::new(facts)]).unwrap();
Box::new(RecordBatchIterator::new(vec![Ok(rb)], schema))
}

View File

@@ -14,6 +14,9 @@
#[cfg(feature = "openai")]
pub mod openai;
#[cfg(feature = "sentence-transformers")]
pub mod sentence_transformers;
use lance::arrow::RecordBatchExt;
use std::{
borrow::Cow,

View File

@@ -176,6 +176,7 @@ impl EmbeddingFunction for OpenAIEmbeddingFunction {
Ok(Arc::new(arr))
}
}
impl OpenAIEmbeddingFunction {
fn compute_inner(&self, source: Arc<dyn Array>) -> Result<Float32Array> {
// OpenAI only supports non-nullable string arrays

View File

@@ -0,0 +1,470 @@
use std::{borrow::Cow, sync::Arc};
use super::EmbeddingFunction;
use arrow::{
array::{AsArray, PrimitiveBuilder},
datatypes::{
ArrowPrimitiveType, Float16Type, Float32Type, Float64Type, Int64Type, UInt32Type, UInt8Type,
},
};
use arrow_array::{Array, FixedSizeListArray, PrimitiveArray};
use arrow_data::ArrayData;
use arrow_schema::DataType;
use candle_core::{CpuStorage, Device, Layout, Storage, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, DTYPE};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{tokenizer::Tokenizer, PaddingParams};
/// Compute embeddings using huggingface sentence-transformers.
pub struct SentenceTransformersEmbeddingsBuilder {
/// The sentence-transformers model to use.
/// Defaults to 'all-MiniLM-L6-v2'
model: Option<String>,
/// The device to use for computation.
/// Defaults to 'cpu'
device: Option<Device>,
/// Defaults to true
normalize: bool,
n_dims: Option<usize>,
revision: Option<String>,
/// path to configuration file.
/// Defaults to `config.json`
config_path: Option<String>,
/// path to tokenizer file.
/// Defaults to `tokenizer.json`
tokenizer_path: Option<String>,
/// path to model file.
/// Defaults to `model.safetensors`
model_path: Option<String>,
/// Padding parameters for the tokenizer.
padding: Option<PaddingParams>,
}
pub struct SentenceTransformersEmbeddings {
model: BertModel,
tokenizer: Tokenizer,
device: Device,
n_dims: Option<usize>,
}
impl std::fmt::Debug for SentenceTransformersEmbeddings {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SentenceTransformersEmbeddings")
.field("tokenizer", &self.tokenizer)
.field("device", &self.device)
.field("n_dims", &self.n_dims)
.finish()
}
}
impl Default for SentenceTransformersEmbeddingsBuilder {
fn default() -> Self {
Self::new()
}
}
impl SentenceTransformersEmbeddingsBuilder {
pub fn new() -> Self {
Self {
model: None,
device: None,
normalize: true,
n_dims: None,
revision: None,
config_path: None,
tokenizer_path: None,
model_path: None,
padding: None,
}
}
pub fn model<S: Into<String>>(mut self, name: S) -> Self {
self.model = Some(name.into());
self
}
pub fn device<D: Into<Device>>(mut self, device: D) -> Self {
self.device = Some(device.into());
self
}
pub fn normalize(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
/// If you know the number of dimensions of the embeddings, you can set it here.
/// This will avoid a call to the model to determine the number of dimensions.
pub fn ndims(mut self, n_dims: usize) -> Self {
self.n_dims = Some(n_dims);
self
}
/// If you want to use a specific revision of the model, you can set it here.
pub fn revision<S: Into<String>>(mut self, revision: S) -> Self {
self.revision = Some(revision.into());
self
}
/// Set the path to the configuration file.
/// Defaults to `config.json`
///
/// Note: this is the path inside the huggingface repo, **NOT the path on disk**.
pub fn config_path<S: Into<String>>(mut self, config: S) -> Self {
self.config_path = Some(config.into());
self
}
/// Set the path to the tokenizer file.
/// Defaults to `tokenizer.json`
///
/// Note: this is the path inside the huggingface repo, **NOT the path on disk**.
pub fn tokenizer_path<S: Into<String>>(mut self, tokenizer: S) -> Self {
self.tokenizer_path = Some(tokenizer.into());
self
}
/// Set the path inside the huggingface repo to the model file.
/// Defaults to `model.safetensors`
///
/// Note: this is the path inside the huggingface repo, **NOT the path on disk**.
///
/// Note: we currently only support a single model file.
pub fn model_path<S: Into<String>>(mut self, model: S) -> Self {
self.model_path = Some(model.into());
self
}
pub fn build(mut self) -> crate::Result<SentenceTransformersEmbeddings> {
let model_id = self.model.as_deref().unwrap_or("all-MiniLM-L6-v2");
let model_id = format!("sentence-transformers/{}", model_id);
let config = self.config_path.as_deref().unwrap_or("config.json");
let tokenizer = self.tokenizer_path.as_deref().unwrap_or("tokenizer.json");
let model_path = self.model_path.as_deref().unwrap_or("model.safetensors");
let device = self.device.unwrap_or(Device::Cpu);
let repo = if let Some(revision) = self.revision {
Repo::with_revision(model_id, RepoType::Model, revision.to_string())
} else {
Repo::new(model_id, RepoType::Model)
};
let (config_filename, tokenizer_filename, weights_filename) = {
let api = Api::new()?;
let api = api.repo(repo);
let config = api.get(config)?;
let tokenizer = api.get(tokenizer)?;
let weights = api.get(model_path)?;
(config, tokenizer, weights)
};
let config = std::fs::read_to_string(config_filename)
.map_err(|e| crate::Error::Runtime {
message: format!("Error reading config file: {}", e),
})
.and_then(|s| {
serde_json::from_str(&s).map_err(|e| crate::Error::Runtime {
message: format!("Error deserializing config file: {}", e),
})
})?;
let mut tokenizer =
Tokenizer::from_file(tokenizer_filename).map_err(|e| crate::Error::Runtime {
message: format!("Error loading tokenizer: {}", e),
})?;
if self.padding.is_some() {
tokenizer.with_padding(self.padding.take());
}
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? };
let model = BertModel::load(vb, &config)?;
Ok(SentenceTransformersEmbeddings {
model,
tokenizer,
device,
n_dims: self.n_dims,
})
}
}
impl SentenceTransformersEmbeddings {
pub fn builder() -> SentenceTransformersEmbeddingsBuilder {
SentenceTransformersEmbeddingsBuilder::new()
}
fn ndims(&self) -> crate::Result<usize> {
if let Some(n_dims) = self.n_dims {
Ok(n_dims)
} else {
Ok(self.compute_ndims_and_dtype()?.0)
}
}
fn compute_ndims_and_dtype(&self) -> crate::Result<(usize, DataType)> {
let token = self.tokenizer.encode("hello", true).unwrap();
let token = token.get_ids().to_vec();
let input_ids = Tensor::new(vec![token], &self.device)?;
let token_type_ids = input_ids.zeros_like()?;
let embeddings = self
.model
.forward(&input_ids, &token_type_ids)
// TODO: it'd be nice to support other devices
.and_then(|output| output.to_device(&Device::Cpu))?;
let (_, _, n_dims) = embeddings.dims3().unwrap();
let (storage, _) = embeddings.storage_and_layout();
let dtype = match &*storage {
Storage::Cpu(CpuStorage::U8(_)) => DataType::UInt8,
Storage::Cpu(CpuStorage::U32(_)) => DataType::UInt32,
Storage::Cpu(CpuStorage::I64(_)) => DataType::Int64,
Storage::Cpu(CpuStorage::F16(_)) => DataType::Float16,
Storage::Cpu(CpuStorage::F32(_)) => DataType::Float32,
Storage::Cpu(CpuStorage::F64(_)) => DataType::Float64,
Storage::Cpu(CpuStorage::BF16(_)) => {
return Err(crate::Error::Runtime {
message: "unsupported data type".to_string(),
})
}
_ => unreachable!("we already moved the tensor to the CPU device"),
};
Ok((n_dims, dtype))
}
fn compute_inner(&self, source: Arc<dyn Array>) -> crate::Result<(Arc<dyn Array>, DataType)> {
if source.is_nullable() {
return Err(crate::Error::InvalidInput {
message: "Expected non-nullable data type".to_string(),
});
}
if !matches!(source.data_type(), DataType::Utf8 | DataType::LargeUtf8) {
return Err(crate::Error::InvalidInput {
message: "Expected Utf8 data type".to_string(),
});
}
let check_nulls = |source: &dyn Array| {
if source.null_count() > 0 {
return Err(crate::Error::Runtime {
message: "null values not supported".to_string(),
});
}
Ok(())
};
let tokens = match source.data_type() {
DataType::Utf8 => {
check_nulls(&*source)?;
source
.as_string::<i32>()
// TODO: should we do this in parallel? (e.g. using rayon)
.into_iter()
.map(|v| {
let value = v.unwrap();
let token = self.tokenizer.encode(value, true).map_err(|e| {
crate::Error::Runtime {
message: format!("failed to encode value: {}", e),
}
})?;
let token = token.get_ids().to_vec();
Ok(Tensor::new(token.as_slice(), &self.device)?)
})
.collect::<crate::Result<Vec<_>>>()?
}
DataType::LargeUtf8 => {
check_nulls(&*source)?;
source
.as_string::<i64>()
// TODO: should we do this in parallel? (e.g. using rayon)
.into_iter()
.map(|v| {
let value = v.unwrap();
let token = self.tokenizer.encode(value, true).map_err(|e| {
crate::Error::Runtime {
message: format!("failed to encode value: {}", e),
}
})?;
let token = token.get_ids().to_vec();
Ok(Tensor::new(token.as_slice(), &self.device)?)
})
.collect::<crate::Result<Vec<_>>>()?
}
DataType::Utf8View => {
return Err(crate::Error::Runtime {
message: "Utf8View not yet implemented".to_string(),
})
}
_ => {
return Err(crate::Error::Runtime {
message: "invalid type".to_string(),
})
}
};
let embeddings = Tensor::stack(&tokens, 0)
.and_then(|tokens| {
let token_type_ids = tokens.zeros_like()?;
self.model.forward(&tokens, &token_type_ids)
})
// TODO: it'd be nice to support other devices
.and_then(|tokens| tokens.to_device(&Device::Cpu))
.map_err(|e| crate::Error::Runtime {
message: format!("failed to compute embeddings: {}", e),
})?;
let (_, n_tokens, _) = embeddings.dims3().map_err(|e| crate::Error::Runtime {
message: format!("failed to get embeddings dimensions: {}", e),
})?;
let embeddings = (embeddings.sum(1).unwrap() / (n_tokens as f64)).map_err(|e| {
crate::Error::Runtime {
message: format!("failed to compute mean embeddings: {}", e),
}
})?;
let dims = embeddings.shape().dims().len();
let (arr, dtype): (Arc<dyn Array>, DataType) = match dims {
2 => {
let (d1, d2) = embeddings.dims2().map_err(|e| crate::Error::Runtime {
message: format!("failed to get embeddings dimensions: {}", e),
})?;
let (storage, layout) = embeddings.storage_and_layout();
match &*storage {
Storage::Cpu(CpuStorage::U8(data)) => {
let data: &[u8] = data.as_slice();
let arr = from_cpu_storage::<UInt8Type>(data, layout, &embeddings, d1, d2);
(Arc::new(arr), DataType::UInt8)
}
Storage::Cpu(CpuStorage::U32(data)) => (
Arc::new(from_cpu_storage::<UInt32Type>(
data,
layout,
&embeddings,
d1,
d2,
)),
DataType::UInt32,
),
Storage::Cpu(CpuStorage::I64(data)) => (
Arc::new(from_cpu_storage::<Int64Type>(
data,
layout,
&embeddings,
d1,
d2,
)),
DataType::Int64,
),
Storage::Cpu(CpuStorage::F16(data)) => (
Arc::new(from_cpu_storage::<Float16Type>(
data,
layout,
&embeddings,
d1,
d2,
)),
DataType::Float16,
),
Storage::Cpu(CpuStorage::F32(data)) => (
Arc::new(from_cpu_storage::<Float32Type>(
data,
layout,
&embeddings,
d1,
d2,
)),
DataType::Float32,
),
Storage::Cpu(CpuStorage::F64(data)) => (
Arc::new(from_cpu_storage::<Float64Type>(
data,
layout,
&embeddings,
d1,
d2,
)),
DataType::Float64,
),
Storage::Cpu(CpuStorage::BF16(_)) => {
panic!("Unsupported storage type: BF16")
}
_ => unreachable!("Only CPU storage currently supported"),
}
}
n_dims => todo!("Only 2 dimensions supported, got {}", n_dims),
};
Ok((arr, dtype))
}
}
impl EmbeddingFunction for SentenceTransformersEmbeddings {
fn name(&self) -> &str {
"sentence-transformers"
}
fn source_type(&self) -> crate::Result<std::borrow::Cow<arrow_schema::DataType>> {
Ok(Cow::Owned(DataType::Utf8))
}
fn dest_type(&self) -> crate::Result<std::borrow::Cow<arrow_schema::DataType>> {
let (n_dims, dtype) = self.compute_ndims_and_dtype()?;
Ok(Cow::Owned(DataType::new_fixed_size_list(
dtype,
n_dims as i32,
false,
)))
}
fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> crate::Result<Arc<dyn Array>> {
let len = source.len();
let n_dims = self.ndims()?;
let (inner, dtype) = self.compute_inner(source)?;
let fsl = DataType::new_fixed_size_list(dtype, n_dims as i32, false);
// We can't use the FixedSizeListBuilder here because it always adds a null bitmap
// and we want to explicitly work with non-nullable arrays.
let array_data = ArrayData::builder(fsl)
.len(len)
.add_child_data(inner.into_data())
.build()?;
Ok(Arc::new(FixedSizeListArray::from(array_data)))
}
fn compute_query_embeddings(&self, input: Arc<dyn Array>) -> crate::Result<Arc<dyn Array>> {
let (arr, _) = self.compute_inner(input)?;
Ok(arr)
}
}
fn from_cpu_storage<T: ArrowPrimitiveType>(
buffer: &[T::Native],
layout: &Layout,
embeddings: &Tensor,
dim1: usize,
dim2: usize,
) -> PrimitiveArray<T> {
let mut builder = PrimitiveBuilder::<T>::with_capacity(dim1 * dim2);
match layout.contiguous_offsets() {
Some((o1, o2)) => {
let data = &buffer[o1..o2];
builder.append_slice(data);
builder.finish()
}
None => {
let mut src_index = embeddings.strided_index();
for _idx_row in 0..dim1 {
let row = (0..dim2)
.map(|_| buffer[src_index.next().unwrap()])
.collect::<Vec<_>>();
builder.append_slice(&row);
}
builder.finish()
}
}
}

View File

@@ -125,3 +125,22 @@ impl From<polars::prelude::PolarsError> for Error {
}
}
}
#[cfg(feature = "sentence-transformers")]
impl From<hf_hub::api::sync::ApiError> for Error {
fn from(source: hf_hub::api::sync::ApiError) -> Self {
Self::Other {
message: "Error in Sentence Transformers integration.".to_string(),
source: Some(Box::new(source)),
}
}
}
#[cfg(feature = "sentence-transformers")]
impl From<candle_core::Error> for Error {
fn from(source: candle_core::Error) -> Self {
Self::Other {
message: "Error in 'candle_core'.".to_string(),
source: Some(Box::new(source)),
}
}
}