diff --git a/nodejs/__test__/sanitize.test.ts b/nodejs/__test__/sanitize.test.ts new file mode 100644 index 00000000..022b845e --- /dev/null +++ b/nodejs/__test__/sanitize.test.ts @@ -0,0 +1,184 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The LanceDB Authors + +import * as arrow from "../lancedb/arrow"; +import { sanitizeField, sanitizeType } from "../lancedb/sanitize"; + +describe("sanitize", function () { + describe("sanitizeType function", function () { + it("should handle type objects", function () { + const type = new arrow.Int32(); + const result = sanitizeType(type); + + expect(result.typeId).toBe(arrow.Type.Int); + expect((result as arrow.Int).bitWidth).toBe(32); + expect((result as arrow.Int).isSigned).toBe(true); + + const floatType = { + typeId: 3, // Type.Float = 3 + precision: 2, + toString: () => "Float", + isFloat: true, + isFixedWidth: true, + }; + + const floatResult = sanitizeType(floatType); + expect(floatResult).toBeInstanceOf(arrow.DataType); + expect(floatResult.typeId).toBe(arrow.Type.Float); + + const floatResult2 = sanitizeType({ ...floatType, typeId: () => 3 }); + expect(floatResult2).toBeInstanceOf(arrow.DataType); + expect(floatResult2.typeId).toBe(arrow.Type.Float); + }); + + const allTypeNameTestCases = [ + ["null", new arrow.Null()], + ["binary", new arrow.Binary()], + ["utf8", new arrow.Utf8()], + ["bool", new arrow.Bool()], + ["int8", new arrow.Int8()], + ["int16", new arrow.Int16()], + ["int32", new arrow.Int32()], + ["int64", new arrow.Int64()], + ["uint8", new arrow.Uint8()], + ["uint16", new arrow.Uint16()], + ["uint32", new arrow.Uint32()], + ["uint64", new arrow.Uint64()], + ["float16", new arrow.Float16()], + ["float32", new arrow.Float32()], + ["float64", new arrow.Float64()], + ["datemillisecond", new arrow.DateMillisecond()], + ["dateday", new arrow.DateDay()], + ["timenanosecond", new arrow.TimeNanosecond()], + ["timemicrosecond", new arrow.TimeMicrosecond()], + ["timemillisecond", new arrow.TimeMillisecond()], + ["timesecond", new arrow.TimeSecond()], + ["intervaldaytime", new arrow.IntervalDayTime()], + ["intervalyearmonth", new arrow.IntervalYearMonth()], + ["durationnanosecond", new arrow.DurationNanosecond()], + ["durationmicrosecond", new arrow.DurationMicrosecond()], + ["durationmillisecond", new arrow.DurationMillisecond()], + ["durationsecond", new arrow.DurationSecond()], + ] as const; + + it.each(allTypeNameTestCases)( + 'should map type name "%s" to %s', + function (name, expected) { + const result = sanitizeType(name); + expect(result).toBeInstanceOf(expected.constructor); + }, + ); + + const caseVariationTestCases = [ + ["NULL", new arrow.Null()], + ["Utf8", new arrow.Utf8()], + ["FLOAT32", new arrow.Float32()], + ["DaTedAy", new arrow.DateDay()], + ] as const; + + it.each(caseVariationTestCases)( + 'should be case insensitive for type name "%s" mapped to %s', + function (name, expected) { + const result = sanitizeType(name); + expect(result).toBeInstanceOf(expected.constructor); + }, + ); + + it("should throw error for unrecognized type name", function () { + expect(() => sanitizeType("invalid_type")).toThrow( + "Unrecognized type name in schema: invalid_type", + ); + }); + }); + + describe("sanitizeField function", function () { + it("should handle field with string type name", function () { + const field = sanitizeField({ + name: "string_field", + type: "utf8", + nullable: true, + metadata: new Map([["key", "value"]]), + }); + + expect(field).toBeInstanceOf(arrow.Field); + expect(field.name).toBe("string_field"); + expect(field.type).toBeInstanceOf(arrow.Utf8); + expect(field.nullable).toBe(true); + expect(field.metadata?.get("key")).toBe("value"); + }); + + it("should handle field with type object", function () { + const floatType = { + typeId: 3, // Float + precision: 32, + }; + + const field = sanitizeField({ + name: "float_field", + type: floatType, + nullable: false, + }); + + expect(field).toBeInstanceOf(arrow.Field); + expect(field.name).toBe("float_field"); + expect(field.type).toBeInstanceOf(arrow.DataType); + expect(field.type.typeId).toBe(arrow.Type.Float); + expect((field.type as arrow.Float64).precision).toBe(32); + expect(field.nullable).toBe(false); + }); + + it("should handle field with direct Type instance", function () { + const field = sanitizeField({ + name: "bool_field", + type: new arrow.Bool(), + nullable: true, + }); + + expect(field).toBeInstanceOf(arrow.Field); + expect(field.name).toBe("bool_field"); + expect(field.type).toBeInstanceOf(arrow.Bool); + expect(field.nullable).toBe(true); + }); + + it("should throw error for invalid field object", function () { + expect(() => + sanitizeField({ + type: "int32", + nullable: true, + }), + ).toThrow( + "The field passed in is missing a `type`/`name`/`nullable` property", + ); + + // Invalid type + expect(() => + sanitizeField({ + name: "invalid", + type: { invalid: true }, + nullable: true, + }), + ).toThrow("Expected a Type to have a typeId property"); + + // Invalid nullable + expect(() => + sanitizeField({ + name: "invalid_nullable", + type: "int32", + nullable: "not a boolean", + }), + ).toThrow("The field passed in had a non-boolean `nullable` property"); + }); + + it("should report error for invalid type name", function () { + expect(() => + sanitizeField({ + name: "invalid_field", + type: "invalid_type", + nullable: true, + }), + ).toThrow( + "Unable to sanitize type for field: invalid_field due to error: Error: Unrecognized type name in schema: invalid_type", + ); + }); + }); +}); diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 145e7a10..540f74b3 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -10,7 +10,13 @@ import * as arrow16 from "apache-arrow-16"; import * as arrow17 from "apache-arrow-17"; import * as arrow18 from "apache-arrow-18"; -import { MatchQuery, PhraseQuery, Table, connect } from "../lancedb"; +import { + Connection, + MatchQuery, + PhraseQuery, + Table, + connect, +} from "../lancedb"; import { Table as ArrowTable, Field, @@ -21,6 +27,8 @@ import { Int64, List, Schema, + SchemaLike, + Type, Uint8, Utf8, makeArrowTable, @@ -2019,3 +2027,52 @@ describe("column name options", () => { expect(results2.length).toBe(10); }); }); + +describe("when creating an empty table", () => { + let con: Connection; + beforeEach(async () => { + const tmpDir = tmp.dirSync({ unsafeCleanup: true }); + con = await connect(tmpDir.name); + }); + afterEach(() => { + con.close(); + }); + + it("can create an empty table from an arrow Schema", async () => { + const schema = new Schema([ + new Field("id", new Int64()), + new Field("vector", new Float64()), + ]); + const table = await con.createEmptyTable("test", schema); + const actualSchema = await table.schema(); + expect(actualSchema.fields[0].type.typeId).toBe(Type.Int); + expect((actualSchema.fields[0].type as Int64).bitWidth).toBe(64); + expect(actualSchema.fields[1].type.typeId).toBe(Type.Float); + expect((actualSchema.fields[1].type as Float64).precision).toBe(2); + }); + + it("can create an empty table from schema that specifies field types by name", async () => { + const schemaLike = { + fields: [ + { + name: "id", + type: "int64", + nullable: true, + }, + { + name: "vector", + type: "float64", + nullable: true, + }, + ], + metadata: new Map(), + names: ["id", "vector"], + } satisfies SchemaLike; + const table = await con.createEmptyTable("test", schemaLike); + const actualSchema = await table.schema(); + expect(actualSchema.fields[0].type.typeId).toBe(Type.Int); + expect((actualSchema.fields[0].type as Int64).bitWidth).toBe(64); + expect(actualSchema.fields[1].type.typeId).toBe(Type.Float); + expect((actualSchema.fields[1].type as Float64).precision).toBe(2); + }); +}); diff --git a/nodejs/lancedb/arrow.ts b/nodejs/lancedb/arrow.ts index 90cefc0f..0d1a2536 100644 --- a/nodejs/lancedb/arrow.ts +++ b/nodejs/lancedb/arrow.ts @@ -73,7 +73,7 @@ export type FieldLike = | { type: string; name: string; - nullable?: boolean; + nullable: boolean; metadata?: Map; }; diff --git a/nodejs/lancedb/sanitize.ts b/nodejs/lancedb/sanitize.ts index 0ea33831..6108ff33 100644 --- a/nodejs/lancedb/sanitize.ts +++ b/nodejs/lancedb/sanitize.ts @@ -326,6 +326,9 @@ export function sanitizeDictionary(typeLike: object) { // biome-ignore lint/suspicious/noExplicitAny: skip export function sanitizeType(typeLike: unknown): DataType { + if (typeof typeLike === "string") { + return dataTypeFromName(typeLike); + } if (typeof typeLike !== "object" || typeLike === null) { throw Error("Expected a Type but object was null/undefined"); } @@ -447,7 +450,7 @@ export function sanitizeType(typeLike: unknown): DataType { case Type.DurationSecond: return new DurationSecond(); default: - throw new Error("Unrecoginized type id in schema: " + typeId); + throw new Error("Unrecognized type id in schema: " + typeId); } } @@ -467,7 +470,15 @@ export function sanitizeField(fieldLike: unknown): Field { "The field passed in is missing a `type`/`name`/`nullable` property", ); } - const type = sanitizeType(fieldLike.type); + let type: DataType; + try { + type = sanitizeType(fieldLike.type); + } catch (error: unknown) { + throw Error( + `Unable to sanitize type for field: ${fieldLike.name} due to error: ${error}`, + { cause: error }, + ); + } const name = fieldLike.name; if (!(typeof name === "string")) { throw Error("The field passed in had a non-string `name` property"); @@ -581,3 +592,46 @@ function sanitizeData( }, ); } + +const constructorsByTypeName = { + null: () => new Null(), + binary: () => new Binary(), + utf8: () => new Utf8(), + bool: () => new Bool(), + int8: () => new Int8(), + int16: () => new Int16(), + int32: () => new Int32(), + int64: () => new Int64(), + uint8: () => new Uint8(), + uint16: () => new Uint16(), + uint32: () => new Uint32(), + uint64: () => new Uint64(), + float16: () => new Float16(), + float32: () => new Float32(), + float64: () => new Float64(), + datemillisecond: () => new DateMillisecond(), + dateday: () => new DateDay(), + timenanosecond: () => new TimeNanosecond(), + timemicrosecond: () => new TimeMicrosecond(), + timemillisecond: () => new TimeMillisecond(), + timesecond: () => new TimeSecond(), + intervaldaytime: () => new IntervalDayTime(), + intervalyearmonth: () => new IntervalYearMonth(), + durationnanosecond: () => new DurationNanosecond(), + durationmicrosecond: () => new DurationMicrosecond(), + durationmillisecond: () => new DurationMillisecond(), + durationsecond: () => new DurationSecond(), +} as const; + +type MappableTypeName = keyof typeof constructorsByTypeName; + +export function dataTypeFromName(typeName: string): DataType { + const normalizedTypeName = typeName.toLowerCase() as MappableTypeName; + const _constructor = constructorsByTypeName[normalizedTypeName]; + + if (!_constructor) { + throw new Error("Unrecognized type name in schema: " + typeName); + } + + return _constructor(); +}