From c9f248b0588ad0a2492365bd74149662bdd45fb3 Mon Sep 17 00:00:00 2001 From: Bert Date: Mon, 30 Dec 2024 09:03:41 -0500 Subject: [PATCH] feat: add hybrid search to node and rust SDKs (#1940) Support hybrid search in both rust and node SDKs. - Adds a new rerankers package to rust LanceDB, with the implementation of the default RRF reranker - Adds a new hybrid package to lancedb, with some helper methods related to hybrid search such as normalizing scores and converting score column to rank columns - Adds capability to LanceDB VectorQuery to perform hybrid search if it has both a nearest vector and full text search parameters. - Adds wrappers for reranker implementations to nodejs SDK. Additional rerankers will be added in followup PRs https://github.com/lancedb/lancedb/issues/1921 --- Notes about how the rust rerankers are wrapped for calling from JS: I wanted to keep the core reranker logic, and the invocation of the reranker by the query code, in Rust. This aligns with the philosophy of the new node SDK where it's just a thin wrapper around Rust. However, I also wanted to have support for users who want to add custom rerankers written in Javascript. When we add a reranker to the query from Javascript, it adds a special Rust reranker that has a callback to the Javascript code (which could then turn around and call an underlying Rust reranker implementation if desired). This adds a bit of complexity, but overall I think it moves us in the right direction of having the majority of the query logic in the underlying Rust SDK while keeping the option open to support custom Javascript Rerankers. --- nodejs/Cargo.toml | 3 + nodejs/__test__/arrow.test.ts | 25 +++ nodejs/__test__/rerankers.test.ts | 79 +++++++ nodejs/lancedb/arrow.ts | 26 +++ nodejs/lancedb/index.ts | 1 + nodejs/lancedb/query.ts | 24 +++ nodejs/lancedb/rerankers/index.ts | 17 ++ nodejs/lancedb/rerankers/rrf.ts | 40 ++++ nodejs/src/lib.rs | 1 + nodejs/src/query.rs | 12 ++ nodejs/src/rerankers.rs | 147 +++++++++++++ rust/lancedb/src/lib.rs | 1 + rust/lancedb/src/query.rs | 269 ++++++++++++++++++++++- rust/lancedb/src/query/hybrid.rs | 346 ++++++++++++++++++++++++++++++ rust/lancedb/src/rerankers.rs | 87 ++++++++ rust/lancedb/src/rerankers/rrf.rs | 223 +++++++++++++++++++ 16 files changed, 1295 insertions(+), 6 deletions(-) create mode 100644 nodejs/__test__/rerankers.test.ts create mode 100644 nodejs/lancedb/rerankers/index.ts create mode 100644 nodejs/lancedb/rerankers/rrf.ts create mode 100644 nodejs/src/rerankers.rs create mode 100644 rust/lancedb/src/query/hybrid.rs create mode 100644 rust/lancedb/src/rerankers.rs create mode 100644 rust/lancedb/src/rerankers/rrf.rs diff --git a/nodejs/Cargo.toml b/nodejs/Cargo.toml index cb9304a0..ad5aa67e 100644 --- a/nodejs/Cargo.toml +++ b/nodejs/Cargo.toml @@ -12,7 +12,10 @@ categories.workspace = true crate-type = ["cdylib"] [dependencies] +async-trait.workspace = true arrow-ipc.workspace = true +arrow-array.workspace = true +arrow-schema.workspace = true env_logger.workspace = true futures.workspace = true lancedb = { path = "../rust/lancedb", features = ["remote"] } diff --git a/nodejs/__test__/arrow.test.ts b/nodejs/__test__/arrow.test.ts index c7cb3a20..4907063d 100644 --- a/nodejs/__test__/arrow.test.ts +++ b/nodejs/__test__/arrow.test.ts @@ -20,6 +20,8 @@ import * as arrow18 from "apache-arrow-18"; import { convertToTable, + fromBufferToRecordBatch, + fromRecordBatchToBuffer, fromTableToBuffer, makeArrowTable, makeEmptyTable, @@ -553,5 +555,28 @@ describe.each([arrow15, arrow16, arrow17, arrow18])( }); }); }); + + describe("converting record batches to buffers", function () { + it("can convert to buffered record batch and back again", async function () { + const records = [ + { text: "dog", vector: [0.1, 0.2] }, + { text: "cat", vector: [0.3, 0.4] }, + ]; + const table = await convertToTable(records); + const batch = table.batches[0]; + + const buffer = await fromRecordBatchToBuffer(batch); + const result = await fromBufferToRecordBatch(buffer); + + expect(JSON.stringify(batch.toArray())).toEqual( + JSON.stringify(result?.toArray()), + ); + }); + + it("converting from buffer returns null if buffer has no record batches", async function () { + const result = await fromBufferToRecordBatch(Buffer.from([0x01, 0x02])); // bad data + expect(result).toEqual(null); + }); + }); }, ); diff --git a/nodejs/__test__/rerankers.test.ts b/nodejs/__test__/rerankers.test.ts new file mode 100644 index 00000000..6a742ca2 --- /dev/null +++ b/nodejs/__test__/rerankers.test.ts @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +import { RecordBatch } from "apache-arrow"; +import * as tmp from "tmp"; +import { Connection, Index, Table, connect, makeArrowTable } from "../lancedb"; +import { RRFReranker } from "../lancedb/rerankers"; + +describe("rerankers", function () { + let tmpDir: tmp.DirResult; + let conn: Connection; + let table: Table; + + beforeEach(async () => { + tmpDir = tmp.dirSync({ unsafeCleanup: true }); + conn = await connect(tmpDir.name); + table = await conn.createTable("mytable", [ + { vector: [0.1, 0.1], text: "dog" }, + { vector: [0.2, 0.2], text: "cat" }, + ]); + await table.createIndex("text", { + config: Index.fts(), + replace: true, + }); + }); + + it("will query with the custom reranker", async function () { + const expectedResult = [ + { + text: "albert", + // biome-ignore lint/style/useNamingConvention: this is the lance field name + _relevance_score: 0.99, + }, + ]; + class MyCustomReranker { + async rerankHybrid( + _query: string, + _vecResults: RecordBatch, + _ftsResults: RecordBatch, + ): Promise { + // no reranker logic, just return some static data + const table = makeArrowTable(expectedResult); + return table.batches[0]; + } + } + + let result = await table + .query() + .nearestTo([0.1, 0.1]) + .fullTextSearch("dog") + .rerank(new MyCustomReranker()) + .select(["text"]) + .limit(5) + .toArray(); + + result = JSON.parse(JSON.stringify(result)); // convert StructRow to Object + expect(result).toEqual([ + { + text: "albert", + // biome-ignore lint/style/useNamingConvention: this is the lance field name + _relevance_score: 0.99, + }, + ]); + }); + + it("will query with RRFReranker", async function () { + // smoke test to see if the Rust wrapping Typescript is wired up correctly + const result = await table + .query() + .nearestTo([0.1, 0.1]) + .fullTextSearch("dog") + .rerank(await RRFReranker.create()) + .select(["text"]) + .limit(5) + .toArray(); + + expect(result).toHaveLength(2); + }); +}); diff --git a/nodejs/lancedb/arrow.ts b/nodejs/lancedb/arrow.ts index cd015ca4..6de51ca8 100644 --- a/nodejs/lancedb/arrow.ts +++ b/nodejs/lancedb/arrow.ts @@ -27,7 +27,9 @@ import { List, Null, RecordBatch, + RecordBatchFileReader, RecordBatchFileWriter, + RecordBatchReader, RecordBatchStreamWriter, Schema, Struct, @@ -810,6 +812,30 @@ export async function fromDataToBuffer( } } +/** + * Read a single record batch from a buffer. + * + * Returns null if the buffer does not contain a record batch + */ +export async function fromBufferToRecordBatch( + data: Buffer, +): Promise { + const iter = await RecordBatchFileReader.readAll(Buffer.from(data)).next() + .value; + const recordBatch = iter?.next().value; + return recordBatch || null; +} + +/** + * Create a buffer containing a single record batch + */ +export async function fromRecordBatchToBuffer( + batch: RecordBatch, +): Promise { + const writer = new RecordBatchFileWriter().writeAll([batch]); + return Buffer.from(await writer.toUint8Array()); +} + /** * Serialize an Arrow Table into a buffer using the Arrow IPC Stream serialization * diff --git a/nodejs/lancedb/index.ts b/nodejs/lancedb/index.ts index 74da915f..e9a0abcb 100644 --- a/nodejs/lancedb/index.ts +++ b/nodejs/lancedb/index.ts @@ -62,6 +62,7 @@ export { Index, IndexOptions, IvfPqOptions } from "./indices"; export { Table, AddDataOptions, UpdateOptions, OptimizeOptions } from "./table"; export * as embedding from "./embedding"; +export * as rerankers from "./rerankers"; /** * Connect to a LanceDB instance at the given URI. diff --git a/nodejs/lancedb/query.ts b/nodejs/lancedb/query.ts index 25fabf70..aa4b560f 100644 --- a/nodejs/lancedb/query.ts +++ b/nodejs/lancedb/query.ts @@ -16,6 +16,8 @@ import { Table as ArrowTable, type IntoVector, RecordBatch, + fromBufferToRecordBatch, + fromRecordBatchToBuffer, tableFromIPC, } from "./arrow"; import { type IvfPqOptions } from "./indices"; @@ -25,6 +27,7 @@ import { Table as NativeTable, VectorQuery as NativeVectorQuery, } from "./native"; +import { Reranker } from "./rerankers"; export class RecordBatchIterator implements AsyncIterator { private promisedInner?: Promise; private inner?: NativeBatchIterator; @@ -542,6 +545,27 @@ export class VectorQuery extends QueryBase { return this; } } + + rerank(reranker: Reranker): VectorQuery { + super.doCall((inner) => + inner.rerank({ + rerankHybrid: async (_, args) => { + const vecResults = await fromBufferToRecordBatch(args.vecResults); + const ftsResults = await fromBufferToRecordBatch(args.ftsResults); + const result = await reranker.rerankHybrid( + args.query, + vecResults as RecordBatch, + ftsResults as RecordBatch, + ); + + const buffer = fromRecordBatchToBuffer(result); + return buffer; + }, + }), + ); + + return this; + } } /** A builder for LanceDB queries. */ diff --git a/nodejs/lancedb/rerankers/index.ts b/nodejs/lancedb/rerankers/index.ts new file mode 100644 index 00000000..653499e3 --- /dev/null +++ b/nodejs/lancedb/rerankers/index.ts @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +import { RecordBatch } from "apache-arrow"; + +export * from "./rrf"; + +// Interface for a reranker. A reranker is used to rerank the results from a +// vector and FTS search. This is useful for combining the results from both +// search methods. +export interface Reranker { + rerankHybrid( + query: string, + vecResults: RecordBatch, + ftsResults: RecordBatch, + ): Promise; +} diff --git a/nodejs/lancedb/rerankers/rrf.ts b/nodejs/lancedb/rerankers/rrf.ts new file mode 100644 index 00000000..1d89c076 --- /dev/null +++ b/nodejs/lancedb/rerankers/rrf.ts @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +import { RecordBatch } from "apache-arrow"; +import { fromBufferToRecordBatch, fromRecordBatchToBuffer } from "../arrow"; +import { RrfReranker as NativeRRFReranker } from "../native"; + +/** + * Reranks the results using the Reciprocal Rank Fusion (RRF) algorithm. + * + * Internally this uses the Rust implementation + */ +export class RRFReranker { + private inner: NativeRRFReranker; + + constructor(inner: NativeRRFReranker) { + this.inner = inner; + } + + public static async create(k: number = 60) { + return new RRFReranker( + await NativeRRFReranker.tryNew(new Float32Array([k])), + ); + } + + async rerankHybrid( + query: string, + vecResults: RecordBatch, + ftsResults: RecordBatch, + ): Promise { + const buffer = await this.inner.rerankHybrid( + query, + await fromRecordBatchToBuffer(vecResults), + await fromRecordBatchToBuffer(ftsResults), + ); + const recordBatch = await fromBufferToRecordBatch(buffer); + + return recordBatch as RecordBatch; + } +} diff --git a/nodejs/src/lib.rs b/nodejs/src/lib.rs index d0a02ee4..0ac75b7c 100644 --- a/nodejs/src/lib.rs +++ b/nodejs/src/lib.rs @@ -24,6 +24,7 @@ mod iterator; pub mod merge; mod query; pub mod remote; +mod rerankers; mod table; mod util; diff --git a/nodejs/src/query.rs b/nodejs/src/query.rs index fd8e3b48..321e4052 100644 --- a/nodejs/src/query.rs +++ b/nodejs/src/query.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use lancedb::index::scalar::FullTextSearchQuery; use lancedb::query::ExecutableQuery; use lancedb::query::Query as LanceDbQuery; @@ -25,6 +27,8 @@ use napi_derive::napi; use crate::error::convert_error; use crate::error::NapiErrorExt; use crate::iterator::RecordBatchIterator; +use crate::rerankers::Reranker; +use crate::rerankers::RerankerCallbacks; use crate::util::parse_distance_type; #[napi] @@ -218,6 +222,14 @@ impl VectorQuery { self.inner = self.inner.clone().with_row_id(); } + #[napi] + pub fn rerank(&mut self, callbacks: RerankerCallbacks) { + self.inner = self + .inner + .clone() + .rerank(Arc::new(Reranker::new(callbacks))); + } + #[napi(catch_unwind)] pub async fn execute( &self, diff --git a/nodejs/src/rerankers.rs b/nodejs/src/rerankers.rs new file mode 100644 index 00000000..bf8a4d96 --- /dev/null +++ b/nodejs/src/rerankers.rs @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use arrow_array::RecordBatch; +use async_trait::async_trait; +use napi::{ + bindgen_prelude::*, + threadsafe_function::{ErrorStrategy, ThreadsafeFunction}, +}; +use napi_derive::napi; + +use lancedb::ipc::batches_to_ipc_file; +use lancedb::rerankers::Reranker as LanceDBReranker; +use lancedb::{error::Error, ipc::ipc_file_to_batches}; + +use crate::error::NapiErrorExt; + +/// Reranker implementation that "wraps" a NodeJS Reranker implementation. +/// This contains references to the callbacks that can be used to invoke the +/// reranking methods on the NodeJS implementation and handles serializing the +/// record batches to Arrow IPC buffers. +#[napi] +pub struct Reranker { + /// callback to the Javascript which will call the rerankHybrid method of + /// some Reranker implementation + rerank_hybrid: ThreadsafeFunction, +} + +#[napi] +impl Reranker { + #[napi] + pub fn new(callbacks: RerankerCallbacks) -> Self { + let rerank_hybrid = callbacks + .rerank_hybrid + .create_threadsafe_function(0, move |ctx| Ok(vec![ctx.value])) + .unwrap(); + + Self { rerank_hybrid } + } +} + +#[async_trait] +impl lancedb::rerankers::Reranker for Reranker { + async fn rerank_hybrid( + &self, + query: &str, + vector_results: RecordBatch, + fts_results: RecordBatch, + ) -> lancedb::error::Result { + let callback_args = RerankHybridCallbackArgs { + query: query.to_string(), + vec_results: batches_to_ipc_file(&[vector_results])?, + fts_results: batches_to_ipc_file(&[fts_results])?, + }; + let promised_buffer: Promise = self + .rerank_hybrid + .call_async(Ok(callback_args)) + .await + .map_err(|e| Error::Runtime { + message: format!("napi error status={}, reason={}", e.status, e.reason), + })?; + let buffer = promised_buffer.await.map_err(|e| Error::Runtime { + message: format!("napi error status={}, reason={}", e.status, e.reason), + })?; + let mut reader = ipc_file_to_batches(buffer.to_vec())?; + let result = reader.next().ok_or(Error::Runtime { + message: "reranker result deserialization failed".to_string(), + })??; + + return Ok(result); + } +} + +impl std::fmt::Debug for Reranker { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("NodeJSRerankerWrapper") + } +} + +#[napi(object)] +pub struct RerankerCallbacks { + pub rerank_hybrid: JsFunction, +} + +#[napi(object)] +pub struct RerankHybridCallbackArgs { + pub query: String, + pub vec_results: Vec, + pub fts_results: Vec, +} + +fn buffer_to_record_batch(buffer: Buffer) -> Result { + let mut reader = ipc_file_to_batches(buffer.to_vec()).default_error()?; + reader + .next() + .ok_or(Error::InvalidInput { + message: "expected buffer containing record batch".to_string(), + }) + .default_error()? + .map_err(Error::from) + .default_error() +} + +/// Wrapper around rust RRFReranker +#[napi] +pub struct RRFReranker { + inner: lancedb::rerankers::rrf::RRFReranker, +} + +#[napi] +impl RRFReranker { + #[napi] + pub async fn try_new(k: &[f32]) -> Result { + let k = k + .first() + .copied() + .ok_or(Error::InvalidInput { + message: "must supply RRF Reranker constructor arg 'k'".to_string(), + }) + .default_error()?; + + Ok(Self { + inner: lancedb::rerankers::rrf::RRFReranker::new(k), + }) + } + + #[napi] + pub async fn rerank_hybrid( + &self, + query: String, + vec_results: Buffer, + fts_results: Buffer, + ) -> Result { + let vec_results = buffer_to_record_batch(vec_results)?; + let fts_results = buffer_to_record_batch(fts_results)?; + + let result = self + .inner + .rerank_hybrid(&query, vec_results, fts_results) + .await + .unwrap(); + + let result_buff = batches_to_ipc_file(&[result]).default_error()?; + + Ok(Buffer::from(result_buff.as_ref())) + } +} diff --git a/rust/lancedb/src/lib.rs b/rust/lancedb/src/lib.rs index 8edca5a8..9ff37ccb 100644 --- a/rust/lancedb/src/lib.rs +++ b/rust/lancedb/src/lib.rs @@ -214,6 +214,7 @@ mod polars_arrow_convertors; pub mod query; #[cfg(feature = "remote")] pub mod remote; +pub mod rerankers; pub mod table; pub mod utils; diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index e9add758..86d39392 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -15,19 +15,31 @@ use std::future::Future; use std::sync::Arc; +use arrow::compute::concat_batches; use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array}; use arrow_schema::DataType; use datafusion_physical_plan::ExecutionPlan; +use futures::{stream, try_join, FutureExt, TryStreamExt}; use half::f16; -use lance::dataset::scanner::DatasetRecordBatchStream; +use lance::{ + arrow::RecordBatchExt, + dataset::{scanner::DatasetRecordBatchStream, ROW_ID}, +}; use lance_datafusion::exec::execute_plan; +use lance_index::scalar::inverted::SCORE_COL; use lance_index::scalar::FullTextSearchQuery; +use lance_index::vector::DIST_COL; +use lance_io::stream::RecordBatchStreamAdapter; use crate::arrow::SendableRecordBatchStream; use crate::error::{Error, Result}; +use crate::rerankers::rrf::RRFReranker; +use crate::rerankers::{check_reranker_result, NormalizeMethod, Reranker}; use crate::table::TableInternal; use crate::DistanceType; +mod hybrid; + pub(crate) const DEFAULT_TOP_K: usize = 10; /// Which columns should be retrieved from the database @@ -435,6 +447,16 @@ pub trait QueryBase { /// Return the `_rowid` meta column from the Table. fn with_row_id(self) -> Self; + + /// Rerank the results using the specified reranker. + /// + /// This is currently only supported for Hybrid Search. + fn rerank(self, reranker: Arc) -> Self; + + /// The method to normalize the scores. Can be "rank" or "Score". If "Rank", + /// the scores are converted to ranks and then normalized. If "Score", the + /// scores are normalized directly. + fn norm(self, norm: NormalizeMethod) -> Self; } pub trait HasQuery { @@ -481,6 +503,16 @@ impl QueryBase for T { self.mut_query().with_row_id = true; self } + + fn rerank(mut self, reranker: Arc) -> Self { + self.mut_query().reranker = Some(reranker); + self + } + + fn norm(mut self, norm: NormalizeMethod) -> Self { + self.mut_query().norm = Some(norm); + self + } } /// Options for controlling the execution of a query @@ -600,6 +632,13 @@ pub struct Query { /// If set to false, the filter will be applied after the vector search. pub(crate) prefilter: bool, + + /// Implementation of reranker that can be used to reorder or combine query + /// results, especially if using hybrid search + pub(crate) reranker: Option>, + + /// Configure how query results are normalized when doing hybrid search + pub(crate) norm: Option, } impl Query { @@ -614,6 +653,8 @@ impl Query { fast_search: false, with_row_id: false, prefilter: true, + reranker: None, + norm: None, } } @@ -862,6 +903,65 @@ impl VectorQuery { self.use_index = false; self } + + pub async fn execute_hybrid(&self) -> Result { + // clone query and specify we want to include row IDs, which can be needed for reranking + let fts_query = self.base.clone().with_row_id(); + let mut vector_query = self.clone().with_row_id(); + + vector_query.base.full_text_search = None; + let (fts_results, vec_results) = try_join!(fts_query.execute(), vector_query.execute())?; + + let (fts_results, vec_results) = try_join!( + fts_results.try_collect::>(), + vec_results.try_collect::>() + )?; + + // try to get the schema to use when combining batches. + // if either + let (fts_schema, vec_schema) = hybrid::query_schemas(&fts_results, &vec_results); + + // concatenate all the batches together + let mut fts_results = concat_batches(&fts_schema, fts_results.iter())?; + let mut vec_results = concat_batches(&vec_schema, vec_results.iter())?; + + if matches!(self.base.norm, Some(NormalizeMethod::Rank)) { + vec_results = hybrid::rank(vec_results, DIST_COL, None)?; + fts_results = hybrid::rank(fts_results, SCORE_COL, None)?; + } + + vec_results = hybrid::normalize_scores(vec_results, DIST_COL, None)?; + fts_results = hybrid::normalize_scores(fts_results, SCORE_COL, None)?; + + let reranker = self + .base + .reranker + .clone() + .unwrap_or(Arc::new(RRFReranker::default())); + + let fts_query = self.base.full_text_search.as_ref().ok_or(Error::Runtime { + message: "there should be an FTS search".to_string(), + })?; + + let mut results = reranker + .rerank_hybrid(&fts_query.query, vec_results, fts_results) + .await?; + + check_reranker_result(&results)?; + + let limit = self.base.limit.unwrap_or(DEFAULT_TOP_K); + if results.num_rows() > limit { + results = results.slice(0, limit); + } + + if !self.base.with_row_id { + results = results.drop_column(ROW_ID)?; + } + + Ok(SendableRecordBatchStream::from( + RecordBatchStreamAdapter::new(results.schema(), stream::iter([Ok(results)])), + )) + } } impl ExecutableQuery for VectorQuery { @@ -873,6 +973,11 @@ impl ExecutableQuery for VectorQuery { &self, options: QueryExecutionOptions, ) -> Result { + if self.base.full_text_search.is_some() { + let hybrid_result = async move { self.execute_hybrid().await }.boxed().await?; + return Ok(hybrid_result); + } + Ok(SendableRecordBatchStream::from( DatasetRecordBatchStream::new(execute_plan( self.create_plan(options).await?, @@ -894,20 +999,20 @@ impl HasQuery for VectorQuery { #[cfg(test)] mod tests { - use std::sync::Arc; + use std::{collections::HashSet, sync::Arc}; use super::*; - use arrow::{compute::concat_batches, datatypes::Int32Type}; + use arrow::{array::downcast_array, compute::concat_batches, datatypes::Int32Type}; use arrow_array::{ - cast::AsArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator, - RecordBatchReader, + cast::AsArray, types::Float32Type, FixedSizeListArray, Float32Array, Int32Array, + RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray, }; use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema}; use futures::{StreamExt, TryStreamExt}; use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector}; use tempfile::tempdir; - use crate::{connect, Table}; + use crate::{connect, connection::CreateTableMode, Table}; #[tokio::test] async fn test_setters_getters() { @@ -1274,4 +1379,156 @@ mod tests { assert!(query_index.values().contains(&0)); assert!(query_index.values().contains(&1)); } + + #[tokio::test] + async fn test_hybrid_search() { + let tmp_dir = tempdir().unwrap(); + let dataset_path = tmp_dir.path(); + let conn = connect(dataset_path.to_str().unwrap()) + .execute() + .await + .unwrap(); + + let dims = 2; + let schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("text", DataType::Utf8, false), + ArrowField::new( + "vector", + DataType::FixedSizeList( + Arc::new(ArrowField::new("item", DataType::Float32, true)), + dims, + ), + false, + ), + ])); + + let text = StringArray::from(vec!["dog", "cat", "a", "b"]); + let vectors = vec![ + Some(vec![Some(0.0), Some(0.0)]), + Some(vec![Some(-2.0), Some(-2.0)]), + Some(vec![Some(50.0), Some(50.0)]), + Some(vec![Some(-30.0), Some(-30.0)]), + ]; + let vector = FixedSizeListArray::from_iter_primitive::(vectors, dims); + + let record_batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(text), Arc::new(vector)]).unwrap(); + let record_batch_iter = + RecordBatchIterator::new(vec![record_batch].into_iter().map(Ok), schema.clone()); + let table = conn + .create_table("my_table", record_batch_iter) + .execute() + .await + .unwrap(); + + table + .create_index(&["text"], crate::index::Index::FTS(Default::default())) + .replace(true) + .execute() + .await + .unwrap(); + + let fts_query = FullTextSearchQuery::new("b".to_string()); + let results = table + .query() + .full_text_search(fts_query) + .limit(2) + .nearest_to(&[-10.0, -10.0]) + .unwrap() + .execute() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + + let batch = &results[0]; + + let texts: StringArray = downcast_array(batch.column_by_name("text").unwrap()); + let texts = texts.iter().map(|e| e.unwrap()).collect::>(); + assert!(texts.contains("cat")); // should be close by vector search + assert!(texts.contains("b")); // should be close by fts search + + // ensure that this works correctly if there are no matching FTS results + let fts_query = FullTextSearchQuery::new("z".to_string()); + table + .query() + .full_text_search(fts_query) + .limit(2) + .nearest_to(&[-10.0, -10.0]) + .unwrap() + .execute() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + } + + #[tokio::test] + async fn test_hybrid_search_empty_table() { + let tmp_dir = tempdir().unwrap(); + let dataset_path = tmp_dir.path(); + let conn = connect(dataset_path.to_str().unwrap()) + .execute() + .await + .unwrap(); + + let dims = 2; + + let schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("text", DataType::Utf8, false), + ArrowField::new( + "vector", + DataType::FixedSizeList( + Arc::new(ArrowField::new("item", DataType::Float32, true)), + dims, + ), + false, + ), + ])); + + // ensure hybrid search is also supported on a fully empty table + let vectors: Vec>>> = Vec::new(); + let record_batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(Vec::<&str>::new())), + Arc::new( + FixedSizeListArray::from_iter_primitive::(vectors, dims), + ), + ], + ) + .unwrap(); + let record_batch_iter = + RecordBatchIterator::new(vec![record_batch].into_iter().map(Ok), schema.clone()); + let table = conn + .create_table("my_table", record_batch_iter) + .mode(CreateTableMode::Overwrite) + .execute() + .await + .unwrap(); + table + .create_index(&["text"], crate::index::Index::FTS(Default::default())) + .replace(true) + .execute() + .await + .unwrap(); + let fts_query = FullTextSearchQuery::new("b".to_string()); + let results = table + .query() + .full_text_search(fts_query) + .limit(2) + .nearest_to(&[-10.0, -10.0]) + .unwrap() + .execute() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + let batch = &results[0]; + assert_eq!(0, batch.num_rows()); + assert_eq!(2, batch.num_columns()); + } } diff --git a/rust/lancedb/src/query/hybrid.rs b/rust/lancedb/src/query/hybrid.rs new file mode 100644 index 00000000..9e85e870 --- /dev/null +++ b/rust/lancedb/src/query/hybrid.rs @@ -0,0 +1,346 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use arrow::compute::{ + kernels::numeric::{div, sub}, + max, min, +}; +use arrow_array::{cast::downcast_array, Float32Array, RecordBatch}; +use arrow_schema::{DataType, Field, Schema, SortOptions}; +use lance::dataset::ROW_ID; +use lance_index::{scalar::inverted::SCORE_COL, vector::DIST_COL}; +use std::sync::Arc; + +use crate::error::{Error, Result}; + +/// Converts results's score column to a rank. +/// +/// Expects the `column` argument to be type Float32 and will panic if it's not +pub fn rank(results: RecordBatch, column: &str, ascending: Option) -> Result { + let scores = results.column_by_name(column).ok_or(Error::InvalidInput { + message: format!( + "expected column {} not found in rank. found columns {:?}", + column, + results + .schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>(), + ), + })?; + + if results.num_rows() == 0 { + return Ok(results); + } + + let scores: Float32Array = downcast_array(scores); + let ranks = Float32Array::from_iter_values( + arrow::compute::kernels::rank::rank( + &scores, + Some(SortOptions { + descending: !ascending.unwrap_or(true), + ..Default::default() + }), + )? + .iter() + .map(|i| *i as f32), + ); + + let schema = results.schema(); + let (column_idx, _) = schema.column_with_name(column).unwrap(); + let mut columns = results.columns().to_vec(); + columns[column_idx] = Arc::new(ranks); + + let results = RecordBatch::try_new(results.schema(), columns)?; + + Ok(results) +} + +/// Get the query schemas needed when combining the search results. +/// +/// If either of the record batches are empty, then we create a schema from the +/// other record batch, and replace the score/distance column. If both record +/// batches are empty, create empty schemas. +pub fn query_schemas( + fts_results: &[RecordBatch], + vec_results: &[RecordBatch], +) -> (Arc, Arc) { + let (fts_schema, vec_schema) = match ( + fts_results.first().map(|r| r.schema()), + vec_results.first().map(|r| r.schema()), + ) { + (Some(fts_schema), Some(vec_schema)) => (fts_schema, vec_schema), + (None, Some(vec_schema)) => { + let fts_schema = with_field_name_replaced(&vec_schema, DIST_COL, SCORE_COL); + (Arc::new(fts_schema), vec_schema) + } + (Some(fts_schema), None) => { + let vec_schema = with_field_name_replaced(&fts_schema, DIST_COL, SCORE_COL); + (fts_schema, Arc::new(vec_schema)) + } + (None, None) => (Arc::new(empty_fts_schema()), Arc::new(empty_vec_schema())), + }; + + (fts_schema, vec_schema) +} + +pub fn empty_fts_schema() -> Schema { + Schema::new(vec![ + Arc::new(Field::new(SCORE_COL, DataType::Float32, false)), + Arc::new(Field::new(ROW_ID, DataType::UInt64, false)), + ]) +} + +pub fn empty_vec_schema() -> Schema { + Schema::new(vec![ + Arc::new(Field::new(DIST_COL, DataType::Float32, false)), + Arc::new(Field::new(ROW_ID, DataType::UInt64, false)), + ]) +} + +pub fn with_field_name_replaced(schema: &Schema, target: &str, replacement: &str) -> Schema { + let field_idx = schema.fields().iter().enumerate().find_map(|(i, field)| { + if field.name() == target { + Some(i) + } else { + None + } + }); + + let mut fields = schema.fields().to_vec(); + if let Some(idx) = field_idx { + let new_field = (*fields[idx]).clone().with_name(replacement); + fields[idx] = Arc::new(new_field); + } + + Schema::new(fields) +} + +/// Normalize the scores column to have values between 0 and 1. +/// +/// Expects the `column` argument to be type Float32 and will panic if it's not +pub fn normalize_scores( + results: RecordBatch, + column: &str, + invert: Option, +) -> Result { + let scores = results.column_by_name(column).ok_or(Error::InvalidInput { + message: format!( + "expected column {} not found in rank. found columns {:?}", + column, + results + .schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>(), + ), + })?; + + if results.num_rows() == 0 { + return Ok(results); + } + let mut scores: Float32Array = downcast_array(scores); + + let max = max(&scores).unwrap_or(0.0); + let min = min(&scores).unwrap_or(0.0); + + // this is equivalent to np.isclose which is used in python + let rng = if max - min < 10e-5 { max } else { max - min }; + + // if rng is 0, then min and max are both 0 so we just leave the scores as is + if rng != 0.0 { + let tmp = div( + &sub(&scores, &Float32Array::new_scalar(min))?, + &Float32Array::new_scalar(rng), + )?; + scores = downcast_array(&tmp); + } + + if invert.unwrap_or(false) { + let tmp = sub(&Float32Array::new_scalar(1.0), &scores)?; + scores = downcast_array(&tmp); + } + + let schema = results.schema(); + let (column_idx, _) = schema.column_with_name(column).unwrap(); + let mut columns = results.columns().to_vec(); + columns[column_idx] = Arc::new(scores); + + let results = RecordBatch::try_new(results.schema(), columns).unwrap(); + + Ok(results) +} + +#[cfg(test)] +mod test { + use super::*; + use arrow_array::StringArray; + use arrow_schema::{DataType, Field, Schema}; + + #[test] + fn test_rank() { + let schema = Arc::new(Schema::new(vec![ + Arc::new(Field::new("name", DataType::Utf8, false)), + Arc::new(Field::new("score", DataType::Float32, false)), + ])); + + let names = StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"]); + let scores = Float32Array::from(vec![0.2, 0.4, 0.1, 0.6, 0.45]); + + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(names), Arc::new(scores)]).unwrap(); + + let result = rank(batch.clone(), "score", Some(false)).unwrap(); + assert_eq!(2, result.schema().fields().len()); + assert_eq!("name", result.schema().field(0).name()); + assert_eq!("score", result.schema().field(1).name()); + + let names: StringArray = downcast_array(result.column(0)); + assert_eq!( + names.iter().map(|e| e.unwrap()).collect::>(), + vec!["foo", "bar", "baz", "bean", "dog"] + ); + let scores: Float32Array = downcast_array(result.column(1)); + assert_eq!( + scores.iter().map(|e| e.unwrap()).collect::>(), + vec![4.0, 3.0, 5.0, 1.0, 2.0] + ); + + // check sort ascending + let result = rank(batch.clone(), "score", Some(true)).unwrap(); + let names: StringArray = downcast_array(result.column(0)); + assert_eq!( + names.iter().map(|e| e.unwrap()).collect::>(), + vec!["foo", "bar", "baz", "bean", "dog"] + ); + let scores: Float32Array = downcast_array(result.column(1)); + assert_eq!( + scores.iter().map(|e| e.unwrap()).collect::>(), + vec![2.0, 3.0, 1.0, 5.0, 4.0] + ); + + // ensure default sort is ascending + let result = rank(batch.clone(), "score", None).unwrap(); + let names: StringArray = downcast_array(result.column(0)); + assert_eq!( + names.iter().map(|e| e.unwrap()).collect::>(), + vec!["foo", "bar", "baz", "bean", "dog"] + ); + let scores: Float32Array = downcast_array(result.column(1)); + assert_eq!( + scores.iter().map(|e| e.unwrap()).collect::>(), + vec![2.0, 3.0, 1.0, 5.0, 4.0] + ); + + // check it can handle an empty batch + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(Vec::<&str>::new())), + Arc::new(Float32Array::from(Vec::::new())), + ], + ) + .unwrap(); + let result = rank(batch.clone(), "score", None).unwrap(); + assert_eq!(0, result.num_rows()); + assert_eq!(2, result.schema().fields().len()); + assert_eq!("name", result.schema().field(0).name()); + assert_eq!("score", result.schema().field(1).name()); + + // check it returns the expected error when there's no column + let result = rank(batch.clone(), "bad_col", None); + match result { + Err(Error::InvalidInput { message }) => { + assert_eq!("expected column bad_col not found in rank. found columns [\"name\", \"score\"]", message); + } + _ => { + panic!("expected invalid input error, received {:?}", result) + } + } + } + + #[test] + fn test_normalize_scores() { + let schema = Arc::new(Schema::new(vec![ + Arc::new(Field::new("name", DataType::Utf8, false)), + Arc::new(Field::new("score", DataType::Float32, false)), + ])); + + let names = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"])); + let scores = Arc::new(Float32Array::from(vec![-4.0, 2.0, 0.0, 3.0, 6.0])); + + let batch = + RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap(); + + let result = normalize_scores(batch.clone(), "score", Some(false)).unwrap(); + let names: StringArray = downcast_array(result.column(0)); + assert_eq!( + names.iter().map(|e| e.unwrap()).collect::>(), + vec!["foo", "bar", "baz", "bean", "dog"] + ); + let scores: Float32Array = downcast_array(result.column(1)); + assert_eq!( + scores.iter().map(|e| e.unwrap()).collect::>(), + vec![0.0, 0.6, 0.4, 0.7, 1.0] + ); + + // check it can invert the normalization + let result = normalize_scores(batch.clone(), "score", Some(true)).unwrap(); + let scores: Float32Array = downcast_array(result.column(1)); + assert_eq!( + scores.iter().map(|e| e.unwrap()).collect::>(), + vec![1.0, 1.0 - 0.6, 0.6, 0.3, 0.0] + ); + + // check that the default is not inverted + let result = normalize_scores(batch.clone(), "score", None).unwrap(); + let scores: Float32Array = downcast_array(result.column(1)); + assert_eq!( + scores.iter().map(|e| e.unwrap()).collect::>(), + vec![0.0, 0.6, 0.4, 0.7, 1.0] + ); + + // check that it will function correctly if all the values are the same + let names = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"])); + let scores = Arc::new(Float32Array::from(vec![2.1, 2.1, 2.1, 2.1, 2.1])); + let batch = + RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap(); + let result = normalize_scores(batch.clone(), "score", None).unwrap(); + let scores: Float32Array = downcast_array(result.column(1)); + assert_eq!( + scores.iter().map(|e| e.unwrap()).collect::>(), + vec![0.0, 0.0, 0.0, 0.0, 0.0] + ); + + // check it keeps floating point rounding errors for same score normalized the same + // e.g., the behaviour is consistent with python + let scores = Arc::new(Float32Array::from(vec![1.0, 1.0, 1.0, 1.0, 0.9999999])); + let batch = + RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap(); + let result = normalize_scores(batch.clone(), "score", None).unwrap(); + let scores: Float32Array = downcast_array(result.column(1)); + assert_eq!( + scores.iter().map(|e| e.unwrap()).collect::>(), + vec![ + 1.0 - 0.9999999, + 1.0 - 0.9999999, + 1.0 - 0.9999999, + 1.0 - 0.9999999, + 0.0 + ] + ); + + // check that it can handle if all the scores are 0 + let scores = Arc::new(Float32Array::from(vec![0.0, 0.0, 0.0, 0.0, 0.0])); + let batch = + RecordBatch::try_new(schema.clone(), vec![names.clone(), scores.clone()]).unwrap(); + let result = normalize_scores(batch.clone(), "score", None).unwrap(); + let scores: Float32Array = downcast_array(result.column(1)); + assert_eq!( + scores.iter().map(|e| e.unwrap()).collect::>(), + vec![0.0, 0.0, 0.0, 0.0, 0.0] + ); + } +} diff --git a/rust/lancedb/src/rerankers.rs b/rust/lancedb/src/rerankers.rs new file mode 100644 index 00000000..338230c0 --- /dev/null +++ b/rust/lancedb/src/rerankers.rs @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use std::collections::BTreeSet; + +use arrow::{ + array::downcast_array, + compute::{concat_batches, filter_record_batch}, +}; +use arrow_array::{BooleanArray, RecordBatch, UInt64Array}; +use async_trait::async_trait; +use lance::dataset::ROW_ID; + +use crate::error::{Error, Result}; + +pub mod rrf; + +/// column name for reranker relevance score +const RELEVANCE_SCORE: &str = "_relevance_score"; + +#[derive(Debug, Clone, PartialEq)] +pub enum NormalizeMethod { + Score, + Rank, +} + +/// Interface for a reranker. A reranker is used to rerank the results from a +/// vector and FTS search. This is useful for combining the results from both +/// search methods. +#[async_trait] +pub trait Reranker: std::fmt::Debug + Sync + Send { + // TODO support vector reranking and FTS reranking. Currently only hybrid reranking is supported. + + /// Rerank function receives the individual results from the vector and FTS search + /// results. You can choose to use any of the results to generate the final results, + /// allowing maximum flexibility. + async fn rerank_hybrid( + &self, + query: &str, + vector_results: RecordBatch, + fts_results: RecordBatch, + ) -> Result; + + fn merge_results( + &self, + vector_results: RecordBatch, + fts_results: RecordBatch, + ) -> Result { + let combined = concat_batches(&fts_results.schema(), [vector_results, fts_results].iter())?; + + let mut mask = BooleanArray::builder(combined.num_rows()); + let mut unique_ids = BTreeSet::new(); + let row_ids = combined.column_by_name(ROW_ID).ok_or(Error::InvalidInput { + message: format!( + "could not find expected column {} while merging results. found columns {:?}", + ROW_ID, + combined + .schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>() + ), + })?; + let row_ids: UInt64Array = downcast_array(row_ids); + row_ids.values().iter().for_each(|id| { + mask.append_value(unique_ids.insert(id)); + }); + + let combined = filter_record_batch(&combined, &mask.finish())?; + + Ok(combined) + } +} + +pub fn check_reranker_result(result: &RecordBatch) -> Result<()> { + if result.schema().column_with_name(RELEVANCE_SCORE).is_none() { + return Err(Error::Schema { + message: format!( + "rerank_hybrid must return a RecordBatch with a column named {}", + RELEVANCE_SCORE + ), + }); + } + + Ok(()) +} diff --git a/rust/lancedb/src/rerankers/rrf.rs b/rust/lancedb/src/rerankers/rrf.rs new file mode 100644 index 00000000..7b5d8361 --- /dev/null +++ b/rust/lancedb/src/rerankers/rrf.rs @@ -0,0 +1,223 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +use std::collections::BTreeMap; +use std::sync::Arc; + +use arrow::{ + array::downcast_array, + compute::{sort_to_indices, take}, +}; +use arrow_array::{Float32Array, RecordBatch, UInt64Array}; +use arrow_schema::{DataType, Field, Schema, SortOptions}; +use async_trait::async_trait; +use lance::dataset::ROW_ID; + +use crate::error::{Error, Result}; +use crate::rerankers::{Reranker, RELEVANCE_SCORE}; + +/// Reranks the results using Reciprocal Rank Fusion(RRF) algorithm based +/// on the scores of vector and FTS search. +/// +#[derive(Debug)] +pub struct RRFReranker { + k: f32, +} + +impl RRFReranker { + /// Create a new RRFReranker + /// + /// The parameter k is a constant used in the RRF formula (default is 60). + /// Experiments indicate that k = 60 was near-optimal, but that the choice + /// is not critical. See paper: + /// https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf + pub fn new(k: f32) -> Self { + Self { k } + } +} + +impl Default for RRFReranker { + fn default() -> Self { + Self { k: 60.0 } + } +} + +#[async_trait] +impl Reranker for RRFReranker { + async fn rerank_hybrid( + &self, + _query: &str, + vector_results: RecordBatch, + fts_results: RecordBatch, + ) -> Result { + let vector_ids = vector_results + .column_by_name(ROW_ID) + .ok_or(Error::InvalidInput { + message: format!( + "expected column {} not found in vector_results. found columns {:?}", + ROW_ID, + vector_results + .schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>() + ), + })?; + let fts_ids = fts_results + .column_by_name(ROW_ID) + .ok_or(Error::InvalidInput { + message: format!( + "expected column {} not found in fts_results. found columns {:?}", + ROW_ID, + fts_results + .schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>() + ), + })?; + + let vector_ids: UInt64Array = downcast_array(&vector_ids); + let fts_ids: UInt64Array = downcast_array(&fts_ids); + + let mut rrf_score_map = BTreeMap::new(); + let mut update_score_map = |(i, result_id)| { + let score = 1.0 / (i as f32 + self.k); + rrf_score_map + .entry(result_id) + .and_modify(|e| *e += score) + .or_insert(score); + }; + vector_ids + .values() + .iter() + .enumerate() + .for_each(&mut update_score_map); + fts_ids + .values() + .iter() + .enumerate() + .for_each(&mut update_score_map); + + let combined_results = self.merge_results(vector_results, fts_results)?; + + let combined_row_ids: UInt64Array = + downcast_array(combined_results.column_by_name(ROW_ID).unwrap()); + let relevance_scores = Float32Array::from_iter_values( + combined_row_ids + .values() + .iter() + .map(|row_id| rrf_score_map.get(row_id).unwrap()) + .copied(), + ); + + // keep track of indices sorted by the relevance column + let sort_indices = sort_to_indices( + &relevance_scores, + Some(SortOptions { + descending: true, + ..Default::default() + }), + None, + ) + .unwrap(); + + // add relevance scores to columns + let mut columns = combined_results.columns().to_vec(); + columns.push(Arc::new(relevance_scores)); + + // sort by the relevance scores + let columns = columns + .iter() + .map(|c| take(c, &sort_indices, None).unwrap()) + .collect(); + + // add relevance score to schema + let mut fields = combined_results.schema().fields().to_vec(); + fields.push(Arc::new(Field::new( + RELEVANCE_SCORE, + DataType::Float32, + false, + ))); + let schema = Schema::new(fields); + + let combined_results = RecordBatch::try_new(Arc::new(schema), columns)?; + + Ok(combined_results) + } +} + +#[cfg(test)] +pub mod test { + use super::*; + use arrow_array::StringArray; + + #[tokio::test] + async fn test_rrf_reranker() { + let schema = Arc::new(Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new(ROW_ID, DataType::UInt64, false), + ])); + + let vec_results = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["foo", "bar", "baz", "bean", "dog"])), + Arc::new(UInt64Array::from(vec![1, 4, 2, 5, 3])), + ], + ) + .unwrap(); + + let fts_results = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["bar", "bean", "dog"])), + Arc::new(UInt64Array::from(vec![4, 5, 3])), + ], + ) + .unwrap(); + + // scores should be calculated as: + // - foo = 1/1 = 1.0 + // - bar = 1/2 + 1/1 = 1.5 + // - baz = 1/3 = 0.333 + // - bean = 1/4 + 1/2 = 0.75 + // - dog = 1/5 + 1/3 = 0.533 + // then we should get the result ranked in descending order + + let reranker = RRFReranker::new(1.0); + + let result = reranker + .rerank_hybrid("", vec_results, fts_results) + .await + .unwrap(); + + assert_eq!(3, result.schema().fields().len()); + assert_eq!("name", result.schema().fields().first().unwrap().name()); + assert_eq!(ROW_ID, result.schema().fields().get(1).unwrap().name()); + assert_eq!( + RELEVANCE_SCORE, + result.schema().fields().get(2).unwrap().name() + ); + + let names: StringArray = downcast_array(result.column(0)); + assert_eq!( + names.iter().map(|e| e.unwrap()).collect::>(), + vec!["bar", "foo", "bean", "dog", "baz"] + ); + + let ids: UInt64Array = downcast_array(result.column(1)); + assert_eq!( + ids.iter().map(|e| e.unwrap()).collect::>(), + vec![4, 1, 5, 3, 2] + ); + + let scores: Float32Array = downcast_array(result.column(2)); + assert_eq!( + scores.iter().map(|e| e.unwrap()).collect::>(), + vec![1.5, 1.0, 0.75, 1.0 / 5.0 + 1.0 / 3.0, 1.0 / 3.0] + ); + } +}