diff --git a/nodejs/__test__/registry.test.ts b/nodejs/__test__/registry.test.ts index 1c57dc4a..a31d15e5 100644 --- a/nodejs/__test__/registry.test.ts +++ b/nodejs/__test__/registry.test.ts @@ -1,3 +1,4 @@ +import * as apiArrow from "apache-arrow"; // Copyright 2024 Lance Developers. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -69,7 +70,7 @@ describe.each([arrow13, arrow14, arrow15, arrow16, arrow17])( return 3; } embeddingDataType() { - return new arrow.Float32(); + return new arrow.Float32() as apiArrow.Float; } async computeSourceEmbeddings(data: string[]) { return data.map(() => [1, 2, 3]); @@ -82,7 +83,7 @@ describe.each([arrow13, arrow14, arrow15, arrow16, arrow17])( const schema = LanceSchema({ id: new arrow.Int32(), - text: func.sourceField(new arrow.Utf8()), + text: func.sourceField(new arrow.Utf8() as apiArrow.DataType), vector: func.vectorField(), }); @@ -119,7 +120,7 @@ describe.each([arrow13, arrow14, arrow15, arrow16, arrow17])( return 3; } embeddingDataType() { - return new arrow.Float32(); + return new arrow.Float32() as apiArrow.Float; } async computeSourceEmbeddings(data: string[]) { return data.map(() => [1, 2, 3]); @@ -144,7 +145,7 @@ describe.each([arrow13, arrow14, arrow15, arrow16, arrow17])( return 3; } embeddingDataType() { - return new arrow.Float32(); + return new arrow.Float32() as apiArrow.Float; } async computeSourceEmbeddings(data: string[]) { return data.map(() => [1, 2, 3]); @@ -154,7 +155,7 @@ describe.each([arrow13, arrow14, arrow15, arrow16, arrow17])( const schema = LanceSchema({ id: new arrow.Int32(), - text: func.sourceField(new arrow.Utf8()), + text: func.sourceField(new arrow.Utf8() as apiArrow.DataType), vector: func.vectorField(), }); const expectedMetadata = new Map([ diff --git a/nodejs/lancedb/arrow.ts b/nodejs/lancedb/arrow.ts index 8096ccbc..cd015ca4 100644 --- a/nodejs/lancedb/arrow.ts +++ b/nodejs/lancedb/arrow.ts @@ -103,50 +103,11 @@ export type IntoVector = | number[] | Promise; -export type FloatLike = - | import("apache-arrow-13").Float - | import("apache-arrow-14").Float - | import("apache-arrow-15").Float - | import("apache-arrow-16").Float - | import("apache-arrow-17").Float; -export type DataTypeLike = - | import("apache-arrow-13").DataType - | import("apache-arrow-14").DataType - | import("apache-arrow-15").DataType - | import("apache-arrow-16").DataType - | import("apache-arrow-17").DataType; - export function isArrowTable(value: object): value is TableLike { if (value instanceof ArrowTable) return true; return "schema" in value && "batches" in value; } -export function isDataType(value: unknown): value is DataTypeLike { - return ( - value instanceof DataType || - DataType.isNull(value) || - DataType.isInt(value) || - DataType.isFloat(value) || - DataType.isBinary(value) || - DataType.isLargeBinary(value) || - DataType.isUtf8(value) || - DataType.isLargeUtf8(value) || - DataType.isBool(value) || - DataType.isDecimal(value) || - DataType.isDate(value) || - DataType.isTime(value) || - DataType.isTimestamp(value) || - DataType.isInterval(value) || - DataType.isDuration(value) || - DataType.isList(value) || - DataType.isStruct(value) || - DataType.isUnion(value) || - DataType.isFixedSizeBinary(value) || - DataType.isFixedSizeList(value) || - DataType.isMap(value) || - DataType.isDictionary(value) - ); -} export function isNull(value: unknown): value is Null { return value instanceof Null || DataType.isNull(value); } diff --git a/nodejs/lancedb/embedding/embedding_function.ts b/nodejs/lancedb/embedding/embedding_function.ts index 11e6d153..c59f6094 100644 --- a/nodejs/lancedb/embedding/embedding_function.ts +++ b/nodejs/lancedb/embedding/embedding_function.ts @@ -15,14 +15,12 @@ import "reflect-metadata"; import { DataType, - DataTypeLike, Field, FixedSizeList, + Float, Float32, - FloatLike, type IntoVector, Utf8, - isDataType, isFixedSizeList, isFloat, newVectorType, @@ -94,11 +92,12 @@ export abstract class EmbeddingFunction< * @see {@link lancedb.LanceSchema} */ sourceField( - optionsOrDatatype: Partial | DataTypeLike, - ): [DataTypeLike, Map] { - let datatype = isDataType(optionsOrDatatype) - ? optionsOrDatatype - : optionsOrDatatype?.datatype; + optionsOrDatatype: Partial | DataType, + ): [DataType, Map] { + let datatype = + "datatype" in optionsOrDatatype + ? optionsOrDatatype.datatype + : optionsOrDatatype; if (!datatype) { throw new Error("Datatype is required"); } @@ -124,15 +123,17 @@ export abstract class EmbeddingFunction< let dims: number | undefined = this.ndims(); // `func.vectorField(new Float32())` - if (isDataType(optionsOrDatatype)) { - dtype = optionsOrDatatype; + if (optionsOrDatatype === undefined) { + dtype = new Float32(); + } else if (!("datatype" in optionsOrDatatype)) { + dtype = sanitizeType(optionsOrDatatype); } else { // `func.vectorField({ // datatype: new Float32(), // dims: 10 // })` dims = dims ?? optionsOrDatatype?.dims; - dtype = optionsOrDatatype?.datatype; + dtype = sanitizeType(optionsOrDatatype?.datatype); } if (dtype !== undefined) { @@ -174,7 +175,7 @@ export abstract class EmbeddingFunction< } /** The datatype of the embeddings */ - abstract embeddingDataType(): FloatLike; + abstract embeddingDataType(): Float; /** * Creates a vector representation for the given values. @@ -210,11 +211,11 @@ export abstract class TextEmbeddingFunction< return this.generateEmbeddings([data]).then((data) => data[0]); } - embeddingDataType(): FloatLike { + embeddingDataType(): Float { return new Float32(); } - override sourceField(): [DataTypeLike, Map] { + override sourceField(): [DataType, Map] { return super.sourceField(new Utf8()); } diff --git a/nodejs/lancedb/embedding/index.ts b/nodejs/lancedb/embedding/index.ts index 8045b0af..cf9090e6 100644 --- a/nodejs/lancedb/embedding/index.ts +++ b/nodejs/lancedb/embedding/index.ts @@ -13,7 +13,6 @@ // limitations under the License. import { Field, Schema } from "../arrow"; -import { isDataType } from "../arrow"; import { sanitizeType } from "../sanitize"; import { EmbeddingFunction } from "./embedding_function"; import { EmbeddingFunctionConfig, getRegistry } from "./registry"; @@ -57,15 +56,15 @@ export function LanceSchema( Partial >(); Object.entries(fields).forEach(([key, value]) => { - if (isDataType(value)) { - arrowFields.push(new Field(key, sanitizeType(value), true)); - } else { + if (Array.isArray(value)) { const [dtype, metadata] = value as [ object, Map, ]; arrowFields.push(new Field(key, sanitizeType(dtype), true)); parseEmbeddingFunctions(embeddingFunctions, key, metadata); + } else { + arrowFields.push(new Field(key, sanitizeType(value), true)); } }); const registry = getRegistry(); diff --git a/nodejs/lancedb/embedding/openai.ts b/nodejs/lancedb/embedding/openai.ts index f5144d00..813a9930 100644 --- a/nodejs/lancedb/embedding/openai.ts +++ b/nodejs/lancedb/embedding/openai.ts @@ -13,7 +13,7 @@ // limitations under the License. import type OpenAI from "openai"; -import { type EmbeddingCreateParams } from "openai/resources"; +import type { EmbeddingCreateParams } from "openai/resources/index"; import { Float, Float32 } from "../arrow"; import { EmbeddingFunction } from "./embedding_function"; import { register } from "./registry"; diff --git a/nodejs/lancedb/sanitize.ts b/nodejs/lancedb/sanitize.ts index 35c08c8d..50298b8a 100644 --- a/nodejs/lancedb/sanitize.ts +++ b/nodejs/lancedb/sanitize.ts @@ -340,8 +340,14 @@ export function sanitizeType(typeLike: unknown): DataType { if (typeof typeLike !== "object" || typeLike === null) { throw Error("Expected a Type but object was null/undefined"); } - if (!("typeId" in typeLike) || !(typeof typeLike.typeId !== "function")) { - throw Error("Expected a Type to have a typeId function"); + if ( + !("typeId" in typeLike) || + !( + typeof typeLike.typeId !== "function" || + typeof typeLike.typeId !== "number" + ) + ) { + throw Error("Expected a Type to have a typeId property"); } let typeId: Type; if (typeof typeLike.typeId === "function") {