feat(js): support list of string input (#755)

Add support for adding lists of string input (e.g., list of categorical
labels)

Follow-up items: #757 #758
This commit is contained in:
Chang She
2024-01-02 20:55:33 -08:00
committed by Weston Pace
parent 3aa233f38a
commit 81487f10fe
2 changed files with 46 additions and 1 deletions

View File

@@ -20,7 +20,7 @@ import {
Utf8,
type Vector,
FixedSizeList,
vectorFromArray, type Schema, Table as ArrowTable, RecordBatchStreamWriter
vectorFromArray, type Schema, Table as ArrowTable, RecordBatchStreamWriter, List, Float64
} from 'apache-arrow'
import { type EmbeddingFunction } from './index'
@@ -59,6 +59,24 @@ export async function convertToTable<T> (data: Array<Record<string, unknown>>, e
if (typeof values[0] === 'string') {
// `vectorFromArray` converts strings into dictionary vectors, forcing it back to a string column
records[columnsKey] = vectorFromArray(values, new Utf8())
} else if (Array.isArray(values[0])) {
const elementType = getElementType(values[0])
let innerType
if (elementType === 'string') {
innerType = new Utf8()
} else if (elementType === 'number') {
innerType = new Float64()
} else {
// TODO: pass in schema if it exists, else keep going to the next element
throw new Error(`Unsupported array element type ${elementType}`)
}
const listBuilder = makeBuilder({
type: new List(new Field('item', innerType, true))
})
for (const value of values) {
listBuilder.append(value)
}
records[columnsKey] = listBuilder.finish().toVector()
} else {
records[columnsKey] = vectorFromArray(values)
}
@@ -68,6 +86,14 @@ export async function convertToTable<T> (data: Array<Record<string, unknown>>, e
return new ArrowTable(records)
}
function getElementType (arr: any[]): string {
if (arr.length === 0) {
return 'undefined'
}
return typeof arr[0]
}
// Creates a new Arrow ListBuilder that stores a Vector column
function newVectorBuilder (dim: number): FixedSizeListBuilder<Float32> {
return makeBuilder({

View File

@@ -218,6 +218,25 @@ describe('LanceDB client', function () {
assert.equal(await table.countRows(), 2)
})
it('creates a new table from javascript objects with variable sized list', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const data = [
{ id: 1, vector: [0.1, 0.2], list_of_str: ['a', 'b', 'c'], list_of_num: [1, 2, 3] },
{ id: 2, vector: [1.1, 1.2], list_of_str: ['x', 'y'], list_of_num: [4, 5, 6] }
]
const tableName = 'with_variable_sized_list'
const table = await con.createTable(tableName, data) as LocalTable
assert.equal(table.name, tableName)
assert.equal(await table.countRows(), 2)
const rs = await table.filter('id>1').execute()
assert.equal(rs.length, 1)
assert.deepEqual(rs[0].list_of_str, ['x', 'y'])
assert.isTrue(rs[0].list_of_num instanceof Float64Array)
})
it('fails to create a new table when the vector column is missing', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)