[Node] Create Table with WriteMode (#246)

Support `createTable(name, data, mode?)`  to be consistent with Python.

Closes #242
This commit is contained in:
Lei Xu
2023-07-03 17:04:21 -07:00
committed by GitHub
parent a6bdffd75b
commit fc725c99f0
6 changed files with 86 additions and 23 deletions

View File

@@ -37,7 +37,7 @@ export async function connect (uri: string): Promise<Connection> {
}
/**
* A LanceDB connection that allows you to open tables and create new ones.
* A LanceDB Connection that allows you to open tables and create new ones.
*
* Connection could be local against filesystem or remote against a server.
*/
@@ -56,11 +56,14 @@ export interface Connection {
/**
* Creates a new Table and initialize it with new data.
*
* @param name The name of the table.
* @param data Non-empty Array of Records to be inserted into the Table
* @param {string} name - The name of the table.
* @param data - Non-empty Array of Records to be inserted into the Table
*/
createTable: ((name: string, data: Array<Record<string, unknown>>) => Promise<Table>) & (<T>(name: string, data: Array<Record<string, unknown>>, embeddings: EmbeddingFunction<T>) => Promise<Table<T>>) & (<T>(name: string, data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>) => Promise<Table<T>>)
createTable: ((name: string, data: Array<Record<string, unknown>>, mode?: WriteMode) => Promise<Table>)
& ((name: string, data: Array<Record<string, unknown>>, mode: WriteMode) => Promise<Table>)
& (<T>(name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings: EmbeddingFunction<T>) => Promise<Table<T>>)
& (<T>(name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings?: EmbeddingFunction<T>) => Promise<Table<T>>)
createTableArrow: (name: string, table: ArrowTable) => Promise<Table>
@@ -72,7 +75,7 @@ export interface Connection {
}
/**
* A LanceDB table that allows you to search and update a table.
* A LanceDB Table is the collection of Records. Each Record has one or more vector fields.
*/
export interface Table<T = number[]> {
name: string
@@ -169,19 +172,25 @@ export class LocalConnection implements Connection {
*
* @param name The name of the table.
* @param data Non-empty Array of Records to be inserted into the Table
* @param mode The write mode to use when creating the table.
*/
async createTable (name: string, data: Array<Record<string, unknown>>, mode?: WriteMode): Promise<Table>
async createTable (name: string, data: Array<Record<string, unknown>>, mode: WriteMode): Promise<Table>
async createTable (name: string, data: Array<Record<string, unknown>>): Promise<Table>
/**
* Creates a new Table and initialize it with new data.
*
* @param name The name of the table.
* @param data Non-empty Array of Records to be inserted into the Table
* @param mode The write mode to use when creating the table.
* @param embeddings An embedding function to use on this Table
*/
async createTable<T> (name: string, data: Array<Record<string, unknown>>, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
async createTable<T> (name: string, data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
const tbl = await tableCreate.call(this._db, name, await fromRecordsToBuffer(data, embeddings))
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
if (mode === undefined) {
mode = WriteMode.Create
}
const tbl = await tableCreate.call(this._db, name, await fromRecordsToBuffer(data, embeddings), mode.toLowerCase())
if (embeddings !== undefined) {
return new LocalTable(tbl, name, embeddings)
} else {
@@ -445,8 +454,15 @@ export class Query<T = number[]> {
}
}
/**
* Write mode for writing a table.
*/
export enum WriteMode {
/** Create a new {@link Table}. */
Create = 'create',
/** Overwrite the existing {@link Table} if presented. */
Overwrite = 'overwrite',
/** Append new data to the table. */
Append = 'append'
}

View File

@@ -1,4 +1,4 @@
// Copyright 2023 Lance Developers.
// Copyright 2023 LanceDB Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -18,7 +18,7 @@ import * as chai from 'chai'
import * as chaiAsPromised from 'chai-as-promised'
import * as lancedb from '../index'
import { type EmbeddingFunction, MetricType, Query } from '../index'
import { type EmbeddingFunction, MetricType, Query, WriteMode } from '../index'
const expect = chai.expect
const assert = chai.assert
@@ -118,6 +118,31 @@ describe('LanceDB client', function () {
assert.equal(await table.countRows(), 2)
})
it('use overwrite flag to overwrite existing table', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const data = [
{ id: 1, vector: [0.1, 0.2], price: 10 },
{ id: 2, vector: [1.1, 1.2], price: 50 }
]
const tableName = 'overwrite'
await con.createTable(tableName, data, WriteMode.Create)
const newData = [
{ id: 1, vector: [0.1, 0.2], price: 10 },
{ id: 2, vector: [1.1, 1.2], price: 50 },
{ id: 3, vector: [1.1, 1.2], price: 50 }
]
await expect(con.createTable(tableName, newData)).to.be.rejectedWith(Error, 'already exists')
const table = await con.createTable(tableName, newData, WriteMode.Overwrite)
assert.equal(table.name, tableName)
assert.equal(await table.countRows(), 3)
})
it('appends records to an existing table ', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
@@ -218,7 +243,7 @@ describe('LanceDB client', function () {
{ price: 10, name: 'foo' },
{ price: 50, name: 'bar' }
]
const table = await con.createTable('vectors', data, embeddings)
const table = await con.createTable('vectors', data, WriteMode.Create, embeddings)
const results = await table.search('foo').execute()
assert.equal(results.length, 2)
})

View File

@@ -170,7 +170,7 @@ class LanceDBConnection:
schema: pyarrow.Schema; optional
The schema of the table.
mode: str; default "create"
The mode to use when creating the table.
The mode to use when creating the table. Can be either "create" or "overwrite".
By default, if the table already exists, an exception is raised.
If you want to overwrite the table, use mode="overwrite".
@@ -249,6 +249,9 @@ class LanceDBConnection:
lat: [[45.5,40.1]]
long: [[-122.7,-74.1]]
"""
if mode.lower() not in ["create", "overwrite"]:
raise ValueError("mode must be either 'create' or 'overwrite'")
if data is not None:
tbl = LanceTable.create(self, name, data, schema, mode=mode)
else:

View File

@@ -21,7 +21,7 @@ use arrow_array::{Float32Array, RecordBatchReader};
use arrow_ipc::writer::FileWriter;
use futures::{TryFutureExt, TryStreamExt};
use lance::arrow::RecordBatchBuffer;
use lance::dataset::WriteMode;
use lance::dataset::{WriteMode, WriteParams};
use lance::index::vector::MetricType;
use neon::prelude::*;
use neon::types::buffer::TypedArray;
@@ -234,6 +234,16 @@ fn table_create(mut cx: FunctionContext) -> JsResult<JsPromise> {
let buffer = cx.argument::<JsBuffer>(1)?;
let batches = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx));
// Write mode
let mode = match cx.argument::<JsString>(2)?.value(&mut cx).as_str() {
"overwrite" => WriteMode::Overwrite,
"append" => WriteMode::Append,
"create" => WriteMode::Create,
_ => return cx.throw_error("Table::create only supports 'overwrite' and 'create' modes")
};
let mut params = WriteParams::default();
params.mode = mode;
let rt = runtime(&mut cx)?;
let channel = cx.channel();
@@ -242,7 +252,7 @@ fn table_create(mut cx: FunctionContext) -> JsResult<JsPromise> {
rt.block_on(async move {
let batch_reader: Box<dyn RecordBatchReader> = Box::new(RecordBatchBuffer::new(batches));
let table_rst = database.create_table(&table_name, batch_reader).await;
let table_rst = database.create_table(&table_name, batch_reader, Some(params)).await;
deferred.settle_with(&channel, move |mut cx| {
let table = Arc::new(Mutex::new(

View File

@@ -16,6 +16,7 @@ use std::fs::create_dir_all;
use std::path::Path;
use arrow_array::RecordBatchReader;
use lance::dataset::WriteParams;
use lance::io::object_store::ObjectStore;
use snafu::prelude::*;
@@ -90,12 +91,19 @@ impl Database {
Ok(f)
}
/// Create a new table in the database.
///
/// # Arguments
/// * `name` - The name of the table.
/// * `batches` - The initial data to write to the table.
/// * `params` - Optional [`WriteParams`] to create the table.
pub async fn create_table(
&self,
name: &str,
batches: Box<dyn RecordBatchReader>,
params: Option<WriteParams>,
) -> Result<Table> {
Table::create(&self.uri, name, batches).await
Table::create(&self.uri, name, batches, params).await
}
/// Open a table in the database.

View File

@@ -1,4 +1,4 @@
// Copyright 2023 Lance Developers.
// Copyright 2023 LanceDB Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -117,6 +117,7 @@ impl Table {
base_uri: &str,
name: &str,
mut batches: Box<dyn RecordBatchReader>,
params: Option<WriteParams>,
) -> Result<Self> {
let base_path = Path::new(base_uri);
let table_uri = base_path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION));
@@ -125,7 +126,7 @@ impl Table {
.to_str()
.context(InvalidTableNameSnafu { name })?
.to_string();
let dataset = Dataset::write(&mut batches, &uri, Some(WriteParams::default()))
let dataset = Dataset::write(&mut batches, &uri, params)
.await
.map_err(|e| match e {
lance::Error::DatasetAlreadyExists { .. } => Error::TableAlreadyExists {
@@ -284,10 +285,10 @@ mod tests {
let batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
let _ = batches.schema().clone();
Table::create(&uri, "test", batches).await.unwrap();
Table::create(&uri, "test", batches, None).await.unwrap();
let batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
let result = Table::create(&uri, "test", batches).await;
let result = Table::create(&uri, "test", batches, None).await;
assert!(matches!(
result.unwrap_err(),
Error::TableAlreadyExists { .. }
@@ -301,7 +302,7 @@ mod tests {
let batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
let schema = batches.schema().clone();
let mut table = Table::create(&uri, "test", batches).await.unwrap();
let mut table = Table::create(&uri, "test", batches, None).await.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
let new_batches: Box<dyn RecordBatchReader> =
@@ -323,7 +324,7 @@ mod tests {
let batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
let schema = batches.schema().clone();
let mut table = Table::create(uri, "test", batches).await.unwrap();
let mut table = Table::create(uri, "test", batches, None).await.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
let new_batches: Box<dyn RecordBatchReader> =
@@ -453,7 +454,7 @@ mod tests {
.unwrap()]);
let reader: Box<dyn RecordBatchReader + Send> = Box::new(batches);
let mut table = Table::create(uri, "test", reader).await.unwrap();
let mut table = Table::create(uri, "test", reader, None).await.unwrap();
let mut i = IvfPQIndexBuilder::new();