diff --git a/nodejs/Cargo.toml b/nodejs/Cargo.toml index f014145b..fe7b1876 100644 --- a/nodejs/Cargo.toml +++ b/nodejs/Cargo.toml @@ -10,14 +10,15 @@ crate-type = ["cdylib"] [dependencies] arrow-ipc.workspace = true +futures.workspace = true +lance-linalg.workspace = true +lance.workspace = true +vectordb = { path = "../rust/vectordb" } napi = { version = "2.14", default-features = false, features = [ "napi7", "async" ] } napi-derive = "2.14" -vectordb = { path = "../rust/vectordb" } -lance.workspace = true -lance-linalg.workspace = true [build-dependencies] napi-build = "2.1" diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index caf3bfaf..a8ccf989 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -53,6 +53,16 @@ describe("Test creating index", () => { const indexDir = path.join(tmpDir, "test.lance", "_indices"); expect(fs.readdirSync(indexDir)).toHaveLength(1); // TODO: check index type. + + // Search without specifying the column + let query_vector = data.toArray()[5].vec.toJSON(); + let rst = await tbl.query().nearestTo(query_vector).limit(2).toArrow(); + expect(rst.numRows).toBe(2); + + // Search with specifying the column + let rst2 = await tbl.search(query_vector, "vec").limit(2).toArrow(); + expect(rst2.numRows).toBe(2); + expect(rst.toString()).toEqual(rst2.toString()); }); test("no vector column available", async () => { @@ -71,6 +81,80 @@ describe("Test creating index", () => { await tbl.createIndex("val").build(); const indexDir = path.join(tmpDir, "no_vec.lance", "_indices"); expect(fs.readdirSync(indexDir)).toHaveLength(1); + + for await (const r of tbl.query().filter("id > 1").select(["id"])) { + expect(r.numRows).toBe(1); + } + }); + + test("two columns with different dimensions", async () => { + const db = await connect(tmpDir); + const schema = new Schema([ + new Field("id", new Int32(), true), + new Field("vec", new FixedSizeList(32, new Field("item", new Float32()))), + new Field( + "vec2", + new FixedSizeList(64, new Field("item", new Float32())) + ), + ]); + const tbl = await db.createTable( + "two_vectors", + makeArrowTable( + Array(300) + .fill(1) + .map((_, i) => ({ + id: i, + vec: Array(32) + .fill(1) + .map(() => Math.random()), + vec2: Array(64) // different dimension + .fill(1) + .map(() => Math.random()), + })), + { schema } + ) + ); + + // Only build index over v1 + await expect(tbl.createIndex().build()).rejects.toThrow( + /.*More than one vector columns found.*/ + ); + tbl + .createIndex("vec") + .ivf_pq({ num_partitions: 2, num_sub_vectors: 2 }) + .build(); + + const rst = await tbl + .query() + .nearestTo( + Array(32) + .fill(1) + .map(() => Math.random()) + ) + .limit(2) + .toArrow(); + expect(rst.numRows).toBe(2); + + // Search with specifying the column + await expect( + tbl + .search( + Array(64) + .fill(1) + .map(() => Math.random()), + "vec" + ) + .limit(2) + .toArrow() + ).rejects.toThrow(/.*does not match the dimension.*/); + + const query64 = Array(64) + .fill(1) + .map(() => Math.random()); + const rst64_1 = await tbl.query().nearestTo(query64).limit(2).toArrow(); + const rst64_2 = await tbl.search(query64, "vec2").limit(2).toArrow(); + expect(rst64_1.toString()).toEqual(rst64_2.toString()); + expect(rst64_1.numRows).toBe(2); }); test("create scalar index", async () => { diff --git a/nodejs/src/index.rs b/nodejs/src/index.rs index 5dd60e4d..c8b06257 100644 --- a/nodejs/src/index.rs +++ b/nodejs/src/index.rs @@ -91,7 +91,6 @@ impl IndexBuilder { #[napi] pub async fn build(&self) -> napi::Result<()> { - println!("nodejs::index.rs : build"); self.inner .build() .await diff --git a/nodejs/src/iterator.rs b/nodejs/src/iterator.rs new file mode 100644 index 00000000..50b3b110 --- /dev/null +++ b/nodejs/src/iterator.rs @@ -0,0 +1,47 @@ +// Copyright 2024 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use futures::StreamExt; +use lance::io::RecordBatchStream; +use napi::bindgen_prelude::*; +use napi_derive::napi; +use vectordb::ipc::batches_to_ipc_file; + +/** Typescript-style Async Iterator over RecordBatches */ +#[napi] +pub struct RecordBatchIterator { + inner: Box, +} + +#[napi] +impl RecordBatchIterator { + pub(crate) fn new(inner: Box) -> Self { + Self { inner } + } + + #[napi] + pub async unsafe fn next(&mut self) -> napi::Result> { + if let Some(rst) = self.inner.next().await { + let batch = rst.map_err(|e| { + napi::Error::from_reason(format!("Failed to get next batch from stream: {}", e)) + })?; + batches_to_ipc_file(&[batch]) + .map_err(|e| napi::Error::from_reason(format!("Failed to write IPC file: {}", e))) + .map(|buf| Some(Buffer::from(buf))) + } else { + // We are done with the stream. + Ok(None) + } + } +} diff --git a/nodejs/src/lib.rs b/nodejs/src/lib.rs index c0f6e953..463ec4ce 100644 --- a/nodejs/src/lib.rs +++ b/nodejs/src/lib.rs @@ -17,6 +17,7 @@ use napi_derive::*; mod connection; mod index; +mod iterator; mod query; mod table; diff --git a/nodejs/src/query.rs b/nodejs/src/query.rs index f96df3ec..5bea8714 100644 --- a/nodejs/src/query.rs +++ b/nodejs/src/query.rs @@ -16,7 +16,7 @@ use napi::bindgen_prelude::*; use napi_derive::napi; use vectordb::query::Query as LanceDBQuery; -use crate::table::Table; +use crate::{iterator::RecordBatchIterator, table::Table}; #[napi] pub struct Query { @@ -32,17 +32,50 @@ impl Query { } #[napi] - pub fn vector(&mut self, vector: Float32Array) { - let inn = self.inner.clone().nearest_to(&vector); - self.inner = inn; + pub fn column(&mut self, column: String) { + self.inner = self.inner.clone().column(&column); } #[napi] - pub fn to_arrow(&self) -> napi::Result<()> { - // let buf = self.inner.to_arrow().map_err(|e| { - // napi::Error::from_reason(format!("Failed to convert query to arrow: {}", e)) - // })?; - // Ok(buf) - todo!() + pub fn filter(&mut self, filter: String) { + self.inner = self.inner.clone().filter(filter); + } + + #[napi] + pub fn select(&mut self, columns: Vec) { + self.inner = self.inner.clone().select(&columns); + } + + #[napi] + pub fn limit(&mut self, limit: u32) { + self.inner = self.inner.clone().limit(limit as usize); + } + + #[napi] + pub fn prefilter(&mut self, prefilter: bool) { + self.inner = self.inner.clone().prefilter(prefilter); + } + + #[napi] + pub fn nearest_to(&mut self, vector: Float32Array) { + self.inner = self.inner.clone().nearest_to(&vector); + } + + #[napi] + pub fn refine_factor(&mut self, refine_factor: u32) { + self.inner = self.inner.clone().refine_factor(refine_factor); + } + + #[napi] + pub fn nprobes(&mut self, nprobe: u32) { + self.inner = self.inner.clone().nprobes(nprobe as usize); + } + + #[napi] + pub async fn execute_stream(&self) -> napi::Result { + let inner_stream = self.inner.execute_stream().await.map_err(|e| { + napi::Error::from_reason(format!("Failed to execute query stream: {}", e)) + })?; + Ok(RecordBatchIterator::new(Box::new(inner_stream))) } } diff --git a/nodejs/vectordb/native.d.ts b/nodejs/vectordb/native.d.ts index e5574f33..192f39a6 100644 --- a/nodejs/vectordb/native.d.ts +++ b/nodejs/vectordb/native.d.ts @@ -54,9 +54,20 @@ export class IndexBuilder { scalar(): void build(): Promise } +/** Typescript-style Async Iterator over RecordBatches */ +export class RecordBatchIterator { + next(): Promise +} export class Query { - vector(vector: Float32Array): void - toArrow(): void + column(column: string): void + filter(filter: string): void + select(columns: Array): void + limit(limit: number): void + prefilter(prefilter: boolean): void + nearestTo(vector: Float32Array): void + refineFactor(refineFactor: number): void + nprobes(nprobe: number): void + executeStream(): Promise } export class Table { /** Return Schema as empty Arrow IPC file. */ diff --git a/nodejs/vectordb/native.js b/nodejs/vectordb/native.js index 3a2ed038..4abf5eb5 100644 --- a/nodejs/vectordb/native.js +++ b/nodejs/vectordb/native.js @@ -295,12 +295,13 @@ if (!nativeBinding) { throw new Error(`Failed to load native binding`) } -const { Connection, IndexType, MetricType, IndexBuilder, Query, Table, WriteMode, connect } = nativeBinding +const { Connection, IndexType, MetricType, IndexBuilder, RecordBatchIterator, Query, Table, WriteMode, connect } = nativeBinding module.exports.Connection = Connection module.exports.IndexType = IndexType module.exports.MetricType = MetricType module.exports.IndexBuilder = IndexBuilder +module.exports.RecordBatchIterator = RecordBatchIterator module.exports.Query = Query module.exports.Table = Table module.exports.WriteMode = WriteMode diff --git a/nodejs/vectordb/query.ts b/nodejs/vectordb/query.ts index 18e2e298..1a662e1c 100644 --- a/nodejs/vectordb/query.ts +++ b/nodejs/vectordb/query.ts @@ -12,46 +12,73 @@ // See the License for the specific language governing permissions and // limitations under the License. -import { RecordBatch } from "apache-arrow"; -import { Table } from "./table"; +import { RecordBatch, tableFromIPC, Table as ArrowTable } from "apache-arrow"; +import { + RecordBatchIterator as NativeBatchIterator, + Query as NativeQuery, + Table as NativeTable, +} from "./native"; -// TODO: re-eanble eslint once we have a real implementation -/* eslint-disable */ class RecordBatchIterator implements AsyncIterator { - next( - ...args: [] | [undefined] - ): Promise, any>> { - throw new Error("Method not implemented."); + private promised_inner?: Promise; + private inner?: NativeBatchIterator; + + constructor( + inner?: NativeBatchIterator, + promise?: Promise + ) { + // TODO: check promise reliably so we dont need to pass two arguments. + this.inner = inner; + this.promised_inner = promise; } - return?(value?: any): Promise, any>> { - throw new Error("Method not implemented."); - } - throw?(e?: any): Promise, any>> { - throw new Error("Method not implemented."); + + async next(): Promise, any>> { + if (this.inner === undefined) { + this.inner = await this.promised_inner; + } + if (this.inner === undefined) { + throw new Error("Invalid iterator state state"); + } + const n = await this.inner.next(); + if (n == null) { + return Promise.resolve({ done: true, value: null }); + } + const tbl = tableFromIPC(n); + if (tbl.batches.length != 1) { + throw new Error("Expected only one batch"); + } + return Promise.resolve({ done: false, value: tbl.batches[0] }); } } /* eslint-enable */ /** Query executor */ export class Query implements AsyncIterable { - private readonly tbl: Table; - private _filter?: string; - private _limit?: number; + private readonly inner: NativeQuery; - // Vector search - private _vector?: Float32Array; - private _nprobes?: number; - private _refine_factor?: number = 1; + constructor(tbl: NativeTable) { + this.inner = tbl.query(); + } - constructor(tbl: Table) { - this.tbl = tbl; + /** Set the column to run query. */ + column(column: string): Query { + this.inner.column(column); + return this; } /** Set the filter predicate, only returns the results that satisfy the filter. * */ filter(predicate: string): Query { - this._filter = predicate; + this.inner.filter(predicate); + return this; + } + + /** + * Select the columns to return. If not set, all columns are returned. + */ + select(columns: string[]): Query { + this.inner.select(columns); return this; } @@ -59,35 +86,67 @@ export class Query implements AsyncIterable { * Set the limit of rows to return. */ limit(limit: number): Query { - this._limit = limit; + this.inner.limit(limit); + return this; + } + + prefilter(prefilter: boolean): Query { + this.inner.prefilter(prefilter); return this; } /** * Set the query vector. */ - vector(vector: number[]): Query { - this._vector = Float32Array.from(vector); + nearestTo(vector: number[]): Query { + this.inner.nearestTo(Float32Array.from(vector)); return this; } /** - * Set the number of probes to use for the query. + * Set the number of IVF partitions to use for the query. */ nprobes(nprobes: number): Query { - this._nprobes = nprobes; + this.inner.nprobes(nprobes); return this; } /** * Set the refine factor for the query. */ - refine_factor(refine_factor: number): Query { - this._refine_factor = refine_factor; + refineFactor(refine_factor: number): Query { + this.inner.refineFactor(refine_factor); return this; } - [Symbol.asyncIterator](): AsyncIterator, any, undefined> { - throw new RecordBatchIterator(); + /** + * Execute the query and return the results as an AsyncIterator. + */ + async executeStream(): Promise { + const inner = await this.inner.executeStream(); + return new RecordBatchIterator(inner); + } + + /** Collect the results as an Arrow Table. */ + async toArrow(): Promise { + const batches = []; + for await (const batch of this) { + batches.push(batch); + } + return new ArrowTable(batches); + } + + /** Returns a JSON Array of All results. + * + */ + async toArray(): Promise { + const tbl = await this.toArrow(); + // eslint-disable-next-line @typescript-eslint/no-unsafe-return + return tbl.toArray(); + } + + [Symbol.asyncIterator](): AsyncIterator> { + const promise = this.inner.executeStream(); + return new RecordBatchIterator(undefined, promise); } } diff --git a/nodejs/vectordb/table.ts b/nodejs/vectordb/table.ts index 05c88028..ec3d31b9 100644 --- a/nodejs/vectordb/table.ts +++ b/nodejs/vectordb/table.ts @@ -95,10 +95,58 @@ export class Table { return builder; } - search(vector?: number[]): Query { - const q = new Query(this); - if (vector !== undefined) { - q.vector(vector); + /** + * Create a generic {@link Query} Builder. + * + * When appropriate, various indices and statistics based pruning will be used to + * accelerate the query. + * + * @example + * + * ### Run a SQL-style query + * ```typescript + * for await (const batch of table.query() + * .filter("id > 1").select(["id"]).limit(20)) { + * console.log(batch); + * } + * ``` + * + * ### Run Top-10 vector similarity search + * ```typescript + * for await (const batch of table.query() + * .nearestTo([1, 2, 3]) + * .refineFactor(5).nprobe(10) + * .limit(10)) { + * console.log(batch); + * } + *``` + * + * ### Scan the full dataset + * ```typescript + * for await (const batch of table.query()) { + * console.log(batch); + * } + * + * ### Return the full dataset as Arrow Table + * ```typescript + * let arrowTbl = await table.query().nearestTo([1.0, 2.0, 0.5, 6.7]).toArrow(); + * ``` + * + * @returns {@link Query} + */ + query(): Query { + return new Query(this.inner); + } + + /** Search the table with a given query vector. + * + * This is a convenience method for preparing an ANN {@link Query}. + */ + search(vector: number[], column?: string): Query { + const q = this.query(); + q.nearestTo(vector); + if (column !== undefined) { + q.column(column); } return q; } diff --git a/rust/vectordb/src/ipc.rs b/rust/vectordb/src/ipc.rs index 70cf324d..54a17a8a 100644 --- a/rust/vectordb/src/ipc.rs +++ b/rust/vectordb/src/ipc.rs @@ -16,10 +16,10 @@ use std::io::Cursor; -use arrow_array::RecordBatchReader; -use arrow_ipc::reader::StreamReader; +use arrow_array::{RecordBatch, RecordBatchReader}; +use arrow_ipc::{reader::StreamReader, writer::FileWriter}; -use crate::Result; +use crate::{Error, Result}; /// Convert a Arrow IPC file to a batch reader pub fn ipc_file_to_batches(buf: Vec) -> Result { @@ -28,6 +28,22 @@ pub fn ipc_file_to_batches(buf: Vec) -> Result { Ok(reader) } +/// Convert record batches to Arrow IPC file +pub fn batches_to_ipc_file(batches: &[RecordBatch]) -> Result> { + if batches.is_empty() { + return Err(Error::Store { + message: "No batches to write".to_string(), + }); + } + let schema = batches[0].schema(); + let mut writer = FileWriter::try_new(vec![], &schema)?; + for batch in batches { + writer.write(batch)?; + } + writer.finish()?; + Ok(writer.into_inner()?) +} + #[cfg(test)] mod tests { diff --git a/rust/vectordb/src/query.rs b/rust/vectordb/src/query.rs index e7f87fdd..7ce900b7 100644 --- a/rust/vectordb/src/query.rs +++ b/rust/vectordb/src/query.rs @@ -22,6 +22,7 @@ use lance_linalg::distance::MetricType; use crate::error::Result; use crate::utils::default_vector_column; +use crate::Error; const DEFAULT_TOP_K: usize = 10; @@ -93,6 +94,19 @@ impl Query { let arrow_schema = Schema::from(self.dataset.schema()); default_vector_column(&arrow_schema, Some(query.len() as i32))? }; + let field = self.dataset.schema().field(&column).ok_or(Error::Store { + message: format!("Column {} not found in dataset schema", column), + })?; + if !matches!(field.data_type(), arrow_schema::DataType::FixedSizeList(f, dim) if f.data_type().is_floating() && dim == query.len() as i32) + { + return Err(Error::Store { + message: format!( + "Vector column '{}' does not match the dimension of the query vector: dim={}", + column, + query.len(), + ), + }); + } scanner.nearest(&column, query, self.limit.unwrap_or(DEFAULT_TOP_K))?; } else { // If there is no vector query, it's ok to not have a limit