feat: support multivector for JS SDK (#2527)

Signed-off-by: BubbleCal <bubble-cal@outlook.com>
This commit is contained in:
BubbleCal
2025-07-22 21:19:34 +08:00
committed by GitHub
parent 0579303602
commit 96c66fd087
16 changed files with 262 additions and 35 deletions

View File

@@ -1863,4 +1863,43 @@ describe("column name options", () => {
expect(results[0].query_index).toBe(0);
expect(results[1].query_index).toBe(1);
});
test("index and search multivectors", async () => {
const db = await connect(tmpDir.name);
const data = [];
// generate 512 random multivectors
for (let i = 0; i < 256; i++) {
data.push({
multivector: Array.from({ length: 10 }, () =>
Array(2).fill(Math.random()),
),
});
}
const table = await db.createTable("multivectors", data, {
schema: new Schema([
new Field(
"multivector",
new List(
new Field(
"item",
new FixedSizeList(2, new Field("item", new Float32())),
),
),
),
]),
});
const results = await table.search(data[0].multivector).limit(10).toArray();
expect(results.length).toBe(10);
await table.createIndex("multivector", {
config: Index.ivfPq({ numPartitions: 2, distanceType: "cosine" }),
});
const results2 = await table
.search(data[0].multivector)
.limit(10)
.toArray();
expect(results2.length).toBe(10);
});
});

View File

@@ -107,6 +107,20 @@ export type IntoVector =
| number[]
| Promise<Float32Array | Float64Array | number[]>;
export type MultiVector = IntoVector[];
export function isMultiVector(value: unknown): value is MultiVector {
return Array.isArray(value) && isIntoVector(value[0]);
}
export function isIntoVector(value: unknown): value is IntoVector {
return (
value instanceof Float32Array ||
value instanceof Float64Array ||
(Array.isArray(value) && !Array.isArray(value[0]))
);
}
export function isArrowTable(value: object): value is TableLike {
if (value instanceof ArrowTable) return true;
return "schema" in value && "batches" in value;

View File

@@ -100,6 +100,7 @@ export {
RecordBatchLike,
DataLike,
IntoVector,
MultiVector,
} from "./arrow";
export { IntoSql, packBits } from "./util";

View File

@@ -6,9 +6,11 @@ import {
Data,
DataType,
IntoVector,
MultiVector,
Schema,
dataTypeToJson,
fromDataToBuffer,
isMultiVector,
tableFromIPC,
} from "./arrow";
@@ -346,7 +348,7 @@ export abstract class Table {
* if the query is a string and no embedding function is defined, it will be treated as a full text search query
*/
abstract search(
query: string | IntoVector | FullTextQuery,
query: string | IntoVector | MultiVector | FullTextQuery,
queryType?: string,
ftsColumns?: string | string[],
): VectorQuery | Query;
@@ -357,7 +359,7 @@ export abstract class Table {
* is the same thing as calling `nearestTo` on the builder returned
* by `query`. @see {@link Query#nearestTo} for more details.
*/
abstract vectorSearch(vector: IntoVector): VectorQuery;
abstract vectorSearch(vector: IntoVector | MultiVector): VectorQuery;
/**
* Add new columns with defined values.
* @param {AddColumnsSql[]} newColumnTransforms pairs of column names and
@@ -668,7 +670,7 @@ export class LocalTable extends Table {
}
search(
query: string | IntoVector | FullTextQuery,
query: string | IntoVector | MultiVector | FullTextQuery,
queryType: string = "auto",
ftsColumns?: string | string[],
): VectorQuery | Query {
@@ -715,7 +717,15 @@ export class LocalTable extends Table {
return this.query().nearestTo(queryPromise);
}
vectorSearch(vector: IntoVector): VectorQuery {
vectorSearch(vector: IntoVector | MultiVector): VectorQuery {
if (isMultiVector(vector)) {
const query = this.query().nearestTo(vector[0]);
for (const v of vector.slice(1)) {
query.addQueryVector(v);
}
return query;
}
return this.query().nearestTo(vector);
}