mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 06:39:57 +00:00
feat(napi): Issue queries as node SDK (#868)
* Query as a fluent API and `AsyncIterator<RecordBatch>` * Much more docs * Add tests for auto infer vector search columns with different dimensions.
This commit is contained in:
@@ -10,14 +10,15 @@ crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
arrow-ipc.workspace = true
|
||||
futures.workspace = true
|
||||
lance-linalg.workspace = true
|
||||
lance.workspace = true
|
||||
vectordb = { path = "../rust/vectordb" }
|
||||
napi = { version = "2.14", default-features = false, features = [
|
||||
"napi7",
|
||||
"async"
|
||||
] }
|
||||
napi-derive = "2.14"
|
||||
vectordb = { path = "../rust/vectordb" }
|
||||
lance.workspace = true
|
||||
lance-linalg.workspace = true
|
||||
|
||||
[build-dependencies]
|
||||
napi-build = "2.1"
|
||||
|
||||
@@ -53,6 +53,16 @@ describe("Test creating index", () => {
|
||||
const indexDir = path.join(tmpDir, "test.lance", "_indices");
|
||||
expect(fs.readdirSync(indexDir)).toHaveLength(1);
|
||||
// TODO: check index type.
|
||||
|
||||
// Search without specifying the column
|
||||
let query_vector = data.toArray()[5].vec.toJSON();
|
||||
let rst = await tbl.query().nearestTo(query_vector).limit(2).toArrow();
|
||||
expect(rst.numRows).toBe(2);
|
||||
|
||||
// Search with specifying the column
|
||||
let rst2 = await tbl.search(query_vector, "vec").limit(2).toArrow();
|
||||
expect(rst2.numRows).toBe(2);
|
||||
expect(rst.toString()).toEqual(rst2.toString());
|
||||
});
|
||||
|
||||
test("no vector column available", async () => {
|
||||
@@ -71,6 +81,80 @@ describe("Test creating index", () => {
|
||||
await tbl.createIndex("val").build();
|
||||
const indexDir = path.join(tmpDir, "no_vec.lance", "_indices");
|
||||
expect(fs.readdirSync(indexDir)).toHaveLength(1);
|
||||
|
||||
for await (const r of tbl.query().filter("id > 1").select(["id"])) {
|
||||
expect(r.numRows).toBe(1);
|
||||
}
|
||||
});
|
||||
|
||||
test("two columns with different dimensions", async () => {
|
||||
const db = await connect(tmpDir);
|
||||
const schema = new Schema([
|
||||
new Field("id", new Int32(), true),
|
||||
new Field("vec", new FixedSizeList(32, new Field("item", new Float32()))),
|
||||
new Field(
|
||||
"vec2",
|
||||
new FixedSizeList(64, new Field("item", new Float32()))
|
||||
),
|
||||
]);
|
||||
const tbl = await db.createTable(
|
||||
"two_vectors",
|
||||
makeArrowTable(
|
||||
Array(300)
|
||||
.fill(1)
|
||||
.map((_, i) => ({
|
||||
id: i,
|
||||
vec: Array(32)
|
||||
.fill(1)
|
||||
.map(() => Math.random()),
|
||||
vec2: Array(64) // different dimension
|
||||
.fill(1)
|
||||
.map(() => Math.random()),
|
||||
})),
|
||||
{ schema }
|
||||
)
|
||||
);
|
||||
|
||||
// Only build index over v1
|
||||
await expect(tbl.createIndex().build()).rejects.toThrow(
|
||||
/.*More than one vector columns found.*/
|
||||
);
|
||||
tbl
|
||||
.createIndex("vec")
|
||||
.ivf_pq({ num_partitions: 2, num_sub_vectors: 2 })
|
||||
.build();
|
||||
|
||||
const rst = await tbl
|
||||
.query()
|
||||
.nearestTo(
|
||||
Array(32)
|
||||
.fill(1)
|
||||
.map(() => Math.random())
|
||||
)
|
||||
.limit(2)
|
||||
.toArrow();
|
||||
expect(rst.numRows).toBe(2);
|
||||
|
||||
// Search with specifying the column
|
||||
await expect(
|
||||
tbl
|
||||
.search(
|
||||
Array(64)
|
||||
.fill(1)
|
||||
.map(() => Math.random()),
|
||||
"vec"
|
||||
)
|
||||
.limit(2)
|
||||
.toArrow()
|
||||
).rejects.toThrow(/.*does not match the dimension.*/);
|
||||
|
||||
const query64 = Array(64)
|
||||
.fill(1)
|
||||
.map(() => Math.random());
|
||||
const rst64_1 = await tbl.query().nearestTo(query64).limit(2).toArrow();
|
||||
const rst64_2 = await tbl.search(query64, "vec2").limit(2).toArrow();
|
||||
expect(rst64_1.toString()).toEqual(rst64_2.toString());
|
||||
expect(rst64_1.numRows).toBe(2);
|
||||
});
|
||||
|
||||
test("create scalar index", async () => {
|
||||
|
||||
@@ -91,7 +91,6 @@ impl IndexBuilder {
|
||||
|
||||
#[napi]
|
||||
pub async fn build(&self) -> napi::Result<()> {
|
||||
println!("nodejs::index.rs : build");
|
||||
self.inner
|
||||
.build()
|
||||
.await
|
||||
|
||||
47
nodejs/src/iterator.rs
Normal file
47
nodejs/src/iterator.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright 2024 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use futures::StreamExt;
|
||||
use lance::io::RecordBatchStream;
|
||||
use napi::bindgen_prelude::*;
|
||||
use napi_derive::napi;
|
||||
use vectordb::ipc::batches_to_ipc_file;
|
||||
|
||||
/** Typescript-style Async Iterator over RecordBatches */
|
||||
#[napi]
|
||||
pub struct RecordBatchIterator {
|
||||
inner: Box<dyn RecordBatchStream + Unpin>,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
impl RecordBatchIterator {
|
||||
pub(crate) fn new(inner: Box<dyn RecordBatchStream + Unpin>) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async unsafe fn next(&mut self) -> napi::Result<Option<Buffer>> {
|
||||
if let Some(rst) = self.inner.next().await {
|
||||
let batch = rst.map_err(|e| {
|
||||
napi::Error::from_reason(format!("Failed to get next batch from stream: {}", e))
|
||||
})?;
|
||||
batches_to_ipc_file(&[batch])
|
||||
.map_err(|e| napi::Error::from_reason(format!("Failed to write IPC file: {}", e)))
|
||||
.map(|buf| Some(Buffer::from(buf)))
|
||||
} else {
|
||||
// We are done with the stream.
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,7 @@ use napi_derive::*;
|
||||
|
||||
mod connection;
|
||||
mod index;
|
||||
mod iterator;
|
||||
mod query;
|
||||
mod table;
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ use napi::bindgen_prelude::*;
|
||||
use napi_derive::napi;
|
||||
use vectordb::query::Query as LanceDBQuery;
|
||||
|
||||
use crate::table::Table;
|
||||
use crate::{iterator::RecordBatchIterator, table::Table};
|
||||
|
||||
#[napi]
|
||||
pub struct Query {
|
||||
@@ -32,17 +32,50 @@ impl Query {
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn vector(&mut self, vector: Float32Array) {
|
||||
let inn = self.inner.clone().nearest_to(&vector);
|
||||
self.inner = inn;
|
||||
pub fn column(&mut self, column: String) {
|
||||
self.inner = self.inner.clone().column(&column);
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn to_arrow(&self) -> napi::Result<()> {
|
||||
// let buf = self.inner.to_arrow().map_err(|e| {
|
||||
// napi::Error::from_reason(format!("Failed to convert query to arrow: {}", e))
|
||||
// })?;
|
||||
// Ok(buf)
|
||||
todo!()
|
||||
pub fn filter(&mut self, filter: String) {
|
||||
self.inner = self.inner.clone().filter(filter);
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn select(&mut self, columns: Vec<String>) {
|
||||
self.inner = self.inner.clone().select(&columns);
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn limit(&mut self, limit: u32) {
|
||||
self.inner = self.inner.clone().limit(limit as usize);
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn prefilter(&mut self, prefilter: bool) {
|
||||
self.inner = self.inner.clone().prefilter(prefilter);
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn nearest_to(&mut self, vector: Float32Array) {
|
||||
self.inner = self.inner.clone().nearest_to(&vector);
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn refine_factor(&mut self, refine_factor: u32) {
|
||||
self.inner = self.inner.clone().refine_factor(refine_factor);
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn nprobes(&mut self, nprobe: u32) {
|
||||
self.inner = self.inner.clone().nprobes(nprobe as usize);
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async fn execute_stream(&self) -> napi::Result<RecordBatchIterator> {
|
||||
let inner_stream = self.inner.execute_stream().await.map_err(|e| {
|
||||
napi::Error::from_reason(format!("Failed to execute query stream: {}", e))
|
||||
})?;
|
||||
Ok(RecordBatchIterator::new(Box::new(inner_stream)))
|
||||
}
|
||||
}
|
||||
|
||||
15
nodejs/vectordb/native.d.ts
vendored
15
nodejs/vectordb/native.d.ts
vendored
@@ -54,9 +54,20 @@ export class IndexBuilder {
|
||||
scalar(): void
|
||||
build(): Promise<void>
|
||||
}
|
||||
/** Typescript-style Async Iterator over RecordBatches */
|
||||
export class RecordBatchIterator {
|
||||
next(): Promise<Buffer | null>
|
||||
}
|
||||
export class Query {
|
||||
vector(vector: Float32Array): void
|
||||
toArrow(): void
|
||||
column(column: string): void
|
||||
filter(filter: string): void
|
||||
select(columns: Array<string>): void
|
||||
limit(limit: number): void
|
||||
prefilter(prefilter: boolean): void
|
||||
nearestTo(vector: Float32Array): void
|
||||
refineFactor(refineFactor: number): void
|
||||
nprobes(nprobe: number): void
|
||||
executeStream(): Promise<RecordBatchIterator>
|
||||
}
|
||||
export class Table {
|
||||
/** Return Schema as empty Arrow IPC file. */
|
||||
|
||||
@@ -295,12 +295,13 @@ if (!nativeBinding) {
|
||||
throw new Error(`Failed to load native binding`)
|
||||
}
|
||||
|
||||
const { Connection, IndexType, MetricType, IndexBuilder, Query, Table, WriteMode, connect } = nativeBinding
|
||||
const { Connection, IndexType, MetricType, IndexBuilder, RecordBatchIterator, Query, Table, WriteMode, connect } = nativeBinding
|
||||
|
||||
module.exports.Connection = Connection
|
||||
module.exports.IndexType = IndexType
|
||||
module.exports.MetricType = MetricType
|
||||
module.exports.IndexBuilder = IndexBuilder
|
||||
module.exports.RecordBatchIterator = RecordBatchIterator
|
||||
module.exports.Query = Query
|
||||
module.exports.Table = Table
|
||||
module.exports.WriteMode = WriteMode
|
||||
|
||||
@@ -12,46 +12,73 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import { RecordBatch } from "apache-arrow";
|
||||
import { Table } from "./table";
|
||||
import { RecordBatch, tableFromIPC, Table as ArrowTable } from "apache-arrow";
|
||||
import {
|
||||
RecordBatchIterator as NativeBatchIterator,
|
||||
Query as NativeQuery,
|
||||
Table as NativeTable,
|
||||
} from "./native";
|
||||
|
||||
// TODO: re-eanble eslint once we have a real implementation
|
||||
/* eslint-disable */
|
||||
class RecordBatchIterator implements AsyncIterator<RecordBatch> {
|
||||
next(
|
||||
...args: [] | [undefined]
|
||||
): Promise<IteratorResult<RecordBatch<any>, any>> {
|
||||
throw new Error("Method not implemented.");
|
||||
private promised_inner?: Promise<NativeBatchIterator>;
|
||||
private inner?: NativeBatchIterator;
|
||||
|
||||
constructor(
|
||||
inner?: NativeBatchIterator,
|
||||
promise?: Promise<NativeBatchIterator>
|
||||
) {
|
||||
// TODO: check promise reliably so we dont need to pass two arguments.
|
||||
this.inner = inner;
|
||||
this.promised_inner = promise;
|
||||
}
|
||||
return?(value?: any): Promise<IteratorResult<RecordBatch<any>, any>> {
|
||||
throw new Error("Method not implemented.");
|
||||
}
|
||||
throw?(e?: any): Promise<IteratorResult<RecordBatch<any>, any>> {
|
||||
throw new Error("Method not implemented.");
|
||||
|
||||
async next(): Promise<IteratorResult<RecordBatch<any>, any>> {
|
||||
if (this.inner === undefined) {
|
||||
this.inner = await this.promised_inner;
|
||||
}
|
||||
if (this.inner === undefined) {
|
||||
throw new Error("Invalid iterator state state");
|
||||
}
|
||||
const n = await this.inner.next();
|
||||
if (n == null) {
|
||||
return Promise.resolve({ done: true, value: null });
|
||||
}
|
||||
const tbl = tableFromIPC(n);
|
||||
if (tbl.batches.length != 1) {
|
||||
throw new Error("Expected only one batch");
|
||||
}
|
||||
return Promise.resolve({ done: false, value: tbl.batches[0] });
|
||||
}
|
||||
}
|
||||
/* eslint-enable */
|
||||
|
||||
/** Query executor */
|
||||
export class Query implements AsyncIterable<RecordBatch> {
|
||||
private readonly tbl: Table;
|
||||
private _filter?: string;
|
||||
private _limit?: number;
|
||||
private readonly inner: NativeQuery;
|
||||
|
||||
// Vector search
|
||||
private _vector?: Float32Array;
|
||||
private _nprobes?: number;
|
||||
private _refine_factor?: number = 1;
|
||||
constructor(tbl: NativeTable) {
|
||||
this.inner = tbl.query();
|
||||
}
|
||||
|
||||
constructor(tbl: Table) {
|
||||
this.tbl = tbl;
|
||||
/** Set the column to run query. */
|
||||
column(column: string): Query {
|
||||
this.inner.column(column);
|
||||
return this;
|
||||
}
|
||||
|
||||
/** Set the filter predicate, only returns the results that satisfy the filter.
|
||||
*
|
||||
*/
|
||||
filter(predicate: string): Query {
|
||||
this._filter = predicate;
|
||||
this.inner.filter(predicate);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Select the columns to return. If not set, all columns are returned.
|
||||
*/
|
||||
select(columns: string[]): Query {
|
||||
this.inner.select(columns);
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -59,35 +86,67 @@ export class Query implements AsyncIterable<RecordBatch> {
|
||||
* Set the limit of rows to return.
|
||||
*/
|
||||
limit(limit: number): Query {
|
||||
this._limit = limit;
|
||||
this.inner.limit(limit);
|
||||
return this;
|
||||
}
|
||||
|
||||
prefilter(prefilter: boolean): Query {
|
||||
this.inner.prefilter(prefilter);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the query vector.
|
||||
*/
|
||||
vector(vector: number[]): Query {
|
||||
this._vector = Float32Array.from(vector);
|
||||
nearestTo(vector: number[]): Query {
|
||||
this.inner.nearestTo(Float32Array.from(vector));
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the number of probes to use for the query.
|
||||
* Set the number of IVF partitions to use for the query.
|
||||
*/
|
||||
nprobes(nprobes: number): Query {
|
||||
this._nprobes = nprobes;
|
||||
this.inner.nprobes(nprobes);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the refine factor for the query.
|
||||
*/
|
||||
refine_factor(refine_factor: number): Query {
|
||||
this._refine_factor = refine_factor;
|
||||
refineFactor(refine_factor: number): Query {
|
||||
this.inner.refineFactor(refine_factor);
|
||||
return this;
|
||||
}
|
||||
|
||||
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>, any, undefined> {
|
||||
throw new RecordBatchIterator();
|
||||
/**
|
||||
* Execute the query and return the results as an AsyncIterator.
|
||||
*/
|
||||
async executeStream(): Promise<RecordBatchIterator> {
|
||||
const inner = await this.inner.executeStream();
|
||||
return new RecordBatchIterator(inner);
|
||||
}
|
||||
|
||||
/** Collect the results as an Arrow Table. */
|
||||
async toArrow(): Promise<ArrowTable> {
|
||||
const batches = [];
|
||||
for await (const batch of this) {
|
||||
batches.push(batch);
|
||||
}
|
||||
return new ArrowTable(batches);
|
||||
}
|
||||
|
||||
/** Returns a JSON Array of All results.
|
||||
*
|
||||
*/
|
||||
async toArray(): Promise<any[]> {
|
||||
const tbl = await this.toArrow();
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-return
|
||||
return tbl.toArray();
|
||||
}
|
||||
|
||||
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>> {
|
||||
const promise = this.inner.executeStream();
|
||||
return new RecordBatchIterator(undefined, promise);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,10 +95,58 @@ export class Table {
|
||||
return builder;
|
||||
}
|
||||
|
||||
search(vector?: number[]): Query {
|
||||
const q = new Query(this);
|
||||
if (vector !== undefined) {
|
||||
q.vector(vector);
|
||||
/**
|
||||
* Create a generic {@link Query} Builder.
|
||||
*
|
||||
* When appropriate, various indices and statistics based pruning will be used to
|
||||
* accelerate the query.
|
||||
*
|
||||
* @example
|
||||
*
|
||||
* ### Run a SQL-style query
|
||||
* ```typescript
|
||||
* for await (const batch of table.query()
|
||||
* .filter("id > 1").select(["id"]).limit(20)) {
|
||||
* console.log(batch);
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* ### Run Top-10 vector similarity search
|
||||
* ```typescript
|
||||
* for await (const batch of table.query()
|
||||
* .nearestTo([1, 2, 3])
|
||||
* .refineFactor(5).nprobe(10)
|
||||
* .limit(10)) {
|
||||
* console.log(batch);
|
||||
* }
|
||||
*```
|
||||
*
|
||||
* ### Scan the full dataset
|
||||
* ```typescript
|
||||
* for await (const batch of table.query()) {
|
||||
* console.log(batch);
|
||||
* }
|
||||
*
|
||||
* ### Return the full dataset as Arrow Table
|
||||
* ```typescript
|
||||
* let arrowTbl = await table.query().nearestTo([1.0, 2.0, 0.5, 6.7]).toArrow();
|
||||
* ```
|
||||
*
|
||||
* @returns {@link Query}
|
||||
*/
|
||||
query(): Query {
|
||||
return new Query(this.inner);
|
||||
}
|
||||
|
||||
/** Search the table with a given query vector.
|
||||
*
|
||||
* This is a convenience method for preparing an ANN {@link Query}.
|
||||
*/
|
||||
search(vector: number[], column?: string): Query {
|
||||
const q = this.query();
|
||||
q.nearestTo(vector);
|
||||
if (column !== undefined) {
|
||||
q.column(column);
|
||||
}
|
||||
return q;
|
||||
}
|
||||
|
||||
@@ -16,10 +16,10 @@
|
||||
|
||||
use std::io::Cursor;
|
||||
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_ipc::reader::StreamReader;
|
||||
use arrow_array::{RecordBatch, RecordBatchReader};
|
||||
use arrow_ipc::{reader::StreamReader, writer::FileWriter};
|
||||
|
||||
use crate::Result;
|
||||
use crate::{Error, Result};
|
||||
|
||||
/// Convert a Arrow IPC file to a batch reader
|
||||
pub fn ipc_file_to_batches(buf: Vec<u8>) -> Result<impl RecordBatchReader> {
|
||||
@@ -28,6 +28,22 @@ pub fn ipc_file_to_batches(buf: Vec<u8>) -> Result<impl RecordBatchReader> {
|
||||
Ok(reader)
|
||||
}
|
||||
|
||||
/// Convert record batches to Arrow IPC file
|
||||
pub fn batches_to_ipc_file(batches: &[RecordBatch]) -> Result<Vec<u8>> {
|
||||
if batches.is_empty() {
|
||||
return Err(Error::Store {
|
||||
message: "No batches to write".to_string(),
|
||||
});
|
||||
}
|
||||
let schema = batches[0].schema();
|
||||
let mut writer = FileWriter::try_new(vec![], &schema)?;
|
||||
for batch in batches {
|
||||
writer.write(batch)?;
|
||||
}
|
||||
writer.finish()?;
|
||||
Ok(writer.into_inner()?)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ use lance_linalg::distance::MetricType;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::utils::default_vector_column;
|
||||
use crate::Error;
|
||||
|
||||
const DEFAULT_TOP_K: usize = 10;
|
||||
|
||||
@@ -93,6 +94,19 @@ impl Query {
|
||||
let arrow_schema = Schema::from(self.dataset.schema());
|
||||
default_vector_column(&arrow_schema, Some(query.len() as i32))?
|
||||
};
|
||||
let field = self.dataset.schema().field(&column).ok_or(Error::Store {
|
||||
message: format!("Column {} not found in dataset schema", column),
|
||||
})?;
|
||||
if !matches!(field.data_type(), arrow_schema::DataType::FixedSizeList(f, dim) if f.data_type().is_floating() && dim == query.len() as i32)
|
||||
{
|
||||
return Err(Error::Store {
|
||||
message: format!(
|
||||
"Vector column '{}' does not match the dimension of the query vector: dim={}",
|
||||
column,
|
||||
query.len(),
|
||||
),
|
||||
});
|
||||
}
|
||||
scanner.nearest(&column, query, self.limit.unwrap_or(DEFAULT_TOP_K))?;
|
||||
} else {
|
||||
// If there is no vector query, it's ok to not have a limit
|
||||
|
||||
Reference in New Issue
Block a user