mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-10 05:42:58 +00:00
feat(node): Create empty tables / Arrow Tables (#399)
- Supports creating an empty table as long as an Arrow Schema is provided - Supports creating a table from an Arrow Table (can be passed as data) - Simplified some Arrow code in the TS/FFI side - removed createTableArrow method, it was never documented / tested.
This commit is contained in:
@@ -13,18 +13,19 @@
|
||||
// limitations under the License.
|
||||
|
||||
import {
|
||||
Field,
|
||||
Field, type FixedSizeListBuilder,
|
||||
Float32,
|
||||
List, type ListBuilder,
|
||||
makeBuilder,
|
||||
RecordBatchFileWriter,
|
||||
Table, Utf8,
|
||||
Utf8,
|
||||
type Vector,
|
||||
vectorFromArray
|
||||
FixedSizeList,
|
||||
vectorFromArray, type Schema, Table as ArrowTable
|
||||
} from 'apache-arrow'
|
||||
import { type EmbeddingFunction } from './index'
|
||||
|
||||
export async function convertToTable<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Table> {
|
||||
// Converts an Array of records into an Arrow Table, optionally applying an embeddings function to it.
|
||||
export async function convertToTable<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<ArrowTable> {
|
||||
if (data.length === 0) {
|
||||
throw new Error('At least one record needs to be provided')
|
||||
}
|
||||
@@ -34,8 +35,8 @@ export async function convertToTable<T> (data: Array<Record<string, unknown>>, e
|
||||
|
||||
for (const columnsKey of columns) {
|
||||
if (columnsKey === 'vector') {
|
||||
const listBuilder = newVectorListBuilder()
|
||||
const vectorSize = (data[0].vector as any[]).length
|
||||
const listBuilder = newVectorBuilder(vectorSize)
|
||||
for (const datum of data) {
|
||||
if ((datum[columnsKey] as any[]).length !== vectorSize) {
|
||||
throw new Error(`Invalid vector size, expected ${vectorSize}`)
|
||||
@@ -52,9 +53,7 @@ export async function convertToTable<T> (data: Array<Record<string, unknown>>, e
|
||||
|
||||
if (columnsKey === embeddings?.sourceColumn) {
|
||||
const vectors = await embeddings.embed(values as T[])
|
||||
const listBuilder = newVectorListBuilder()
|
||||
vectors.map(v => listBuilder.append(v))
|
||||
records.vector = listBuilder.finish().toVector()
|
||||
records.vector = vectorFromArray(vectors, newVectorType(vectors[0].length))
|
||||
}
|
||||
|
||||
if (typeof values[0] === 'string') {
|
||||
@@ -66,20 +65,47 @@ export async function convertToTable<T> (data: Array<Record<string, unknown>>, e
|
||||
}
|
||||
}
|
||||
|
||||
return new Table(records)
|
||||
return new ArrowTable(records)
|
||||
}
|
||||
|
||||
// Creates a new Arrow ListBuilder that stores a Vector column
|
||||
function newVectorListBuilder (): ListBuilder<Float32, any> {
|
||||
const children = new Field<Float32>('item', new Float32())
|
||||
const list = new List(children)
|
||||
function newVectorBuilder (dim: number): FixedSizeListBuilder<Float32> {
|
||||
return makeBuilder({
|
||||
type: list
|
||||
type: newVectorType(dim)
|
||||
})
|
||||
}
|
||||
|
||||
// Creates the Arrow Type for a Vector column with dimension `dim`
|
||||
function newVectorType (dim: number): FixedSizeList<Float32> {
|
||||
const children = new Field<Float32>('item', new Float32())
|
||||
return new FixedSizeList(dim, children)
|
||||
}
|
||||
|
||||
// Converts an Array of records into Arrow IPC format
|
||||
export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
|
||||
const table = await convertToTable(data, embeddings)
|
||||
const writer = RecordBatchFileWriter.writeAll(table)
|
||||
return Buffer.from(await writer.toUint8Array())
|
||||
}
|
||||
|
||||
// Converts an Arrow Table into Arrow IPC format
|
||||
export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
|
||||
if (embeddings !== undefined) {
|
||||
const source = table.getChild(embeddings.sourceColumn)
|
||||
|
||||
if (source === null) {
|
||||
throw new Error(`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`)
|
||||
}
|
||||
|
||||
const vectors = await embeddings.embed(source.toArray() as T[])
|
||||
const column = vectorFromArray(vectors, newVectorType(vectors[0].length))
|
||||
table = table.assign(new ArrowTable({ vector: column }))
|
||||
}
|
||||
const writer = RecordBatchFileWriter.writeAll(table)
|
||||
return Buffer.from(await writer.toUint8Array())
|
||||
}
|
||||
|
||||
// Creates an empty Arrow Table
|
||||
export function createEmptyTable (schema: Schema): ArrowTable {
|
||||
return new ArrowTable(schema)
|
||||
}
|
||||
|
||||
@@ -13,10 +13,10 @@
|
||||
// limitations under the License.
|
||||
|
||||
import {
|
||||
RecordBatchFileWriter,
|
||||
type Table as ArrowTable
|
||||
type Schema,
|
||||
Table as ArrowTable
|
||||
} from 'apache-arrow'
|
||||
import { fromRecordsToBuffer } from './arrow'
|
||||
import { createEmptyTable, fromRecordsToBuffer, fromTableToBuffer } from './arrow'
|
||||
import type { EmbeddingFunction } from './embedding/embedding_function'
|
||||
import { RemoteConnection } from './remote'
|
||||
import { Query } from './query'
|
||||
@@ -51,6 +51,23 @@ export interface ConnectionOptions {
|
||||
hostOverride?: string
|
||||
}
|
||||
|
||||
export interface CreateTableOptions<T> {
|
||||
// Name of Table
|
||||
name: string
|
||||
|
||||
// Data to insert into the Table
|
||||
data?: Array<Record<string, unknown>> | ArrowTable | undefined
|
||||
|
||||
// Optional Arrow Schema for this table
|
||||
schema?: Schema | undefined
|
||||
|
||||
// Optional embedding function used to create embeddings
|
||||
embeddingFunction?: EmbeddingFunction<T> | undefined
|
||||
|
||||
// WriteOptions for this operation
|
||||
writeOptions?: WriteOptions | undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* Connect to a LanceDB instance at the given URI
|
||||
* @param uri The uri of the database.
|
||||
@@ -97,6 +114,17 @@ export interface Connection {
|
||||
*/
|
||||
openTable<T>(name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>>
|
||||
|
||||
/**
|
||||
* Creates a new Table, optionally initializing it with new data.
|
||||
*
|
||||
* @param {string} name - The name of the table.
|
||||
* @param data - Array of Records to be inserted into the table
|
||||
* @param schema - An Arrow Schema that describe this table columns
|
||||
* @param {EmbeddingFunction} embeddings - An embedding function to use on this table
|
||||
* @param {WriteOptions} writeOptions - The write options to use when creating the table.
|
||||
*/
|
||||
createTable<T> ({ name, data, schema, embeddingFunction, writeOptions }: CreateTableOptions<T>): Promise<Table<T>>
|
||||
|
||||
/**
|
||||
* Creates a new Table and initialize it with new data.
|
||||
*
|
||||
@@ -132,8 +160,6 @@ export interface Connection {
|
||||
*/
|
||||
createTable<T> (name: string, data: Array<Record<string, unknown>>, embeddings: EmbeddingFunction<T>, options: WriteOptions): Promise<Table<T>>
|
||||
|
||||
createTableArrow(name: string, table: ArrowTable): Promise<Table>
|
||||
|
||||
/**
|
||||
* Drop an existing table.
|
||||
* @param name The name of the table to drop.
|
||||
@@ -256,59 +282,80 @@ export class LocalConnection implements Connection {
|
||||
async openTable<T> (name: string, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
|
||||
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>>
|
||||
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
|
||||
const tbl = await databaseOpenTable.call(this._db, name, ...this.awsParams())
|
||||
if (embeddings !== undefined) {
|
||||
return new LocalTable(tbl, name, this._options(), embeddings)
|
||||
} else {
|
||||
return new LocalTable(tbl, name, this._options())
|
||||
}
|
||||
}
|
||||
|
||||
async createTable<T> (name: string | CreateTableOptions<T>, data?: Array<Record<string, unknown>>, optsOrEmbedding?: WriteOptions | EmbeddingFunction<T>, opt?: WriteOptions): Promise<Table<T>> {
|
||||
if (typeof name === 'string') {
|
||||
let writeOptions: WriteOptions = new DefaultWriteOptions()
|
||||
if (opt !== undefined && isWriteOptions(opt)) {
|
||||
writeOptions = opt
|
||||
} else if (optsOrEmbedding !== undefined && isWriteOptions(optsOrEmbedding)) {
|
||||
writeOptions = optsOrEmbedding
|
||||
}
|
||||
|
||||
let embeddings: undefined | EmbeddingFunction<T>
|
||||
if (optsOrEmbedding !== undefined && isEmbeddingFunction(optsOrEmbedding)) {
|
||||
embeddings = optsOrEmbedding
|
||||
}
|
||||
return await this.createTableImpl({ name, data, embeddingFunction: embeddings, writeOptions })
|
||||
}
|
||||
return await this.createTableImpl(name)
|
||||
}
|
||||
|
||||
private async createTableImpl<T> ({ name, data, schema, embeddingFunction, writeOptions = new DefaultWriteOptions() }: {
|
||||
name: string
|
||||
data?: Array<Record<string, unknown>> | ArrowTable | undefined
|
||||
schema?: Schema | undefined
|
||||
embeddingFunction?: EmbeddingFunction<T> | undefined
|
||||
writeOptions?: WriteOptions | undefined
|
||||
}): Promise<Table<T>> {
|
||||
let buffer: Buffer
|
||||
|
||||
function isEmpty (data: Array<Record<string, unknown>> | ArrowTable<any>): boolean {
|
||||
if (data instanceof ArrowTable) {
|
||||
return data.data.length === 0
|
||||
}
|
||||
return data.length === 0
|
||||
}
|
||||
|
||||
if ((data === undefined) || isEmpty(data)) {
|
||||
if (schema === undefined) {
|
||||
throw new Error('Either data or schema needs to defined')
|
||||
}
|
||||
buffer = await fromTableToBuffer(createEmptyTable(schema))
|
||||
} else if (data instanceof ArrowTable) {
|
||||
buffer = await fromTableToBuffer(data, embeddingFunction)
|
||||
} else {
|
||||
// data is Array<Record<...>>
|
||||
buffer = await fromRecordsToBuffer(data, embeddingFunction)
|
||||
}
|
||||
|
||||
const tbl = await tableCreate.call(this._db, name, buffer, writeOptions?.writeMode?.toString(), ...this.awsParams())
|
||||
if (embeddingFunction !== undefined) {
|
||||
return new LocalTable(tbl, name, this._options(), embeddingFunction)
|
||||
} else {
|
||||
return new LocalTable(tbl, name, this._options())
|
||||
}
|
||||
}
|
||||
|
||||
private awsParams (): any[] {
|
||||
// TODO: move this thing into rust
|
||||
const callArgs = [this._db, name]
|
||||
const awsCredentials = this._options().awsCredentials
|
||||
const params = []
|
||||
if (awsCredentials !== undefined) {
|
||||
callArgs.push(awsCredentials.accessKeyId)
|
||||
callArgs.push(awsCredentials.secretKey)
|
||||
params.push(awsCredentials.accessKeyId)
|
||||
params.push(awsCredentials.secretKey)
|
||||
if (awsCredentials.sessionToken !== undefined) {
|
||||
callArgs.push(awsCredentials.sessionToken)
|
||||
params.push(awsCredentials.sessionToken)
|
||||
}
|
||||
}
|
||||
const tbl = await databaseOpenTable.call(...callArgs)
|
||||
if (embeddings !== undefined) {
|
||||
return new LocalTable(tbl, name, this._options(), embeddings)
|
||||
} else {
|
||||
return new LocalTable(tbl, name, this._options())
|
||||
}
|
||||
}
|
||||
|
||||
async createTable<T> (name: string, data: Array<Record<string, unknown>>, optsOrEmbedding?: WriteOptions | EmbeddingFunction<T>, opt?: WriteOptions): Promise<Table<T>> {
|
||||
let writeOptions: WriteOptions = new DefaultWriteOptions()
|
||||
if (opt !== undefined && isWriteOptions(opt)) {
|
||||
writeOptions = opt
|
||||
} else if (optsOrEmbedding !== undefined && isWriteOptions(optsOrEmbedding)) {
|
||||
writeOptions = optsOrEmbedding
|
||||
}
|
||||
|
||||
let embeddings: undefined | EmbeddingFunction<T>
|
||||
if (optsOrEmbedding !== undefined && isEmbeddingFunction(optsOrEmbedding)) {
|
||||
embeddings = optsOrEmbedding
|
||||
}
|
||||
const createArgs = [this._db, name, await fromRecordsToBuffer(data, embeddings), writeOptions.writeMode?.toString()]
|
||||
const awsCredentials = this._options().awsCredentials
|
||||
if (awsCredentials !== undefined) {
|
||||
createArgs.push(awsCredentials.accessKeyId)
|
||||
createArgs.push(awsCredentials.secretKey)
|
||||
if (awsCredentials.sessionToken !== undefined) {
|
||||
createArgs.push(awsCredentials.sessionToken)
|
||||
}
|
||||
}
|
||||
|
||||
const tbl = await tableCreate.call(...createArgs)
|
||||
|
||||
if (embeddings !== undefined) {
|
||||
return new LocalTable(tbl, name, this._options(), embeddings)
|
||||
} else {
|
||||
return new LocalTable(tbl, name, this._options())
|
||||
}
|
||||
}
|
||||
|
||||
async createTableArrow (name: string, table: ArrowTable): Promise<Table> {
|
||||
const writer = RecordBatchFileWriter.writeAll(table)
|
||||
await tableCreate.call(this._db, name, Buffer.from(await writer.toUint8Array()))
|
||||
return await this.openTable(name)
|
||||
return params
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -14,11 +14,11 @@
|
||||
|
||||
import {
|
||||
type EmbeddingFunction, type Table, type VectorIndexParams, type Connection,
|
||||
type ConnectionOptions
|
||||
type ConnectionOptions, type CreateTableOptions, type WriteOptions
|
||||
} from '../index'
|
||||
import { Query } from '../query'
|
||||
|
||||
import { type Table as ArrowTable, Vector } from 'apache-arrow'
|
||||
import { Vector } from 'apache-arrow'
|
||||
import { HttpLancedbClient } from './client'
|
||||
|
||||
/**
|
||||
@@ -66,13 +66,7 @@ export class RemoteConnection implements Connection {
|
||||
}
|
||||
}
|
||||
|
||||
async createTable (name: string, data: Array<Record<string, unknown>>): Promise<Table>
|
||||
async createTable<T> (name: string, data: Array<Record<string, unknown>>, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
|
||||
async createTable<T> (name: string, data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
|
||||
throw new Error('Not implemented')
|
||||
}
|
||||
|
||||
async createTableArrow (name: string, table: ArrowTable): Promise<Table> {
|
||||
async createTable<T> (name: string | CreateTableOptions<T>, data?: Array<Record<string, unknown>>, optsOrEmbedding?: WriteOptions | EmbeddingFunction<T>, opt?: WriteOptions): Promise<Table<T>> {
|
||||
throw new Error('Not implemented')
|
||||
}
|
||||
|
||||
|
||||
@@ -47,7 +47,9 @@ describe('LanceDB S3 client', function () {
|
||||
}
|
||||
}
|
||||
const table = await createTestDB(opts, 2, 20)
|
||||
console.log(table)
|
||||
const con = await lancedb.connect(opts)
|
||||
console.log(con)
|
||||
assert.equal(con.uri, opts.uri)
|
||||
|
||||
const results = await table.search([0.1, 0.3]).limit(5).execute()
|
||||
@@ -70,5 +72,5 @@ async function createTestDB (opts: ConnectionOptions, numDimensions: number = 2,
|
||||
data.push({ id: i + 1, name: `name_${i}`, price: i + 10, is_active: (i % 2 === 0), vector })
|
||||
}
|
||||
|
||||
return await con.createTable('vectors', data)
|
||||
return await con.createTable('vectors_2', data)
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import * as chaiAsPromised from 'chai-as-promised'
|
||||
|
||||
import * as lancedb from '../index'
|
||||
import { type AwsCredentials, type EmbeddingFunction, MetricType, Query, WriteMode, DefaultWriteOptions, isWriteOptions } from '../index'
|
||||
import { Field, Int32, makeVector, Schema, Utf8, Table as ArrowTable, vectorFromArray } from 'apache-arrow'
|
||||
|
||||
const expect = chai.expect
|
||||
const assert = chai.assert
|
||||
@@ -119,6 +120,45 @@ describe('LanceDB client', function () {
|
||||
})
|
||||
|
||||
describe('when creating a new dataset', function () {
|
||||
it('create an empty table', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
const schema = new Schema(
|
||||
[new Field('id', new Int32()), new Field('name', new Utf8())]
|
||||
)
|
||||
const table = await con.createTable({ name: 'vectors', schema })
|
||||
assert.equal(table.name, 'vectors')
|
||||
assert.deepEqual(await con.tableNames(), ['vectors'])
|
||||
})
|
||||
|
||||
it('create a table with a empty data array', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
const schema = new Schema(
|
||||
[new Field('id', new Int32()), new Field('name', new Utf8())]
|
||||
)
|
||||
const table = await con.createTable({ name: 'vectors', schema, data: [] })
|
||||
assert.equal(table.name, 'vectors')
|
||||
assert.deepEqual(await con.tableNames(), ['vectors'])
|
||||
})
|
||||
|
||||
it('create a table from an Arrow Table', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
const i32s = new Int32Array(new Array<number>(10))
|
||||
const i32 = makeVector(i32s)
|
||||
|
||||
const data = new ArrowTable({ vector: i32 })
|
||||
|
||||
const table = await con.createTable({ name: 'vectors', data })
|
||||
assert.equal(table.name, 'vectors')
|
||||
assert.equal(await table.countRows(), 10)
|
||||
assert.deepEqual(await con.tableNames(), ['vectors'])
|
||||
})
|
||||
|
||||
it('creates a new table from javascript objects', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
@@ -291,6 +331,20 @@ describe('LanceDB client', function () {
|
||||
const results = await table.search('foo').execute()
|
||||
assert.equal(results.length, 2)
|
||||
})
|
||||
|
||||
it('should create embeddings for Arrow Table', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
const embeddingFunction = new TextEmbedding('name')
|
||||
|
||||
const names = vectorFromArray(['foo', 'bar'], new Utf8())
|
||||
const data = new ArrowTable({ name: names })
|
||||
|
||||
const table = await con.createTable({ name: 'vectors', data, embeddingFunction })
|
||||
assert.equal(table.name, 'vectors')
|
||||
const results = await table.search('foo').execute()
|
||||
assert.equal(results.length, 2)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user