mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-25 07:50:40 +00:00
feat(node): support Float16, Float64, and Uint8 vector queries (#3193)
Fixes #2716 ## Summary Add support for querying with Float16Array, Float64Array, and Uint8Array vectors in the Node.js SDK, eliminating precision loss from the previous \Float32Array.from()\ conversion. ## Implementation Follows @wjones127's [5-step plan](https://github.com/lancedb/lancedb/issues/2716#issuecomment-3447750543): ### Rust (\ odejs/src/query.rs\) 1. \ytes_to_arrow_array(data: Uint8Array, dtype: String)\ helper that: - Creates an Arrow \Buffer\ from the raw bytes - Wraps it in a typed \ScalarBuffer<T>\ based on the dtype enum - Constructs a \PrimitiveArray\ and returns \Arc<dyn Array>\ 2. \ earest_to_raw(data, dtype)\ and \dd_query_vector_raw(data, dtype)\ NAPI methods that pass the type-erased array to the core \ earest_to\/\dd_query_vector\ which already accept \impl IntoQueryVector\ for \Arc<dyn Array>\ ### TypeScript (\ odejs/lancedb/query.ts\, \rrow.ts\) 3. Extended \IntoVector\ type to include \Uint8Array\ (and \Float16Array\ via runtime check for Node 22+) 4. \xtractVectorBuffer()\ helper detects non-Float32 typed arrays and extracts their underlying byte buffer + dtype string 5. \ earestTo()\ and \ddQueryVector()\ route through the raw NAPI path when the input is Float16/Float64/Uint8 ### Backward compatibility Existing \Float32Array\ and \ umber[]\ inputs are unchanged -- they still use the original \ earest_to(Float32Array)\ NAPI method. The new raw path is only used when a non-Float32 typed array is detected. ## Usage \\\ ypescript // Float16Array (Node 22+) -- no precision loss const f16vec = new Float16Array([0.1, 0.2, 0.3]); const results = await table.query().nearestTo(f16vec).limit(10).toArray(); // Float64Array -- no precision loss const f64vec = new Float64Array([0.1, 0.2, 0.3]); const results = await table.query().nearestTo(f64vec).limit(10).toArray(); // Uint8Array (binary embeddings) const u8vec = new Uint8Array([1, 0, 1, 1, 0]); const results = await table.query().nearestTo(u8vec).limit(10).toArray(); // Existing usage unchanged const results = await table.query().nearestTo([0.1, 0.2, 0.3]).limit(10).toArray(); \\\ ## Note on dependencies The Rust side uses \rrow_array\, \rrow_buffer\, and \half\ crates. These should already be in the dependency tree via \lancedb\ core, but \Cargo.toml\ may need explicit entries for \half\ and the arrow sub-crates in the nodejs workspace. --------- Signed-off-by: Vedant Madane <6527493+VedantMadane@users.noreply.github.com> Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -4700,6 +4700,7 @@ name = "lancedb-nodejs"
|
||||
version = "0.27.2-beta.1"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
"arrow-ipc",
|
||||
"arrow-schema",
|
||||
"async-trait",
|
||||
@@ -4707,6 +4708,7 @@ dependencies = [
|
||||
"aws-lc-sys",
|
||||
"env_logger",
|
||||
"futures",
|
||||
"half",
|
||||
"lancedb",
|
||||
"log",
|
||||
"lzma-sys",
|
||||
|
||||
@@ -52,7 +52,7 @@ new EmbeddingFunction<T, M>(): EmbeddingFunction<T, M>
|
||||
### computeQueryEmbeddings()
|
||||
|
||||
```ts
|
||||
computeQueryEmbeddings(data): Promise<number[] | Float32Array | Float64Array>
|
||||
computeQueryEmbeddings(data): Promise<number[] | Uint8Array | Float32Array | Float64Array>
|
||||
```
|
||||
|
||||
Compute the embeddings for a single query
|
||||
@@ -63,7 +63,7 @@ Compute the embeddings for a single query
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`number`[] \| `Float32Array` \| `Float64Array`>
|
||||
`Promise`<`number`[] \| `Uint8Array` \| `Float32Array` \| `Float64Array`>
|
||||
|
||||
***
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ new TextEmbeddingFunction<M>(): TextEmbeddingFunction<M>
|
||||
### computeQueryEmbeddings()
|
||||
|
||||
```ts
|
||||
computeQueryEmbeddings(data): Promise<number[] | Float32Array | Float64Array>
|
||||
computeQueryEmbeddings(data): Promise<number[] | Uint8Array | Float32Array | Float64Array>
|
||||
```
|
||||
|
||||
Compute the embeddings for a single query
|
||||
@@ -48,7 +48,7 @@ Compute the embeddings for a single query
|
||||
|
||||
#### Returns
|
||||
|
||||
`Promise`<`number`[] \| `Float32Array` \| `Float64Array`>
|
||||
`Promise`<`number`[] \| `Uint8Array` \| `Float32Array` \| `Float64Array`>
|
||||
|
||||
#### Overrides
|
||||
|
||||
|
||||
@@ -7,5 +7,10 @@
|
||||
# Type Alias: IntoVector
|
||||
|
||||
```ts
|
||||
type IntoVector: Float32Array | Float64Array | number[] | Promise<Float32Array | Float64Array | number[]>;
|
||||
type IntoVector:
|
||||
| Float32Array
|
||||
| Float64Array
|
||||
| Uint8Array
|
||||
| number[]
|
||||
| Promise<Float32Array | Float64Array | Uint8Array | number[]>;
|
||||
```
|
||||
|
||||
@@ -15,6 +15,8 @@ crate-type = ["cdylib"]
|
||||
async-trait.workspace = true
|
||||
arrow-ipc.workspace = true
|
||||
arrow-array.workspace = true
|
||||
arrow-buffer = "57.2"
|
||||
half.workspace = true
|
||||
arrow-schema.workspace = true
|
||||
env_logger.workspace = true
|
||||
futures.workspace = true
|
||||
|
||||
110
nodejs/__test__/vector_types.test.ts
Normal file
110
nodejs/__test__/vector_types.test.ts
Normal file
@@ -0,0 +1,110 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import * as tmp from "tmp";
|
||||
|
||||
import { type Table, connect } from "../lancedb";
|
||||
import {
|
||||
Field,
|
||||
FixedSizeList,
|
||||
Float32,
|
||||
Int64,
|
||||
Schema,
|
||||
makeArrowTable,
|
||||
} from "../lancedb/arrow";
|
||||
|
||||
describe("Vector query with different typed arrays", () => {
|
||||
let tmpDir: tmp.DirResult;
|
||||
|
||||
afterEach(() => {
|
||||
tmpDir?.removeCallback();
|
||||
});
|
||||
|
||||
async function createFloat32Table(): Promise<Table> {
|
||||
tmpDir = tmp.dirSync({ unsafeCleanup: true });
|
||||
const db = await connect(tmpDir.name);
|
||||
const schema = new Schema([
|
||||
new Field("id", new Int64(), true),
|
||||
new Field(
|
||||
"vec",
|
||||
new FixedSizeList(2, new Field("item", new Float32())),
|
||||
true,
|
||||
),
|
||||
]);
|
||||
const data = makeArrowTable(
|
||||
[
|
||||
{ id: 1n, vec: [1.0, 0.0] },
|
||||
{ id: 2n, vec: [0.0, 1.0] },
|
||||
{ id: 3n, vec: [1.0, 1.0] },
|
||||
],
|
||||
{ schema },
|
||||
);
|
||||
return db.createTable("test_f32", data);
|
||||
}
|
||||
|
||||
it("should search with Float32Array (baseline)", async () => {
|
||||
const table = await createFloat32Table();
|
||||
const results = await table
|
||||
.query()
|
||||
.nearestTo(new Float32Array([1.0, 0.0]))
|
||||
.limit(1)
|
||||
.toArray();
|
||||
|
||||
expect(results.length).toBe(1);
|
||||
expect(Number(results[0].id)).toBe(1);
|
||||
});
|
||||
|
||||
it("should search with number[] (backward compat)", async () => {
|
||||
const table = await createFloat32Table();
|
||||
const results = await table
|
||||
.query()
|
||||
.nearestTo([1.0, 0.0])
|
||||
.limit(1)
|
||||
.toArray();
|
||||
|
||||
expect(results.length).toBe(1);
|
||||
expect(Number(results[0].id)).toBe(1);
|
||||
});
|
||||
|
||||
it("should search with Float64Array via raw path", async () => {
|
||||
const table = await createFloat32Table();
|
||||
const results = await table
|
||||
.query()
|
||||
.nearestTo(new Float64Array([1.0, 0.0]))
|
||||
.limit(1)
|
||||
.toArray();
|
||||
|
||||
expect(results.length).toBe(1);
|
||||
expect(Number(results[0].id)).toBe(1);
|
||||
});
|
||||
|
||||
it("should add multiple query vectors with Float64Array", async () => {
|
||||
const table = await createFloat32Table();
|
||||
const results = await table
|
||||
.query()
|
||||
.nearestTo(new Float64Array([1.0, 0.0]))
|
||||
.addQueryVector(new Float64Array([0.0, 1.0]))
|
||||
.limit(2)
|
||||
.toArray();
|
||||
|
||||
expect(results.length).toBeGreaterThanOrEqual(2);
|
||||
});
|
||||
|
||||
// Float16Array is only available in Node 22+; not in TypeScript's standard lib yet
|
||||
const float16ArrayCtor = (globalThis as unknown as Record<string, unknown>)
|
||||
.Float16Array as (new (values: number[]) => unknown) | undefined;
|
||||
const hasFloat16 = float16ArrayCtor !== undefined;
|
||||
const f16it = hasFloat16 ? it : it.skip;
|
||||
|
||||
f16it("should search with Float16Array via raw path", async () => {
|
||||
const table = await createFloat32Table();
|
||||
const results = await table
|
||||
.query()
|
||||
.nearestTo(new float16ArrayCtor!([1.0, 0.0]) as Float32Array)
|
||||
.limit(1)
|
||||
.toArray();
|
||||
|
||||
expect(results.length).toBe(1);
|
||||
expect(Number(results[0].id)).toBe(1);
|
||||
});
|
||||
});
|
||||
@@ -117,8 +117,9 @@ export type TableLike =
|
||||
export type IntoVector =
|
||||
| Float32Array
|
||||
| Float64Array
|
||||
| Uint8Array
|
||||
| number[]
|
||||
| Promise<Float32Array | Float64Array | number[]>;
|
||||
| Promise<Float32Array | Float64Array | Uint8Array | number[]>;
|
||||
|
||||
export type MultiVector = IntoVector[];
|
||||
|
||||
@@ -126,14 +127,48 @@ export function isMultiVector(value: unknown): value is MultiVector {
|
||||
return Array.isArray(value) && isIntoVector(value[0]);
|
||||
}
|
||||
|
||||
// Float16Array is not in TypeScript's standard lib yet; access dynamically
|
||||
type Float16ArrayCtor = new (
|
||||
...args: unknown[]
|
||||
) => { buffer: ArrayBuffer; byteOffset: number; byteLength: number };
|
||||
const float16ArrayCtor = (globalThis as unknown as Record<string, unknown>)
|
||||
.Float16Array as Float16ArrayCtor | undefined;
|
||||
|
||||
export function isIntoVector(value: unknown): value is IntoVector {
|
||||
return (
|
||||
value instanceof Float32Array ||
|
||||
value instanceof Float64Array ||
|
||||
value instanceof Uint8Array ||
|
||||
(float16ArrayCtor !== undefined && value instanceof float16ArrayCtor) ||
|
||||
(Array.isArray(value) && !Array.isArray(value[0]))
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract the underlying byte buffer and data type from a typed array
|
||||
* for passing to the Rust NAPI layer without precision loss.
|
||||
*/
|
||||
export function extractVectorBuffer(
|
||||
vector: Float32Array | Float64Array | Uint8Array,
|
||||
): { data: Uint8Array; dtype: string } | null {
|
||||
if (float16ArrayCtor !== undefined && vector instanceof float16ArrayCtor) {
|
||||
return {
|
||||
data: new Uint8Array(vector.buffer, vector.byteOffset, vector.byteLength),
|
||||
dtype: "float16",
|
||||
};
|
||||
}
|
||||
if (vector instanceof Float64Array) {
|
||||
return {
|
||||
data: new Uint8Array(vector.buffer, vector.byteOffset, vector.byteLength),
|
||||
dtype: "float64",
|
||||
};
|
||||
}
|
||||
if (vector instanceof Uint8Array && !(vector instanceof Float32Array)) {
|
||||
return { data: vector, dtype: "uint8" };
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export function isArrowTable(value: object): value is TableLike {
|
||||
if (value instanceof ArrowTable) return true;
|
||||
return "schema" in value && "batches" in value;
|
||||
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
Table as ArrowTable,
|
||||
type IntoVector,
|
||||
RecordBatch,
|
||||
extractVectorBuffer,
|
||||
fromBufferToRecordBatch,
|
||||
fromRecordBatchToBuffer,
|
||||
tableFromIPC,
|
||||
@@ -661,10 +662,8 @@ export class VectorQuery extends StandardQueryBase<NativeVectorQuery> {
|
||||
const res = (async () => {
|
||||
try {
|
||||
const v = await vector;
|
||||
const arr = Float32Array.from(v);
|
||||
//
|
||||
// biome-ignore lint/suspicious/noExplicitAny: we need to get the `inner`, but js has no package scoping
|
||||
const value: any = this.addQueryVector(arr);
|
||||
const value: any = this.addQueryVector(v);
|
||||
const inner = value.inner as
|
||||
| NativeVectorQuery
|
||||
| Promise<NativeVectorQuery>;
|
||||
@@ -676,7 +675,12 @@ export class VectorQuery extends StandardQueryBase<NativeVectorQuery> {
|
||||
return new VectorQuery(res);
|
||||
} else {
|
||||
super.doCall((inner) => {
|
||||
inner.addQueryVector(Float32Array.from(vector));
|
||||
const raw = Array.isArray(vector) ? null : extractVectorBuffer(vector);
|
||||
if (raw) {
|
||||
inner.addQueryVectorRaw(raw.data, raw.dtype);
|
||||
} else {
|
||||
inner.addQueryVector(Float32Array.from(vector as number[]));
|
||||
}
|
||||
});
|
||||
return this;
|
||||
}
|
||||
@@ -765,14 +769,23 @@ export class Query extends StandardQueryBase<NativeQuery> {
|
||||
* a default `limit` of 10 will be used. @see {@link Query#limit}
|
||||
*/
|
||||
nearestTo(vector: IntoVector): VectorQuery {
|
||||
const callNearestTo = (
|
||||
inner: NativeQuery,
|
||||
resolved: Float32Array | Float64Array | Uint8Array | number[],
|
||||
): NativeVectorQuery => {
|
||||
const raw = Array.isArray(resolved)
|
||||
? null
|
||||
: extractVectorBuffer(resolved);
|
||||
if (raw) {
|
||||
return inner.nearestToRaw(raw.data, raw.dtype);
|
||||
}
|
||||
return inner.nearestTo(Float32Array.from(resolved as number[]));
|
||||
};
|
||||
|
||||
if (this.inner instanceof Promise) {
|
||||
const nativeQuery = this.inner.then(async (inner) => {
|
||||
if (vector instanceof Promise) {
|
||||
const arr = await vector.then((v) => Float32Array.from(v));
|
||||
return inner.nearestTo(arr);
|
||||
} else {
|
||||
return inner.nearestTo(Float32Array.from(vector));
|
||||
}
|
||||
const resolved = vector instanceof Promise ? await vector : vector;
|
||||
return callNearestTo(inner, resolved);
|
||||
});
|
||||
return new VectorQuery(nativeQuery);
|
||||
}
|
||||
@@ -780,10 +793,8 @@ export class Query extends StandardQueryBase<NativeQuery> {
|
||||
const res = (async () => {
|
||||
try {
|
||||
const v = await vector;
|
||||
const arr = Float32Array.from(v);
|
||||
//
|
||||
// biome-ignore lint/suspicious/noExplicitAny: we need to get the `inner`, but js has no package scoping
|
||||
const value: any = this.nearestTo(arr);
|
||||
const value: any = this.nearestTo(v);
|
||||
const inner = value.inner as
|
||||
| NativeVectorQuery
|
||||
| Promise<NativeVectorQuery>;
|
||||
@@ -794,7 +805,7 @@ export class Query extends StandardQueryBase<NativeQuery> {
|
||||
})();
|
||||
return new VectorQuery(res);
|
||||
} else {
|
||||
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector));
|
||||
const vectorQuery = callNearestTo(this.inner, vector);
|
||||
return new VectorQuery(vectorQuery);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,12 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::{
|
||||
Array, Float16Array as ArrowFloat16Array, Float32Array as ArrowFloat32Array,
|
||||
Float64Array as ArrowFloat64Array, UInt8Array as ArrowUInt8Array,
|
||||
};
|
||||
use arrow_buffer::ScalarBuffer;
|
||||
use half::f16;
|
||||
use lancedb::index::scalar::{
|
||||
BooleanQuery, BoostQuery, FtsQuery, FullTextSearchQuery, MatchQuery, MultiMatchQuery, Occur,
|
||||
Operator, PhraseQuery,
|
||||
@@ -24,6 +30,33 @@ use crate::rerankers::RerankHybridCallbackArgs;
|
||||
use crate::rerankers::Reranker;
|
||||
use crate::util::{parse_distance_type, schema_to_buffer};
|
||||
|
||||
fn bytes_to_arrow_array(data: Uint8Array, dtype: String) -> napi::Result<Arc<dyn Array>> {
|
||||
let buf = arrow_buffer::Buffer::from(data.to_vec());
|
||||
let num_bytes = buf.len();
|
||||
match dtype.as_str() {
|
||||
"float16" => {
|
||||
let scalar_buf = ScalarBuffer::<f16>::new(buf, 0, num_bytes / 2);
|
||||
Ok(Arc::new(ArrowFloat16Array::new(scalar_buf, None)))
|
||||
}
|
||||
"float32" => {
|
||||
let scalar_buf = ScalarBuffer::<f32>::new(buf, 0, num_bytes / 4);
|
||||
Ok(Arc::new(ArrowFloat32Array::new(scalar_buf, None)))
|
||||
}
|
||||
"float64" => {
|
||||
let scalar_buf = ScalarBuffer::<f64>::new(buf, 0, num_bytes / 8);
|
||||
Ok(Arc::new(ArrowFloat64Array::new(scalar_buf, None)))
|
||||
}
|
||||
"uint8" => {
|
||||
let scalar_buf = ScalarBuffer::<u8>::new(buf, 0, num_bytes);
|
||||
Ok(Arc::new(ArrowUInt8Array::new(scalar_buf, None)))
|
||||
}
|
||||
_ => Err(napi::Error::from_reason(format!(
|
||||
"Unsupported vector dtype: {}. Expected one of: float16, float32, float64, uint8",
|
||||
dtype
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub struct Query {
|
||||
inner: LanceDbQuery,
|
||||
@@ -78,6 +111,13 @@ impl Query {
|
||||
Ok(VectorQuery { inner })
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn nearest_to_raw(&mut self, data: Uint8Array, dtype: String) -> Result<VectorQuery> {
|
||||
let array = bytes_to_arrow_array(data, dtype)?;
|
||||
let inner = self.inner.clone().nearest_to(array).default_error()?;
|
||||
Ok(VectorQuery { inner })
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn fast_search(&mut self) {
|
||||
self.inner = self.inner.clone().fast_search();
|
||||
@@ -163,6 +203,13 @@ impl VectorQuery {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn add_query_vector_raw(&mut self, data: Uint8Array, dtype: String) -> Result<()> {
|
||||
let array = bytes_to_arrow_array(data, dtype)?;
|
||||
self.inner = self.inner.clone().add_query_vector(array).default_error()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub fn distance_type(&mut self, distance_type: String) -> napi::Result<()> {
|
||||
let distance_type = parse_distance_type(distance_type)?;
|
||||
|
||||
Reference in New Issue
Block a user