From fc725c99f0629c61c6d5c68a770778085ac20dd8 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Mon, 3 Jul 2023 17:04:21 -0700 Subject: [PATCH] [Node] Create Table with WriteMode (#246) Support `createTable(name, data, mode?)` to be consistent with Python. Closes #242 --- node/src/index.ts | 34 +++++++++++++++++++++++++--------- node/src/test/test.ts | 31 ++++++++++++++++++++++++++++--- python/lancedb/db.py | 5 ++++- rust/ffi/node/src/lib.rs | 14 ++++++++++++-- rust/vectordb/src/database.rs | 10 +++++++++- rust/vectordb/src/table.rs | 15 ++++++++------- 6 files changed, 86 insertions(+), 23 deletions(-) diff --git a/node/src/index.ts b/node/src/index.ts index 375c68d3..dbffad39 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -37,7 +37,7 @@ export async function connect (uri: string): Promise { } /** - * A LanceDB connection that allows you to open tables and create new ones. + * A LanceDB Connection that allows you to open tables and create new ones. * * Connection could be local against filesystem or remote against a server. */ @@ -56,11 +56,14 @@ export interface Connection { /** * Creates a new Table and initialize it with new data. * - * @param name The name of the table. - * @param data Non-empty Array of Records to be inserted into the Table + * @param {string} name - The name of the table. + * @param data - Non-empty Array of Records to be inserted into the Table */ - createTable: ((name: string, data: Array>) => Promise) & ((name: string, data: Array>, embeddings: EmbeddingFunction) => Promise>) & ((name: string, data: Array>, embeddings?: EmbeddingFunction) => Promise>) + createTable: ((name: string, data: Array>, mode?: WriteMode) => Promise
) + & ((name: string, data: Array>, mode: WriteMode) => Promise
) + & ((name: string, data: Array>, mode: WriteMode, embeddings: EmbeddingFunction) => Promise>) + & ((name: string, data: Array>, mode: WriteMode, embeddings?: EmbeddingFunction) => Promise>) createTableArrow: (name: string, table: ArrowTable) => Promise
@@ -72,7 +75,7 @@ export interface Connection { } /** - * A LanceDB table that allows you to search and update a table. + * A LanceDB Table is the collection of Records. Each Record has one or more vector fields. */ export interface Table { name: string @@ -169,19 +172,25 @@ export class LocalConnection implements Connection { * * @param name The name of the table. * @param data Non-empty Array of Records to be inserted into the Table + * @param mode The write mode to use when creating the table. */ + async createTable (name: string, data: Array>, mode?: WriteMode): Promise
+ async createTable (name: string, data: Array>, mode: WriteMode): Promise
- async createTable (name: string, data: Array>): Promise
/** * Creates a new Table and initialize it with new data. * * @param name The name of the table. * @param data Non-empty Array of Records to be inserted into the Table + * @param mode The write mode to use when creating the table. * @param embeddings An embedding function to use on this Table */ - async createTable (name: string, data: Array>, embeddings: EmbeddingFunction): Promise> - async createTable (name: string, data: Array>, embeddings?: EmbeddingFunction): Promise> { - const tbl = await tableCreate.call(this._db, name, await fromRecordsToBuffer(data, embeddings)) + async createTable (name: string, data: Array>, mode: WriteMode, embeddings: EmbeddingFunction): Promise> + async createTable (name: string, data: Array>, mode: WriteMode, embeddings?: EmbeddingFunction): Promise> { + if (mode === undefined) { + mode = WriteMode.Create + } + const tbl = await tableCreate.call(this._db, name, await fromRecordsToBuffer(data, embeddings), mode.toLowerCase()) if (embeddings !== undefined) { return new LocalTable(tbl, name, embeddings) } else { @@ -445,8 +454,15 @@ export class Query { } } +/** + * Write mode for writing a table. + */ export enum WriteMode { + /** Create a new {@link Table}. */ + Create = 'create', + /** Overwrite the existing {@link Table} if presented. */ Overwrite = 'overwrite', + /** Append new data to the table. */ Append = 'append' } diff --git a/node/src/test/test.ts b/node/src/test/test.ts index b10332b5..882055cf 100644 --- a/node/src/test/test.ts +++ b/node/src/test/test.ts @@ -1,4 +1,4 @@ -// Copyright 2023 Lance Developers. +// Copyright 2023 LanceDB Developers. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ import * as chai from 'chai' import * as chaiAsPromised from 'chai-as-promised' import * as lancedb from '../index' -import { type EmbeddingFunction, MetricType, Query } from '../index' +import { type EmbeddingFunction, MetricType, Query, WriteMode } from '../index' const expect = chai.expect const assert = chai.assert @@ -118,6 +118,31 @@ describe('LanceDB client', function () { assert.equal(await table.countRows(), 2) }) + it('use overwrite flag to overwrite existing table', async function () { + const dir = await track().mkdir('lancejs') + const con = await lancedb.connect(dir) + + const data = [ + { id: 1, vector: [0.1, 0.2], price: 10 }, + { id: 2, vector: [1.1, 1.2], price: 50 } + ] + + const tableName = 'overwrite' + await con.createTable(tableName, data, WriteMode.Create) + + const newData = [ + { id: 1, vector: [0.1, 0.2], price: 10 }, + { id: 2, vector: [1.1, 1.2], price: 50 }, + { id: 3, vector: [1.1, 1.2], price: 50 } + ] + + await expect(con.createTable(tableName, newData)).to.be.rejectedWith(Error, 'already exists') + + const table = await con.createTable(tableName, newData, WriteMode.Overwrite) + assert.equal(table.name, tableName) + assert.equal(await table.countRows(), 3) + }) + it('appends records to an existing table ', async function () { const dir = await track().mkdir('lancejs') const con = await lancedb.connect(dir) @@ -218,7 +243,7 @@ describe('LanceDB client', function () { { price: 10, name: 'foo' }, { price: 50, name: 'bar' } ] - const table = await con.createTable('vectors', data, embeddings) + const table = await con.createTable('vectors', data, WriteMode.Create, embeddings) const results = await table.search('foo').execute() assert.equal(results.length, 2) }) diff --git a/python/lancedb/db.py b/python/lancedb/db.py index ded65562..a9a43eb8 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -170,7 +170,7 @@ class LanceDBConnection: schema: pyarrow.Schema; optional The schema of the table. mode: str; default "create" - The mode to use when creating the table. + The mode to use when creating the table. Can be either "create" or "overwrite". By default, if the table already exists, an exception is raised. If you want to overwrite the table, use mode="overwrite". @@ -249,6 +249,9 @@ class LanceDBConnection: lat: [[45.5,40.1]] long: [[-122.7,-74.1]] """ + if mode.lower() not in ["create", "overwrite"]: + raise ValueError("mode must be either 'create' or 'overwrite'") + if data is not None: tbl = LanceTable.create(self, name, data, schema, mode=mode) else: diff --git a/rust/ffi/node/src/lib.rs b/rust/ffi/node/src/lib.rs index a1f68b54..a325e6de 100644 --- a/rust/ffi/node/src/lib.rs +++ b/rust/ffi/node/src/lib.rs @@ -21,7 +21,7 @@ use arrow_array::{Float32Array, RecordBatchReader}; use arrow_ipc::writer::FileWriter; use futures::{TryFutureExt, TryStreamExt}; use lance::arrow::RecordBatchBuffer; -use lance::dataset::WriteMode; +use lance::dataset::{WriteMode, WriteParams}; use lance::index::vector::MetricType; use neon::prelude::*; use neon::types::buffer::TypedArray; @@ -234,6 +234,16 @@ fn table_create(mut cx: FunctionContext) -> JsResult { let buffer = cx.argument::(1)?; let batches = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)); + // Write mode + let mode = match cx.argument::(2)?.value(&mut cx).as_str() { + "overwrite" => WriteMode::Overwrite, + "append" => WriteMode::Append, + "create" => WriteMode::Create, + _ => return cx.throw_error("Table::create only supports 'overwrite' and 'create' modes") + }; + let mut params = WriteParams::default(); + params.mode = mode; + let rt = runtime(&mut cx)?; let channel = cx.channel(); @@ -242,7 +252,7 @@ fn table_create(mut cx: FunctionContext) -> JsResult { rt.block_on(async move { let batch_reader: Box = Box::new(RecordBatchBuffer::new(batches)); - let table_rst = database.create_table(&table_name, batch_reader).await; + let table_rst = database.create_table(&table_name, batch_reader, Some(params)).await; deferred.settle_with(&channel, move |mut cx| { let table = Arc::new(Mutex::new( diff --git a/rust/vectordb/src/database.rs b/rust/vectordb/src/database.rs index 90713ed0..c759d9ab 100644 --- a/rust/vectordb/src/database.rs +++ b/rust/vectordb/src/database.rs @@ -16,6 +16,7 @@ use std::fs::create_dir_all; use std::path::Path; use arrow_array::RecordBatchReader; +use lance::dataset::WriteParams; use lance::io::object_store::ObjectStore; use snafu::prelude::*; @@ -90,12 +91,19 @@ impl Database { Ok(f) } + /// Create a new table in the database. + /// + /// # Arguments + /// * `name` - The name of the table. + /// * `batches` - The initial data to write to the table. + /// * `params` - Optional [`WriteParams`] to create the table. pub async fn create_table( &self, name: &str, batches: Box, + params: Option, ) -> Result
{ - Table::create(&self.uri, name, batches).await + Table::create(&self.uri, name, batches, params).await } /// Open a table in the database. diff --git a/rust/vectordb/src/table.rs b/rust/vectordb/src/table.rs index bc8e9f83..a6c475a9 100644 --- a/rust/vectordb/src/table.rs +++ b/rust/vectordb/src/table.rs @@ -1,4 +1,4 @@ -// Copyright 2023 Lance Developers. +// Copyright 2023 LanceDB Developers. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -117,6 +117,7 @@ impl Table { base_uri: &str, name: &str, mut batches: Box, + params: Option, ) -> Result { let base_path = Path::new(base_uri); let table_uri = base_path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION)); @@ -125,7 +126,7 @@ impl Table { .to_str() .context(InvalidTableNameSnafu { name })? .to_string(); - let dataset = Dataset::write(&mut batches, &uri, Some(WriteParams::default())) + let dataset = Dataset::write(&mut batches, &uri, params) .await .map_err(|e| match e { lance::Error::DatasetAlreadyExists { .. } => Error::TableAlreadyExists { @@ -284,10 +285,10 @@ mod tests { let batches: Box = Box::new(make_test_batches()); let _ = batches.schema().clone(); - Table::create(&uri, "test", batches).await.unwrap(); + Table::create(&uri, "test", batches, None).await.unwrap(); let batches: Box = Box::new(make_test_batches()); - let result = Table::create(&uri, "test", batches).await; + let result = Table::create(&uri, "test", batches, None).await; assert!(matches!( result.unwrap_err(), Error::TableAlreadyExists { .. } @@ -301,7 +302,7 @@ mod tests { let batches: Box = Box::new(make_test_batches()); let schema = batches.schema().clone(); - let mut table = Table::create(&uri, "test", batches).await.unwrap(); + let mut table = Table::create(&uri, "test", batches, None).await.unwrap(); assert_eq!(table.count_rows().await.unwrap(), 10); let new_batches: Box = @@ -323,7 +324,7 @@ mod tests { let batches: Box = Box::new(make_test_batches()); let schema = batches.schema().clone(); - let mut table = Table::create(uri, "test", batches).await.unwrap(); + let mut table = Table::create(uri, "test", batches, None).await.unwrap(); assert_eq!(table.count_rows().await.unwrap(), 10); let new_batches: Box = @@ -453,7 +454,7 @@ mod tests { .unwrap()]); let reader: Box = Box::new(batches); - let mut table = Table::create(uri, "test", reader).await.unwrap(); + let mut table = Table::create(uri, "test", reader, None).await.unwrap(); let mut i = IvfPQIndexBuilder::new();