Files
lancedb/nodejs/src/rerankers.rs
Bert c9f248b058 feat: add hybrid search to node and rust SDKs (#1940)
Support hybrid search in both rust and node SDKs.

- Adds a new rerankers package to rust LanceDB, with the implementation
of the default RRF reranker
- Adds a new hybrid package to lancedb, with some helper methods related
to hybrid search such as normalizing scores and converting score column
to rank columns
- Adds capability to LanceDB VectorQuery to perform hybrid search if it
has both a nearest vector and full text search parameters.
- Adds wrappers for reranker implementations to nodejs SDK.

Additional rerankers will be added in followup PRs

https://github.com/lancedb/lancedb/issues/1921

---
Notes about how the rust rerankers are wrapped for calling from JS:

I wanted to keep the core reranker logic, and the invocation of the
reranker by the query code, in Rust. This aligns with the philosophy of
the new node SDK where it's just a thin wrapper around Rust. However, I
also wanted to have support for users who want to add custom rerankers
written in Javascript.

When we add a reranker to the query from Javascript, it adds a special
Rust reranker that has a callback to the Javascript code (which could
then turn around and call an underlying Rust reranker implementation if
desired). This adds a bit of complexity, but overall I think it moves us
in the right direction of having the majority of the query logic in the
underlying Rust SDK while keeping the option open to support custom
Javascript Rerankers.
2024-12-30 09:03:41 -05:00

148 lines
4.2 KiB
Rust

// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
use arrow_array::RecordBatch;
use async_trait::async_trait;
use napi::{
bindgen_prelude::*,
threadsafe_function::{ErrorStrategy, ThreadsafeFunction},
};
use napi_derive::napi;
use lancedb::ipc::batches_to_ipc_file;
use lancedb::rerankers::Reranker as LanceDBReranker;
use lancedb::{error::Error, ipc::ipc_file_to_batches};
use crate::error::NapiErrorExt;
/// Reranker implementation that "wraps" a NodeJS Reranker implementation.
/// This contains references to the callbacks that can be used to invoke the
/// reranking methods on the NodeJS implementation and handles serializing the
/// record batches to Arrow IPC buffers.
#[napi]
pub struct Reranker {
/// callback to the Javascript which will call the rerankHybrid method of
/// some Reranker implementation
rerank_hybrid: ThreadsafeFunction<RerankHybridCallbackArgs, ErrorStrategy::CalleeHandled>,
}
#[napi]
impl Reranker {
#[napi]
pub fn new(callbacks: RerankerCallbacks) -> Self {
let rerank_hybrid = callbacks
.rerank_hybrid
.create_threadsafe_function(0, move |ctx| Ok(vec![ctx.value]))
.unwrap();
Self { rerank_hybrid }
}
}
#[async_trait]
impl lancedb::rerankers::Reranker for Reranker {
async fn rerank_hybrid(
&self,
query: &str,
vector_results: RecordBatch,
fts_results: RecordBatch,
) -> lancedb::error::Result<RecordBatch> {
let callback_args = RerankHybridCallbackArgs {
query: query.to_string(),
vec_results: batches_to_ipc_file(&[vector_results])?,
fts_results: batches_to_ipc_file(&[fts_results])?,
};
let promised_buffer: Promise<Buffer> = self
.rerank_hybrid
.call_async(Ok(callback_args))
.await
.map_err(|e| Error::Runtime {
message: format!("napi error status={}, reason={}", e.status, e.reason),
})?;
let buffer = promised_buffer.await.map_err(|e| Error::Runtime {
message: format!("napi error status={}, reason={}", e.status, e.reason),
})?;
let mut reader = ipc_file_to_batches(buffer.to_vec())?;
let result = reader.next().ok_or(Error::Runtime {
message: "reranker result deserialization failed".to_string(),
})??;
return Ok(result);
}
}
impl std::fmt::Debug for Reranker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("NodeJSRerankerWrapper")
}
}
#[napi(object)]
pub struct RerankerCallbacks {
pub rerank_hybrid: JsFunction,
}
#[napi(object)]
pub struct RerankHybridCallbackArgs {
pub query: String,
pub vec_results: Vec<u8>,
pub fts_results: Vec<u8>,
}
fn buffer_to_record_batch(buffer: Buffer) -> Result<RecordBatch> {
let mut reader = ipc_file_to_batches(buffer.to_vec()).default_error()?;
reader
.next()
.ok_or(Error::InvalidInput {
message: "expected buffer containing record batch".to_string(),
})
.default_error()?
.map_err(Error::from)
.default_error()
}
/// Wrapper around rust RRFReranker
#[napi]
pub struct RRFReranker {
inner: lancedb::rerankers::rrf::RRFReranker,
}
#[napi]
impl RRFReranker {
#[napi]
pub async fn try_new(k: &[f32]) -> Result<Self> {
let k = k
.first()
.copied()
.ok_or(Error::InvalidInput {
message: "must supply RRF Reranker constructor arg 'k'".to_string(),
})
.default_error()?;
Ok(Self {
inner: lancedb::rerankers::rrf::RRFReranker::new(k),
})
}
#[napi]
pub async fn rerank_hybrid(
&self,
query: String,
vec_results: Buffer,
fts_results: Buffer,
) -> Result<Buffer> {
let vec_results = buffer_to_record_batch(vec_results)?;
let fts_results = buffer_to_record_batch(fts_results)?;
let result = self
.inner
.rerank_hybrid(&query, vec_results, fts_results)
.await
.unwrap();
let result_buff = batches_to_ipc_file(&[result]).default_error()?;
Ok(Buffer::from(result_buff.as_ref()))
}
}