diff --git a/nodejs/Cargo.toml b/nodejs/Cargo.toml index cb9304a0..ad5aa67e 100644 --- a/nodejs/Cargo.toml +++ b/nodejs/Cargo.toml @@ -12,7 +12,10 @@ categories.workspace = true crate-type = ["cdylib"] [dependencies] +async-trait.workspace = true arrow-ipc.workspace = true +arrow-array.workspace = true +arrow-schema.workspace = true env_logger.workspace = true futures.workspace = true lancedb = { path = "../rust/lancedb", features = ["remote"] } diff --git a/nodejs/__test__/arrow.test.ts b/nodejs/__test__/arrow.test.ts index c7cb3a20..4907063d 100644 --- a/nodejs/__test__/arrow.test.ts +++ b/nodejs/__test__/arrow.test.ts @@ -20,6 +20,8 @@ import * as arrow18 from "apache-arrow-18"; import { convertToTable, + fromBufferToRecordBatch, + fromRecordBatchToBuffer, fromTableToBuffer, makeArrowTable, makeEmptyTable, @@ -553,5 +555,28 @@ describe.each([arrow15, arrow16, arrow17, arrow18])( }); }); }); + + describe("converting record batches to buffers", function () { + it("can convert to buffered record batch and back again", async function () { + const records = [ + { text: "dog", vector: [0.1, 0.2] }, + { text: "cat", vector: [0.3, 0.4] }, + ]; + const table = await convertToTable(records); + const batch = table.batches[0]; + + const buffer = await fromRecordBatchToBuffer(batch); + const result = await fromBufferToRecordBatch(buffer); + + expect(JSON.stringify(batch.toArray())).toEqual( + JSON.stringify(result?.toArray()), + ); + }); + + it("converting from buffer returns null if buffer has no record batches", async function () { + const result = await fromBufferToRecordBatch(Buffer.from([0x01, 0x02])); // bad data + expect(result).toEqual(null); + }); + }); }, ); diff --git a/nodejs/__test__/rerankers.test.ts b/nodejs/__test__/rerankers.test.ts new file mode 100644 index 00000000..6a742ca2 --- /dev/null +++ b/nodejs/__test__/rerankers.test.ts @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +import { RecordBatch } from "apache-arrow"; +import * as tmp from "tmp"; +import { Connection, Index, Table, connect, makeArrowTable } from "../lancedb"; +import { RRFReranker } from "../lancedb/rerankers"; + +describe("rerankers", function () { + let tmpDir: tmp.DirResult; + let conn: Connection; + let table: Table; + + beforeEach(async () => { + tmpDir = tmp.dirSync({ unsafeCleanup: true }); + conn = await connect(tmpDir.name); + table = await conn.createTable("mytable", [ + { vector: [0.1, 0.1], text: "dog" }, + { vector: [0.2, 0.2], text: "cat" }, + ]); + await table.createIndex("text", { + config: Index.fts(), + replace: true, + }); + }); + + it("will query with the custom reranker", async function () { + const expectedResult = [ + { + text: "albert", + // biome-ignore lint/style/useNamingConvention: this is the lance field name + _relevance_score: 0.99, + }, + ]; + class MyCustomReranker { + async rerankHybrid( + _query: string, + _vecResults: RecordBatch, + _ftsResults: RecordBatch, + ): Promise { + // no reranker logic, just return some static data + const table = makeArrowTable(expectedResult); + return table.batches[0]; + } + } + + let result = await table + .query() + .nearestTo([0.1, 0.1]) + .fullTextSearch("dog") + .rerank(new MyCustomReranker()) + .select(["text"]) + .limit(5) + .toArray(); + + result = JSON.parse(JSON.stringify(result)); // convert StructRow to Object + expect(result).toEqual([ + { + text: "albert", + // biome-ignore lint/style/useNamingConvention: this is the lance field name + _relevance_score: 0.99, + }, + ]); + }); + + it("will query with RRFReranker", async function () { + // smoke test to see if the Rust wrapping Typescript is wired up correctly + const result = await table + .query() + .nearestTo([0.1, 0.1]) + .fullTextSearch("dog") + .rerank(await RRFReranker.create()) + .select(["text"]) + .limit(5) + .toArray(); + + expect(result).toHaveLength(2); + }); +}); diff --git a/nodejs/lancedb/arrow.ts b/nodejs/lancedb/arrow.ts index cd015ca4..6de51ca8 100644 --- a/nodejs/lancedb/arrow.ts +++ b/nodejs/lancedb/arrow.ts @@ -27,7 +27,9 @@ import { List, Null, RecordBatch, + RecordBatchFileReader, RecordBatchFileWriter, + RecordBatchReader, RecordBatchStreamWriter, Schema, Struct, @@ -810,6 +812,30 @@ export async function fromDataToBuffer( } } +/** + * Read a single record batch from a buffer. + * + * Returns null if the buffer does not contain a record batch + */ +export async function fromBufferToRecordBatch( + data: Buffer, +): Promise { + const iter = await RecordBatchFileReader.readAll(Buffer.from(data)).next() + .value; + const recordBatch = iter?.next().value; + return recordBatch || null; +} + +/** + * Create a buffer containing a single record batch + */ +export async function fromRecordBatchToBuffer( + batch: RecordBatch, +): Promise { + const writer = new RecordBatchFileWriter().writeAll([batch]); + return Buffer.from(await writer.toUint8Array()); +} + /** * Serialize an Arrow Table into a buffer using the Arrow IPC Stream serialization * diff --git a/nodejs/lancedb/index.ts b/nodejs/lancedb/index.ts index 74da915f..e9a0abcb 100644 --- a/nodejs/lancedb/index.ts +++ b/nodejs/lancedb/index.ts @@ -62,6 +62,7 @@ export { Index, IndexOptions, IvfPqOptions } from "./indices"; export { Table, AddDataOptions, UpdateOptions, OptimizeOptions } from "./table"; export * as embedding from "./embedding"; +export * as rerankers from "./rerankers"; /** * Connect to a LanceDB instance at the given URI. diff --git a/nodejs/lancedb/query.ts b/nodejs/lancedb/query.ts index 25fabf70..aa4b560f 100644 --- a/nodejs/lancedb/query.ts +++ b/nodejs/lancedb/query.ts @@ -16,6 +16,8 @@ import { Table as ArrowTable, type IntoVector, RecordBatch, + fromBufferToRecordBatch, + fromRecordBatchToBuffer, tableFromIPC, } from "./arrow"; import { type IvfPqOptions } from "./indices"; @@ -25,6 +27,7 @@ import { Table as NativeTable, VectorQuery as NativeVectorQuery, } from "./native"; +import { Reranker } from "./rerankers"; export class RecordBatchIterator implements AsyncIterator { private promisedInner?: Promise; private inner?: NativeBatchIterator; @@ -542,6 +545,27 @@ export class VectorQuery extends QueryBase { return this; } } + + rerank(reranker: Reranker): VectorQuery { + super.doCall((inner) => + inner.rerank({ + rerankHybrid: async (_, args) => { + const vecResults = await fromBufferToRecordBatch(args.vecResults); + const ftsResults = await fromBufferToRecordBatch(args.ftsResults); + const result = await reranker.rerankHybrid( + args.query, + vecResults as RecordBatch, + ftsResults as RecordBatch, + ); + + const buffer = fromRecordBatchToBuffer(result); + return buffer; + }, + }), + ); + + return this; + } } /** A builder for LanceDB queries. */ diff --git a/nodejs/lancedb/rerankers/index.ts b/nodejs/lancedb/rerankers/index.ts new file mode 100644 index 00000000..653499e3 --- /dev/null +++ b/nodejs/lancedb/rerankers/index.ts @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +import { RecordBatch } from "apache-arrow"; + +export * from "./rrf"; + +// Interface for a reranker. A reranker is used to rerank the results from a +// vector and FTS search. This is useful for combining the results from both +// search methods. +export interface Reranker { + rerankHybrid( + query: string, + vecResults: RecordBatch, + ftsResults: RecordBatch, + ): Promise; +} diff --git a/nodejs/lancedb/rerankers/rrf.ts b/nodejs/lancedb/rerankers/rrf.ts new file mode 100644 index 00000000..1d89c076 --- /dev/null +++ b/nodejs/lancedb/rerankers/rrf.ts @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +import { RecordBatch } from "apache-arrow"; +import { fromBufferToRecordBatch, fromRecordBatchToBuffer } from "../arrow"; +import { RrfReranker as NativeRRFReranker } from "../native"; + +/** + * Reranks the results using the Reciprocal Rank Fusion (RRF) algorithm. + * + * Internally this uses the Rust implementation + */ +export class RRFReranker { + private inner: NativeRRFReranker; + + constructor(inner: NativeRRFReranker) { + this.inner = inner; + } + + public static async create(k: number = 60) { + return new RRFReranker( + await NativeRRFReranker.tryNew(new Float32Array([k])), + ); + } + + async rerankHybrid( + query: string, + vecResults: RecordBatch, + ftsResults: RecordBatch, + ): Promise { + const buffer = await this.inner.rerankHybrid( + query, + await fromRecordBatchToBuffer(vecResults), + await fromRecordBatchToBuffer(ftsResults), + ); + const recordBatch = await fromBufferToRecordBatch(buffer); + + return recordBatch as RecordBatch; + } +} diff --git a/nodejs/src/lib.rs b/nodejs/src/lib.rs index d0a02ee4..0ac75b7c 100644 --- a/nodejs/src/lib.rs +++ b/nodejs/src/lib.rs @@ -24,6 +24,7 @@ mod iterator; pub mod merge; mod query; pub mod remote; +mod rerankers; mod table; mod util; diff --git a/nodejs/src/query.rs b/nodejs/src/query.rs index fd8e3b48..321e4052 100644 --- a/nodejs/src/query.rs +++ b/nodejs/src/query.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use lancedb::index::scalar::FullTextSearchQuery; use lancedb::query::ExecutableQuery; use lancedb::query::Query as LanceDbQuery; @@ -25,6 +27,8 @@ use napi_derive::napi; use crate::error::convert_error; use crate::error::NapiErrorExt; use crate::iterator::RecordBatchIterator; +use crate::rerankers::Reranker; +use crate::rerankers::RerankerCallbacks; use crate::util::parse_distance_type; #[napi] @@ -218,6 +222,14 @@ impl VectorQuery { self.inner = self.inner.clone().with_row_id(); } + #[napi] + pub fn rerank(&mut self, callbacks: RerankerCallbacks) { + self.inner = self + .inner + .clone() + .rerank(Arc::new(Reranker::new(callbacks))); + } + #[napi(catch_unwind)] pub async fn execute( &self, diff --git a/nodejs/src/rerankers.rs b/nodejs/src/rerankers.rs new file mode 100644 index 00000000..bf8a4d96 --- /dev/null +++ b/nodejs/src/rerankers.rs @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use arrow_array::RecordBatch; +use async_trait::async_trait; +use napi::{ + bindgen_prelude::*, + threadsafe_function::{ErrorStrategy, ThreadsafeFunction}, +}; +use napi_derive::napi; + +use lancedb::ipc::batches_to_ipc_file; +use lancedb::rerankers::Reranker as LanceDBReranker; +use lancedb::{error::Error, ipc::ipc_file_to_batches}; + +use crate::error::NapiErrorExt; + +/// Reranker implementation that "wraps" a NodeJS Reranker implementation. +/// This contains references to the callbacks that can be used to invoke the +/// reranking methods on the NodeJS implementation and handles serializing the +/// record batches to Arrow IPC buffers. +#[napi] +pub struct Reranker { + /// callback to the Javascript which will call the rerankHybrid method of + /// some Reranker implementation + rerank_hybrid: ThreadsafeFunction, +} + +#[napi] +impl Reranker { + #[napi] + pub fn new(callbacks: RerankerCallbacks) -> Self { + let rerank_hybrid = callbacks + .rerank_hybrid + .create_threadsafe_function(0, move |ctx| Ok(vec![ctx.value])) + .unwrap(); + + Self { rerank_hybrid } + } +} + +#[async_trait] +impl lancedb::rerankers::Reranker for Reranker { + async fn rerank_hybrid( + &self, + query: &str, + vector_results: RecordBatch, + fts_results: RecordBatch, + ) -> lancedb::error::Result { + let callback_args = RerankHybridCallbackArgs { + query: query.to_string(), + vec_results: batches_to_ipc_file(&[vector_results])?, + fts_results: batches_to_ipc_file(&[fts_results])?, + }; + let promised_buffer: Promise = self + .rerank_hybrid + .call_async(Ok(callback_args)) + .await + .map_err(|e| Error::Runtime { + message: format!("napi error status={}, reason={}", e.status, e.reason), + })?; + let buffer = promised_buffer.await.map_err(|e| Error::Runtime { + message: format!("napi error status={}, reason={}", e.status, e.reason), + })?; + let mut reader = ipc_file_to_batches(buffer.to_vec())?; + let result = reader.next().ok_or(Error::Runtime { + message: "reranker result deserialization failed".to_string(), + })??; + + return Ok(result); + } +} + +impl std::fmt::Debug for Reranker { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("NodeJSRerankerWrapper") + } +} + +#[napi(object)] +pub struct RerankerCallbacks { + pub rerank_hybrid: JsFunction, +} + +#[napi(object)] +pub struct RerankHybridCallbackArgs { + pub query: String, + pub vec_results: Vec, + pub fts_results: Vec, +} + +fn buffer_to_record_batch(buffer: Buffer) -> Result { + let mut reader = ipc_file_to_batches(buffer.to_vec()).default_error()?; + reader + .next() + .ok_or(Error::InvalidInput { + message: "expected buffer containing record batch".to_string(), + }) + .default_error()? + .map_err(Error::from) + .default_error() +} + +/// Wrapper around rust RRFReranker +#[napi] +pub struct RRFReranker { + inner: lancedb::rerankers::rrf::RRFReranker, +} + +#[napi] +impl RRFReranker { + #[napi] + pub async fn try_new(k: &[f32]) -> Result { + let k = k + .first() + .copied() + .ok_or(Error::InvalidInput { + message: "must supply RRF Reranker constructor arg 'k'".to_string(), + }) + .default_error()?; + + Ok(Self { + inner: lancedb::rerankers::rrf::RRFReranker::new(k), + }) + } + + #[napi] + pub async fn rerank_hybrid( + &self, + query: String, + vec_results: Buffer, + fts_results: Buffer, + ) -> Result { + let vec_results = buffer_to_record_batch(vec_results)?; + let fts_results = buffer_to_record_batch(fts_results)?; + + let result = self + .inner + .rerank_hybrid(&query, vec_results, fts_results) + .await + .unwrap(); + + let result_buff = batches_to_ipc_file(&[result]).default_error()?; + + Ok(Buffer::from(result_buff.as_ref())) + } +} diff --git a/rust/lancedb/src/lib.rs b/rust/lancedb/src/lib.rs index 8edca5a8..9ff37ccb 100644 --- a/rust/lancedb/src/lib.rs +++ b/rust/lancedb/src/lib.rs @@ -214,6 +214,7 @@ mod polars_arrow_convertors; pub mod query; #[cfg(feature = "remote")] pub mod remote; +pub mod rerankers; pub mod table; pub mod utils; diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index e9add758..86d39392 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -15,19 +15,31 @@ use std::future::Future; use std::sync::Arc; +use arrow::compute::concat_batches; use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array}; use arrow_schema::DataType; use datafusion_physical_plan::ExecutionPlan; +use futures::{stream, try_join, FutureExt, TryStreamExt}; use half::f16; -use lance::dataset::scanner::DatasetRecordBatchStream; +use lance::{ + arrow::RecordBatchExt, + dataset::{scanner::DatasetRecordBatchStream, ROW_ID}, +}; use lance_datafusion::exec::execute_plan; +use lance_index::scalar::inverted::SCORE_COL; use lance_index::scalar::FullTextSearchQuery; +use lance_index::vector::DIST_COL; +use lance_io::stream::RecordBatchStreamAdapter; use crate::arrow::SendableRecordBatchStream; use crate::error::{Error, Result}; +use crate::rerankers::rrf::RRFReranker; +use crate::rerankers::{check_reranker_result, NormalizeMethod, Reranker}; use crate::table::TableInternal; use crate::DistanceType; +mod hybrid; + pub(crate) const DEFAULT_TOP_K: usize = 10; /// Which columns should be retrieved from the database @@ -435,6 +447,16 @@ pub trait QueryBase { /// Return the `_rowid` meta column from the Table. fn with_row_id(self) -> Self; + + /// Rerank the results using the specified reranker. + /// + /// This is currently only supported for Hybrid Search. + fn rerank(self, reranker: Arc) -> Self; + + /// The method to normalize the scores. Can be "rank" or "Score". If "Rank", + /// the scores are converted to ranks and then normalized. If "Score", the + /// scores are normalized directly. + fn norm(self, norm: NormalizeMethod) -> Self; } pub trait HasQuery { @@ -481,6 +503,16 @@ impl QueryBase for T { self.mut_query().with_row_id = true; self } + + fn rerank(mut self, reranker: Arc) -> Self { + self.mut_query().reranker = Some(reranker); + self + } + + fn norm(mut self, norm: NormalizeMethod) -> Self { + self.mut_query().norm = Some(norm); + self + } } /// Options for controlling the execution of a query @@ -600,6 +632,13 @@ pub struct Query { /// If set to false, the filter will be applied after the vector search. pub(crate) prefilter: bool, + + /// Implementation of reranker that can be used to reorder or combine query + /// results, especially if using hybrid search + pub(crate) reranker: Option>, + + /// Configure how query results are normalized when doing hybrid search + pub(crate) norm: Option, } impl Query { @@ -614,6 +653,8 @@ impl Query { fast_search: false, with_row_id: false, prefilter: true, + reranker: None, + norm: None, } } @@ -862,6 +903,65 @@ impl VectorQuery { self.use_index = false; self } + + pub async fn execute_hybrid(&self) -> Result { + // clone query and specify we want to include row IDs, which can be needed for reranking + let fts_query = self.base.clone().with_row_id(); + let mut vector_query = self.clone().with_row_id(); + + vector_query.base.full_text_search = None; + let (fts_results, vec_results) = try_join!(fts_query.execute(), vector_query.execute())?; + + let (fts_results, vec_results) = try_join!( + fts_results.try_collect::>(), + vec_results.try_collect::>() + )?; + + // try to get the schema to use when combining batches. + // if either + let (fts_schema, vec_schema) = hybrid::query_schemas(&fts_results, &vec_results); + + // concatenate all the batches together + let mut fts_results = concat_batches(&fts_schema, fts_results.iter())?; + let mut vec_results = concat_batches(&vec_schema, vec_results.iter())?; + + if matches!(self.base.norm, Some(NormalizeMethod::Rank)) { + vec_results = hybrid::rank(vec_results, DIST_COL, None)?; + fts_results = hybrid::rank(fts_results, SCORE_COL, None)?; + } + + vec_results = hybrid::normalize_scores(vec_results, DIST_COL, None)?; + fts_results = hybrid::normalize_scores(fts_results, SCORE_COL, None)?; + + let reranker = self + .base + .reranker + .clone() + .unwrap_or(Arc::new(RRFReranker::default())); + + let fts_query = self.base.full_text_search.as_ref().ok_or(Error::Runtime { + message: "there should be an FTS search".to_string(), + })?; + + let mut results = reranker + .rerank_hybrid(&fts_query.query, vec_results, fts_results) + .await?; + + check_reranker_result(&results)?; + + let limit = self.base.limit.unwrap_or(DEFAULT_TOP_K); + if results.num_rows() > limit { + results = results.slice(0, limit); + } + + if !self.base.with_row_id { + results = results.drop_column(ROW_ID)?; + } + + Ok(SendableRecordBatchStream::from( + RecordBatchStreamAdapter::new(results.schema(), stream::iter([Ok(results)])), + )) + } } impl ExecutableQuery for VectorQuery { @@ -873,6 +973,11 @@ impl ExecutableQuery for VectorQuery { &self, options: QueryExecutionOptions, ) -> Result { + if self.base.full_text_search.is_some() { + let hybrid_result = async move { self.execute_hybrid().await }.boxed().await?; + return Ok(hybrid_result); + } + Ok(SendableRecordBatchStream::from( DatasetRecordBatchStream::new(execute_plan( self.create_plan(options).await?, @@ -894,20 +999,20 @@ impl HasQuery for VectorQuery { #[cfg(test)] mod tests { - use std::sync::Arc; + use std::{collections::HashSet, sync::Arc}; use super::*; - use arrow::{compute::concat_batches, datatypes::Int32Type}; + use arrow::{array::downcast_array, compute::concat_batches, datatypes::Int32Type}; use arrow_array::{ - cast::AsArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator, - RecordBatchReader, + cast::AsArray, types::Float32Type, FixedSizeListArray, Float32Array, Int32Array, + RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray, }; use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema}; use futures::{StreamExt, TryStreamExt}; use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector}; use tempfile::tempdir; - use crate::{connect, Table}; + use crate::{connect, connection::CreateTableMode, Table}; #[tokio::test] async fn test_setters_getters() { @@ -1274,4 +1379,156 @@ mod tests { assert!(query_index.values().contains(&0)); assert!(query_index.values().contains(&1)); } + + #[tokio::test] + async fn test_hybrid_search() { + let tmp_dir = tempdir().unwrap(); + let dataset_path = tmp_dir.path(); + let conn = connect(dataset_path.to_str().unwrap()) + .execute() + .await + .unwrap(); + + let dims = 2; + let schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("text", DataType::Utf8, false), + ArrowField::new( + "vector", + DataType::FixedSizeList( + Arc::new(ArrowField::new("item", DataType::Float32, true)), + dims, + ), + false, + ), + ])); + + let text = StringArray::from(vec!["dog", "cat", "a", "b"]); + let vectors = vec![ + Some(vec![Some(0.0), Some(0.0)]), + Some(vec![Some(-2.0), Some(-2.0)]), + Some(vec![Some(50.0), Some(50.0)]), + Some(vec![Some(-30.0), Some(-30.0)]), + ]; + let vector = FixedSizeListArray::from_iter_primitive::(vectors, dims); + + let record_batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(text), Arc::new(vector)]).unwrap(); + let record_batch_iter = + RecordBatchIterator::new(vec![record_batch].into_iter().map(Ok), schema.clone()); + let table = conn + .create_table("my_table", record_batch_iter) + .execute() + .await + .unwrap(); + + table + .create_index(&["text"], crate::index::Index::FTS(Default::default())) + .replace(true) + .execute() + .await + .unwrap(); + + let fts_query = FullTextSearchQuery::new("b".to_string()); + let results = table + .query() + .full_text_search(fts_query) + .limit(2) + .nearest_to(&[-10.0, -10.0]) + .unwrap() + .execute() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + + let batch = &results[0]; + + let texts: StringArray = downcast_array(batch.column_by_name("text").unwrap()); + let texts = texts.iter().map(|e| e.unwrap()).collect::>(); + assert!(texts.contains("cat")); // should be close by vector search + assert!(texts.contains("b")); // should be close by fts search + + // ensure that this works correctly if there are no matching FTS results + let fts_query = FullTextSearchQuery::new("z".to_string()); + table + .query() + .full_text_search(fts_query) + .limit(2) + .nearest_to(&[-10.0, -10.0]) + .unwrap() + .execute() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + } + + #[tokio::test] + async fn test_hybrid_search_empty_table() { + let tmp_dir = tempdir().unwrap(); + let dataset_path = tmp_dir.path(); + let conn = connect(dataset_path.to_str().unwrap()) + .execute() + .await + .unwrap(); + + let dims = 2; + + let schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("text", DataType::Utf8, false), + ArrowField::new( + "vector", + DataType::FixedSizeList( + Arc::new(ArrowField::new("item", DataType::Float32, true)), + dims, + ), + false, + ), + ])); + + // ensure hybrid search is also supported on a fully empty table + let vectors: Vec>>> = Vec::new(); + let record_batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(Vec::<&str>::new())), + Arc::new( + FixedSizeListArray::from_iter_primitive::(vectors, dims), + ), + ], + ) + .unwrap(); + let record_batch_iter = + RecordBatchIterator::new(vec![record_batch].into_iter().map(Ok), schema.clone()); + let table = conn + .create_table("my_table", record_batch_iter) + .mode(CreateTableMode::Overwrite) + .execute() + .await + .unwrap(); + table + .create_index(&["text"], crate::index::Index::FTS(Default::default())) + .replace(true) + .execute() + .await + .unwrap(); + let fts_query = FullTextSearchQuery::new("b".to_string()); + let results = table + .query() + .full_text_search(fts_query) + .limit(2) + .nearest_to(&[-10.0, -10.0]) + .unwrap() + .execute() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + let batch = &results[0]; + assert_eq!(0, batch.num_rows()); + assert_eq!(2, batch.num_columns()); + } } diff --git a/rust/lancedb/src/query/hybrid.rs b/rust/lancedb/src/query/hybrid.rs new file mode 100644 index 00000000..9e85e870 --- /dev/null +++ b/rust/lancedb/src/query/hybrid.rs @@ -0,0 +1,346 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use arrow::compute::{ + kernels::numeric::{div, sub}, + max, min, +}; +use arrow_array::{cast::downcast_array, Float32Array, RecordBatch}; +use arrow_schema::{DataType, Field, Schema, SortOptions}; +use lance::dataset::ROW_ID; +use lance_index::{scalar::inverted::SCORE_COL, vector::DIST_COL}; +use std::sync::Arc; + +use crate::error::{Error, Result}; + +/// Converts results's score column to a rank. +/// +/// Expects the `column` argument to be type Float32 and will panic if it's not +pub fn rank(results: RecordBatch, column: &str, ascending: Option) -> Result { + let scores = results.column_by_name(column).ok_or(Error::InvalidInput { + message: format!( + "expected column {} not found in rank. found columns {:?}", + column, + results + .schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>(), + ), + })?; + + if results.num_rows() == 0 { + return Ok(results); + } + + let scores: Float32Array = downcast_array(scores); + let ranks = Float32Array::from_iter_values( + arrow::compute::kernels::rank::rank( + &scores, + Some(SortOptions { + descending: !ascending.unwrap_or(true), + ..Default::default() + }), + )? + .iter() + .map(|i| *i as f32), + ); + + let schema = results.schema(); + let (column_idx, _) = schema.column_with_name(column).unwrap(); + let mut columns = results.columns().to_vec(); + columns[column_idx] = Arc::new(ranks); + + let results = RecordBatch::try_new(results.schema(), columns)?; + + Ok(results) +} + +/// Get the query schemas needed when combining the search results. +/// +/// If either of the record batches are empty, then we create a schema from the +/// other record batch, and replace the score/distance column. If both record +/// batches are empty, create empty schemas. +pub fn query_schemas( + fts_results: &[RecordBatch], + vec_results: &[RecordBatch], +) -> (Arc, Arc) { + let (fts_schema, vec_schema) = match ( + fts_results.first().map(|r| r.schema()), + vec_results.first().map(|r| r.schema()), + ) { + (Some(fts_schema), Some(vec_schema)) => (fts_schema, vec_schema), + (None, Some(vec_schema)) => { + let fts_schema = with_field_name_replaced(&vec_schema, DIST_COL, SCORE_COL); + (Arc::new(fts_schema), vec_schema) + } + (Some(fts_schema), None) => { + let vec_schema = with_field_name_replaced(&fts_schema, DIST_COL, SCORE_COL); + (fts_schema, Arc::new(vec_schema)) + } + (None, None) => (Arc::new(empty_fts_schema()), Arc::new(empty_vec_schema())), + }; + + (fts_schema, vec_schema) +} + +pub fn empty_fts_schema() -> Schema { + Schema::new(vec![ + Arc::new(Field::new(SCORE_COL, DataType::Float32, false)), + Arc::new(Field::new(ROW_ID, DataType::UInt64, false)), + ]) +} + +pub fn empty_vec_schema() -> Schema { + Schema::new(vec![ + Arc::new(Field::new(DIST_COL, DataType::Float32, false)), + Arc::new(Field::new(ROW_ID, DataType::UInt64, false)), + ]) +} + +pub fn with_field_name_replaced(schema: &Schema, target: &str, replacement: &str) -> Schema { + let field_idx = schema.fields().iter().enumerate().find_map(|(i, field)| { + if field.name() == target { + Some(i) + } else { + None + } + }); + + let mut fields = schema.fields().to_vec(); + if let Some(idx) = field_idx { + let new_field = (*fields[idx]).clone().with_name(replacement); + fields[idx] = Arc::new(new_field); + } + + Schema::new(fields) +} + +/// Normalize the scores column to have values between 0 and 1. +/// +/// Expects the `column` argument to be type Float32 and will panic if it's not +pub fn normalize_scores( + results: RecordBatch, + column: &str, + invert: Option, +) -> Result { + let scores = results.column_by_name(column).ok_or(Error::InvalidInput { + message: format!( + "expected column {} not found in rank. found columns {:?}", + column, + results + .schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>(), + ), + })?; + + if results.num_rows() == 0 { + return Ok(results); + } + let mut scores: Float32Array = downcast_array(scores); + + let max = max(&scores).unwrap_or(0.0); + let min = min(&scores).unwrap_or(0.0); + + // this is equivalent to np.isclose which is used in python + let rng = if max - min < 10e-5 { max } else { max - min }; + + // if rng is 0, then min and max are both 0 so we just leave the scores as is + if rng != 0.0 { + let tmp = div( + &sub(&scores, &Float32Array::new_scalar(min))?, + &Float32Array::new_scalar(rng), + )?; + scores = downcast_array(&tmp); + } + + if invert.unwrap_or(false) { + let tmp = sub(&Float32Array::new_scalar(1.0), &scores)?; + scores = downcast_array(&tmp); + } + + let schema = results.schema(); + let (column_idx, _) = schema.column_with_name(column).unwrap(); + let mut columns = results.columns().to_vec(); + columns[column_idx] = Arc::new(scores); + + let results = RecordBatch::try_new(results.schema(), columns).unwrap(); + + Ok(results) +} + +#[cfg(test)] +mod test { + use super::*; + use arrow_array::StringArray; + use arrow_schema::{DataType, Field, Schema}; + + #[test] + fn test_rank() { + let schema = Arc::new(Schema::new(vec![ + Arc::new(Field::new("name", DataType::Utf8, false)), + Arc::new(Field::new("score", DataType::Float32, false)), + ])); + + let names = StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"]); + let scores = Float32Array::from(vec![0.2, 0.4, 0.1, 0.6, 0.45]); + + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(names), Arc::new(scores)]).unwrap(); + + let result = rank(batch.clone(), "score", Some(false)).unwrap(); + assert_eq!(2, result.schema().fields().len()); + assert_eq!("name", result.schema().field(0).name()); + assert_eq!("score", result.schema().field(1).name()); + + let names: StringArray = downcast_array(result.column(0)); + assert_eq!( + names.iter().map(|e| e.unwrap()).collect::>(), + vec!["foo", "bar", "baz", "bean", "dog"] + ); + let scores: Float32Array = downcast_array(result.column(1)); + assert_eq!( + scores.iter().map(|e| e.unwrap()).collect::>(), + vec![4.0, 3.0, 5.0, 1.0, 2.0] + ); + + // check sort ascending + let result = rank(batch.clone(), "score", Some(true)).unwrap(); + let names: StringArray = downcast_array(result.column(0)); + assert_eq!( + names.iter().map(|e| e.unwrap()).collect::>(), + vec!["foo", "bar", "baz", "bean", "dog"] + ); + let scores: Float32Array = downcast_array(result.column(1)); + assert_eq!( + scores.iter().map(|e| e.unwrap()).collect::>(), + vec![2.0, 3.0, 1.0, 5.0, 4.0] + ); + + // ensure default sort is ascending + let result = rank(batch.clone(), "score", None).unwrap(); + let names: StringArray = downcast_array(result.column(0)); + assert_eq!( + names.iter().map(|e| e.unwrap()).collect::>(), + vec!["foo", "bar", "baz", "bean", "dog"] + ); + let scores: Float32Array = downcast_array(result.column(1)); + assert_eq!( + scores.iter().map(|e| e.unwrap()).collect::>(), + vec![2.0, 3.0, 1.0, 5.0, 4.0] + ); + + // check it can handle an empty batch + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(Vec::<&str>::new())), + Arc::new(Float32Array::from(Vec::::new())), + ], + ) + .unwrap(); + let result = rank(batch.clone(), "score", None).unwrap(); + assert_eq!(0, result.num_rows()); + assert_eq!(2, result.schema().fields().len()); + assert_eq!("name", result.schema().field(0).name()); + assert_eq!("score", result.schema().field(1).name()); + + // check it returns the expected error when there's no column + let result = rank(batch.clone(), "bad_col", None); + match result { + Err(Error::InvalidInput { message }) => { + assert_eq!("expected column bad_col not found in rank. found columns [\"name\", \"score\"]", message); + } + _ => { + panic!("expected invalid input error, received {:?}", result) + } + } + } + + #[test] + fn test_normalize_scores() { + let schema = Arc::new(Schema::new(vec![ + Arc::new(Field::new("name", DataType::Utf8, false)), + Arc::new(Field::new("score", DataType::Float32, false)), + ])); + + let names = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"])); + let scores = Arc::new(Float32Array::from(vec![-4.0, 2.0, 0.0, 3.0, 6.0])); + + let batch = + RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap(); + + let result = normalize_scores(batch.clone(), "score", Some(false)).unwrap(); + let names: StringArray = downcast_array(result.column(0)); + assert_eq!( + names.iter().map(|e| e.unwrap()).collect::>(), + vec!["foo", "bar", "baz", "bean", "dog"] + ); + let scores: Float32Array = downcast_array(result.column(1)); + assert_eq!( + scores.iter().map(|e| e.unwrap()).collect::>(), + vec![0.0, 0.6, 0.4, 0.7, 1.0] + ); + + // check it can invert the normalization + let result = normalize_scores(batch.clone(), "score", Some(true)).unwrap(); + let scores: Float32Array = downcast_array(result.column(1)); + assert_eq!( + scores.iter().map(|e| e.unwrap()).collect::>(), + vec![1.0, 1.0 - 0.6, 0.6, 0.3, 0.0] + ); + + // check that the default is not inverted + let result = normalize_scores(batch.clone(), "score", None).unwrap(); + let scores: Float32Array = downcast_array(result.column(1)); + assert_eq!( + scores.iter().map(|e| e.unwrap()).collect::>(), + vec![0.0, 0.6, 0.4, 0.7, 1.0] + ); + + // check that it will function correctly if all the values are the same + let names = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"])); + let scores = Arc::new(Float32Array::from(vec![2.1, 2.1, 2.1, 2.1, 2.1])); + let batch = + RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap(); + let result = normalize_scores(batch.clone(), "score", None).unwrap(); + let scores: Float32Array = downcast_array(result.column(1)); + assert_eq!( + scores.iter().map(|e| e.unwrap()).collect::>(), + vec![0.0, 0.0, 0.0, 0.0, 0.0] + ); + + // check it keeps floating point rounding errors for same score normalized the same + // e.g., the behaviour is consistent with python + let scores = Arc::new(Float32Array::from(vec![1.0, 1.0, 1.0, 1.0, 0.9999999])); + let batch = + RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap(); + let result = normalize_scores(batch.clone(), "score", None).unwrap(); + let scores: Float32Array = downcast_array(result.column(1)); + assert_eq!( + scores.iter().map(|e| e.unwrap()).collect::>(), + vec![ + 1.0 - 0.9999999, + 1.0 - 0.9999999, + 1.0 - 0.9999999, + 1.0 - 0.9999999, + 0.0 + ] + ); + + // check that it can handle if all the scores are 0 + let scores = Arc::new(Float32Array::from(vec![0.0, 0.0, 0.0, 0.0, 0.0])); + let batch = + RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap(); + let result = normalize_scores(batch.clone(), "score", None).unwrap(); + let scores: Float32Array = downcast_array(result.column(1)); + assert_eq!( + scores.iter().map(|e| e.unwrap()).collect::>(), + vec![0.0, 0.0, 0.0, 0.0, 0.0] + ); + } +} diff --git a/rust/lancedb/src/rerankers.rs b/rust/lancedb/src/rerankers.rs new file mode 100644 index 00000000..338230c0 --- /dev/null +++ b/rust/lancedb/src/rerankers.rs @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use std::collections::BTreeSet; + +use arrow::{ + array::downcast_array, + compute::{concat_batches, filter_record_batch}, +}; +use arrow_array::{BooleanArray, RecordBatch, UInt64Array}; +use async_trait::async_trait; +use lance::dataset::ROW_ID; + +use crate::error::{Error, Result}; + +pub mod rrf; + +/// column name for reranker relevance score +const RELEVANCE_SCORE: &str = "_relevance_score"; + +#[derive(Debug, Clone, PartialEq)] +pub enum NormalizeMethod { + Score, + Rank, +} + +/// Interface for a reranker. A reranker is used to rerank the results from a +/// vector and FTS search. This is useful for combining the results from both +/// search methods. +#[async_trait] +pub trait Reranker: std::fmt::Debug + Sync + Send { + // TODO support vector reranking and FTS reranking. Currently only hybrid reranking is supported. + + /// Rerank function receives the individual results from the vector and FTS search + /// results. You can choose to use any of the results to generate the final results, + /// allowing maximum flexibility. + async fn rerank_hybrid( + &self, + query: &str, + vector_results: RecordBatch, + fts_results: RecordBatch, + ) -> Result; + + fn merge_results( + &self, + vector_results: RecordBatch, + fts_results: RecordBatch, + ) -> Result { + let combined = concat_batches(&fts_results.schema(), [vector_results, fts_results].iter())?; + + let mut mask = BooleanArray::builder(combined.num_rows()); + let mut unique_ids = BTreeSet::new(); + let row_ids = combined.column_by_name(ROW_ID).ok_or(Error::InvalidInput { + message: format!( + "could not find expected column {} while merging results. found columns {:?}", + ROW_ID, + combined + .schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>() + ), + })?; + let row_ids: UInt64Array = downcast_array(row_ids); + row_ids.values().iter().for_each(|id| { + mask.append_value(unique_ids.insert(id)); + }); + + let combined = filter_record_batch(&combined, &mask.finish())?; + + Ok(combined) + } +} + +pub fn check_reranker_result(result: &RecordBatch) -> Result<()> { + if result.schema().column_with_name(RELEVANCE_SCORE).is_none() { + return Err(Error::Schema { + message: format!( + "rerank_hybrid must return a RecordBatch with a column named {}", + RELEVANCE_SCORE + ), + }); + } + + Ok(()) +} diff --git a/rust/lancedb/src/rerankers/rrf.rs b/rust/lancedb/src/rerankers/rrf.rs new file mode 100644 index 00000000..7b5d8361 --- /dev/null +++ b/rust/lancedb/src/rerankers/rrf.rs @@ -0,0 +1,223 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use std::collections::BTreeMap; +use std::sync::Arc; + +use arrow::{ + array::downcast_array, + compute::{sort_to_indices, take}, +}; +use arrow_array::{Float32Array, RecordBatch, UInt64Array}; +use arrow_schema::{DataType, Field, Schema, SortOptions}; +use async_trait::async_trait; +use lance::dataset::ROW_ID; + +use crate::error::{Error, Result}; +use crate::rerankers::{Reranker, RELEVANCE_SCORE}; + +/// Reranks the results using Reciprocal Rank Fusion(RRF) algorithm based +/// on the scores of vector and FTS search. +/// +#[derive(Debug)] +pub struct RRFReranker { + k: f32, +} + +impl RRFReranker { + /// Create a new RRFReranker + /// + /// The parameter k is a constant used in the RRF formula (default is 60). + /// Experiments indicate that k = 60 was near-optimal, but that the choice + /// is not critical. See paper: + /// https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf + pub fn new(k: f32) -> Self { + Self { k } + } +} + +impl Default for RRFReranker { + fn default() -> Self { + Self { k: 60.0 } + } +} + +#[async_trait] +impl Reranker for RRFReranker { + async fn rerank_hybrid( + &self, + _query: &str, + vector_results: RecordBatch, + fts_results: RecordBatch, + ) -> Result { + let vector_ids = vector_results + .column_by_name(ROW_ID) + .ok_or(Error::InvalidInput { + message: format!( + "expected column {} not found in vector_results. found columns {:?}", + ROW_ID, + vector_results + .schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>() + ), + })?; + let fts_ids = fts_results + .column_by_name(ROW_ID) + .ok_or(Error::InvalidInput { + message: format!( + "expected column {} not found in fts_results. found columns {:?}", + ROW_ID, + fts_results + .schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>() + ), + })?; + + let vector_ids: UInt64Array = downcast_array(&vector_ids); + let fts_ids: UInt64Array = downcast_array(&fts_ids); + + let mut rrf_score_map = BTreeMap::new(); + let mut update_score_map = |(i, result_id)| { + let score = 1.0 / (i as f32 + self.k); + rrf_score_map + .entry(result_id) + .and_modify(|e| *e += score) + .or_insert(score); + }; + vector_ids + .values() + .iter() + .enumerate() + .for_each(&mut update_score_map); + fts_ids + .values() + .iter() + .enumerate() + .for_each(&mut update_score_map); + + let combined_results = self.merge_results(vector_results, fts_results)?; + + let combined_row_ids: UInt64Array = + downcast_array(combined_results.column_by_name(ROW_ID).unwrap()); + let relevance_scores = Float32Array::from_iter_values( + combined_row_ids + .values() + .iter() + .map(|row_id| rrf_score_map.get(row_id).unwrap()) + .copied(), + ); + + // keep track of indices sorted by the relevance column + let sort_indices = sort_to_indices( + &relevance_scores, + Some(SortOptions { + descending: true, + ..Default::default() + }), + None, + ) + .unwrap(); + + // add relevance scores to columns + let mut columns = combined_results.columns().to_vec(); + columns.push(Arc::new(relevance_scores)); + + // sort by the relevance scores + let columns = columns + .iter() + .map(|c| take(c, &sort_indices, None).unwrap()) + .collect(); + + // add relevance score to schema + let mut fields = combined_results.schema().fields().to_vec(); + fields.push(Arc::new(Field::new( + RELEVANCE_SCORE, + DataType::Float32, + false, + ))); + let schema = Schema::new(fields); + + let combined_results = RecordBatch::try_new(Arc::new(schema), columns)?; + + Ok(combined_results) + } +} + +#[cfg(test)] +pub mod test { + use super::*; + use arrow_array::StringArray; + + #[tokio::test] + async fn test_rrf_reranker() { + let schema = Arc::new(Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new(ROW_ID, DataType::UInt64, false), + ])); + + let vec_results = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"])), + Arc::new(UInt64Array::from(vec![1, 4, 2, 5, 3])), + ], + ) + .unwrap(); + + let fts_results = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["bar", "bean", "dog"])), + Arc::new(UInt64Array::from(vec![4, 5, 3])), + ], + ) + .unwrap(); + + // scores should be calculated as: + // - foo = 1/1 = 1.0 + // - bar = 1/2 + 1/1 = 1.5 + // - baz = 1/3 = 0.333 + // - bean = 1/4 + 1/2 = 0.75 + // - dog = 1/5 + 1/3 = 0.533 + // then we should get the result ranked in descending order + + let reranker = RRFReranker::new(1.0); + + let result = reranker + .rerank_hybrid("", vec_results, fts_results) + .await + .unwrap(); + + assert_eq!(3, result.schema().fields().len()); + assert_eq!("name", result.schema().fields().first().unwrap().name()); + assert_eq!(ROW_ID, result.schema().fields().get(1).unwrap().name()); + assert_eq!( + RELEVANCE_SCORE, + result.schema().fields().get(2).unwrap().name() + ); + + let names: StringArray = downcast_array(result.column(0)); + assert_eq!( + names.iter().map(|e| e.unwrap()).collect::>(), + vec!["bar", "foo", "bean", "dog", "baz"] + ); + + let ids: UInt64Array = downcast_array(result.column(1)); + assert_eq!( + ids.iter().map(|e| e.unwrap()).collect::>(), + vec![4, 1, 5, 3, 2] + ); + + let scores: Float32Array = downcast_array(result.column(2)); + assert_eq!( + scores.iter().map(|e| e.unwrap()).collect::>(), + vec![1.5, 1.0, 0.75, 1.0 / 5.0 + 1.0 / 3.0, 1.0 / 3.0] + ); + } +}