Update in Node & Rust (#696)

Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
Bert
2023-12-13 14:53:06 -05:00
committed by Weston Pace
parent 3413e79b0f
commit e479acc1bd
10 changed files with 589 additions and 12 deletions

View File

@@ -5,10 +5,10 @@ exclude = ["python"]
resolver = "2"
[workspace.dependencies]
lance = { "version" = "=0.8.17", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.8.17" }
lance-linalg = { "version" = "=0.8.17" }
lance-testing = { "version" = "=0.8.17" }
lance = { "version" = "=0.8.20", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.8.20" }
lance-linalg = { "version" = "=0.8.20" }
lance-testing = { "version" = "=0.8.20" }
# Note that this one does not include pyarrow
arrow = { version = "47.0.0", optional = false }
arrow-array = "47.0"

View File

@@ -21,9 +21,10 @@ import type { EmbeddingFunction } from './embedding/embedding_function'
import { RemoteConnection } from './remote'
import { Query } from './query'
import { isEmbeddingFunction } from './embedding/embedding_function'
import { type Literal, toSQL } from './util'
// eslint-disable-next-line @typescript-eslint/no-var-requires
const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, tableCreate, tableAdd, tableCreateVectorIndex, tableCountRows, tableDelete, tableCleanupOldVersions, tableCompactFiles, tableListIndices, tableIndexStats } = require('../native.js')
const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, tableCreate, tableAdd, tableCreateVectorIndex, tableCountRows, tableDelete, tableUpdate, tableCleanupOldVersions, tableCompactFiles, tableListIndices, tableIndexStats } = require('../native.js')
export { Query }
export type { EmbeddingFunction }
@@ -261,6 +262,39 @@ export interface Table<T = number[]> {
*/
delete: (filter: string) => Promise<void>
/**
* Update rows in this table.
*
* This can be used to update a single row, many rows, all rows, or
* sometimes no rows (if your predicate matches nothing).
*
* @param args see {@link UpdateArgs} and {@link UpdateSqlArgs} for more details
*
* @examples
*
* ```ts
* const con = await lancedb.connect("./.lancedb")
* const data = [
* {id: 1, vector: [3, 3], name: 'Ye'},
* {id: 2, vector: [4, 4], name: 'Mike'},
* ];
* const tbl = await con.createTable("my_table", data)
*
* await tbl.update({
* filter: "id = 2",
* updates: { vector: [2, 2], name: "Michael" },
* })
*
* let results = await tbl.search([1, 1]).execute();
* // Returns [
* // {id: 2, vector: [2, 2], name: 'Michael'}
* // {id: 1, vector: [3, 3], name: 'Ye'}
* // ]
* ```
*
*/
update: (args: UpdateArgs | UpdateSqlArgs) => Promise<void>
/**
* List the indicies on this table.
*/
@@ -272,6 +306,34 @@ export interface Table<T = number[]> {
indexStats: (indexUuid: string) => Promise<IndexStats>
}
export interface UpdateArgs {
/**
* A filter in the same format used by a sql WHERE clause. The filter may be empty,
* in which case all rows will be updated.
*/
where?: string
/**
* A key-value map of updates. The keys are the column names, and the values are the
* new values to set
*/
values: Record<string, Literal>
}
export interface UpdateSqlArgs {
/**
* A filter in the same format used by a sql WHERE clause. The filter may be empty,
* in which case all rows will be updated.
*/
where?: string
/**
* A key-value map of updates. The keys are the column names, and the values are the
* new values to set as SQL expressions.
*/
valuesSql: Record<string, string>
}
export interface VectorIndex {
columns: string[]
name: string
@@ -481,6 +543,31 @@ export class LocalTable<T = number[]> implements Table<T> {
return tableDelete.call(this._tbl, filter).then((newTable: any) => { this._tbl = newTable })
}
/**
* Update rows in this table.
*
* @param args see {@link UpdateArgs} and {@link UpdateSqlArgs} for more details
*
* @returns
*/
async update (args: UpdateArgs | UpdateSqlArgs): Promise<void> {
let filter: string | null
let updates: Record<string, string>
if ('valuesSql' in args) {
filter = args.where ?? null
updates = args.valuesSql
} else {
filter = args.where ?? null
updates = {}
for (const [key, value] of Object.entries(args.values)) {
updates[key] = toSQL(value)
}
}
return tableUpdate.call(this._tbl, filter, updates).then((newTable: any) => { this._tbl = newTable })
}
/**
* Clean up old versions of the table, freeing disk space.
*

View File

@@ -16,7 +16,8 @@ import {
type EmbeddingFunction, type Table, type VectorIndexParams, type Connection,
type ConnectionOptions, type CreateTableOptions, type VectorIndex,
type WriteOptions,
type IndexStats
type IndexStats,
type UpdateArgs, type UpdateSqlArgs
} from '../index'
import { Query } from '../query'
@@ -246,6 +247,10 @@ export class RemoteTable<T = number[]> implements Table<T> {
await this._client.post(`/v1/table/${this._name}/delete/`, { predicate: filter })
}
async update (args: UpdateArgs | UpdateSqlArgs): Promise<void> {
throw new Error('Not implemented')
}
async listIndices (): Promise<VectorIndex[]> {
const results = await this._client.post(`/v1/table/${this._name}/index/list/`)
return results.data.indexes?.map((index: any) => ({

View File

@@ -260,6 +260,46 @@ describe('LanceDB client', function () {
assert.equal(await table.countRows(), 2)
})
it('can update records in the table', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
assert.equal(await table.countRows(), 2)
await table.update({ where: 'price = 10', valuesSql: { price: '100' } })
const results = await table.search([0.1, 0.2]).execute()
assert.equal(results[0].price, 100)
assert.equal(results[1].price, 11)
})
it('can update the records using a literal value', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
assert.equal(await table.countRows(), 2)
await table.update({ where: 'price = 10', values: { price: 100 } })
const results = await table.search([0.1, 0.2]).execute()
assert.equal(results[0].price, 100)
assert.equal(results[1].price, 11)
})
it('can update every record in the table', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
assert.equal(await table.countRows(), 2)
await table.update({ valuesSql: { price: '100' } })
const results = await table.search([0.1, 0.2]).execute()
assert.equal(results[0].price, 100)
assert.equal(results[1].price, 100)
})
it('can delete records from a table', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)
@@ -542,7 +582,7 @@ describe('Compact and cleanup', function () {
// should have no effect, but this validates the arguments are parsed.
await table.compactFiles({
targetRowsPerFragment: 1024 * 10,
targetRowsPerFragment: 102410,
maxRowsPerGroup: 1024,
materializeDeletions: true,
materializeDeletionsThreshold: 0.5,

45
node/src/test/util.ts Normal file
View File

@@ -0,0 +1,45 @@
// Copyright 2023 LanceDB Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { toSQL } from '../util'
import * as chai from 'chai'
const expect = chai.expect
describe('toSQL', function () {
it('should turn string to SQL expression', function () {
expect(toSQL('foo')).to.equal("'foo'")
})
it('should turn number to SQL expression', function () {
expect(toSQL(123)).to.equal('123')
})
it('should turn boolean to SQL expression', function () {
expect(toSQL(true)).to.equal('TRUE')
})
it('should turn null to SQL expression', function () {
expect(toSQL(null)).to.equal('NULL')
})
it('should turn Date to SQL expression', function () {
const date = new Date('05 October 2011 14:48 UTC')
expect(toSQL(date)).to.equal("'2011-10-05T14:48:00.000Z'")
})
it('should turn array to SQL expression', function () {
expect(toSQL(['foo', 'bar', true, 1])).to.equal("['foo', 'bar', TRUE, 1]")
})
})

44
node/src/util.ts Normal file
View File

@@ -0,0 +1,44 @@
// Copyright 2023 LanceDB Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
export type Literal = string | number | boolean | null | Date | Literal[]
export function toSQL (value: Literal): string {
if (typeof value === 'string') {
return `'${value}'`
}
if (typeof value === 'number') {
return value.toString()
}
if (typeof value === 'boolean') {
return value ? 'TRUE' : 'FALSE'
}
if (value === null) {
return 'NULL'
}
if (value instanceof Date) {
return `'${value.toISOString()}'`
}
if (Array.isArray(value)) {
return `[${value.map(toSQL).join(', ')}]`
}
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
throw new Error(`Unsupported value type: ${typeof value} value: (${value})`)
}

View File

@@ -237,6 +237,7 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> {
cx.export_function("tableAdd", JsTable::js_add)?;
cx.export_function("tableCountRows", JsTable::js_count_rows)?;
cx.export_function("tableDelete", JsTable::js_delete)?;
cx.export_function("tableUpdate", JsTable::js_update)?;
cx.export_function("tableCleanupOldVersions", JsTable::js_cleanup)?;
cx.export_function("tableCompactFiles", JsTable::js_compact)?;
cx.export_function("tableListIndices", JsTable::js_list_indices)?;

View File

@@ -48,7 +48,9 @@ impl JsQuery {
.map(|s| s.value(&mut cx))
.map(|s| MetricType::try_from(s.as_str()).unwrap());
let prefilter = query_obj.get::<JsBoolean, _, _>(&mut cx, "_prefilter")?.value(&mut cx);
let prefilter = query_obj
.get::<JsBoolean, _, _>(&mut cx, "_prefilter")?
.value(&mut cx);
let is_electron = cx
.argument::<JsBoolean>(1)

View File

@@ -165,6 +165,69 @@ impl JsTable {
Ok(promise)
}
pub(crate) fn js_update(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let mut table = js_table.table.clone();
let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise();
let channel = cx.channel();
// create a vector of updates from the passed map
let updates_arg = cx.argument::<JsObject>(1)?;
let properties = updates_arg.get_own_property_names(&mut cx)?;
let mut updates: Vec<(String, String)> =
Vec::with_capacity(properties.len(&mut cx) as usize);
let len_properties = properties.len(&mut cx);
for i in 0..len_properties {
let property = properties
.get_value(&mut cx, i)?
.downcast_or_throw::<JsString, _>(&mut cx)?;
let value = updates_arg
.get_value(&mut cx, property.clone())?
.downcast_or_throw::<JsString, _>(&mut cx)?;
let property = property.value(&mut cx);
let value = value.value(&mut cx);
updates.push((property, value));
}
// get the filter/predicate if the user passed one
let predicate = cx.argument_opt(0);
let predicate = predicate.unwrap().downcast::<JsString, _>(&mut cx);
let predicate = match predicate {
Ok(_) => {
let val = predicate.map(|s| s.value(&mut cx)).unwrap();
Some(val)
}
Err(_) => {
// if the predicate is not string, check it's null otherwise an invalid
// type was passed
cx.argument::<JsNull>(0)?;
None
}
};
rt.spawn(async move {
let updates_arg = updates
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect::<Vec<_>>();
let predicate = predicate.as_ref().map(|s| s.as_str());
let update_result = table.update(predicate, updates_arg).await;
deferred.settle_with(&channel, move |mut cx| {
update_result.or_throw(&mut cx)?;
Ok(cx.boxed(JsTable::from(table)))
})
});
Ok(promise)
}
pub(crate) fn js_cleanup(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let rt = runtime(&mut cx)?;

View File

@@ -23,7 +23,7 @@ use lance::dataset::cleanup::RemovalStats;
use lance::dataset::optimize::{
compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions,
};
use lance::dataset::{Dataset, WriteParams};
use lance::dataset::{Dataset, UpdateBuilder, WriteParams};
use lance::index::DatasetIndexExt;
use lance::io::object_store::WrappingObjectStore;
use std::path::Path;
@@ -338,6 +338,27 @@ impl Table {
Ok(())
}
pub async fn update(
&mut self,
predicate: Option<&str>,
updates: Vec<(&str, &str)>,
) -> Result<()> {
let mut builder = UpdateBuilder::new(self.dataset.clone());
if let Some(predicate) = predicate {
builder = builder.update_where(predicate)?;
}
for (column, value) in updates {
builder = builder.set(column, value)?;
}
let operation = builder.build()?;
let new_ds = operation.execute().await?;
self.dataset = new_ds;
Ok(())
}
/// Remove old versions of the dataset from disk.
///
/// # Arguments
@@ -413,11 +434,14 @@ mod tests {
use std::sync::Arc;
use arrow_array::{
Array, FixedSizeListArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
RecordBatchReader,
Array, BooleanArray, Date32Array, FixedSizeListArray, Float32Array, Float64Array,
Int32Array, Int64Array, LargeStringArray, RecordBatch, RecordBatchIterator,
RecordBatchReader, StringArray, TimestampMillisecondArray, TimestampNanosecondArray,
UInt32Array,
};
use arrow_data::ArrayDataBuilder;
use arrow_schema::{DataType, Field, Schema};
use arrow_schema::{DataType, Field, Schema, TimeUnit};
use futures::TryStreamExt;
use lance::dataset::{Dataset, WriteMode};
use lance::index::vector::pq::PQBuildParams;
use lance::io::object_store::{ObjectStoreParams, WrappingObjectStore};
@@ -540,6 +564,272 @@ mod tests {
assert_eq!(table.name, "test");
}
#[tokio::test]
async fn test_update_with_predicate() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path().join("test.lance");
let uri = dataset_path.to_str().unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
]));
let record_batch_iter = RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..10)),
Arc::new(StringArray::from_iter_values(vec![
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
])),
],
)
.unwrap()]
.into_iter()
.map(Ok),
schema.clone(),
);
Dataset::write(record_batch_iter, uri, None).await.unwrap();
let mut table = Table::open(uri).await.unwrap();
table
.update(Some("id > 5"), vec![("name", "'foo'")])
.await
.unwrap();
let ds_after = Dataset::open(uri).await.unwrap();
let mut batches = ds_after
.scan()
.project(&["id", "name"])
.unwrap()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
while let Some(batch) = batches.pop() {
let ids = batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.iter()
.collect::<Vec<_>>();
let names = batch
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.iter()
.collect::<Vec<_>>();
for (i, name) in names.iter().enumerate() {
let id = ids[i].unwrap();
let name = name.unwrap();
if id > 5 {
assert_eq!(name, "foo");
} else {
assert_eq!(name, &format!("{}", (b'a' + id as u8) as char));
}
}
}
}
#[tokio::test]
async fn test_update_all_types() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path().join("test.lance");
let uri = dataset_path.to_str().unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("int32", DataType::Int32, false),
Field::new("int64", DataType::Int64, false),
Field::new("uint32", DataType::UInt32, false),
Field::new("string", DataType::Utf8, false),
Field::new("large_string", DataType::LargeUtf8, false),
Field::new("float32", DataType::Float32, false),
Field::new("float64", DataType::Float64, false),
Field::new("bool", DataType::Boolean, false),
Field::new("date32", DataType::Date32, false),
Field::new(
"timestamp_ns",
DataType::Timestamp(TimeUnit::Nanosecond, None),
false,
),
Field::new(
"timestamp_ms",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
Field::new(
"vec_f32",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
false,
),
Field::new(
"vec_f64",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), 2),
false,
),
]));
let record_batch_iter = RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..10)),
Arc::new(Int64Array::from_iter_values(0..10)),
Arc::new(UInt32Array::from_iter_values(0..10)),
Arc::new(StringArray::from_iter_values(vec![
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
])),
Arc::new(LargeStringArray::from_iter_values(vec![
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
])),
Arc::new(Float32Array::from_iter_values(
(0..10).into_iter().map(|i| i as f32),
)),
Arc::new(Float64Array::from_iter_values(
(0..10).into_iter().map(|i| i as f64),
)),
Arc::new(Into::<BooleanArray>::into(vec![
true, false, true, false, true, false, true, false, true, false,
])),
Arc::new(Date32Array::from_iter_values(0..10)),
Arc::new(TimestampNanosecondArray::from_iter_values(0..10)),
Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
Arc::new(
create_fixed_size_list(
Float32Array::from_iter_values((0..20).into_iter().map(|i| i as f32)),
2,
)
.unwrap(),
),
Arc::new(
create_fixed_size_list(
Float64Array::from_iter_values((0..20).into_iter().map(|i| i as f64)),
2,
)
.unwrap(),
),
],
)
.unwrap()]
.into_iter()
.map(Ok),
schema.clone(),
);
Dataset::write(record_batch_iter, uri, None).await.unwrap();
let mut table = Table::open(uri).await.unwrap();
// check it can do update for each type
let updates: Vec<(&str, &str)> = vec![
("string", "'foo'"),
("large_string", "'large_foo'"),
("int32", "1"),
("int64", "1"),
("uint32", "1"),
("float32", "1.0"),
("float64", "1.0"),
("bool", "true"),
("date32", "1"),
("timestamp_ns", "1"),
("timestamp_ms", "1"),
("vec_f32", "[1.0, 1.0]"),
("vec_f64", "[1.0, 1.0]"),
];
// for (column, value) in test_cases {
table.update(None, updates).await.unwrap();
let ds_after = Dataset::open(uri).await.unwrap();
let mut batches = ds_after
.scan()
.project(&[
"string",
"large_string",
"int32",
"int64",
"uint32",
"float32",
"float64",
"bool",
"date32",
"timestamp_ns",
"timestamp_ms",
"vec_f32",
"vec_f64",
])
.unwrap()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = batches.pop().unwrap();
macro_rules! assert_column {
($column:expr, $array_type:ty, $expected:expr) => {
let array = $column
.as_any()
.downcast_ref::<$array_type>()
.unwrap()
.iter()
.collect::<Vec<_>>();
for v in array {
assert_eq!(v, Some($expected));
}
};
}
assert_column!(batch.column(0), StringArray, "foo");
assert_column!(batch.column(1), LargeStringArray, "large_foo");
assert_column!(batch.column(2), Int32Array, 1);
assert_column!(batch.column(3), Int64Array, 1);
assert_column!(batch.column(4), UInt32Array, 1);
assert_column!(batch.column(5), Float32Array, 1.0);
assert_column!(batch.column(6), Float64Array, 1.0);
assert_column!(batch.column(7), BooleanArray, true);
assert_column!(batch.column(8), Date32Array, 1);
assert_column!(batch.column(9), TimestampNanosecondArray, 1);
assert_column!(batch.column(10), TimestampMillisecondArray, 1);
let array = batch
.column(11)
.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap()
.iter()
.collect::<Vec<_>>();
for v in array {
let v = v.unwrap();
let f32array = v.as_any().downcast_ref::<Float32Array>().unwrap();
for v in f32array {
assert_eq!(v, Some(1.0));
}
}
let array = batch
.column(12)
.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap()
.iter()
.collect::<Vec<_>>();
for v in array {
let v = v.unwrap();
let f64array = v.as_any().downcast_ref::<Float64Array>().unwrap();
for v in f64array {
assert_eq!(v, Some(1.0));
}
}
}
#[tokio::test]
async fn test_search() {
let tmp_dir = tempdir().unwrap();