diff --git a/nodejs/__test__/registry.test.ts b/nodejs/__test__/registry.test.ts index 0ab97d9e..e87a38e6 100644 --- a/nodejs/__test__/registry.test.ts +++ b/nodejs/__test__/registry.test.ts @@ -11,18 +11,21 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -import { Float, Float32, Int32, Utf8, Vector } from "apache-arrow"; +import * as arrow from "apache-arrow"; +import * as arrowOld from "apache-arrow-old"; + import * as tmp from "tmp"; + import { connect } from "../lancedb"; import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding"; import { getRegistry, register } from "../lancedb/embedding/registry"; -describe("LanceSchema", () => { +describe.each([arrow, arrowOld])("LanceSchema", (arrow) => { test("should preserve input order", async () => { const schema = LanceSchema({ - id: new Int32(), - text: new Utf8(), - vector: new Float32(), + id: new arrow.Int32(), + text: new arrow.Utf8(), + vector: new arrow.Float32(), }); expect(schema.fields.map((x) => x.name)).toEqual(["id", "text", "vector"]); }); @@ -53,8 +56,8 @@ describe("Registry", () => { ndims() { return 3; } - embeddingDataType(): Float { - return new Float32(); + embeddingDataType(): arrow.Float { + return new arrow.Float32(); } async computeSourceEmbeddings(data: string[]) { return data.map(() => [1, 2, 3]); @@ -65,8 +68,8 @@ describe("Registry", () => { .create(); const schema = LanceSchema({ - id: new Int32(), - text: func.sourceField(new Utf8()), + id: new arrow.Int32(), + text: func.sourceField(new arrow.Utf8()), vector: func.vectorField(), }); @@ -88,7 +91,7 @@ describe("Registry", () => { .getChild("vector") ?.toArray() .map((x: unknown) => { - if (x instanceof Vector) { + if (x instanceof arrow.Vector) { return [...x]; } else { return x; @@ -109,8 +112,8 @@ describe("Registry", () => { ndims() { return 3; } - embeddingDataType(): Float { - return new Float32(); + embeddingDataType(): arrow.Float { + return new arrow.Float32(); } async computeSourceEmbeddings(data: string[]) { return data.map(() => [1, 2, 3]); @@ -134,8 +137,8 @@ describe("Registry", () => { ndims() { return 3; } - embeddingDataType(): Float { - return new Float32(); + embeddingDataType(): arrow.Float { + return new arrow.Float32(); } async computeSourceEmbeddings(data: string[]) { return data.map(() => [1, 2, 3]); @@ -144,8 +147,8 @@ describe("Registry", () => { const func = new MockEmbeddingFunction(); const schema = LanceSchema({ - id: new Int32(), - text: func.sourceField(new Utf8()), + id: new arrow.Int32(), + text: func.sourceField(new arrow.Utf8()), vector: func.vectorField(), }); const expectedMetadata = new Map([ diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 695f1a9d..9d191d7d 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -16,6 +16,10 @@ import * as fs from "fs"; import * as path from "path"; import * as tmp from "tmp"; +import * as arrow from "apache-arrow"; +import * as arrowOld from "apache-arrow-old"; + +import { Table, connect } from "../lancedb"; import { Field, FixedSizeList, @@ -26,17 +30,20 @@ import { Int64, Schema, Utf8, -} from "apache-arrow"; -import { Table, connect } from "../lancedb"; -import { makeArrowTable } from "../lancedb/arrow"; + makeArrowTable, +} from "../lancedb/arrow"; import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding"; import { getRegistry, register } from "../lancedb/embedding/registry"; import { Index } from "../lancedb/indices"; -describe("Given a table", () => { +// biome-ignore lint/suspicious/noExplicitAny: +describe.each([arrow, arrowOld])("Given a table", (arrow: any) => { let tmpDir: tmp.DirResult; let table: Table; - const schema = new Schema([new Field("id", new Float64(), true)]); + + const schema = new arrow.Schema([ + new arrow.Field("id", new arrow.Float64(), true), + ]); beforeEach(async () => { tmpDir = tmp.dirSync({ unsafeCleanup: true }); const conn = await connect(tmpDir.name); @@ -551,7 +558,7 @@ describe("embedding functions", () => { const func = getRegistry().get("mock")!.create(); const schema = LanceSchema({ - id: new Float64(), + id: new arrow.Float64(), text: func.sourceField(new Utf8()), vector: func.vectorField(), }); diff --git a/nodejs/lancedb/arrow.ts b/nodejs/lancedb/arrow.ts index 7781bd11..57175ca5 100644 --- a/nodejs/lancedb/arrow.ts +++ b/nodejs/lancedb/arrow.ts @@ -17,10 +17,14 @@ import { Binary, DataType, Field, + FixedSizeBinary, FixedSizeList, - type Float, + Float, Float32, + Int, + LargeBinary, List, + Null, RecordBatch, RecordBatchFileWriter, RecordBatchStreamWriter, @@ -35,7 +39,98 @@ import { } from "apache-arrow"; import { type EmbeddingFunction } from "./embedding/embedding_function"; import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry"; -import { sanitizeSchema } from "./sanitize"; +import { sanitizeField, sanitizeSchema, sanitizeType } from "./sanitize"; +export * from "apache-arrow"; + +export function isArrowTable(value: object): value is ArrowTable { + if (value instanceof ArrowTable) return true; + return "schema" in value && "batches" in value; +} + +export function isDataType(value: unknown): value is DataType { + return ( + value instanceof DataType || + DataType.isNull(value) || + DataType.isInt(value) || + DataType.isFloat(value) || + DataType.isBinary(value) || + DataType.isLargeBinary(value) || + DataType.isUtf8(value) || + DataType.isLargeUtf8(value) || + DataType.isBool(value) || + DataType.isDecimal(value) || + DataType.isDate(value) || + DataType.isTime(value) || + DataType.isTimestamp(value) || + DataType.isInterval(value) || + DataType.isDuration(value) || + DataType.isList(value) || + DataType.isStruct(value) || + DataType.isUnion(value) || + DataType.isFixedSizeBinary(value) || + DataType.isFixedSizeList(value) || + DataType.isMap(value) || + DataType.isDictionary(value) + ); +} +export function isNull(value: unknown): value is Null { + return value instanceof Null || DataType.isNull(value); +} +export function isInt(value: unknown): value is Int { + return value instanceof Int || DataType.isInt(value); +} +export function isFloat(value: unknown): value is Float { + return value instanceof Float || DataType.isFloat(value); +} +export function isBinary(value: unknown): value is Binary { + return value instanceof Binary || DataType.isBinary(value); +} +export function isLargeBinary(value: unknown): value is LargeBinary { + return value instanceof LargeBinary || DataType.isLargeBinary(value); +} +export function isUtf8(value: unknown): value is Utf8 { + return value instanceof Utf8 || DataType.isUtf8(value); +} +export function isLargeUtf8(value: unknown): value is Utf8 { + return value instanceof Utf8 || DataType.isLargeUtf8(value); +} +export function isBool(value: unknown): value is Utf8 { + return value instanceof Utf8 || DataType.isBool(value); +} +export function isDecimal(value: unknown): value is Utf8 { + return value instanceof Utf8 || DataType.isDecimal(value); +} +export function isDate(value: unknown): value is Utf8 { + return value instanceof Utf8 || DataType.isDate(value); +} +export function isTime(value: unknown): value is Utf8 { + return value instanceof Utf8 || DataType.isTime(value); +} +export function isTimestamp(value: unknown): value is Utf8 { + return value instanceof Utf8 || DataType.isTimestamp(value); +} +export function isInterval(value: unknown): value is Utf8 { + return value instanceof Utf8 || DataType.isInterval(value); +} +export function isDuration(value: unknown): value is Utf8 { + return value instanceof Utf8 || DataType.isDuration(value); +} +export function isList(value: unknown): value is List { + return value instanceof List || DataType.isList(value); +} +export function isStruct(value: unknown): value is Struct { + return value instanceof Struct || DataType.isStruct(value); +} +export function isUnion(value: unknown): value is Struct { + return value instanceof Struct || DataType.isUnion(value); +} +export function isFixedSizeBinary(value: unknown): value is FixedSizeBinary { + return value instanceof FixedSizeBinary || DataType.isFixedSizeBinary(value); +} + +export function isFixedSizeList(value: unknown): value is FixedSizeList { + return value instanceof FixedSizeList || DataType.isFixedSizeList(value); +} /** Data type accepted by NodeJS SDK */ export type Data = Record[] | ArrowTable; @@ -442,8 +537,8 @@ async function applyEmbeddingsFromMetadata( } let destType: DataType; const dtype = schema.fields.find((f) => f.name === destColumn)!.type; - if (dtype instanceof FixedSizeList) { - destType = dtype; + if (isFixedSizeList(dtype)) { + destType = sanitizeType(dtype); } else { throw new Error( "Expected FixedSizeList as datatype for vector field, instead got: " + @@ -588,7 +683,7 @@ export function newVectorType( ): FixedSizeList { // in Lance we always default to have the elements nullable, so we need to set it to true // otherwise we often get schema mismatches because the stored data always has schema with nullable elements - const children = new Field("item", innerType, true); + const children = new Field("item", sanitizeType(innerType), true); return new FixedSizeList(dim, children); } @@ -669,7 +764,7 @@ export async function fromDataToBuffer( if (schema !== undefined && schema !== null) { schema = sanitizeSchema(schema); } - if (data instanceof ArrowTable) { + if (isArrowTable(data)) { return fromTableToBuffer(data, embeddings, schema); } else { const table = await convertToTable(data, embeddings, { schema }); @@ -750,8 +845,10 @@ function validateSchemaEmbeddings( // if it does not, we add it to the list of missing embedding fields // Finally, we check if those missing embedding fields are `this._embeddings` // if they are not, we throw an error - for (const field of schema.fields) { - if (field.type instanceof FixedSizeList) { + for (let field of schema.fields) { + if (isFixedSizeList(field.type)) { + field = sanitizeField(field); + if (data.length !== 0 && data?.[0]?.[field.name] === undefined) { if (schema.metadata.has("embedding_functions")) { const embeddings = JSON.parse( diff --git a/nodejs/lancedb/connection.ts b/nodejs/lancedb/connection.ts index 6bee8dc6..8948d869 100644 --- a/nodejs/lancedb/connection.ts +++ b/nodejs/lancedb/connection.ts @@ -12,8 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -import { Table as ArrowTable, Schema } from "apache-arrow"; -import { fromTableToBuffer, makeArrowTable, makeEmptyTable } from "./arrow"; +import { Table as ArrowTable, Schema } from "./arrow"; +import { + fromTableToBuffer, + isArrowTable, + makeArrowTable, + makeEmptyTable, +} from "./arrow"; import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry"; import { ConnectionOptions, Connection as LanceDbConnection } from "./native"; import { Table } from "./table"; @@ -200,7 +205,7 @@ export class Connection { } let table: ArrowTable; - if (data instanceof ArrowTable) { + if (isArrowTable(data)) { table = data; } else { table = makeArrowTable(data, options); diff --git a/nodejs/lancedb/embedding/embedding_function.ts b/nodejs/lancedb/embedding/embedding_function.ts index 360bea9f..4342139d 100644 --- a/nodejs/lancedb/embedding/embedding_function.ts +++ b/nodejs/lancedb/embedding/embedding_function.ts @@ -12,9 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -import { DataType, Field, FixedSizeList, Float, Float32 } from "apache-arrow"; import "reflect-metadata"; -import { newVectorType } from "../arrow"; +import { + DataType, + Field, + FixedSizeList, + Float, + Float32, + isDataType, + isFixedSizeList, + isFloat, + newVectorType, +} from "../arrow"; +import { sanitizeType } from "../sanitize"; /** * Options for a given embedding function @@ -69,13 +79,13 @@ export abstract class EmbeddingFunction< sourceField( optionsOrDatatype: Partial | DataType, ): [DataType, Map] { - const datatype = - optionsOrDatatype instanceof DataType - ? optionsOrDatatype - : optionsOrDatatype?.datatype; + let datatype = isDataType(optionsOrDatatype) + ? optionsOrDatatype + : optionsOrDatatype?.datatype; if (!datatype) { throw new Error("Datatype is required"); } + datatype = sanitizeType(datatype); const metadata = new Map(); metadata.set("source_column_for", this); @@ -100,9 +110,9 @@ export abstract class EmbeddingFunction< } dtype = new FixedSizeList(dims, new Field("item", new Float32(), true)); } else { - if (options.datatype instanceof FixedSizeList) { + if (isFixedSizeList(options.datatype)) { dtype = options.datatype; - } else if (options.datatype instanceof Float) { + } else if (isFloat(options.datatype)) { if (dims === undefined) { throw new Error("ndims is required for vector field"); } diff --git a/nodejs/lancedb/embedding/index.ts b/nodejs/lancedb/embedding/index.ts index 4cce94cd..095c213b 100644 --- a/nodejs/lancedb/embedding/index.ts +++ b/nodejs/lancedb/embedding/index.ts @@ -12,12 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -import { DataType, Field, Schema } from "apache-arrow"; +import { DataType, Field, Schema } from "../arrow"; +import { isDataType } from "../arrow"; +import { sanitizeType } from "../sanitize"; import { EmbeddingFunction } from "./embedding_function"; import { EmbeddingFunctionConfig, getRegistry } from "./registry"; export { EmbeddingFunction } from "./embedding_function"; + +// We need to explicitly export '*' so that the `register` decorator actually registers the class. export * from "./openai"; +export * from "./registry"; /** * Create a schema with embedding functions. @@ -42,7 +47,7 @@ export * from "./openai"; * ``` */ export function LanceSchema( - fields: Record] | DataType>, + fields: Record] | object>, ): Schema { const arrowFields: Field[] = []; @@ -51,11 +56,14 @@ export function LanceSchema( Partial >(); Object.entries(fields).forEach(([key, value]) => { - if (value instanceof DataType) { - arrowFields.push(new Field(key, value, true)); + if (isDataType(value)) { + arrowFields.push(new Field(key, sanitizeType(value), true)); } else { - const [dtype, metadata] = value; - arrowFields.push(new Field(key, dtype, true)); + const [dtype, metadata] = value as [ + object, + Map, + ]; + arrowFields.push(new Field(key, sanitizeType(dtype), true)); parseEmbeddingFunctions(embeddingFunctions, key, metadata); } }); diff --git a/nodejs/lancedb/embedding/openai.ts b/nodejs/lancedb/embedding/openai.ts index cc101819..e055b175 100644 --- a/nodejs/lancedb/embedding/openai.ts +++ b/nodejs/lancedb/embedding/openai.ts @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -import { Float, Float32 } from "apache-arrow"; import type OpenAI from "openai"; +import { Float, Float32 } from "../arrow"; import { EmbeddingFunction } from "./embedding_function"; import { register } from "./registry"; diff --git a/nodejs/lancedb/query.ts b/nodejs/lancedb/query.ts index 430b3da5..d3566959 100644 --- a/nodejs/lancedb/query.ts +++ b/nodejs/lancedb/query.ts @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -import { Table as ArrowTable, RecordBatch, tableFromIPC } from "apache-arrow"; +import { Table as ArrowTable, RecordBatch, tableFromIPC } from "./arrow"; import { type IvfPqOptions } from "./indices"; import { RecordBatchIterator as NativeBatchIterator, diff --git a/nodejs/lancedb/sanitize.ts b/nodejs/lancedb/sanitize.ts index 8e127656..cebbc9e3 100644 --- a/nodejs/lancedb/sanitize.ts +++ b/nodejs/lancedb/sanitize.ts @@ -20,6 +20,7 @@ // comes from the exact same library instance. This is not always the case // and so we must sanitize the input to ensure that it is compatible. +import type { IntBitWidth, TKeys, TimeBitWidth } from "apache-arrow/type"; import { Binary, Bool, @@ -75,10 +76,9 @@ import { Uint64, Union, Utf8, -} from "apache-arrow"; -import type { IntBitWidth, TKeys, TimeBitWidth } from "apache-arrow/type"; +} from "./arrow"; -function sanitizeMetadata( +export function sanitizeMetadata( metadataLike?: unknown, ): Map | undefined { if (metadataLike === undefined || metadataLike === null) { @@ -97,7 +97,7 @@ function sanitizeMetadata( return metadataLike as Map; } -function sanitizeInt(typeLike: object) { +export function sanitizeInt(typeLike: object) { if ( !("bitWidth" in typeLike) || typeof typeLike.bitWidth !== "number" || @@ -111,14 +111,14 @@ function sanitizeInt(typeLike: object) { return new Int(typeLike.isSigned, typeLike.bitWidth as IntBitWidth); } -function sanitizeFloat(typeLike: object) { +export function sanitizeFloat(typeLike: object) { if (!("precision" in typeLike) || typeof typeLike.precision !== "number") { throw Error("Expected a Float Type to have a `precision` property"); } return new Float(typeLike.precision as Precision); } -function sanitizeDecimal(typeLike: object) { +export function sanitizeDecimal(typeLike: object) { if ( !("scale" in typeLike) || typeof typeLike.scale !== "number" || @@ -134,14 +134,14 @@ function sanitizeDecimal(typeLike: object) { return new Decimal(typeLike.scale, typeLike.precision, typeLike.bitWidth); } -function sanitizeDate(typeLike: object) { +export function sanitizeDate(typeLike: object) { if (!("unit" in typeLike) || typeof typeLike.unit !== "number") { throw Error("Expected a Date type to have a `unit` property"); } return new Date_(typeLike.unit as DateUnit); } -function sanitizeTime(typeLike: object) { +export function sanitizeTime(typeLike: object) { if ( !("unit" in typeLike) || typeof typeLike.unit !== "number" || @@ -155,7 +155,7 @@ function sanitizeTime(typeLike: object) { return new Time(typeLike.unit, typeLike.bitWidth as TimeBitWidth); } -function sanitizeTimestamp(typeLike: object) { +export function sanitizeTimestamp(typeLike: object) { if (!("unit" in typeLike) || typeof typeLike.unit !== "number") { throw Error("Expected a Timestamp type to have a `unit` property"); } @@ -166,7 +166,7 @@ function sanitizeTimestamp(typeLike: object) { return new Timestamp(typeLike.unit, timezone); } -function sanitizeTypedTimestamp( +export function sanitizeTypedTimestamp( typeLike: object, // eslint-disable-next-line @typescript-eslint/naming-convention Datatype: @@ -182,14 +182,14 @@ function sanitizeTypedTimestamp( return new Datatype(timezone); } -function sanitizeInterval(typeLike: object) { +export function sanitizeInterval(typeLike: object) { if (!("unit" in typeLike) || typeof typeLike.unit !== "number") { throw Error("Expected an Interval type to have a `unit` property"); } return new Interval(typeLike.unit); } -function sanitizeList(typeLike: object) { +export function sanitizeList(typeLike: object) { if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { throw Error( "Expected a List type to have an array-like `children` property", @@ -201,7 +201,7 @@ function sanitizeList(typeLike: object) { return new List(sanitizeField(typeLike.children[0])); } -function sanitizeStruct(typeLike: object) { +export function sanitizeStruct(typeLike: object) { if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { throw Error( "Expected a Struct type to have an array-like `children` property", @@ -210,7 +210,7 @@ function sanitizeStruct(typeLike: object) { return new Struct(typeLike.children.map((child) => sanitizeField(child))); } -function sanitizeUnion(typeLike: object) { +export function sanitizeUnion(typeLike: object) { if ( !("typeIds" in typeLike) || !("mode" in typeLike) || @@ -234,7 +234,7 @@ function sanitizeUnion(typeLike: object) { ); } -function sanitizeTypedUnion( +export function sanitizeTypedUnion( typeLike: object, // eslint-disable-next-line @typescript-eslint/naming-convention UnionType: typeof DenseUnion | typeof SparseUnion, @@ -256,7 +256,7 @@ function sanitizeTypedUnion( ); } -function sanitizeFixedSizeBinary(typeLike: object) { +export function sanitizeFixedSizeBinary(typeLike: object) { if (!("byteWidth" in typeLike) || typeof typeLike.byteWidth !== "number") { throw Error( "Expected a FixedSizeBinary type to have a `byteWidth` property", @@ -265,7 +265,7 @@ function sanitizeFixedSizeBinary(typeLike: object) { return new FixedSizeBinary(typeLike.byteWidth); } -function sanitizeFixedSizeList(typeLike: object) { +export function sanitizeFixedSizeList(typeLike: object) { if (!("listSize" in typeLike) || typeof typeLike.listSize !== "number") { throw Error("Expected a FixedSizeList type to have a `listSize` property"); } @@ -283,7 +283,7 @@ function sanitizeFixedSizeList(typeLike: object) { ); } -function sanitizeMap(typeLike: object) { +export function sanitizeMap(typeLike: object) { if (!("children" in typeLike) || !Array.isArray(typeLike.children)) { throw Error( "Expected a Map type to have an array-like `children` property", @@ -300,14 +300,14 @@ function sanitizeMap(typeLike: object) { ); } -function sanitizeDuration(typeLike: object) { +export function sanitizeDuration(typeLike: object) { if (!("unit" in typeLike) || typeof typeLike.unit !== "number") { throw Error("Expected a Duration type to have a `unit` property"); } return new Duration(typeLike.unit); } -function sanitizeDictionary(typeLike: object) { +export function sanitizeDictionary(typeLike: object) { if (!("id" in typeLike) || typeof typeLike.id !== "number") { throw Error("Expected a Dictionary type to have an `id` property"); } @@ -329,7 +329,7 @@ function sanitizeDictionary(typeLike: object) { } // biome-ignore lint/suspicious/noExplicitAny: skip -function sanitizeType(typeLike: unknown): DataType { +export function sanitizeType(typeLike: unknown): DataType { if (typeof typeLike !== "object" || typeLike === null) { throw Error("Expected a Type but object was null/undefined"); } @@ -449,7 +449,7 @@ function sanitizeType(typeLike: unknown): DataType { } } -function sanitizeField(fieldLike: unknown): Field { +export function sanitizeField(fieldLike: unknown): Field { if (fieldLike instanceof Field) { return fieldLike; } diff --git a/nodejs/lancedb/table.ts b/nodejs/lancedb/table.ts index 0516ef96..7e40669c 100644 --- a/nodejs/lancedb/table.ts +++ b/nodejs/lancedb/table.ts @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -import { Schema, tableFromIPC } from "apache-arrow"; -import { Data, fromDataToBuffer } from "./arrow"; +import { Data, Schema, fromDataToBuffer, tableFromIPC } from "./arrow"; + import { getRegistry } from "./embedding/registry"; import { IndexOptions } from "./indices"; import { diff --git a/nodejs/package-lock.json b/nodejs/package-lock.json index f7a0298e..e5bf7470 100644 --- a/nodejs/package-lock.json +++ b/nodejs/package-lock.json @@ -1,12 +1,12 @@ { "name": "@lancedb/lancedb", - "version": "0.4.20", + "version": "0.5.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@lancedb/lancedb", - "version": "0.4.20", + "version": "0.5.0", "cpu": [ "x64", "arm64" diff --git a/nodejs/package.json b/nodejs/package.json index cc9b1365..49ec53de 100644 --- a/nodejs/package.json +++ b/nodejs/package.json @@ -1,8 +1,12 @@ { "name": "@lancedb/lancedb", "version": "0.5.0", - "main": "./dist/index.js", - "types": "./dist/index.d.ts", + "main": "dist/index.js", + "exports": { + ".": "./dist/index.js", + "./embedding": "./dist/embedding/index.js" + }, + "types": "dist/index.d.ts", "napi": { "name": "lancedb", "triples": {