mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-07 04:12:59 +00:00
feat(js): add helper function to create Arrow Table with schema (#838)
Support to make Apache Arrow Table from an array of javascript Records, with optionally provided Schema.
This commit is contained in:
@@ -13,18 +13,168 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
import {
|
import {
|
||||||
Field, type FixedSizeListBuilder,
|
Field,
|
||||||
|
type FixedSizeListBuilder,
|
||||||
Float32,
|
Float32,
|
||||||
makeBuilder,
|
makeBuilder,
|
||||||
RecordBatchFileWriter,
|
RecordBatchFileWriter,
|
||||||
Utf8, type Vector,
|
Utf8,
|
||||||
|
type Vector,
|
||||||
FixedSizeList,
|
FixedSizeList,
|
||||||
vectorFromArray, type Schema, Table as ArrowTable, RecordBatchStreamWriter, List, Float64, RecordBatch, makeData, Struct
|
vectorFromArray,
|
||||||
|
type Schema,
|
||||||
|
Table as ArrowTable,
|
||||||
|
RecordBatchStreamWriter,
|
||||||
|
List,
|
||||||
|
Float64,
|
||||||
|
RecordBatch,
|
||||||
|
makeData,
|
||||||
|
Struct,
|
||||||
|
type Float
|
||||||
} from 'apache-arrow'
|
} from 'apache-arrow'
|
||||||
import { type EmbeddingFunction } from './index'
|
import { type EmbeddingFunction } from './index'
|
||||||
|
|
||||||
|
export class VectorColumnOptions {
|
||||||
|
/** Vector column type. */
|
||||||
|
type: Float = new Float32()
|
||||||
|
|
||||||
|
constructor (values?: Partial<VectorColumnOptions>) {
|
||||||
|
Object.assign(this, values)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Options to control the makeArrowTable call. */
|
||||||
|
export class MakeArrowTableOptions {
|
||||||
|
/** Provided schema. */
|
||||||
|
schema?: Schema
|
||||||
|
|
||||||
|
/** Vector columns */
|
||||||
|
vectorColumns: Record<string, VectorColumnOptions> = {
|
||||||
|
vector: new VectorColumnOptions()
|
||||||
|
}
|
||||||
|
|
||||||
|
constructor (values?: Partial<MakeArrowTableOptions>) {
|
||||||
|
Object.assign(this, values)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An enhanced version of the {@link makeTable} function from Apache Arrow
|
||||||
|
* that supports nested fields and embeddings columns.
|
||||||
|
*
|
||||||
|
* Note that it currently does not support nulls.
|
||||||
|
*
|
||||||
|
* @param data input data
|
||||||
|
* @param options options to control the makeArrowTable call.
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
*
|
||||||
|
* ```ts
|
||||||
|
*
|
||||||
|
* import { fromTableToBuffer, makeArrowTable } from "../arrow";
|
||||||
|
* import { Field, FixedSizeList, Float16, Float32, Int32, Schema } from "apache-arrow";
|
||||||
|
*
|
||||||
|
* const schema = new Schema([
|
||||||
|
* new Field("a", new Int32()),
|
||||||
|
* new Field("b", new Float32()),
|
||||||
|
* new Field("c", new FixedSizeList(3, new Field("item", new Float16()))),
|
||||||
|
* ]);
|
||||||
|
* const table = makeArrowTable([
|
||||||
|
* { a: 1, b: 2, c: [1, 2, 3] },
|
||||||
|
* { a: 4, b: 5, c: [4, 5, 6] },
|
||||||
|
* { a: 7, b: 8, c: [7, 8, 9] },
|
||||||
|
* ], { schema });
|
||||||
|
* ```
|
||||||
|
*
|
||||||
|
* It guesses the vector columns if the schema is not provided. For example,
|
||||||
|
* by default it assumes that the column named `vector` is a vector column.
|
||||||
|
*
|
||||||
|
* ```ts
|
||||||
|
*
|
||||||
|
* const schema = new Schema([
|
||||||
|
new Field("a", new Float64()),
|
||||||
|
new Field("b", new Float64()),
|
||||||
|
new Field(
|
||||||
|
"vector",
|
||||||
|
new FixedSizeList(3, new Field("item", new Float32()))
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
const table = makeArrowTable([
|
||||||
|
{ a: 1, b: 2, vector: [1, 2, 3] },
|
||||||
|
{ a: 4, b: 5, vector: [4, 5, 6] },
|
||||||
|
{ a: 7, b: 8, vector: [7, 8, 9] },
|
||||||
|
]);
|
||||||
|
assert.deepEqual(table.schema, schema);
|
||||||
|
* ```
|
||||||
|
*
|
||||||
|
* You can specify the vector column types and names using the options as well
|
||||||
|
*
|
||||||
|
* ```typescript
|
||||||
|
*
|
||||||
|
* const schema = new Schema([
|
||||||
|
new Field('a', new Float64()),
|
||||||
|
new Field('b', new Float64()),
|
||||||
|
new Field('vec1', new FixedSizeList(3, new Field('item', new Float16()))),
|
||||||
|
new Field('vec2', new FixedSizeList(3, new Field('item', new Float16())))
|
||||||
|
]);
|
||||||
|
* const table = makeArrowTable([
|
||||||
|
{ a: 1, b: 2, vec1: [1, 2, 3], vec2: [2, 4, 6] },
|
||||||
|
{ a: 4, b: 5, vec1: [4, 5, 6], vec2: [8, 10, 12] },
|
||||||
|
{ a: 7, b: 8, vec1: [7, 8, 9], vec2: [14, 16, 18] }
|
||||||
|
], {
|
||||||
|
vectorColumns: {
|
||||||
|
vec1: { type: new Float16() },
|
||||||
|
vec2: { type: new Float16() }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
* assert.deepEqual(table.schema, schema)
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
export function makeArrowTable (
|
||||||
|
data: Array<Record<string, any>>,
|
||||||
|
options?: Partial<MakeArrowTableOptions>
|
||||||
|
): ArrowTable {
|
||||||
|
if (data.length === 0) {
|
||||||
|
throw new Error('At least one record needs to be provided')
|
||||||
|
}
|
||||||
|
const opt = new MakeArrowTableOptions(options !== undefined ? options : {})
|
||||||
|
const columns: Record<string, Vector> = {}
|
||||||
|
// TODO: sample dataset to find missing columns
|
||||||
|
const columnNames = Object.keys(data[0])
|
||||||
|
for (const colName of columnNames) {
|
||||||
|
const values = data.map((datum) => datum[colName])
|
||||||
|
let vector: Vector
|
||||||
|
|
||||||
|
if (opt.schema !== undefined) {
|
||||||
|
// Explicit schema is provided, highest priority
|
||||||
|
vector = vectorFromArray(
|
||||||
|
values,
|
||||||
|
opt.schema?.fields.filter((f) => f.name === colName)[0]?.type
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
const vectorColumnOptions = opt.vectorColumns[colName]
|
||||||
|
if (vectorColumnOptions !== undefined) {
|
||||||
|
const fslType = new FixedSizeList(
|
||||||
|
values[0].length,
|
||||||
|
new Field('item', vectorColumnOptions.type, false)
|
||||||
|
)
|
||||||
|
vector = vectorFromArray(values, fslType)
|
||||||
|
} else {
|
||||||
|
// Normal case
|
||||||
|
vector = vectorFromArray(values)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
columns[colName] = vector
|
||||||
|
}
|
||||||
|
|
||||||
|
return new ArrowTable(columns)
|
||||||
|
}
|
||||||
|
|
||||||
// Converts an Array of records into an Arrow Table, optionally applying an embeddings function to it.
|
// Converts an Array of records into an Arrow Table, optionally applying an embeddings function to it.
|
||||||
export async function convertToTable<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<ArrowTable> {
|
export async function convertToTable<T> (
|
||||||
|
data: Array<Record<string, unknown>>,
|
||||||
|
embeddings?: EmbeddingFunction<T>
|
||||||
|
): Promise<ArrowTable> {
|
||||||
if (data.length === 0) {
|
if (data.length === 0) {
|
||||||
throw new Error('At least one record needs to be provided')
|
throw new Error('At least one record needs to be provided')
|
||||||
}
|
}
|
||||||
@@ -52,7 +202,10 @@ export async function convertToTable<T> (data: Array<Record<string, unknown>>, e
|
|||||||
|
|
||||||
if (columnsKey === embeddings?.sourceColumn) {
|
if (columnsKey === embeddings?.sourceColumn) {
|
||||||
const vectors = await embeddings.embed(values as T[])
|
const vectors = await embeddings.embed(values as T[])
|
||||||
records.vector = vectorFromArray(vectors, newVectorType(vectors[0].length))
|
records.vector = vectorFromArray(
|
||||||
|
vectors,
|
||||||
|
newVectorType(vectors[0].length)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (typeof values[0] === 'string') {
|
if (typeof values[0] === 'string') {
|
||||||
@@ -110,7 +263,11 @@ function newVectorType (dim: number): FixedSizeList<Float32> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Converts an Array of records into Arrow IPC format
|
// Converts an Array of records into Arrow IPC format
|
||||||
export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
|
export async function fromRecordsToBuffer<T> (
|
||||||
|
data: Array<Record<string, unknown>>,
|
||||||
|
embeddings?: EmbeddingFunction<T>,
|
||||||
|
schema?: Schema
|
||||||
|
): Promise<Buffer> {
|
||||||
let table = await convertToTable(data, embeddings)
|
let table = await convertToTable(data, embeddings)
|
||||||
if (schema !== undefined) {
|
if (schema !== undefined) {
|
||||||
table = alignTable(table, schema)
|
table = alignTable(table, schema)
|
||||||
@@ -120,7 +277,11 @@ export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Converts an Array of records into Arrow IPC stream format
|
// Converts an Array of records into Arrow IPC stream format
|
||||||
export async function fromRecordsToStreamBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
|
export async function fromRecordsToStreamBuffer<T> (
|
||||||
|
data: Array<Record<string, unknown>>,
|
||||||
|
embeddings?: EmbeddingFunction<T>,
|
||||||
|
schema?: Schema
|
||||||
|
): Promise<Buffer> {
|
||||||
let table = await convertToTable(data, embeddings)
|
let table = await convertToTable(data, embeddings)
|
||||||
if (schema !== undefined) {
|
if (schema !== undefined) {
|
||||||
table = alignTable(table, schema)
|
table = alignTable(table, schema)
|
||||||
@@ -130,12 +291,18 @@ export async function fromRecordsToStreamBuffer<T> (data: Array<Record<string, u
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Converts an Arrow Table into Arrow IPC format
|
// Converts an Arrow Table into Arrow IPC format
|
||||||
export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
|
export async function fromTableToBuffer<T> (
|
||||||
|
table: ArrowTable,
|
||||||
|
embeddings?: EmbeddingFunction<T>,
|
||||||
|
schema?: Schema
|
||||||
|
): Promise<Buffer> {
|
||||||
if (embeddings !== undefined) {
|
if (embeddings !== undefined) {
|
||||||
const source = table.getChild(embeddings.sourceColumn)
|
const source = table.getChild(embeddings.sourceColumn)
|
||||||
|
|
||||||
if (source === null) {
|
if (source === null) {
|
||||||
throw new Error(`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`)
|
throw new Error(
|
||||||
|
`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
const vectors = await embeddings.embed(source.toArray() as T[])
|
const vectors = await embeddings.embed(source.toArray() as T[])
|
||||||
@@ -150,12 +317,18 @@ export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: Embe
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Converts an Arrow Table into Arrow IPC stream format
|
// Converts an Arrow Table into Arrow IPC stream format
|
||||||
export async function fromTableToStreamBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
|
export async function fromTableToStreamBuffer<T> (
|
||||||
|
table: ArrowTable,
|
||||||
|
embeddings?: EmbeddingFunction<T>,
|
||||||
|
schema?: Schema
|
||||||
|
): Promise<Buffer> {
|
||||||
if (embeddings !== undefined) {
|
if (embeddings !== undefined) {
|
||||||
const source = table.getChild(embeddings.sourceColumn)
|
const source = table.getChild(embeddings.sourceColumn)
|
||||||
|
|
||||||
if (source === null) {
|
if (source === null) {
|
||||||
throw new Error(`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`)
|
throw new Error(
|
||||||
|
`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
const vectors = await embeddings.embed(source.toArray() as T[])
|
const vectors = await embeddings.embed(source.toArray() as T[])
|
||||||
@@ -172,9 +345,13 @@ export async function fromTableToStreamBuffer<T> (table: ArrowTable, embeddings?
|
|||||||
function alignBatch (batch: RecordBatch, schema: Schema): RecordBatch {
|
function alignBatch (batch: RecordBatch, schema: Schema): RecordBatch {
|
||||||
const alignedChildren = []
|
const alignedChildren = []
|
||||||
for (const field of schema.fields) {
|
for (const field of schema.fields) {
|
||||||
const indexInBatch = batch.schema.fields?.findIndex((f) => f.name === field.name)
|
const indexInBatch = batch.schema.fields?.findIndex(
|
||||||
|
(f) => f.name === field.name
|
||||||
|
)
|
||||||
if (indexInBatch < 0) {
|
if (indexInBatch < 0) {
|
||||||
throw new Error(`The column ${field.name} was not found in the Arrow Table`)
|
throw new Error(
|
||||||
|
`The column ${field.name} was not found in the Arrow Table`
|
||||||
|
)
|
||||||
}
|
}
|
||||||
alignedChildren.push(batch.data.children[indexInBatch])
|
alignedChildren.push(batch.data.children[indexInBatch])
|
||||||
}
|
}
|
||||||
@@ -188,7 +365,9 @@ function alignBatch (batch: RecordBatch, schema: Schema): RecordBatch {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function alignTable (table: ArrowTable, schema: Schema): ArrowTable {
|
function alignTable (table: ArrowTable, schema: Schema): ArrowTable {
|
||||||
const alignedBatches = table.batches.map(batch => alignBatch(batch, schema))
|
const alignedBatches = table.batches.map((batch) =>
|
||||||
|
alignBatch(batch, schema)
|
||||||
|
)
|
||||||
return new ArrowTable(schema, alignedBatches)
|
return new ArrowTable(schema, alignedBatches)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -41,12 +41,13 @@ const {
|
|||||||
tableListIndices,
|
tableListIndices,
|
||||||
tableIndexStats,
|
tableIndexStats,
|
||||||
tableSchema
|
tableSchema
|
||||||
// eslint-disable-next-line @typescript-eslint/no-var-requires
|
// eslint-disable-next-line @typescript-eslint/no-var-requires
|
||||||
} = require('../native.js')
|
} = require('../native.js')
|
||||||
|
|
||||||
export { Query }
|
export { Query }
|
||||||
export type { EmbeddingFunction }
|
export type { EmbeddingFunction }
|
||||||
export { OpenAIEmbeddingFunction } from './embedding/openai'
|
export { OpenAIEmbeddingFunction } from './embedding/openai'
|
||||||
|
export { makeArrowTable, type MakeArrowTableOptions } from './arrow'
|
||||||
|
|
||||||
const defaultAwsRegion = 'us-west-2'
|
const defaultAwsRegion = 'us-west-2'
|
||||||
|
|
||||||
@@ -859,7 +860,10 @@ export class LocalTable<T = number[]> implements Table<T> {
|
|||||||
private checkElectron (): boolean {
|
private checkElectron (): boolean {
|
||||||
try {
|
try {
|
||||||
// eslint-disable-next-line no-prototype-builtins
|
// eslint-disable-next-line no-prototype-builtins
|
||||||
return (process?.versions?.hasOwnProperty('electron') || navigator?.userAgent?.toLowerCase()?.includes(' electron'))
|
return (
|
||||||
|
Object.prototype.hasOwnProperty.call(process?.versions, 'electron') ||
|
||||||
|
navigator?.userAgent?.toLowerCase()?.includes(' electron')
|
||||||
|
)
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
108
node/src/test/arrow.test.ts
Normal file
108
node/src/test/arrow.test.ts
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
// Copyright 2024 Lance Developers.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
import { describe } from 'mocha'
|
||||||
|
import { assert } from 'chai'
|
||||||
|
|
||||||
|
import { fromTableToBuffer, makeArrowTable } from '../arrow'
|
||||||
|
import {
|
||||||
|
Field,
|
||||||
|
FixedSizeList,
|
||||||
|
Float16,
|
||||||
|
Float32,
|
||||||
|
Int32,
|
||||||
|
tableFromIPC,
|
||||||
|
Schema,
|
||||||
|
Float64
|
||||||
|
} from 'apache-arrow'
|
||||||
|
|
||||||
|
describe('Apache Arrow tables', function () {
|
||||||
|
it('customized schema', async function () {
|
||||||
|
const schema = new Schema([
|
||||||
|
new Field('a', new Int32()),
|
||||||
|
new Field('b', new Float32()),
|
||||||
|
new Field('c', new FixedSizeList(3, new Field('item', new Float16())))
|
||||||
|
])
|
||||||
|
const table = makeArrowTable(
|
||||||
|
[
|
||||||
|
{ a: 1, b: 2, c: [1, 2, 3] },
|
||||||
|
{ a: 4, b: 5, c: [4, 5, 6] },
|
||||||
|
{ a: 7, b: 8, c: [7, 8, 9] }
|
||||||
|
],
|
||||||
|
{ schema }
|
||||||
|
)
|
||||||
|
|
||||||
|
const buf = await fromTableToBuffer(table)
|
||||||
|
assert.isAbove(buf.byteLength, 0)
|
||||||
|
|
||||||
|
const actual = tableFromIPC(buf)
|
||||||
|
assert.equal(actual.numRows, 3)
|
||||||
|
const actualSchema = actual.schema
|
||||||
|
assert.deepEqual(actualSchema, schema)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('default vector column', async function () {
|
||||||
|
const schema = new Schema([
|
||||||
|
new Field('a', new Float64()),
|
||||||
|
new Field('b', new Float64()),
|
||||||
|
new Field(
|
||||||
|
'vector',
|
||||||
|
new FixedSizeList(3, new Field('item', new Float32()))
|
||||||
|
)
|
||||||
|
])
|
||||||
|
const table = makeArrowTable([
|
||||||
|
{ a: 1, b: 2, vector: [1, 2, 3] },
|
||||||
|
{ a: 4, b: 5, vector: [4, 5, 6] },
|
||||||
|
{ a: 7, b: 8, vector: [7, 8, 9] }
|
||||||
|
])
|
||||||
|
|
||||||
|
const buf = await fromTableToBuffer(table)
|
||||||
|
assert.isAbove(buf.byteLength, 0)
|
||||||
|
|
||||||
|
const actual = tableFromIPC(buf)
|
||||||
|
assert.equal(actual.numRows, 3)
|
||||||
|
const actualSchema = actual.schema
|
||||||
|
assert.deepEqual(actualSchema, schema)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('2 vector columns', async function () {
|
||||||
|
const schema = new Schema([
|
||||||
|
new Field('a', new Float64()),
|
||||||
|
new Field('b', new Float64()),
|
||||||
|
new Field('vec1', new FixedSizeList(3, new Field('item', new Float16()))),
|
||||||
|
new Field('vec2', new FixedSizeList(3, new Field('item', new Float16())))
|
||||||
|
])
|
||||||
|
const table = makeArrowTable(
|
||||||
|
[
|
||||||
|
{ a: 1, b: 2, vec1: [1, 2, 3], vec2: [2, 4, 6] },
|
||||||
|
{ a: 4, b: 5, vec1: [4, 5, 6], vec2: [8, 10, 12] },
|
||||||
|
{ a: 7, b: 8, vec1: [7, 8, 9], vec2: [14, 16, 18] }
|
||||||
|
],
|
||||||
|
{
|
||||||
|
vectorColumns: {
|
||||||
|
vec1: { type: new Float16() },
|
||||||
|
vec2: { type: new Float16() }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
const buf = await fromTableToBuffer(table)
|
||||||
|
assert.isAbove(buf.byteLength, 0)
|
||||||
|
|
||||||
|
const actual = tableFromIPC(buf)
|
||||||
|
assert.equal(actual.numRows, 3)
|
||||||
|
const actualSchema = actual.schema
|
||||||
|
assert.deepEqual(actualSchema, schema)
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -1,10 +1,14 @@
|
|||||||
{
|
{
|
||||||
"include": ["src/**/*.ts"],
|
"include": [
|
||||||
|
"src/**/*.ts",
|
||||||
|
"src/*.ts"
|
||||||
|
],
|
||||||
"compilerOptions": {
|
"compilerOptions": {
|
||||||
"target": "es2016",
|
"target": "ES2020",
|
||||||
"module": "commonjs",
|
"module": "commonjs",
|
||||||
"declaration": true,
|
"declaration": true,
|
||||||
"outDir": "./dist",
|
"outDir": "./dist",
|
||||||
"strict": true
|
"strict": true,
|
||||||
|
// "esModuleInterop": true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user