nodejs append records api (#85)

This commit is contained in:
gsilvestrin
2023-05-18 15:13:57 -07:00
committed by GitHub
parent 61b9479bd9
commit e28fe7b468
8 changed files with 283 additions and 73 deletions

View File

@@ -6,7 +6,7 @@
"types": "dist/index.d.ts",
"scripts": {
"tsc": "tsc -b",
"build": "cargo-cp-artifact --artifact cdylib vectordb-node index.node -- cargo build --message-format=json-render-diagnostics",
"build": "cargo-cp-artifact --artifact cdylib vectordb-node darwin_arm64.node -- cargo build --message-format=json-render-diagnostics",
"build-release": "npm run build -- --release",
"test": "mocha -recursive dist/test",
"lint": "eslint src --ext .js,.ts"

66
node/src/arrow.ts Normal file
View File

@@ -0,0 +1,66 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import {
Field,
Float32,
List,
makeBuilder,
RecordBatchFileWriter,
Table,
type Vector,
vectorFromArray
} from 'apache-arrow'
export function convertToTable (data: Array<Record<string, unknown>>): Table {
if (data.length === 0) {
throw new Error('At least one record needs to be provided')
}
const columns = Object.keys(data[0])
const records: Record<string, Vector> = {}
for (const columnsKey of columns) {
if (columnsKey === 'vector') {
const children = new Field<Float32>('item', new Float32())
const list = new List(children)
const listBuilder = makeBuilder({
type: list
})
const vectorSize = (data[0].vector as any[]).length
for (const datum of data) {
if ((datum[columnsKey] as any[]).length !== vectorSize) {
throw new Error(`Invalid vector size, expected ${vectorSize}`)
}
listBuilder.append(datum[columnsKey])
}
records[columnsKey] = listBuilder.finish().toVector()
} else {
const values = []
for (const datum of data) {
values.push(datum[columnsKey])
}
records[columnsKey] = vectorFromArray(values)
}
}
return new Table(records)
}
export async function fromRecordsToBuffer (data: Array<Record<string, unknown>>): Promise<Buffer> {
const table = convertToTable(data)
const writer = RecordBatchFileWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array())
}

View File

@@ -13,19 +13,15 @@
// limitations under the License.
import {
Field,
Float32,
List,
makeBuilder,
RecordBatchFileWriter,
Table as ArrowTable,
type Table as ArrowTable,
tableFromIPC,
Vector,
vectorFromArray
Vector
} from 'apache-arrow'
import { fromRecordsToBuffer } from './arrow'
// eslint-disable-next-line @typescript-eslint/no-var-requires
const { databaseNew, databaseTableNames, databaseOpenTable, tableCreate, tableSearch } = require('../native.js')
const { databaseNew, databaseTableNames, databaseOpenTable, tableCreate, tableSearch, tableAdd } = require('../native.js')
/**
* Connect to a LanceDB instance at the given URI
@@ -68,40 +64,7 @@ export class Connection {
}
async createTable (name: string, data: Array<Record<string, unknown>>): Promise<Table> {
if (data.length === 0) {
throw new Error('At least one record needs to be provided')
}
const columns = Object.keys(data[0])
const records: Record<string, Vector> = {}
for (const columnsKey of columns) {
if (columnsKey === 'vector') {
const children = new Field<Float32>('item', new Float32())
const list = new List(children)
const listBuilder = makeBuilder({
type: list
})
const vectorSize = (data[0].vector as any[]).length
for (const datum of data) {
if ((datum[columnsKey] as any[]).length !== vectorSize) {
throw new Error(`Invalid vector size, expected ${vectorSize}`)
}
listBuilder.append(datum[columnsKey])
}
records[columnsKey] = listBuilder.finish().toVector()
} else {
const values = []
for (const datum of data) {
values.push(datum[columnsKey])
}
records[columnsKey] = vectorFromArray(values)
}
}
const table = new ArrowTable(records)
await this.createTableArrow(name, table)
await tableCreate.call(this._db, name, await fromRecordsToBuffer(data))
return await this.openTable(name)
}
@@ -135,6 +98,21 @@ export class Table {
search (queryVector: number[]): Query {
return new Query(this._tbl, queryVector)
}
/**
* Insert records into this Table
* @param data Records to be inserted into the Table
*
* @param mode Append / Overwrite existing records. Default: Append
* @return The number of rows added to the table
*/
async add (data: Array<Record<string, unknown>>): Promise<number> {
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data), WriteMode.Append.toString())
}
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data), WriteMode.Overwrite.toString())
}
}
/**
@@ -194,3 +172,8 @@ export class Query {
})
}
}
export enum WriteMode {
Overwrite = 'overwrite',
Append = 'append'
}

View File

@@ -90,6 +90,45 @@ describe('LanceDB client', function () {
const results = await table.search([0.1, 0.3]).execute()
assert.equal(results.length, 2)
})
it('appends records to an existing table ', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const data = [
{ id: 1, vector: [0.1, 0.2], price: 10 },
{ id: 2, vector: [1.1, 1.2], price: 50 }
]
const table = await con.createTable('vectors', data)
const results = await table.search([0.1, 0.3]).execute()
assert.equal(results.length, 2)
const dataAdd = [
{ id: 3, vector: [2.1, 2.2], price: 10 },
{ id: 4, vector: [3.1, 3.2], price: 50 }
]
await table.add(dataAdd)
const resultsAdd = await table.search([0.1, 0.3]).execute()
assert.equal(resultsAdd.length, 4)
})
it('overwrite all records in a table', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
const results = await table.search([0.1, 0.3]).execute()
assert.equal(results.length, 2)
const dataOver = [
{ vector: [2.1, 2.2], price: 10, name: 'foo' },
{ vector: [3.1, 3.2], price: 50, name: 'bar' }
]
await table.overwrite(dataOver)
const resultsAdd = await table.search([0.1, 0.3]).execute()
assert.equal(resultsAdd.length, 2)
})
})
})

View File

@@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::io::Cursor;
use std::ops::Deref;
use std::sync::Arc;
use arrow_array::cast::as_list_array;
use arrow_array::{Array, FixedSizeListArray, RecordBatch};
use arrow_ipc::reader::FileReader;
use arrow_schema::{DataType, Field, Schema};
use lance::arrow::{FixedSizeListArrayExt, RecordBatchExt};
@@ -45,3 +47,14 @@ pub(crate) fn convert_record_batch(record_batch: RecordBatch) -> RecordBatch {
}
new_batch
}
pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Vec<RecordBatch> {
let mut batches: Vec<RecordBatch> = Vec::new();
let fr = FileReader::try_new(Cursor::new(slice), None);
let file_reader = fr.unwrap();
for b in file_reader {
let record_batch = convert_record_batch(b.unwrap());
batches.push(record_batch);
}
batches
}

View File

@@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::io::Cursor;
use std::collections::HashMap;
use std::ops::Deref;
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use arrow_array::{Float32Array, RecordBatch, RecordBatchReader};
use arrow_ipc::reader::FileReader;
use arrow_array::{Float32Array, RecordBatchReader};
use arrow_ipc::writer::FileWriter;
use futures::{TryFutureExt, TryStreamExt};
use lance::arrow::RecordBatchBuffer;
use lance::dataset::WriteMode;
use neon::prelude::*;
use neon::types::buffer::TypedArray;
use once_cell::sync::OnceCell;
@@ -30,7 +30,7 @@ use vectordb::database::Database;
use vectordb::error::Error;
use vectordb::table::Table;
use crate::arrow::convert_record_batch;
use crate::arrow::arrow_buffer_to_record_batch;
mod arrow;
mod convert;
@@ -40,7 +40,7 @@ struct JsDatabase {
}
struct JsTable {
table: Arc<Table>,
table: Arc<Mutex<Table>>,
}
impl Finalize for JsDatabase {}
@@ -87,7 +87,7 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
let table_rst = database.open_table(table_name).await;
deferred.settle_with(&channel, move |mut cx| {
let table = Arc::new(table_rst.or_else(|err| cx.throw_error(err.to_string()))?);
let table = Arc::new(Mutex::new(table_rst.or_else(|err| cx.throw_error(err.to_string()))?));
Ok(cx.boxed(JsTable { table }))
});
});
@@ -109,6 +109,8 @@ fn table_search(mut cx: FunctionContext) -> JsResult<JsPromise> {
rt.spawn(async move {
let builder = table
.lock()
.unwrap()
.search(Float32Array::from(query))
.limit(limit as usize)
.filter(filter);
@@ -149,15 +151,7 @@ fn table_create(mut cx: FunctionContext) -> JsResult<JsPromise> {
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
let buffer = cx.argument::<JsBuffer>(1)?;
let slice = buffer.as_slice(&mut cx);
let mut batches: Vec<RecordBatch> = Vec::new();
let fr = FileReader::try_new(Cursor::new(slice), None);
let file_reader = fr.unwrap();
for b in file_reader {
let record_batch = convert_record_batch(b.unwrap());
batches.push(record_batch);
}
let batches = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx));
let rt = runtime(&mut cx)?;
let channel = cx.channel();
@@ -170,13 +164,47 @@ fn table_create(mut cx: FunctionContext) -> JsResult<JsPromise> {
let table_rst = database.create_table(table_name, batch_reader).await;
deferred.settle_with(&channel, move |mut cx| {
let table = Arc::new(table_rst.or_else(|err| cx.throw_error(err.to_string()))?);
let table = Arc::new(Mutex::new(table_rst.or_else(|err| cx.throw_error(err.to_string()))?));
Ok(cx.boxed(JsTable { table }))
});
});
Ok(promise)
}
fn table_add(mut cx: FunctionContext) -> JsResult<JsPromise> {
let write_mode_map: HashMap<&str, WriteMode> = HashMap::from([
("create", WriteMode::Create),
("append", WriteMode::Append),
("overwrite", WriteMode::Overwrite),
]);
let js_table = cx
.this()
.downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let buffer = cx.argument::<JsBuffer>(0)?;
let write_mode = cx.argument::<JsString>(1)?.value(&mut cx);
let batches = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx));
let rt = runtime(&mut cx)?;
let channel = cx.channel();
let (deferred, promise) = cx.promise();
let table = js_table.table.clone();
let write_mode = write_mode_map.get(write_mode.as_str()).cloned();
rt.block_on(async move {
let batch_reader: Box<dyn RecordBatchReader> = Box::new(RecordBatchBuffer::new(batches));
let add_result = table.lock().unwrap().add(batch_reader, write_mode).await;
deferred.settle_with(&channel, move |mut cx| {
let added = add_result.or_else(|err| cx.throw_error(err.to_string()))?;
Ok(cx.number(added as f64))
});
});
Ok(promise)
}
#[neon::main]
fn main(mut cx: ModuleContext) -> NeonResult<()> {
cx.export_function("databaseNew", database_new)?;
@@ -184,5 +212,6 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> {
cx.export_function("databaseOpenTable", database_open_table)?;
cx.export_function("tableSearch", table_search)?;
cx.export_function("tableCreate", table_create)?;
cx.export_function("tableAdd", table_add)?;
Ok(())
}

View File

@@ -91,7 +91,7 @@ impl Database {
///
/// * A [Table] object.
pub async fn open_table(&self, name: String) -> Result<Table> {
Table::new(self.path.clone(), name).await
Table::open(self.path.clone(), name).await
}
}

View File

@@ -16,7 +16,7 @@ use std::path::PathBuf;
use std::sync::Arc;
use arrow_array::{Float32Array, RecordBatchReader};
use lance::dataset::{Dataset, WriteParams};
use lance::dataset::{Dataset, WriteMode, WriteParams};
use crate::error::{Error, Result};
use crate::query::Query;
@@ -28,11 +28,12 @@ pub const LANCE_FILE_EXTENSION: &str = "lance";
/// A table in a LanceDB database.
pub struct Table {
name: String,
path: String,
dataset: Arc<Dataset>,
}
impl Table {
/// Creates a new Table object
/// Opens an existing Table
///
/// # Arguments
///
@@ -42,7 +43,7 @@ impl Table {
/// # Returns
///
/// * A [Table] object.
pub async fn new(base_path: Arc<PathBuf>, name: String) -> Result<Self> {
pub async fn open(base_path: Arc<PathBuf>, name: String) -> Result<Self> {
let ds_path = base_path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION));
let ds_uri = ds_path
.to_str()
@@ -50,24 +51,57 @@ impl Table {
let dataset = Dataset::open(ds_uri).await?;
let table = Table {
name,
path: ds_uri.to_string(),
dataset: Arc::new(dataset),
};
Ok(table)
}
/// Creates a new Table
///
/// # Arguments
///
/// * `base_path` - The base path where the table is located
/// * `name` The Table name
/// * `batches` RecordBatch to be saved in the database
///
/// # Returns
///
/// * A [Table] object.
pub async fn create(
base_path: Arc<PathBuf>,
name: String,
mut batches: Box<dyn RecordBatchReader>,
) -> Result<Self> {
let ds_path = base_path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION));
let ds_uri = ds_path
let path = ds_path
.to_str()
.ok_or(Error::IO(format!("Unable to find table {}", name)))?;
let dataset =
Arc::new(Dataset::write(&mut batches, ds_uri, Some(WriteParams::default())).await?);
Ok(Table { name, dataset })
Arc::new(Dataset::write(&mut batches, path, Some(WriteParams::default())).await?);
Ok(Table { name, path: path.to_string(), dataset })
}
/// Insert records into this Table
///
/// # Arguments
///
/// * `batches` RecordBatch to be saved in the Table
/// * `write_mode` Append / Overwrite existing records. Default: Append
/// # Returns
///
/// * The number of rows added
pub async fn add(
&mut self,
mut batches: Box<dyn RecordBatchReader>,
write_mode: Option<WriteMode>
) -> Result<usize> {
let mut params = WriteParams::default();
params.mode = write_mode.unwrap_or(WriteMode::Append);
self.dataset = Arc::new(Dataset::write(&mut batches, self.path.as_str(), Some(params)).await?);
Ok(batches.count())
}
/// Creates a new Query object that can be executed.
@@ -82,6 +116,11 @@ impl Table {
pub fn search(&self, query_vector: Float32Array) -> Query {
Query::new(self.dataset.clone(), query_vector)
}
/// Returns the number of rows in this Table
pub async fn count_rows(&self) -> Result<usize> {
Ok(self.dataset.count_rows().await?)
}
}
#[cfg(test)]
@@ -89,7 +128,7 @@ mod tests {
use arrow_array::{Float32Array, Int32Array, RecordBatch, RecordBatchReader};
use arrow_schema::{DataType, Field, Schema};
use lance::arrow::RecordBatchBuffer;
use lance::dataset::Dataset;
use lance::dataset::{Dataset, WriteMode};
use std::sync::Arc;
use tempfile::tempdir;
@@ -100,12 +139,12 @@ mod tests {
let tmp_dir = tempdir().unwrap();
let path_buf = tmp_dir.into_path();
let table = Table::new(Arc::new(path_buf), "test".to_string()).await;
let table = Table::open(Arc::new(path_buf), "test".to_string()).await;
assert!(table.is_err());
}
#[tokio::test]
async fn test_new() {
async fn test_open() {
let tmp_dir = tempdir().unwrap();
let path_buf = tmp_dir.into_path();
@@ -118,13 +157,54 @@ mod tests {
.await
.unwrap();
let table = Table::new(Arc::new(path_buf), "test".to_string())
let table = Table::open(Arc::new(path_buf), "test".to_string())
.await
.unwrap();
assert_eq!(table.name, "test")
}
#[tokio::test]
async fn test_add() {
let tmp_dir = tempdir().unwrap();
let path_buf = tmp_dir.into_path();
let batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
let schema = batches.schema().clone();
let mut table = Table::create(Arc::new(path_buf), "test".to_string(), batches).await.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
let new_batches: Box<dyn RecordBatchReader> = Box::new(RecordBatchBuffer::new(vec![RecordBatch::try_new(
schema,
vec![Arc::new(Int32Array::from_iter_values(100..110))],
)
.unwrap()]));
table.add(new_batches, None).await.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 20);
assert_eq!(table.name, "test");
}
#[tokio::test]
async fn test_add_overwrite() {
let tmp_dir = tempdir().unwrap();
let path_buf = tmp_dir.into_path();
let batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
let schema = batches.schema().clone();
let mut table = Table::create(Arc::new(path_buf), "test".to_string(), batches).await.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
let new_batches: Box<dyn RecordBatchReader> = Box::new(RecordBatchBuffer::new(vec![RecordBatch::try_new(
schema,
vec![Arc::new(Int32Array::from_iter_values(100..110))],
).unwrap()]));
table.add(new_batches, Some(WriteMode::Overwrite)).await.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
assert_eq!(table.name, "test");
}
#[tokio::test]
async fn test_search() {
let tmp_dir = tempdir().unwrap();
@@ -139,7 +219,7 @@ mod tests {
.await
.unwrap();
let table = Table::new(Arc::new(path_buf), "test".to_string())
let table = Table::open(Arc::new(path_buf), "test".to_string())
.await
.unwrap();
@@ -152,7 +232,7 @@ mod tests {
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
RecordBatchBuffer::new(vec![RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..20))],
vec![Arc::new(Int32Array::from_iter_values(0..10))],
)
.unwrap()])
}