mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-26 08:20:39 +00:00
feat: add hybrid search to node and rust SDKs (#1940)
Support hybrid search in both rust and node SDKs. - Adds a new rerankers package to rust LanceDB, with the implementation of the default RRF reranker - Adds a new hybrid package to lancedb, with some helper methods related to hybrid search such as normalizing scores and converting score column to rank columns - Adds capability to LanceDB VectorQuery to perform hybrid search if it has both a nearest vector and full text search parameters. - Adds wrappers for reranker implementations to nodejs SDK. Additional rerankers will be added in followup PRs https://github.com/lancedb/lancedb/issues/1921 --- Notes about how the rust rerankers are wrapped for calling from JS: I wanted to keep the core reranker logic, and the invocation of the reranker by the query code, in Rust. This aligns with the philosophy of the new node SDK where it's just a thin wrapper around Rust. However, I also wanted to have support for users who want to add custom rerankers written in Javascript. When we add a reranker to the query from Javascript, it adds a special Rust reranker that has a callback to the Javascript code (which could then turn around and call an underlying Rust reranker implementation if desired). This adds a bit of complexity, but overall I think it moves us in the right direction of having the majority of the query logic in the underlying Rust SDK while keeping the option open to support custom Javascript Rerankers.
This commit is contained in:
@@ -12,7 +12,10 @@ categories.workspace = true
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
async-trait.workspace = true
|
||||
arrow-ipc.workspace = true
|
||||
arrow-array.workspace = true
|
||||
arrow-schema.workspace = true
|
||||
env_logger.workspace = true
|
||||
futures.workspace = true
|
||||
lancedb = { path = "../rust/lancedb", features = ["remote"] }
|
||||
|
||||
@@ -20,6 +20,8 @@ import * as arrow18 from "apache-arrow-18";
|
||||
|
||||
import {
|
||||
convertToTable,
|
||||
fromBufferToRecordBatch,
|
||||
fromRecordBatchToBuffer,
|
||||
fromTableToBuffer,
|
||||
makeArrowTable,
|
||||
makeEmptyTable,
|
||||
@@ -553,5 +555,28 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("converting record batches to buffers", function () {
|
||||
it("can convert to buffered record batch and back again", async function () {
|
||||
const records = [
|
||||
{ text: "dog", vector: [0.1, 0.2] },
|
||||
{ text: "cat", vector: [0.3, 0.4] },
|
||||
];
|
||||
const table = await convertToTable(records);
|
||||
const batch = table.batches[0];
|
||||
|
||||
const buffer = await fromRecordBatchToBuffer(batch);
|
||||
const result = await fromBufferToRecordBatch(buffer);
|
||||
|
||||
expect(JSON.stringify(batch.toArray())).toEqual(
|
||||
JSON.stringify(result?.toArray()),
|
||||
);
|
||||
});
|
||||
|
||||
it("converting from buffer returns null if buffer has no record batches", async function () {
|
||||
const result = await fromBufferToRecordBatch(Buffer.from([0x01, 0x02])); // bad data
|
||||
expect(result).toEqual(null);
|
||||
});
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
79
nodejs/__test__/rerankers.test.ts
Normal file
79
nodejs/__test__/rerankers.test.ts
Normal file
@@ -0,0 +1,79 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import { RecordBatch } from "apache-arrow";
|
||||
import * as tmp from "tmp";
|
||||
import { Connection, Index, Table, connect, makeArrowTable } from "../lancedb";
|
||||
import { RRFReranker } from "../lancedb/rerankers";
|
||||
|
||||
describe("rerankers", function () {
|
||||
let tmpDir: tmp.DirResult;
|
||||
let conn: Connection;
|
||||
let table: Table;
|
||||
|
||||
beforeEach(async () => {
|
||||
tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
||||
conn = await connect(tmpDir.name);
|
||||
table = await conn.createTable("mytable", [
|
||||
{ vector: [0.1, 0.1], text: "dog" },
|
||||
{ vector: [0.2, 0.2], text: "cat" },
|
||||
]);
|
||||
await table.createIndex("text", {
|
||||
config: Index.fts(),
|
||||
replace: true,
|
||||
});
|
||||
});
|
||||
|
||||
it("will query with the custom reranker", async function () {
|
||||
const expectedResult = [
|
||||
{
|
||||
text: "albert",
|
||||
// biome-ignore lint/style/useNamingConvention: this is the lance field name
|
||||
_relevance_score: 0.99,
|
||||
},
|
||||
];
|
||||
class MyCustomReranker {
|
||||
async rerankHybrid(
|
||||
_query: string,
|
||||
_vecResults: RecordBatch,
|
||||
_ftsResults: RecordBatch,
|
||||
): Promise<RecordBatch> {
|
||||
// no reranker logic, just return some static data
|
||||
const table = makeArrowTable(expectedResult);
|
||||
return table.batches[0];
|
||||
}
|
||||
}
|
||||
|
||||
let result = await table
|
||||
.query()
|
||||
.nearestTo([0.1, 0.1])
|
||||
.fullTextSearch("dog")
|
||||
.rerank(new MyCustomReranker())
|
||||
.select(["text"])
|
||||
.limit(5)
|
||||
.toArray();
|
||||
|
||||
result = JSON.parse(JSON.stringify(result)); // convert StructRow to Object
|
||||
expect(result).toEqual([
|
||||
{
|
||||
text: "albert",
|
||||
// biome-ignore lint/style/useNamingConvention: this is the lance field name
|
||||
_relevance_score: 0.99,
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it("will query with RRFReranker", async function () {
|
||||
// smoke test to see if the Rust wrapping Typescript is wired up correctly
|
||||
const result = await table
|
||||
.query()
|
||||
.nearestTo([0.1, 0.1])
|
||||
.fullTextSearch("dog")
|
||||
.rerank(await RRFReranker.create())
|
||||
.select(["text"])
|
||||
.limit(5)
|
||||
.toArray();
|
||||
|
||||
expect(result).toHaveLength(2);
|
||||
});
|
||||
});
|
||||
@@ -27,7 +27,9 @@ import {
|
||||
List,
|
||||
Null,
|
||||
RecordBatch,
|
||||
RecordBatchFileReader,
|
||||
RecordBatchFileWriter,
|
||||
RecordBatchReader,
|
||||
RecordBatchStreamWriter,
|
||||
Schema,
|
||||
Struct,
|
||||
@@ -810,6 +812,30 @@ export async function fromDataToBuffer(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Read a single record batch from a buffer.
|
||||
*
|
||||
* Returns null if the buffer does not contain a record batch
|
||||
*/
|
||||
export async function fromBufferToRecordBatch(
|
||||
data: Buffer,
|
||||
): Promise<RecordBatch | null> {
|
||||
const iter = await RecordBatchFileReader.readAll(Buffer.from(data)).next()
|
||||
.value;
|
||||
const recordBatch = iter?.next().value;
|
||||
return recordBatch || null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a buffer containing a single record batch
|
||||
*/
|
||||
export async function fromRecordBatchToBuffer(
|
||||
batch: RecordBatch,
|
||||
): Promise<Buffer> {
|
||||
const writer = new RecordBatchFileWriter().writeAll([batch]);
|
||||
return Buffer.from(await writer.toUint8Array());
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize an Arrow Table into a buffer using the Arrow IPC Stream serialization
|
||||
*
|
||||
|
||||
@@ -62,6 +62,7 @@ export { Index, IndexOptions, IvfPqOptions } from "./indices";
|
||||
export { Table, AddDataOptions, UpdateOptions, OptimizeOptions } from "./table";
|
||||
|
||||
export * as embedding from "./embedding";
|
||||
export * as rerankers from "./rerankers";
|
||||
|
||||
/**
|
||||
* Connect to a LanceDB instance at the given URI.
|
||||
|
||||
@@ -16,6 +16,8 @@ import {
|
||||
Table as ArrowTable,
|
||||
type IntoVector,
|
||||
RecordBatch,
|
||||
fromBufferToRecordBatch,
|
||||
fromRecordBatchToBuffer,
|
||||
tableFromIPC,
|
||||
} from "./arrow";
|
||||
import { type IvfPqOptions } from "./indices";
|
||||
@@ -25,6 +27,7 @@ import {
|
||||
Table as NativeTable,
|
||||
VectorQuery as NativeVectorQuery,
|
||||
} from "./native";
|
||||
import { Reranker } from "./rerankers";
|
||||
export class RecordBatchIterator implements AsyncIterator<RecordBatch> {
|
||||
private promisedInner?: Promise<NativeBatchIterator>;
|
||||
private inner?: NativeBatchIterator;
|
||||
@@ -542,6 +545,27 @@ export class VectorQuery extends QueryBase<NativeVectorQuery> {
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
rerank(reranker: Reranker): VectorQuery {
|
||||
super.doCall((inner) =>
|
||||
inner.rerank({
|
||||
rerankHybrid: async (_, args) => {
|
||||
const vecResults = await fromBufferToRecordBatch(args.vecResults);
|
||||
const ftsResults = await fromBufferToRecordBatch(args.ftsResults);
|
||||
const result = await reranker.rerankHybrid(
|
||||
args.query,
|
||||
vecResults as RecordBatch,
|
||||
ftsResults as RecordBatch,
|
||||
);
|
||||
|
||||
const buffer = fromRecordBatchToBuffer(result);
|
||||
return buffer;
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
/** A builder for LanceDB queries. */
|
||||
|
||||
17
nodejs/lancedb/rerankers/index.ts
Normal file
17
nodejs/lancedb/rerankers/index.ts
Normal file
@@ -0,0 +1,17 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import { RecordBatch } from "apache-arrow";
|
||||
|
||||
export * from "./rrf";
|
||||
|
||||
// Interface for a reranker. A reranker is used to rerank the results from a
|
||||
// vector and FTS search. This is useful for combining the results from both
|
||||
// search methods.
|
||||
export interface Reranker {
|
||||
rerankHybrid(
|
||||
query: string,
|
||||
vecResults: RecordBatch,
|
||||
ftsResults: RecordBatch,
|
||||
): Promise<RecordBatch>;
|
||||
}
|
||||
40
nodejs/lancedb/rerankers/rrf.ts
Normal file
40
nodejs/lancedb/rerankers/rrf.ts
Normal file
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import { RecordBatch } from "apache-arrow";
|
||||
import { fromBufferToRecordBatch, fromRecordBatchToBuffer } from "../arrow";
|
||||
import { RrfReranker as NativeRRFReranker } from "../native";
|
||||
|
||||
/**
|
||||
* Reranks the results using the Reciprocal Rank Fusion (RRF) algorithm.
|
||||
*
|
||||
* Internally this uses the Rust implementation
|
||||
*/
|
||||
export class RRFReranker {
|
||||
private inner: NativeRRFReranker;
|
||||
|
||||
constructor(inner: NativeRRFReranker) {
|
||||
this.inner = inner;
|
||||
}
|
||||
|
||||
public static async create(k: number = 60) {
|
||||
return new RRFReranker(
|
||||
await NativeRRFReranker.tryNew(new Float32Array([k])),
|
||||
);
|
||||
}
|
||||
|
||||
async rerankHybrid(
|
||||
query: string,
|
||||
vecResults: RecordBatch,
|
||||
ftsResults: RecordBatch,
|
||||
): Promise<RecordBatch> {
|
||||
const buffer = await this.inner.rerankHybrid(
|
||||
query,
|
||||
await fromRecordBatchToBuffer(vecResults),
|
||||
await fromRecordBatchToBuffer(ftsResults),
|
||||
);
|
||||
const recordBatch = await fromBufferToRecordBatch(buffer);
|
||||
|
||||
return recordBatch as RecordBatch;
|
||||
}
|
||||
}
|
||||
@@ -24,6 +24,7 @@ mod iterator;
|
||||
pub mod merge;
|
||||
mod query;
|
||||
pub mod remote;
|
||||
mod rerankers;
|
||||
mod table;
|
||||
mod util;
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use lancedb::index::scalar::FullTextSearchQuery;
|
||||
use lancedb::query::ExecutableQuery;
|
||||
use lancedb::query::Query as LanceDbQuery;
|
||||
@@ -25,6 +27,8 @@ use napi_derive::napi;
|
||||
use crate::error::convert_error;
|
||||
use crate::error::NapiErrorExt;
|
||||
use crate::iterator::RecordBatchIterator;
|
||||
use crate::rerankers::Reranker;
|
||||
use crate::rerankers::RerankerCallbacks;
|
||||
use crate::util::parse_distance_type;
|
||||
|
||||
#[napi]
|
||||
@@ -218,6 +222,14 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().with_row_id();
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn rerank(&mut self, callbacks: RerankerCallbacks) {
|
||||
self.inner = self
|
||||
.inner
|
||||
.clone()
|
||||
.rerank(Arc::new(Reranker::new(callbacks)));
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
pub async fn execute(
|
||||
&self,
|
||||
|
||||
147
nodejs/src/rerankers.rs
Normal file
147
nodejs/src/rerankers.rs
Normal file
@@ -0,0 +1,147 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use arrow_array::RecordBatch;
|
||||
use async_trait::async_trait;
|
||||
use napi::{
|
||||
bindgen_prelude::*,
|
||||
threadsafe_function::{ErrorStrategy, ThreadsafeFunction},
|
||||
};
|
||||
use napi_derive::napi;
|
||||
|
||||
use lancedb::ipc::batches_to_ipc_file;
|
||||
use lancedb::rerankers::Reranker as LanceDBReranker;
|
||||
use lancedb::{error::Error, ipc::ipc_file_to_batches};
|
||||
|
||||
use crate::error::NapiErrorExt;
|
||||
|
||||
/// Reranker implementation that "wraps" a NodeJS Reranker implementation.
|
||||
/// This contains references to the callbacks that can be used to invoke the
|
||||
/// reranking methods on the NodeJS implementation and handles serializing the
|
||||
/// record batches to Arrow IPC buffers.
|
||||
#[napi]
|
||||
pub struct Reranker {
|
||||
/// callback to the Javascript which will call the rerankHybrid method of
|
||||
/// some Reranker implementation
|
||||
rerank_hybrid: ThreadsafeFunction<RerankHybridCallbackArgs, ErrorStrategy::CalleeHandled>,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl Reranker {
|
||||
#[napi]
|
||||
pub fn new(callbacks: RerankerCallbacks) -> Self {
|
||||
let rerank_hybrid = callbacks
|
||||
.rerank_hybrid
|
||||
.create_threadsafe_function(0, move |ctx| Ok(vec![ctx.value]))
|
||||
.unwrap();
|
||||
|
||||
Self { rerank_hybrid }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl lancedb::rerankers::Reranker for Reranker {
|
||||
async fn rerank_hybrid(
|
||||
&self,
|
||||
query: &str,
|
||||
vector_results: RecordBatch,
|
||||
fts_results: RecordBatch,
|
||||
) -> lancedb::error::Result<RecordBatch> {
|
||||
let callback_args = RerankHybridCallbackArgs {
|
||||
query: query.to_string(),
|
||||
vec_results: batches_to_ipc_file(&[vector_results])?,
|
||||
fts_results: batches_to_ipc_file(&[fts_results])?,
|
||||
};
|
||||
let promised_buffer: Promise<Buffer> = self
|
||||
.rerank_hybrid
|
||||
.call_async(Ok(callback_args))
|
||||
.await
|
||||
.map_err(|e| Error::Runtime {
|
||||
message: format!("napi error status={}, reason={}", e.status, e.reason),
|
||||
})?;
|
||||
let buffer = promised_buffer.await.map_err(|e| Error::Runtime {
|
||||
message: format!("napi error status={}, reason={}", e.status, e.reason),
|
||||
})?;
|
||||
let mut reader = ipc_file_to_batches(buffer.to_vec())?;
|
||||
let result = reader.next().ok_or(Error::Runtime {
|
||||
message: "reranker result deserialization failed".to_string(),
|
||||
})??;
|
||||
|
||||
return Ok(result);
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Reranker {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str("NodeJSRerankerWrapper")
|
||||
}
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
pub struct RerankerCallbacks {
|
||||
pub rerank_hybrid: JsFunction,
|
||||
}
|
||||
|
||||
#[napi(object)]
|
||||
pub struct RerankHybridCallbackArgs {
|
||||
pub query: String,
|
||||
pub vec_results: Vec<u8>,
|
||||
pub fts_results: Vec<u8>,
|
||||
}
|
||||
|
||||
fn buffer_to_record_batch(buffer: Buffer) -> Result<RecordBatch> {
|
||||
let mut reader = ipc_file_to_batches(buffer.to_vec()).default_error()?;
|
||||
reader
|
||||
.next()
|
||||
.ok_or(Error::InvalidInput {
|
||||
message: "expected buffer containing record batch".to_string(),
|
||||
})
|
||||
.default_error()?
|
||||
.map_err(Error::from)
|
||||
.default_error()
|
||||
}
|
||||
|
||||
/// Wrapper around rust RRFReranker
|
||||
#[napi]
|
||||
pub struct RRFReranker {
|
||||
inner: lancedb::rerankers::rrf::RRFReranker,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl RRFReranker {
|
||||
#[napi]
|
||||
pub async fn try_new(k: &[f32]) -> Result<Self> {
|
||||
let k = k
|
||||
.first()
|
||||
.copied()
|
||||
.ok_or(Error::InvalidInput {
|
||||
message: "must supply RRF Reranker constructor arg 'k'".to_string(),
|
||||
})
|
||||
.default_error()?;
|
||||
|
||||
Ok(Self {
|
||||
inner: lancedb::rerankers::rrf::RRFReranker::new(k),
|
||||
})
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async fn rerank_hybrid(
|
||||
&self,
|
||||
query: String,
|
||||
vec_results: Buffer,
|
||||
fts_results: Buffer,
|
||||
) -> Result<Buffer> {
|
||||
let vec_results = buffer_to_record_batch(vec_results)?;
|
||||
let fts_results = buffer_to_record_batch(fts_results)?;
|
||||
|
||||
let result = self
|
||||
.inner
|
||||
.rerank_hybrid(&query, vec_results, fts_results)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let result_buff = batches_to_ipc_file(&[result]).default_error()?;
|
||||
|
||||
Ok(Buffer::from(result_buff.as_ref()))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user