From 1ba19d728e6bceb7148f27ad573583de3a96786a Mon Sep 17 00:00:00 2001 From: Vedant Madane <6527493+VedantMadane@users.noreply.github.com> Date: Mon, 30 Mar 2026 23:45:35 +0530 Subject: [PATCH] feat(node): support Float16, Float64, and Uint8 vector queries (#3193) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #2716 ## Summary Add support for querying with Float16Array, Float64Array, and Uint8Array vectors in the Node.js SDK, eliminating precision loss from the previous \Float32Array.from()\ conversion. ## Implementation Follows @wjones127's [5-step plan](https://github.com/lancedb/lancedb/issues/2716#issuecomment-3447750543): ### Rust (\ odejs/src/query.rs\) 1. \ytes_to_arrow_array(data: Uint8Array, dtype: String)\ helper that: - Creates an Arrow \Buffer\ from the raw bytes - Wraps it in a typed \ScalarBuffer\ based on the dtype enum - Constructs a \PrimitiveArray\ and returns \Arc\ 2. \ earest_to_raw(data, dtype)\ and \dd_query_vector_raw(data, dtype)\ NAPI methods that pass the type-erased array to the core \ earest_to\/\dd_query_vector\ which already accept \impl IntoQueryVector\ for \Arc\ ### TypeScript (\ odejs/lancedb/query.ts\, \rrow.ts\) 3. Extended \IntoVector\ type to include \Uint8Array\ (and \Float16Array\ via runtime check for Node 22+) 4. \xtractVectorBuffer()\ helper detects non-Float32 typed arrays and extracts their underlying byte buffer + dtype string 5. \ earestTo()\ and \ddQueryVector()\ route through the raw NAPI path when the input is Float16/Float64/Uint8 ### Backward compatibility Existing \Float32Array\ and \ umber[]\ inputs are unchanged -- they still use the original \ earest_to(Float32Array)\ NAPI method. The new raw path is only used when a non-Float32 typed array is detected. ## Usage \\\ ypescript // Float16Array (Node 22+) -- no precision loss const f16vec = new Float16Array([0.1, 0.2, 0.3]); const results = await table.query().nearestTo(f16vec).limit(10).toArray(); // Float64Array -- no precision loss const f64vec = new Float64Array([0.1, 0.2, 0.3]); const results = await table.query().nearestTo(f64vec).limit(10).toArray(); // Uint8Array (binary embeddings) const u8vec = new Uint8Array([1, 0, 1, 1, 0]); const results = await table.query().nearestTo(u8vec).limit(10).toArray(); // Existing usage unchanged const results = await table.query().nearestTo([0.1, 0.2, 0.3]).limit(10).toArray(); \\\ ## Note on dependencies The Rust side uses \rrow_array\, \rrow_buffer\, and \half\ crates. These should already be in the dependency tree via \lancedb\ core, but \Cargo.toml\ may need explicit entries for \half\ and the arrow sub-crates in the nodejs workspace. --------- Signed-off-by: Vedant Madane <6527493+VedantMadane@users.noreply.github.com> Co-authored-by: Will Jones --- Cargo.lock | 2 + .../embedding/classes/EmbeddingFunction.md | 4 +- .../classes/TextEmbeddingFunction.md | 4 +- docs/src/js/type-aliases/IntoVector.md | 7 +- nodejs/Cargo.toml | 2 + nodejs/__test__/vector_types.test.ts | 110 ++++++++++++++++++ nodejs/lancedb/arrow.ts | 37 +++++- nodejs/lancedb/query.ts | 39 ++++--- nodejs/src/query.rs | 47 ++++++++ 9 files changed, 232 insertions(+), 20 deletions(-) create mode 100644 nodejs/__test__/vector_types.test.ts diff --git a/Cargo.lock b/Cargo.lock index 6a6843865..91eb6b63d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4700,6 +4700,7 @@ name = "lancedb-nodejs" version = "0.27.2-beta.1" dependencies = [ "arrow-array", + "arrow-buffer", "arrow-ipc", "arrow-schema", "async-trait", @@ -4707,6 +4708,7 @@ dependencies = [ "aws-lc-sys", "env_logger", "futures", + "half", "lancedb", "log", "lzma-sys", diff --git a/docs/src/js/namespaces/embedding/classes/EmbeddingFunction.md b/docs/src/js/namespaces/embedding/classes/EmbeddingFunction.md index 66d6ee162..574a0d71f 100644 --- a/docs/src/js/namespaces/embedding/classes/EmbeddingFunction.md +++ b/docs/src/js/namespaces/embedding/classes/EmbeddingFunction.md @@ -52,7 +52,7 @@ new EmbeddingFunction(): EmbeddingFunction ### computeQueryEmbeddings() ```ts -computeQueryEmbeddings(data): Promise +computeQueryEmbeddings(data): Promise ``` Compute the embeddings for a single query @@ -63,7 +63,7 @@ Compute the embeddings for a single query #### Returns -`Promise`<`number`[] \| `Float32Array` \| `Float64Array`> +`Promise`<`number`[] \| `Uint8Array` \| `Float32Array` \| `Float64Array`> *** diff --git a/docs/src/js/namespaces/embedding/classes/TextEmbeddingFunction.md b/docs/src/js/namespaces/embedding/classes/TextEmbeddingFunction.md index 8aee4f44c..444c4c3f0 100644 --- a/docs/src/js/namespaces/embedding/classes/TextEmbeddingFunction.md +++ b/docs/src/js/namespaces/embedding/classes/TextEmbeddingFunction.md @@ -37,7 +37,7 @@ new TextEmbeddingFunction(): TextEmbeddingFunction ### computeQueryEmbeddings() ```ts -computeQueryEmbeddings(data): Promise +computeQueryEmbeddings(data): Promise ``` Compute the embeddings for a single query @@ -48,7 +48,7 @@ Compute the embeddings for a single query #### Returns -`Promise`<`number`[] \| `Float32Array` \| `Float64Array`> +`Promise`<`number`[] \| `Uint8Array` \| `Float32Array` \| `Float64Array`> #### Overrides diff --git a/docs/src/js/type-aliases/IntoVector.md b/docs/src/js/type-aliases/IntoVector.md index 813d3cdc8..bfab8a0d3 100644 --- a/docs/src/js/type-aliases/IntoVector.md +++ b/docs/src/js/type-aliases/IntoVector.md @@ -7,5 +7,10 @@ # Type Alias: IntoVector ```ts -type IntoVector: Float32Array | Float64Array | number[] | Promise; +type IntoVector: + | Float32Array + | Float64Array + | Uint8Array + | number[] + | Promise; ``` diff --git a/nodejs/Cargo.toml b/nodejs/Cargo.toml index 2062ffa93..fe31b3c2d 100644 --- a/nodejs/Cargo.toml +++ b/nodejs/Cargo.toml @@ -15,6 +15,8 @@ crate-type = ["cdylib"] async-trait.workspace = true arrow-ipc.workspace = true arrow-array.workspace = true +arrow-buffer = "57.2" +half.workspace = true arrow-schema.workspace = true env_logger.workspace = true futures.workspace = true diff --git a/nodejs/__test__/vector_types.test.ts b/nodejs/__test__/vector_types.test.ts new file mode 100644 index 000000000..4ac524cc6 --- /dev/null +++ b/nodejs/__test__/vector_types.test.ts @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +import * as tmp from "tmp"; + +import { type Table, connect } from "../lancedb"; +import { + Field, + FixedSizeList, + Float32, + Int64, + Schema, + makeArrowTable, +} from "../lancedb/arrow"; + +describe("Vector query with different typed arrays", () => { + let tmpDir: tmp.DirResult; + + afterEach(() => { + tmpDir?.removeCallback(); + }); + + async function createFloat32Table(): Promise { + tmpDir = tmp.dirSync({ unsafeCleanup: true }); + const db = await connect(tmpDir.name); + const schema = new Schema([ + new Field("id", new Int64(), true), + new Field( + "vec", + new FixedSizeList(2, new Field("item", new Float32())), + true, + ), + ]); + const data = makeArrowTable( + [ + { id: 1n, vec: [1.0, 0.0] }, + { id: 2n, vec: [0.0, 1.0] }, + { id: 3n, vec: [1.0, 1.0] }, + ], + { schema }, + ); + return db.createTable("test_f32", data); + } + + it("should search with Float32Array (baseline)", async () => { + const table = await createFloat32Table(); + const results = await table + .query() + .nearestTo(new Float32Array([1.0, 0.0])) + .limit(1) + .toArray(); + + expect(results.length).toBe(1); + expect(Number(results[0].id)).toBe(1); + }); + + it("should search with number[] (backward compat)", async () => { + const table = await createFloat32Table(); + const results = await table + .query() + .nearestTo([1.0, 0.0]) + .limit(1) + .toArray(); + + expect(results.length).toBe(1); + expect(Number(results[0].id)).toBe(1); + }); + + it("should search with Float64Array via raw path", async () => { + const table = await createFloat32Table(); + const results = await table + .query() + .nearestTo(new Float64Array([1.0, 0.0])) + .limit(1) + .toArray(); + + expect(results.length).toBe(1); + expect(Number(results[0].id)).toBe(1); + }); + + it("should add multiple query vectors with Float64Array", async () => { + const table = await createFloat32Table(); + const results = await table + .query() + .nearestTo(new Float64Array([1.0, 0.0])) + .addQueryVector(new Float64Array([0.0, 1.0])) + .limit(2) + .toArray(); + + expect(results.length).toBeGreaterThanOrEqual(2); + }); + + // Float16Array is only available in Node 22+; not in TypeScript's standard lib yet + const float16ArrayCtor = (globalThis as unknown as Record) + .Float16Array as (new (values: number[]) => unknown) | undefined; + const hasFloat16 = float16ArrayCtor !== undefined; + const f16it = hasFloat16 ? it : it.skip; + + f16it("should search with Float16Array via raw path", async () => { + const table = await createFloat32Table(); + const results = await table + .query() + .nearestTo(new float16ArrayCtor!([1.0, 0.0]) as Float32Array) + .limit(1) + .toArray(); + + expect(results.length).toBe(1); + expect(Number(results[0].id)).toBe(1); + }); +}); diff --git a/nodejs/lancedb/arrow.ts b/nodejs/lancedb/arrow.ts index 7fff42f47..84f5ddf7b 100644 --- a/nodejs/lancedb/arrow.ts +++ b/nodejs/lancedb/arrow.ts @@ -117,8 +117,9 @@ export type TableLike = export type IntoVector = | Float32Array | Float64Array + | Uint8Array | number[] - | Promise; + | Promise; export type MultiVector = IntoVector[]; @@ -126,14 +127,48 @@ export function isMultiVector(value: unknown): value is MultiVector { return Array.isArray(value) && isIntoVector(value[0]); } +// Float16Array is not in TypeScript's standard lib yet; access dynamically +type Float16ArrayCtor = new ( + ...args: unknown[] +) => { buffer: ArrayBuffer; byteOffset: number; byteLength: number }; +const float16ArrayCtor = (globalThis as unknown as Record) + .Float16Array as Float16ArrayCtor | undefined; + export function isIntoVector(value: unknown): value is IntoVector { return ( value instanceof Float32Array || value instanceof Float64Array || + value instanceof Uint8Array || + (float16ArrayCtor !== undefined && value instanceof float16ArrayCtor) || (Array.isArray(value) && !Array.isArray(value[0])) ); } +/** + * Extract the underlying byte buffer and data type from a typed array + * for passing to the Rust NAPI layer without precision loss. + */ +export function extractVectorBuffer( + vector: Float32Array | Float64Array | Uint8Array, +): { data: Uint8Array; dtype: string } | null { + if (float16ArrayCtor !== undefined && vector instanceof float16ArrayCtor) { + return { + data: new Uint8Array(vector.buffer, vector.byteOffset, vector.byteLength), + dtype: "float16", + }; + } + if (vector instanceof Float64Array) { + return { + data: new Uint8Array(vector.buffer, vector.byteOffset, vector.byteLength), + dtype: "float64", + }; + } + if (vector instanceof Uint8Array && !(vector instanceof Float32Array)) { + return { data: vector, dtype: "uint8" }; + } + return null; +} + export function isArrowTable(value: object): value is TableLike { if (value instanceof ArrowTable) return true; return "schema" in value && "batches" in value; diff --git a/nodejs/lancedb/query.ts b/nodejs/lancedb/query.ts index a10191ed2..c077234ec 100644 --- a/nodejs/lancedb/query.ts +++ b/nodejs/lancedb/query.ts @@ -5,6 +5,7 @@ import { Table as ArrowTable, type IntoVector, RecordBatch, + extractVectorBuffer, fromBufferToRecordBatch, fromRecordBatchToBuffer, tableFromIPC, @@ -661,10 +662,8 @@ export class VectorQuery extends StandardQueryBase { const res = (async () => { try { const v = await vector; - const arr = Float32Array.from(v); - // // biome-ignore lint/suspicious/noExplicitAny: we need to get the `inner`, but js has no package scoping - const value: any = this.addQueryVector(arr); + const value: any = this.addQueryVector(v); const inner = value.inner as | NativeVectorQuery | Promise; @@ -676,7 +675,12 @@ export class VectorQuery extends StandardQueryBase { return new VectorQuery(res); } else { super.doCall((inner) => { - inner.addQueryVector(Float32Array.from(vector)); + const raw = Array.isArray(vector) ? null : extractVectorBuffer(vector); + if (raw) { + inner.addQueryVectorRaw(raw.data, raw.dtype); + } else { + inner.addQueryVector(Float32Array.from(vector as number[])); + } }); return this; } @@ -765,14 +769,23 @@ export class Query extends StandardQueryBase { * a default `limit` of 10 will be used. @see {@link Query#limit} */ nearestTo(vector: IntoVector): VectorQuery { + const callNearestTo = ( + inner: NativeQuery, + resolved: Float32Array | Float64Array | Uint8Array | number[], + ): NativeVectorQuery => { + const raw = Array.isArray(resolved) + ? null + : extractVectorBuffer(resolved); + if (raw) { + return inner.nearestToRaw(raw.data, raw.dtype); + } + return inner.nearestTo(Float32Array.from(resolved as number[])); + }; + if (this.inner instanceof Promise) { const nativeQuery = this.inner.then(async (inner) => { - if (vector instanceof Promise) { - const arr = await vector.then((v) => Float32Array.from(v)); - return inner.nearestTo(arr); - } else { - return inner.nearestTo(Float32Array.from(vector)); - } + const resolved = vector instanceof Promise ? await vector : vector; + return callNearestTo(inner, resolved); }); return new VectorQuery(nativeQuery); } @@ -780,10 +793,8 @@ export class Query extends StandardQueryBase { const res = (async () => { try { const v = await vector; - const arr = Float32Array.from(v); - // // biome-ignore lint/suspicious/noExplicitAny: we need to get the `inner`, but js has no package scoping - const value: any = this.nearestTo(arr); + const value: any = this.nearestTo(v); const inner = value.inner as | NativeVectorQuery | Promise; @@ -794,7 +805,7 @@ export class Query extends StandardQueryBase { })(); return new VectorQuery(res); } else { - const vectorQuery = this.inner.nearestTo(Float32Array.from(vector)); + const vectorQuery = callNearestTo(this.inner, vector); return new VectorQuery(vectorQuery); } } diff --git a/nodejs/src/query.rs b/nodejs/src/query.rs index 4ad42f32f..4516385d5 100644 --- a/nodejs/src/query.rs +++ b/nodejs/src/query.rs @@ -3,6 +3,12 @@ use std::sync::Arc; +use arrow_array::{ + Array, Float16Array as ArrowFloat16Array, Float32Array as ArrowFloat32Array, + Float64Array as ArrowFloat64Array, UInt8Array as ArrowUInt8Array, +}; +use arrow_buffer::ScalarBuffer; +use half::f16; use lancedb::index::scalar::{ BooleanQuery, BoostQuery, FtsQuery, FullTextSearchQuery, MatchQuery, MultiMatchQuery, Occur, Operator, PhraseQuery, @@ -24,6 +30,33 @@ use crate::rerankers::RerankHybridCallbackArgs; use crate::rerankers::Reranker; use crate::util::{parse_distance_type, schema_to_buffer}; +fn bytes_to_arrow_array(data: Uint8Array, dtype: String) -> napi::Result> { + let buf = arrow_buffer::Buffer::from(data.to_vec()); + let num_bytes = buf.len(); + match dtype.as_str() { + "float16" => { + let scalar_buf = ScalarBuffer::::new(buf, 0, num_bytes / 2); + Ok(Arc::new(ArrowFloat16Array::new(scalar_buf, None))) + } + "float32" => { + let scalar_buf = ScalarBuffer::::new(buf, 0, num_bytes / 4); + Ok(Arc::new(ArrowFloat32Array::new(scalar_buf, None))) + } + "float64" => { + let scalar_buf = ScalarBuffer::::new(buf, 0, num_bytes / 8); + Ok(Arc::new(ArrowFloat64Array::new(scalar_buf, None))) + } + "uint8" => { + let scalar_buf = ScalarBuffer::::new(buf, 0, num_bytes); + Ok(Arc::new(ArrowUInt8Array::new(scalar_buf, None))) + } + _ => Err(napi::Error::from_reason(format!( + "Unsupported vector dtype: {}. Expected one of: float16, float32, float64, uint8", + dtype + ))), + } +} + #[napi] pub struct Query { inner: LanceDbQuery, @@ -78,6 +111,13 @@ impl Query { Ok(VectorQuery { inner }) } + #[napi] + pub fn nearest_to_raw(&mut self, data: Uint8Array, dtype: String) -> Result { + let array = bytes_to_arrow_array(data, dtype)?; + let inner = self.inner.clone().nearest_to(array).default_error()?; + Ok(VectorQuery { inner }) + } + #[napi] pub fn fast_search(&mut self) { self.inner = self.inner.clone().fast_search(); @@ -163,6 +203,13 @@ impl VectorQuery { Ok(()) } + #[napi] + pub fn add_query_vector_raw(&mut self, data: Uint8Array, dtype: String) -> Result<()> { + let array = bytes_to_arrow_array(data, dtype)?; + self.inner = self.inner.clone().add_query_vector(array).default_error()?; + Ok(()) + } + #[napi] pub fn distance_type(&mut self, distance_type: String) -> napi::Result<()> { let distance_type = parse_distance_type(distance_type)?;