fix(nodejs): better support for f16 and f64 (#1343)

closes https://github.com/lancedb/lancedb/issues/1292
closes https://github.com/lancedb/lancedb/issues/1293
This commit is contained in:
Cory Grinstead
2024-06-04 13:41:21 -05:00
committed by GitHub
parent 56b4fd2bd9
commit d9fb6457e1
6 changed files with 393 additions and 178 deletions

View File

@@ -31,7 +31,7 @@ import {
Schema,
Struct,
Utf8,
type Vector,
Vector,
makeBuilder,
makeData,
type makeTable,
@@ -182,6 +182,7 @@ export class MakeArrowTableOptions {
vector: new VectorColumnOptions(),
};
embeddings?: EmbeddingFunction<unknown>;
embeddingFunction?: EmbeddingFunctionConfig;
/**
* If true then string columns will be encoded with dictionary encoding
@@ -306,7 +307,11 @@ export function makeArrowTable(
const opt = new MakeArrowTableOptions(options !== undefined ? options : {});
if (opt.schema !== undefined && opt.schema !== null) {
opt.schema = sanitizeSchema(opt.schema);
opt.schema = validateSchemaEmbeddings(opt.schema, data, opt.embeddings);
opt.schema = validateSchemaEmbeddings(
opt.schema,
data,
options?.embeddingFunction,
);
}
const columns: Record<string, Vector> = {};
// TODO: sample dataset to find missing columns
@@ -545,7 +550,6 @@ async function applyEmbeddingsFromMetadata(
dtype,
);
}
const vector = makeVector(vectors, destType);
columns[destColumn] = vector;
}
@@ -835,7 +839,7 @@ export function createEmptyTable(schema: Schema): ArrowTable {
function validateSchemaEmbeddings(
schema: Schema,
data: Array<Record<string, unknown>>,
embeddings: EmbeddingFunction<unknown> | undefined,
embeddings: EmbeddingFunctionConfig | undefined,
) {
const fields = [];
const missingEmbeddingFields = [];

View File

@@ -100,33 +100,55 @@ export abstract class EmbeddingFunction<
* @see {@link lancedb.LanceSchema}
*/
vectorField(
options?: Partial<FieldOptions>,
optionsOrDatatype?: Partial<FieldOptions> | DataType,
): [DataType, Map<string, EmbeddingFunction>] {
let dtype: DataType;
const dims = this.ndims() ?? options?.dims;
if (!options?.datatype) {
if (dims === undefined) {
throw new Error("ndims is required for vector field");
}
dtype = new FixedSizeList(dims, new Field("item", new Float32(), true));
let dtype: DataType | undefined;
let vectorType: DataType;
let dims: number | undefined = this.ndims();
// `func.vectorField(new Float32())`
if (isDataType(optionsOrDatatype)) {
dtype = optionsOrDatatype;
} else {
if (isFixedSizeList(options.datatype)) {
dtype = options.datatype;
} else if (isFloat(options.datatype)) {
// `func.vectorField({
// datatype: new Float32(),
// dims: 10
// })`
dims = dims ?? optionsOrDatatype?.dims;
dtype = optionsOrDatatype?.datatype;
}
if (dtype !== undefined) {
// `func.vectorField(new FixedSizeList(dims, new Field("item", new Float32(), true)))`
// or `func.vectorField({datatype: new FixedSizeList(dims, new Field("item", new Float32(), true))})`
if (isFixedSizeList(dtype)) {
vectorType = dtype;
// `func.vectorField(new Float32())`
// or `func.vectorField({datatype: new Float32()})`
} else if (isFloat(dtype)) {
// No `ndims` impl and no `{dims: n}` provided;
if (dims === undefined) {
throw new Error("ndims is required for vector field");
}
dtype = newVectorType(dims, options.datatype);
vectorType = newVectorType(dims, dtype);
} else {
throw new Error(
"Expected FixedSizeList or Float as datatype for vector field",
);
}
} else {
if (dims === undefined) {
throw new Error("ndims is required for vector field");
}
vectorType = new FixedSizeList(
dims,
new Field("item", new Float32(), true),
);
}
const metadata = new Map<string, EmbeddingFunction>();
metadata.set("vector_column_for", this);
return [dtype, metadata];
return [vectorType, metadata];
}
/** The number of dimensions of the embeddings */

View File

@@ -168,10 +168,10 @@ export class QueryBase<
}
/** Collect the results as an array of objects. */
async toArray(): Promise<unknown[]> {
// biome-ignore lint/suspicious/noExplicitAny: arrow.toArrow() returns any[]
async toArray(): Promise<any[]> {
const tbl = await this.toArrow();
// eslint-disable-next-line @typescript-eslint/no-unsafe-return
return tbl.toArray();
}
}

View File

@@ -135,6 +135,7 @@ export class Table {
const buffer = await fromDataToBuffer(
data,
functions.values().next().value,
schema,
);
await this.inner.add(buffer, mode);
}