mirror of
https://github.com/lancedb/lancedb.git
synced 2026-06-02 03:40:40 +00:00
add embedding functions to the nodejs client (#95)
This commit is contained in:
@@ -15,15 +15,16 @@
|
||||
import {
|
||||
Field,
|
||||
Float32,
|
||||
List,
|
||||
List, type ListBuilder,
|
||||
makeBuilder,
|
||||
RecordBatchFileWriter,
|
||||
Table, Utf8,
|
||||
type Vector,
|
||||
vectorFromArray
|
||||
} from 'apache-arrow'
|
||||
import { type EmbeddingFunction } from './index'
|
||||
|
||||
export function convertToTable (data: Array<Record<string, unknown>>): Table {
|
||||
export function convertToTable<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Table {
|
||||
if (data.length === 0) {
|
||||
throw new Error('At least one record needs to be provided')
|
||||
}
|
||||
@@ -33,11 +34,7 @@ export function convertToTable (data: Array<Record<string, unknown>>): Table {
|
||||
|
||||
for (const columnsKey of columns) {
|
||||
if (columnsKey === 'vector') {
|
||||
const children = new Field<Float32>('item', new Float32())
|
||||
const list = new List(children)
|
||||
const listBuilder = makeBuilder({
|
||||
type: list
|
||||
})
|
||||
const listBuilder = newVectorListBuilder()
|
||||
const vectorSize = (data[0].vector as any[]).length
|
||||
for (const datum of data) {
|
||||
if ((datum[columnsKey] as any[]).length !== vectorSize) {
|
||||
@@ -52,6 +49,14 @@ export function convertToTable (data: Array<Record<string, unknown>>): Table {
|
||||
for (const datum of data) {
|
||||
values.push(datum[columnsKey])
|
||||
}
|
||||
|
||||
if (columnsKey === embeddings?.sourceColumn) {
|
||||
const vectors = embeddings.embed(values as T[])
|
||||
const listBuilder = newVectorListBuilder()
|
||||
vectors.map(v => listBuilder.append(v))
|
||||
records.vector = listBuilder.finish().toVector()
|
||||
}
|
||||
|
||||
if (typeof values[0] === 'string') {
|
||||
// `vectorFromArray` converts strings into dictionary vectors, forcing it back to a string column
|
||||
records[columnsKey] = vectorFromArray(values, new Utf8())
|
||||
@@ -64,8 +69,17 @@ export function convertToTable (data: Array<Record<string, unknown>>): Table {
|
||||
return new Table(records)
|
||||
}
|
||||
|
||||
export async function fromRecordsToBuffer (data: Array<Record<string, unknown>>): Promise<Buffer> {
|
||||
const table = convertToTable(data)
|
||||
// Creates a new Arrow ListBuilder that stores a Vector column
|
||||
function newVectorListBuilder (): ListBuilder<Float32, any> {
|
||||
const children = new Field<Float32>('item', new Float32())
|
||||
const list = new List(children)
|
||||
return makeBuilder({
|
||||
type: list
|
||||
})
|
||||
}
|
||||
|
||||
export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
|
||||
const table = convertToTable(data, embeddings)
|
||||
const writer = RecordBatchFileWriter.writeAll(table)
|
||||
return Buffer.from(await writer.toUint8Array())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user