From 7747c9bcbf5d8564b1476c29f4b27b3adb966fd4 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Wed, 12 Mar 2025 09:57:36 -0700 Subject: [PATCH] feat(node): parse arrow types in `alterColumns()` (#2208) Previously, users could only specify new data types in `alterColumns` as strings: ```ts await tbl.alterColumns([ path: "price", dataType: "float" ]); ``` But this has some problems: 1. It wasn't clear what were valid types 2. It was impossible to specify nested types, like lists and vector columns. This PR changes it to take an Arrow data type, similar to how the Python API works. This allows casting vector types: ```ts await tbl.alterColumns([ { path: "vector", dataType: new arrow.FixedSizeList( 2, new arrow.Field("item", new arrow.Float16(), false), ), }, ]); ``` Closes #2185 --- Cargo.lock | 8 +- docs/src/guides/tables.md | 22 ++++ docs/src/js/interfaces/ColumnAlteration.md | 2 +- nodejs/__test__/table.test.ts | 88 +++++++++++++ nodejs/examples/basic.test.ts | 11 ++ nodejs/lancedb/arrow.ts | 141 +++++++++++++++++++++ nodejs/lancedb/index.ts | 2 +- nodejs/lancedb/table.ts | 61 ++++++++- python/python/tests/docs/test_basic.py | 30 +++++ rust/lancedb/src/utils.rs | 12 +- 10 files changed, 365 insertions(+), 12 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4f553056..ad29ea93 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3931,7 +3931,7 @@ dependencies = [ [[package]] name = "lancedb" -version = "0.18.0-beta.0" +version = "0.18.1" dependencies = [ "arrow", "arrow-array", @@ -4017,7 +4017,7 @@ dependencies = [ [[package]] name = "lancedb-node" -version = "0.18.0-beta.0" +version = "0.18.1" dependencies = [ "arrow-array", "arrow-ipc", @@ -4042,7 +4042,7 @@ dependencies = [ [[package]] name = "lancedb-nodejs" -version = "0.18.0-beta.0" +version = "0.18.1" dependencies = [ "arrow-array", "arrow-ipc", @@ -4060,7 +4060,7 @@ dependencies = [ [[package]] name = "lancedb-python" -version = "0.21.0-beta.1" +version = "0.21.1" dependencies = [ "arrow", "env_logger", diff --git a/docs/src/guides/tables.md b/docs/src/guides/tables.md index a202d2cc..57b51874 100644 --- a/docs/src/guides/tables.md +++ b/docs/src/guides/tables.md @@ -942,6 +942,28 @@ rewriting the column, which can be a heavy operation. ``` **API Reference:** [lancedb.Table.alterColumns](../js/classes/Table.md/#altercolumns) +You can even cast the a vector column to a different dimension: + +=== "Python" + + === "Sync API" + + ```python + --8<-- "python/python/tests/docs/test_guide_tables.py:import-pyarrow" + --8<-- "python/python/tests/docs/test_basic.py:alter_columns_vector" + ``` + === "Async API" + + ```python + --8<-- "python/python/tests/docs/test_guide_tables.py:import-pyarrow" + --8<-- "python/python/tests/docs/test_basic.py:alter_columns_async_vector" + ``` +=== "Typescript" + + ```typescript + --8<-- "nodejs/examples/basic.test.ts:alter_columns_vector" + ``` + ### Dropping columns You can drop columns from the table with the `drop_columns` method. This will diff --git a/docs/src/js/interfaces/ColumnAlteration.md b/docs/src/js/interfaces/ColumnAlteration.md index 28aeeb21..30113d3a 100644 --- a/docs/src/js/interfaces/ColumnAlteration.md +++ b/docs/src/js/interfaces/ColumnAlteration.md @@ -16,7 +16,7 @@ must be provided. ### dataType? ```ts -optional dataType: string; +optional dataType: string | DataType; ``` A new data type for the column. If not provided then the data type will not be changed. diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 1171415c..22f6b9c3 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -24,6 +24,7 @@ import { Utf8, makeArrowTable, } from "../lancedb/arrow"; +import * as arrow from "../lancedb/arrow"; import { EmbeddingFunction, LanceSchema, @@ -920,6 +921,93 @@ describe("schema evolution", function () { new Field("price", new Float64(), true), ]); expect(await table.schema()).toEqual(expectedSchema2); + + await table.alterColumns([ + { + path: "vector", + dataType: new FixedSizeList(2, new Field("item", new Float64(), true)), + }, + ]); + const expectedSchema3 = new Schema([ + new Field("new_id", new Int32(), true), + new Field( + "vector", + new FixedSizeList(2, new Field("item", new Float64(), true)), + true, + ), + new Field("price", new Float64(), true), + ]); + expect(await table.schema()).toEqual(expectedSchema3); + }); + + it("can cast to various types", async function () { + const con = await connect(tmpDir.name); + + // integers + const intTypes = [ + new arrow.Int8(), + new arrow.Int16(), + new arrow.Int32(), + new arrow.Int64(), + new arrow.Uint8(), + new arrow.Uint16(), + new arrow.Uint32(), + new arrow.Uint64(), + ]; + const tableInts = await con.createTable("ints", [{ id: 1n }], { + schema: new Schema([new Field("id", new Int64(), true)]), + }); + for (const intType of intTypes) { + await tableInts.alterColumns([{ path: "id", dataType: intType }]); + const schema = new Schema([new Field("id", intType, true)]); + expect(await tableInts.schema()).toEqual(schema); + } + + // floats + const floatTypes = [ + new arrow.Float16(), + new arrow.Float32(), + new arrow.Float64(), + ]; + const tableFloats = await con.createTable("floats", [{ val: 2.1 }], { + schema: new Schema([new Field("val", new Float32(), true)]), + }); + for (const floatType of floatTypes) { + await tableFloats.alterColumns([{ path: "val", dataType: floatType }]); + const schema = new Schema([new Field("val", floatType, true)]); + expect(await tableFloats.schema()).toEqual(schema); + } + + // Lists of floats + const listTypes = [ + new arrow.List(new arrow.Field("item", new arrow.Float32(), true)), + new arrow.FixedSizeList( + 2, + new arrow.Field("item", new arrow.Float64(), true), + ), + new arrow.FixedSizeList( + 2, + new arrow.Field("item", new arrow.Float16(), true), + ), + new arrow.FixedSizeList( + 2, + new arrow.Field("item", new arrow.Float32(), true), + ), + ]; + const tableLists = await con.createTable("lists", [{ val: [2.1, 3.2] }], { + schema: new Schema([ + new Field( + "val", + new FixedSizeList(2, new arrow.Field("item", new Float32())), + true, + ), + ]), + }); + for (const listType of listTypes) { + await tableLists.alterColumns([{ path: "val", dataType: listType }]); + const schema = new Schema([new Field("val", listType, true)]); + expect(await tableLists.schema()).toEqual(schema); + } }); it("can drop a column from the schema", async function () { diff --git a/nodejs/examples/basic.test.ts b/nodejs/examples/basic.test.ts index 14f48b5f..b56bb2f9 100644 --- a/nodejs/examples/basic.test.ts +++ b/nodejs/examples/basic.test.ts @@ -132,6 +132,17 @@ test("basic table examples", async () => { }, ]); // --8<-- [end:alter_columns] + // --8<-- [start:alter_columns_vector] + await tbl.alterColumns([ + { + path: "vector", + dataType: new arrow.FixedSizeList( + 2, + new arrow.Field("item", new arrow.Float16(), false), + ), + }, + ]); + // --8<-- [end:alter_columns_vector] // --8<-- [start:drop_columns] await tbl.dropColumns(["dbl_price"]); // --8<-- [end:drop_columns] diff --git a/nodejs/lancedb/arrow.ts b/nodejs/lancedb/arrow.ts index 8daca04c..00f8ffed 100644 --- a/nodejs/lancedb/arrow.ts +++ b/nodejs/lancedb/arrow.ts @@ -8,7 +8,11 @@ import { Bool, BufferType, DataType, + DateUnit, + Date_, + Decimal, Dictionary, + Duration, Field, FixedSizeBinary, FixedSizeList, @@ -21,12 +25,15 @@ import { LargeBinary, List, Null, + Precision, RecordBatch, RecordBatchFileReader, RecordBatchFileWriter, RecordBatchStreamWriter, Schema, Struct, + Timestamp, + Type, Utf8, Vector, makeVector as arrowMakeVector, @@ -1170,3 +1177,137 @@ function validateSchemaEmbeddings( return new Schema(fields, schema.metadata); } + +interface JsonDataType { + type: string; + fields?: JsonField[]; + length?: number; +} + +interface JsonField { + name: string; + type: JsonDataType; + nullable: boolean; + metadata: Map; +} + +// Matches format of https://github.com/lancedb/lance/blob/main/rust/lance/src/arrow/json.rs +export function dataTypeToJson(dataType: DataType): JsonDataType { + switch (dataType.typeId) { + // For primitives, matches https://github.com/lancedb/lance/blob/e12bb9eff2a52f753668d4b62c52e4d72b10d294/rust/lance-core/src/datatypes.rs#L185 + case Type.Null: + return { type: "null" }; + case Type.Bool: + return { type: "bool" }; + case Type.Int8: + return { type: "int8" }; + case Type.Int16: + return { type: "int16" }; + case Type.Int32: + return { type: "int32" }; + case Type.Int64: + return { type: "int64" }; + case Type.Uint8: + return { type: "uint8" }; + case Type.Uint16: + return { type: "uint16" }; + case Type.Uint32: + return { type: "uint32" }; + case Type.Uint64: + return { type: "uint64" }; + case Type.Int: { + const bitWidth = (dataType as Int).bitWidth; + const signed = (dataType as Int).isSigned; + const prefix = signed ? "" : "u"; + return { type: `${prefix}int${bitWidth}` }; + } + case Type.Float: { + switch ((dataType as Float).precision) { + case Precision.HALF: + return { type: "halffloat" }; + case Precision.SINGLE: + return { type: "float" }; + case Precision.DOUBLE: + return { type: "double" }; + } + throw Error("Unsupported float precision"); + } + case Type.Float16: + return { type: "halffloat" }; + case Type.Float32: + return { type: "float" }; + case Type.Float64: + return { type: "double" }; + case Type.Utf8: + return { type: "string" }; + case Type.Binary: + return { type: "binary" }; + case Type.LargeUtf8: + return { type: "large_string" }; + case Type.LargeBinary: + return { type: "large_binary" }; + case Type.List: + return { + type: "list", + fields: [fieldToJson((dataType as List).children[0])], + }; + case Type.FixedSizeList: { + const fixedSizeList = dataType as FixedSizeList; + return { + type: "fixed_size_list", + fields: [fieldToJson(fixedSizeList.children[0])], + length: fixedSizeList.listSize, + }; + } + case Type.Struct: + return { + type: "struct", + fields: (dataType as Struct).children.map(fieldToJson), + }; + case Type.Date: { + const unit = (dataType as Date_).unit; + return { + type: unit === DateUnit.DAY ? "date32:day" : "date64:ms", + }; + } + case Type.Timestamp: { + const timestamp = dataType as Timestamp; + const timezone = timestamp.timezone || "-"; + return { + type: `timestamp:${timestamp.unit}:${timezone}`, + }; + } + case Type.Decimal: { + const decimal = dataType as Decimal; + return { + type: `decimal:${decimal.bitWidth}:${decimal.precision}:${decimal.scale}`, + }; + } + case Type.Duration: { + const duration = dataType as Duration; + return { type: `duration:${duration.unit}` }; + } + case Type.FixedSizeBinary: { + const byteWidth = (dataType as FixedSizeBinary).byteWidth; + return { type: `fixed_size_binary:${byteWidth}` }; + } + case Type.Dictionary: { + const dict = dataType as Dictionary; + const indexType = dataTypeToJson(dict.indices); + const valueType = dataTypeToJson(dict.valueType); + return { + type: `dict:${valueType.type}:${indexType.type}:false`, + }; + } + } + throw new Error("Unsupported data type"); +} + +function fieldToJson(field: Field): JsonField { + return { + name: field.name, + type: dataTypeToJson(field.type), + nullable: field.nullable, + metadata: field.metadata, + }; +} diff --git a/nodejs/lancedb/index.ts b/nodejs/lancedb/index.ts index 66500c68..482910e3 100644 --- a/nodejs/lancedb/index.ts +++ b/nodejs/lancedb/index.ts @@ -14,7 +14,6 @@ import { export { AddColumnsSql, - ColumnAlteration, ConnectionOptions, IndexStatistics, IndexConfig, @@ -65,6 +64,7 @@ export { UpdateOptions, OptimizeOptions, Version, + ColumnAlteration, } from "./table"; export { MergeInsertBuilder } from "./merge"; diff --git a/nodejs/lancedb/table.ts b/nodejs/lancedb/table.ts index c0ac1d10..81bccb38 100644 --- a/nodejs/lancedb/table.ts +++ b/nodejs/lancedb/table.ts @@ -4,8 +4,10 @@ import { Table as ArrowTable, Data, + DataType, IntoVector, Schema, + dataTypeToJson, fromDataToBuffer, tableFromIPC, } from "./arrow"; @@ -15,13 +17,13 @@ import { IndexOptions } from "./indices"; import { MergeInsertBuilder } from "./merge"; import { AddColumnsSql, - ColumnAlteration, IndexConfig, IndexStatistics, OptimizeStats, Table as _NativeTable, } from "./native"; import { Query, VectorQuery } from "./query"; +import { sanitizeType } from "./sanitize"; import { IntoSql, toSQL } from "./util"; export { IndexConfig } from "./native"; @@ -618,7 +620,27 @@ export class LocalTable extends Table { } async alterColumns(columnAlterations: ColumnAlteration[]): Promise { - await this.inner.alterColumns(columnAlterations); + const processedAlterations = columnAlterations.map((alteration) => { + if (typeof alteration.dataType === "string") { + return { + ...alteration, + dataType: JSON.stringify({ type: alteration.dataType }), + }; + } else if (alteration.dataType === undefined) { + return { + ...alteration, + dataType: undefined, + }; + } else { + const dataType = sanitizeType(alteration.dataType); + return { + ...alteration, + dataType: JSON.stringify(dataTypeToJson(dataType)), + }; + } + }); + + await this.inner.alterColumns(processedAlterations); } async dropColumns(columnNames: string[]): Promise { @@ -711,3 +733,38 @@ export class LocalTable extends Table { await this.inner.migrateManifestPathsV2(); } } + +/** + * A definition of a column alteration. The alteration changes the column at + * `path` to have the new name `name`, to be nullable if `nullable` is true, + * and to have the data type `data_type`. At least one of `rename` or `nullable` + * must be provided. + */ +export interface ColumnAlteration { + /** + * The path to the column to alter. This is a dot-separated path to the column. + * If it is a top-level column then it is just the name of the column. If it is + * a nested column then it is the path to the column, e.g. "a.b.c" for a column + * `c` nested inside a column `b` nested inside a column `a`. + */ + path: string; + /** + * The new name of the column. If not provided then the name will not be changed. + * This must be distinct from the names of all other columns in the table. + */ + rename?: string; + /** + * A new data type for the column. If not provided then the data type will not be changed. + * Changing data types is limited to casting to the same general type. For example, these + * changes are valid: + * * `int32` -> `int64` (integers) + * * `double` -> `float` (floats) + * * `string` -> `large_string` (strings) + * But these changes are not: + * * `int32` -> `double` (mix integers and floats) + * * `string` -> `int32` (mix strings and integers) + */ + dataType?: string | DataType; + /** Set the new nullability. Note that a nullable column cannot be made non-nullable. */ + nullable?: boolean; +} diff --git a/python/python/tests/docs/test_basic.py b/python/python/tests/docs/test_basic.py index 0e7f4897..2a824371 100644 --- a/python/python/tests/docs/test_basic.py +++ b/python/python/tests/docs/test_basic.py @@ -83,6 +83,21 @@ def test_quickstart(tmp_path): } ) # --8<-- [end:alter_columns] + # --8<-- [start:alter_columns_vector] + tbl.alter_columns( + { + "path": "vector", + "data_type": pa.list_(pa.float16(), list_size=2), + } + ) + # --8<-- [end:alter_columns_vector] + # Change it back since we can get a panic with fp16 + tbl.alter_columns( + { + "path": "vector", + "data_type": pa.list_(pa.float32(), list_size=2), + } + ) # --8<-- [start:drop_columns] tbl.drop_columns(["dbl_price"]) # --8<-- [end:drop_columns] @@ -162,6 +177,21 @@ async def test_quickstart_async(tmp_path): } ) # --8<-- [end:alter_columns_async] + # --8<-- [start:alter_columns_async_vector] + await tbl.alter_columns( + { + "path": "vector", + "data_type": pa.list_(pa.float16(), list_size=2), + } + ) + # --8<-- [end:alter_columns_async_vector] + # Change it back since we can get a panic with fp16 + await tbl.alter_columns( + { + "path": "vector", + "data_type": pa.list_(pa.float32(), list_size=2), + } + ) # --8<-- [start:drop_columns_async] await tbl.drop_columns(["dbl_price"]) # --8<-- [end:drop_columns_async] diff --git a/rust/lancedb/src/utils.rs b/rust/lancedb/src/utils.rs index 3056dd9b..f2fd130a 100644 --- a/rust/lancedb/src/utils.rs +++ b/rust/lancedb/src/utils.rs @@ -166,10 +166,14 @@ pub fn supported_vector_data_type(dtype: &DataType) -> bool { /// Note: this is temporary until we get a proper datatype conversion in Lance. pub fn string_to_datatype(s: &str) -> Option { - let data_type = serde_json::Value::String(s.to_string()); - let json_type = - serde_json::Value::Object([("type".to_string(), data_type)].iter().cloned().collect()); - let json_type: JsonDataType = serde_json::from_value(json_type).ok()?; + let data_type: serde_json::Value = { + if let Ok(data_type) = serde_json::from_str(s) { + data_type + } else { + serde_json::json!({ "type": s }) + } + }; + let json_type: JsonDataType = serde_json::from_value(data_type).ok()?; (&json_type).try_into().ok() }