diff --git a/node/package.json b/node/package.json index 2b41f498..fc746055 100644 --- a/node/package.json +++ b/node/package.json @@ -6,7 +6,7 @@ "types": "dist/index.d.ts", "scripts": { "tsc": "tsc -b", - "build": "cargo-cp-artifact --artifact cdylib vectordb-node index.node -- cargo build --message-format=json-render-diagnostics", + "build": "cargo-cp-artifact --artifact cdylib vectordb-node darwin_arm64.node -- cargo build --message-format=json-render-diagnostics", "build-release": "npm run build -- --release", "test": "mocha -recursive dist/test", "lint": "eslint src --ext .js,.ts" diff --git a/node/src/arrow.ts b/node/src/arrow.ts new file mode 100644 index 00000000..0a593088 --- /dev/null +++ b/node/src/arrow.ts @@ -0,0 +1,66 @@ +// Copyright 2023 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import { + Field, + Float32, + List, + makeBuilder, + RecordBatchFileWriter, + Table, + type Vector, + vectorFromArray +} from 'apache-arrow' + +export function convertToTable (data: Array>): Table { + if (data.length === 0) { + throw new Error('At least one record needs to be provided') + } + + const columns = Object.keys(data[0]) + const records: Record = {} + + for (const columnsKey of columns) { + if (columnsKey === 'vector') { + const children = new Field('item', new Float32()) + const list = new List(children) + const listBuilder = makeBuilder({ + type: list + }) + const vectorSize = (data[0].vector as any[]).length + for (const datum of data) { + if ((datum[columnsKey] as any[]).length !== vectorSize) { + throw new Error(`Invalid vector size, expected ${vectorSize}`) + } + + listBuilder.append(datum[columnsKey]) + } + records[columnsKey] = listBuilder.finish().toVector() + } else { + const values = [] + for (const datum of data) { + values.push(datum[columnsKey]) + } + records[columnsKey] = vectorFromArray(values) + } + } + + return new Table(records) +} + +export async function fromRecordsToBuffer (data: Array>): Promise { + const table = convertToTable(data) + const writer = RecordBatchFileWriter.writeAll(table) + return Buffer.from(await writer.toUint8Array()) +} diff --git a/node/src/index.ts b/node/src/index.ts index 55d65574..07b96c23 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -13,19 +13,15 @@ // limitations under the License. import { - Field, - Float32, - List, - makeBuilder, RecordBatchFileWriter, - Table as ArrowTable, + type Table as ArrowTable, tableFromIPC, - Vector, - vectorFromArray + Vector } from 'apache-arrow' +import { fromRecordsToBuffer } from './arrow' // eslint-disable-next-line @typescript-eslint/no-var-requires -const { databaseNew, databaseTableNames, databaseOpenTable, tableCreate, tableSearch } = require('../native.js') +const { databaseNew, databaseTableNames, databaseOpenTable, tableCreate, tableSearch, tableAdd } = require('../native.js') /** * Connect to a LanceDB instance at the given URI @@ -68,40 +64,7 @@ export class Connection { } async createTable (name: string, data: Array>): Promise { - if (data.length === 0) { - throw new Error('At least one record needs to be provided') - } - - const columns = Object.keys(data[0]) - const records: Record = {} - - for (const columnsKey of columns) { - if (columnsKey === 'vector') { - const children = new Field('item', new Float32()) - const list = new List(children) - const listBuilder = makeBuilder({ - type: list - }) - const vectorSize = (data[0].vector as any[]).length - for (const datum of data) { - if ((datum[columnsKey] as any[]).length !== vectorSize) { - throw new Error(`Invalid vector size, expected ${vectorSize}`) - } - - listBuilder.append(datum[columnsKey]) - } - records[columnsKey] = listBuilder.finish().toVector() - } else { - const values = [] - for (const datum of data) { - values.push(datum[columnsKey]) - } - records[columnsKey] = vectorFromArray(values) - } - } - - const table = new ArrowTable(records) - await this.createTableArrow(name, table) + await tableCreate.call(this._db, name, await fromRecordsToBuffer(data)) return await this.openTable(name) } @@ -135,6 +98,21 @@ export class Table { search (queryVector: number[]): Query { return new Query(this._tbl, queryVector) } + + /** + * Insert records into this Table + * @param data Records to be inserted into the Table + * + * @param mode Append / Overwrite existing records. Default: Append + * @return The number of rows added to the table + */ + async add (data: Array>): Promise { + return tableAdd.call(this._tbl, await fromRecordsToBuffer(data), WriteMode.Append.toString()) + } + + async overwrite (data: Array>): Promise { + return tableAdd.call(this._tbl, await fromRecordsToBuffer(data), WriteMode.Overwrite.toString()) + } } /** @@ -194,3 +172,8 @@ export class Query { }) } } + +export enum WriteMode { + Overwrite = 'overwrite', + Append = 'append' +} diff --git a/node/src/test/test.ts b/node/src/test/test.ts index c7e61e5d..9ab570e7 100644 --- a/node/src/test/test.ts +++ b/node/src/test/test.ts @@ -90,6 +90,45 @@ describe('LanceDB client', function () { const results = await table.search([0.1, 0.3]).execute() assert.equal(results.length, 2) }) + + it('appends records to an existing table ', async function () { + const dir = await track().mkdir('lancejs') + const con = await lancedb.connect(dir) + + const data = [ + { id: 1, vector: [0.1, 0.2], price: 10 }, + { id: 2, vector: [1.1, 1.2], price: 50 } + ] + + const table = await con.createTable('vectors', data) + const results = await table.search([0.1, 0.3]).execute() + assert.equal(results.length, 2) + + const dataAdd = [ + { id: 3, vector: [2.1, 2.2], price: 10 }, + { id: 4, vector: [3.1, 3.2], price: 50 } + ] + await table.add(dataAdd) + const resultsAdd = await table.search([0.1, 0.3]).execute() + assert.equal(resultsAdd.length, 4) + }) + + it('overwrite all records in a table', async function () { + const uri = await createTestDB() + const con = await lancedb.connect(uri) + + const table = await con.openTable('vectors') + const results = await table.search([0.1, 0.3]).execute() + assert.equal(results.length, 2) + + const dataOver = [ + { vector: [2.1, 2.2], price: 10, name: 'foo' }, + { vector: [3.1, 3.2], price: 50, name: 'bar' } + ] + await table.overwrite(dataOver) + const resultsAdd = await table.search([0.1, 0.3]).execute() + assert.equal(resultsAdd.length, 2) + }) }) }) diff --git a/rust/ffi/node/src/arrow.rs b/rust/ffi/node/src/arrow.rs index 599a354f..f16ea60c 100644 --- a/rust/ffi/node/src/arrow.rs +++ b/rust/ffi/node/src/arrow.rs @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::io::Cursor; use std::ops::Deref; use std::sync::Arc; use arrow_array::cast::as_list_array; use arrow_array::{Array, FixedSizeListArray, RecordBatch}; +use arrow_ipc::reader::FileReader; use arrow_schema::{DataType, Field, Schema}; use lance::arrow::{FixedSizeListArrayExt, RecordBatchExt}; @@ -45,3 +47,14 @@ pub(crate) fn convert_record_batch(record_batch: RecordBatch) -> RecordBatch { } new_batch } + +pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Vec { + let mut batches: Vec = Vec::new(); + let fr = FileReader::try_new(Cursor::new(slice), None); + let file_reader = fr.unwrap(); + for b in file_reader { + let record_batch = convert_record_batch(b.unwrap()); + batches.push(record_batch); + } + batches +} diff --git a/rust/ffi/node/src/lib.rs b/rust/ffi/node/src/lib.rs index 2fbeeb1f..d4fc814d 100644 --- a/rust/ffi/node/src/lib.rs +++ b/rust/ffi/node/src/lib.rs @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::io::Cursor; +use std::collections::HashMap; use std::ops::Deref; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; -use arrow_array::{Float32Array, RecordBatch, RecordBatchReader}; -use arrow_ipc::reader::FileReader; +use arrow_array::{Float32Array, RecordBatchReader}; use arrow_ipc::writer::FileWriter; use futures::{TryFutureExt, TryStreamExt}; use lance::arrow::RecordBatchBuffer; +use lance::dataset::WriteMode; use neon::prelude::*; use neon::types::buffer::TypedArray; use once_cell::sync::OnceCell; @@ -30,7 +30,7 @@ use vectordb::database::Database; use vectordb::error::Error; use vectordb::table::Table; -use crate::arrow::convert_record_batch; +use crate::arrow::arrow_buffer_to_record_batch; mod arrow; mod convert; @@ -40,7 +40,7 @@ struct JsDatabase { } struct JsTable { - table: Arc
, + table: Arc>, } impl Finalize for JsDatabase {} @@ -87,7 +87,7 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult { let table_rst = database.open_table(table_name).await; deferred.settle_with(&channel, move |mut cx| { - let table = Arc::new(table_rst.or_else(|err| cx.throw_error(err.to_string()))?); + let table = Arc::new(Mutex::new(table_rst.or_else(|err| cx.throw_error(err.to_string()))?)); Ok(cx.boxed(JsTable { table })) }); }); @@ -109,6 +109,8 @@ fn table_search(mut cx: FunctionContext) -> JsResult { rt.spawn(async move { let builder = table + .lock() + .unwrap() .search(Float32Array::from(query)) .limit(limit as usize) .filter(filter); @@ -149,15 +151,7 @@ fn table_create(mut cx: FunctionContext) -> JsResult { .downcast_or_throw::, _>(&mut cx)?; let table_name = cx.argument::(0)?.value(&mut cx); let buffer = cx.argument::(1)?; - let slice = buffer.as_slice(&mut cx); - - let mut batches: Vec = Vec::new(); - let fr = FileReader::try_new(Cursor::new(slice), None); - let file_reader = fr.unwrap(); - for b in file_reader { - let record_batch = convert_record_batch(b.unwrap()); - batches.push(record_batch); - } + let batches = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)); let rt = runtime(&mut cx)?; let channel = cx.channel(); @@ -170,13 +164,47 @@ fn table_create(mut cx: FunctionContext) -> JsResult { let table_rst = database.create_table(table_name, batch_reader).await; deferred.settle_with(&channel, move |mut cx| { - let table = Arc::new(table_rst.or_else(|err| cx.throw_error(err.to_string()))?); + let table = Arc::new(Mutex::new(table_rst.or_else(|err| cx.throw_error(err.to_string()))?)); Ok(cx.boxed(JsTable { table })) }); }); Ok(promise) } +fn table_add(mut cx: FunctionContext) -> JsResult { + let write_mode_map: HashMap<&str, WriteMode> = HashMap::from([ + ("create", WriteMode::Create), + ("append", WriteMode::Append), + ("overwrite", WriteMode::Overwrite), + ]); + + let js_table = cx + .this() + .downcast_or_throw::, _>(&mut cx)?; + let buffer = cx.argument::(0)?; + let write_mode = cx.argument::(1)?.value(&mut cx); + let batches = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)); + + let rt = runtime(&mut cx)?; + let channel = cx.channel(); + + let (deferred, promise) = cx.promise(); + let table = js_table.table.clone(); + let write_mode = write_mode_map.get(write_mode.as_str()).cloned(); + + rt.block_on(async move { + let batch_reader: Box = Box::new(RecordBatchBuffer::new(batches)); + let add_result = table.lock().unwrap().add(batch_reader, write_mode).await; + + deferred.settle_with(&channel, move |mut cx| { + let added = add_result.or_else(|err| cx.throw_error(err.to_string()))?; + Ok(cx.number(added as f64)) + }); + }); + Ok(promise) +} + + #[neon::main] fn main(mut cx: ModuleContext) -> NeonResult<()> { cx.export_function("databaseNew", database_new)?; @@ -184,5 +212,6 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> { cx.export_function("databaseOpenTable", database_open_table)?; cx.export_function("tableSearch", table_search)?; cx.export_function("tableCreate", table_create)?; + cx.export_function("tableAdd", table_add)?; Ok(()) } diff --git a/rust/vectordb/src/database.rs b/rust/vectordb/src/database.rs index 4dd238fa..de65a991 100644 --- a/rust/vectordb/src/database.rs +++ b/rust/vectordb/src/database.rs @@ -91,7 +91,7 @@ impl Database { /// /// * A [Table] object. pub async fn open_table(&self, name: String) -> Result
{ - Table::new(self.path.clone(), name).await + Table::open(self.path.clone(), name).await } } diff --git a/rust/vectordb/src/table.rs b/rust/vectordb/src/table.rs index a8a9d3e6..bab07d77 100644 --- a/rust/vectordb/src/table.rs +++ b/rust/vectordb/src/table.rs @@ -16,7 +16,7 @@ use std::path::PathBuf; use std::sync::Arc; use arrow_array::{Float32Array, RecordBatchReader}; -use lance::dataset::{Dataset, WriteParams}; +use lance::dataset::{Dataset, WriteMode, WriteParams}; use crate::error::{Error, Result}; use crate::query::Query; @@ -28,11 +28,12 @@ pub const LANCE_FILE_EXTENSION: &str = "lance"; /// A table in a LanceDB database. pub struct Table { name: String, + path: String, dataset: Arc, } impl Table { - /// Creates a new Table object + /// Opens an existing Table /// /// # Arguments /// @@ -42,7 +43,7 @@ impl Table { /// # Returns /// /// * A [Table] object. - pub async fn new(base_path: Arc, name: String) -> Result { + pub async fn open(base_path: Arc, name: String) -> Result { let ds_path = base_path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION)); let ds_uri = ds_path .to_str() @@ -50,24 +51,57 @@ impl Table { let dataset = Dataset::open(ds_uri).await?; let table = Table { name, + path: ds_uri.to_string(), dataset: Arc::new(dataset), }; Ok(table) } + /// Creates a new Table + /// + /// # Arguments + /// + /// * `base_path` - The base path where the table is located + /// * `name` The Table name + /// * `batches` RecordBatch to be saved in the database + /// + /// # Returns + /// + /// * A [Table] object. pub async fn create( base_path: Arc, name: String, mut batches: Box, ) -> Result { let ds_path = base_path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION)); - let ds_uri = ds_path + let path = ds_path .to_str() .ok_or(Error::IO(format!("Unable to find table {}", name)))?; let dataset = - Arc::new(Dataset::write(&mut batches, ds_uri, Some(WriteParams::default())).await?); - Ok(Table { name, dataset }) + Arc::new(Dataset::write(&mut batches, path, Some(WriteParams::default())).await?); + Ok(Table { name, path: path.to_string(), dataset }) + } + + /// Insert records into this Table + /// + /// # Arguments + /// + /// * `batches` RecordBatch to be saved in the Table + /// * `write_mode` Append / Overwrite existing records. Default: Append + /// # Returns + /// + /// * The number of rows added + pub async fn add( + &mut self, + mut batches: Box, + write_mode: Option + ) -> Result { + let mut params = WriteParams::default(); + params.mode = write_mode.unwrap_or(WriteMode::Append); + + self.dataset = Arc::new(Dataset::write(&mut batches, self.path.as_str(), Some(params)).await?); + Ok(batches.count()) } /// Creates a new Query object that can be executed. @@ -82,6 +116,11 @@ impl Table { pub fn search(&self, query_vector: Float32Array) -> Query { Query::new(self.dataset.clone(), query_vector) } + + /// Returns the number of rows in this Table + pub async fn count_rows(&self) -> Result { + Ok(self.dataset.count_rows().await?) + } } #[cfg(test)] @@ -89,7 +128,7 @@ mod tests { use arrow_array::{Float32Array, Int32Array, RecordBatch, RecordBatchReader}; use arrow_schema::{DataType, Field, Schema}; use lance::arrow::RecordBatchBuffer; - use lance::dataset::Dataset; + use lance::dataset::{Dataset, WriteMode}; use std::sync::Arc; use tempfile::tempdir; @@ -100,12 +139,12 @@ mod tests { let tmp_dir = tempdir().unwrap(); let path_buf = tmp_dir.into_path(); - let table = Table::new(Arc::new(path_buf), "test".to_string()).await; + let table = Table::open(Arc::new(path_buf), "test".to_string()).await; assert!(table.is_err()); } #[tokio::test] - async fn test_new() { + async fn test_open() { let tmp_dir = tempdir().unwrap(); let path_buf = tmp_dir.into_path(); @@ -118,13 +157,54 @@ mod tests { .await .unwrap(); - let table = Table::new(Arc::new(path_buf), "test".to_string()) + let table = Table::open(Arc::new(path_buf), "test".to_string()) .await .unwrap(); assert_eq!(table.name, "test") } + #[tokio::test] + async fn test_add() { + let tmp_dir = tempdir().unwrap(); + let path_buf = tmp_dir.into_path(); + + let batches: Box = Box::new(make_test_batches()); + let schema = batches.schema().clone(); + let mut table = Table::create(Arc::new(path_buf), "test".to_string(), batches).await.unwrap(); + assert_eq!(table.count_rows().await.unwrap(), 10); + + let new_batches: Box = Box::new(RecordBatchBuffer::new(vec![RecordBatch::try_new( + schema, + vec![Arc::new(Int32Array::from_iter_values(100..110))], + ) + .unwrap()])); + + table.add(new_batches, None).await.unwrap(); + assert_eq!(table.count_rows().await.unwrap(), 20); + assert_eq!(table.name, "test"); + } + + #[tokio::test] + async fn test_add_overwrite() { + let tmp_dir = tempdir().unwrap(); + let path_buf = tmp_dir.into_path(); + + let batches: Box = Box::new(make_test_batches()); + let schema = batches.schema().clone(); + let mut table = Table::create(Arc::new(path_buf), "test".to_string(), batches).await.unwrap(); + assert_eq!(table.count_rows().await.unwrap(), 10); + + let new_batches: Box = Box::new(RecordBatchBuffer::new(vec![RecordBatch::try_new( + schema, + vec![Arc::new(Int32Array::from_iter_values(100..110))], + ).unwrap()])); + + table.add(new_batches, Some(WriteMode::Overwrite)).await.unwrap(); + assert_eq!(table.count_rows().await.unwrap(), 10); + assert_eq!(table.name, "test"); + } + #[tokio::test] async fn test_search() { let tmp_dir = tempdir().unwrap(); @@ -139,7 +219,7 @@ mod tests { .await .unwrap(); - let table = Table::new(Arc::new(path_buf), "test".to_string()) + let table = Table::open(Arc::new(path_buf), "test".to_string()) .await .unwrap(); @@ -152,7 +232,7 @@ mod tests { let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)])); RecordBatchBuffer::new(vec![RecordBatch::try_new( schema.clone(), - vec![Arc::new(Int32Array::from_iter_values(0..20))], + vec![Arc::new(Int32Array::from_iter_values(0..10))], ) .unwrap()]) }