diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index e9e465ec..17417c37 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -66,6 +66,23 @@ describe("Given a table", () => { expect(table.isOpen()).toBe(false); expect(table.countRows()).rejects.toThrow("Table some_table is closed"); }); + + it("should let me update values", async () => { + await table.add([{ id: 1 }]); + expect(await table.countRows("id == 1")).toBe(1); + expect(await table.countRows("id == 7")).toBe(0); + await table.update({ id: "7" }); + expect(await table.countRows("id == 1")).toBe(0); + expect(await table.countRows("id == 7")).toBe(1); + await table.add([{ id: 2 }]); + // Test Map as input + await table.update(new Map(Object.entries({ id: "10" })), { + where: "id % 2 == 0", + }); + expect(await table.countRows("id == 2")).toBe(0); + expect(await table.countRows("id == 7")).toBe(1); + expect(await table.countRows("id == 10")).toBe(1); + }); }); describe("When creating an index", () => { diff --git a/nodejs/lancedb/native.d.ts b/nodejs/lancedb/native.d.ts index f208b66d..86eb1b09 100644 --- a/nodejs/lancedb/native.d.ts +++ b/nodejs/lancedb/native.d.ts @@ -113,6 +113,7 @@ export class Table { countRows(filter?: string | undefined | null): Promise delete(predicate: string): Promise createIndex(index: Index | undefined | null, column: string, replace?: boolean | undefined | null): Promise + update(onlyIf: string | undefined | null, columns: Array<[string, string]>): Promise query(): Query addColumns(transforms: Array): Promise alterColumns(alterations: Array): Promise diff --git a/nodejs/lancedb/table.ts b/nodejs/lancedb/table.ts index 85bf22e8..ef6d5aaa 100644 --- a/nodejs/lancedb/table.ts +++ b/nodejs/lancedb/table.ts @@ -33,6 +33,20 @@ export interface AddDataOptions { mode: "append" | "overwrite"; } +export interface UpdateOptions { + /** + * A filter that limits the scope of the update. + * + * This should be an SQL filter expression. + * + * Only rows that satisfy the expression will be updated. + * + * For example, this could be 'my_col == 0' to replace all instances + * of 0 in a column with some other default value. + */ + where: string; +} + /** * A Table is a collection of Records in a LanceDB Database. * @@ -93,6 +107,45 @@ export class Table { await this.inner.add(buffer, mode); } + /** + * Update existing records in the Table + * + * An update operation can be used to adjust existing values. Use the + * returned builder to specify which columns to update. The new value + * can be a literal value (e.g. replacing nulls with some default value) + * or an expression applied to the old value (e.g. incrementing a value) + * + * An optional condition can be specified (e.g. "only update if the old + * value is 0") + * + * Note: if your condition is something like "some_id_column == 7" and + * you are updating many rows (with different ids) then you will get + * better performance with a single [`merge_insert`] call instead of + * repeatedly calilng this method. + * + * @param updates the columns to update + * + * Keys in the map should specify the name of the column to update. + * Values in the map provide the new value of the column. These can + * be SQL literal strings (e.g. "7" or "'foo'") or they can be expressions + * based on the row being updated (e.g. "my_col + 1") + * + * @param options additional options to control the update behavior + */ + async update( + updates: Map | Record, + options?: Partial, + ) { + const onlyIf = options?.where; + let columns: [string, string][]; + if (updates instanceof Map) { + columns = Array.from(updates.entries()); + } else { + columns = Object.entries(updates); + } + await this.inner.update(onlyIf, columns); + } + /** Count the total number of rows in the dataset. */ async countRows(filter?: string): Promise { return await this.inner.countRows(filter); diff --git a/nodejs/src/table.rs b/nodejs/src/table.rs index 9bfbb4c8..0d2e2102 100644 --- a/nodejs/src/table.rs +++ b/nodejs/src/table.rs @@ -150,6 +150,22 @@ impl Table { builder.execute().await.default_error() } + #[napi] + pub async fn update( + &self, + only_if: Option, + columns: Vec<(String, String)>, + ) -> napi::Result<()> { + let mut op = self.inner_ref()?.update(); + if let Some(only_if) = only_if { + op = op.only_if(only_if); + } + for (column_name, value) in columns { + op = op.column(column_name, value); + } + op.execute().await.default_error() + } + #[napi] pub fn query(&self) -> napi::Result { Ok(Query::new(self.inner_ref()?.query())) diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index 6605c934..6613aa6e 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Dict, Optional import pyarrow as pa @@ -30,6 +30,7 @@ class Table: def __repr__(self) -> str: ... async def schema(self) -> pa.Schema: ... async def add(self, data: pa.RecordBatchReader, mode: str) -> None: ... + async def update(self, updates: Dict[str, str], where: Optional[str]) -> None: ... async def count_rows(self, filter: Optional[str]) -> int: ... async def create_index( self, column: str, config: Optional[Index], replace: Optional[bool] diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 3454ebe5..93bc27d5 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -2214,58 +2214,57 @@ class AsyncTable: async def update( self, - where: Optional[str] = None, - values: Optional[dict] = None, + updates: Optional[Dict[str, Any]] = None, *, - values_sql: Optional[Dict[str, str]] = None, + where: Optional[str] = None, + updates_sql: Optional[Dict[str, str]] = None, ): """ - This can be used to update zero to all rows depending on how many - rows match the where clause. If no where clause is provided, then - all rows will be updated. + This can be used to update zero to all rows in the table. - Either `values` or `values_sql` must be provided. You cannot provide - both. + If a filter is provided with `where` then only rows matching the + filter will be updated. Otherwise all rows will be updated. Parameters ---------- + updates: dict, optional + The updates to apply. The keys should be the name of the column to + update. The values should be the new values to assign. This is + required unless updates_sql is supplied. where: str, optional - The SQL where clause to use when updating rows. For example, 'x = 2' - or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error. - values: dict, optional - The values to update. The keys are the column names and the values - are the values to set. - values_sql: dict, optional - The values to update, expressed as SQL expression strings. These can - reference existing columns. For example, {"x": "x + 1"} will increment - the x column by 1. + An SQL filter that controls which rows are updated. For example, 'x = 2' + or 'x IN (1, 2, 3)'. Only rows that satisfy this filter will be udpated. + updates_sql: dict, optional + The updates to apply, expressed as SQL expression strings. The keys should + be column names. The values should be SQL expressions. These can be SQL + literals (e.g. "7" or "'foo'") or they can be expressions based on the + previous value of the row (e.g. "x + 1" to increment the x column by 1) Examples -------- + >>> import asyncio >>> import lancedb >>> import pandas as pd - >>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]}) - >>> db = lancedb.connect("./.lancedb") - >>> table = db.create_table("my_table", data) - >>> table.to_pandas() - x vector - 0 1 [1.0, 2.0] - 1 2 [3.0, 4.0] - 2 3 [5.0, 6.0] - >>> table.update(where="x = 2", values={"vector": [10, 10]}) - >>> table.to_pandas() - x vector - 0 1 [1.0, 2.0] - 1 3 [5.0, 6.0] - 2 2 [10.0, 10.0] - >>> table.update(values_sql={"x": "x + 1"}) - >>> table.to_pandas() - x vector - 0 2 [1.0, 2.0] - 1 4 [5.0, 6.0] - 2 3 [10.0, 10.0] + >>> async def demo_update(): + ... data = pd.DataFrame({"x": [1, 2], "vector": [[1, 2], [3, 4]]}) + ... db = await lancedb.connect_async("./.lancedb") + ... table = await db.create_table("my_table", data) + ... # x is [1, 2], vector is [[1, 2], [3, 4]] + ... await table.update({"vector": [10, 10]}, where="x = 2") + ... # x is [1, 2], vector is [[1, 2], [10, 10]] + ... await table.update(updates_sql={"x": "x + 1"}) + ... # x is [2, 3], vector is [[1, 2], [10, 10]] + >>> asyncio.run(demo_update()) """ - raise NotImplementedError + if updates is not None and updates_sql is not None: + raise ValueError("Only one of updates or updates_sql can be provided") + if updates is None and updates_sql is None: + raise ValueError("Either updates or updates_sql must be provided") + + if updates is not None: + updates_sql = {k: value_to_sql(v) for k, v in updates.items()} + + return await self._inner.update(updates_sql, where) async def cleanup_old_versions( self, diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index d04261c2..518b19e1 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -85,6 +85,23 @@ async def test_close(db_async: AsyncConnection): assert str(table) == "ClosedTable(some_table)" +@pytest.mark.asyncio +async def test_update_async(db_async: AsyncConnection): + table = await db_async.create_table("some_table", data=[{"id": 0}]) + assert await table.count_rows("id == 0") == 1 + assert await table.count_rows("id == 7") == 0 + await table.update({"id": 7}) + assert await table.count_rows("id == 7") == 1 + assert await table.count_rows("id == 0") == 0 + await table.add([{"id": 2}]) + await table.update(where="id % 2 == 0", updates_sql={"id": "5"}) + assert await table.count_rows("id == 7") == 1 + assert await table.count_rows("id == 2") == 0 + assert await table.count_rows("id == 5") == 1 + await table.update({"id": 10}, where="id == 5") + assert await table.count_rows("id == 10") == 1 + + def test_create_table(db): schema = pa.schema( [ diff --git a/python/src/table.rs b/python/src/table.rs index 5231204f..11fab442 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -5,7 +5,9 @@ use arrow::{ use lancedb::table::{AddDataMode, Table as LanceDbTable}; use pyo3::{ exceptions::{PyRuntimeError, PyValueError}, - pyclass, pymethods, PyAny, PyRef, PyResult, Python, + pyclass, pymethods, + types::{PyDict, PyString}, + PyAny, PyRef, PyResult, Python, }; use pyo3_asyncio::tokio::future_into_py; @@ -74,6 +76,28 @@ impl Table { }) } + pub fn update<'a>( + self_: PyRef<'a, Self>, + updates: &PyDict, + r#where: Option, + ) -> PyResult<&'a PyAny> { + let mut op = self_.inner_ref()?.update(); + if let Some(only_if) = r#where { + op = op.only_if(only_if); + } + for (column_name, value) in updates.into_iter() { + let column_name: &PyString = column_name.downcast()?; + let column_name = column_name.to_str()?.to_string(); + let value: &PyString = value.downcast()?; + let value = value.to_str()?.to_string(); + op = op.column(column_name, value); + } + future_into_py(self_.py(), async move { + op.execute().await.infer_error()?; + Ok(()) + }) + } + pub fn count_rows(self_: PyRef<'_, Self>, filter: Option) -> PyResult<&PyAny> { let inner = self_.inner_ref()?.clone(); future_into_py(self_.py(), async move { diff --git a/rust/ffi/node/src/table.rs b/rust/ffi/node/src/table.rs index 9d0013e7..4e35825b 100644 --- a/rust/ffi/node/src/table.rs +++ b/rust/ffi/node/src/table.rs @@ -297,11 +297,14 @@ impl JsTable { let predicate = predicate.as_deref(); - let update_result = table - .as_native() - .unwrap() - .update(predicate, updates_arg) - .await; + let mut update_op = table.update(); + if let Some(predicate) = predicate { + update_op = update_op.only_if(predicate); + } + for (column, value) in updates_arg { + update_op = update_op.column(column, value); + } + let update_result = update_op.execute().await; deferred.settle_with(&channel, move |mut cx| { update_result.or_throw(&mut cx)?; Ok(cx.boxed(Self::from(table))) diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 32779fb5..68d67274 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -9,7 +9,7 @@ use crate::{ query::Query, table::{ merge::MergeInsertBuilder, AddDataBuilder, NativeTable, OptimizeAction, OptimizeStats, - TableInternal, + TableInternal, UpdateBuilder, }, }; @@ -69,6 +69,9 @@ impl TableInternal for RemoteTable { async fn query(&self, _query: &Query) -> Result { todo!() } + async fn update(&self, _update: UpdateBuilder) -> Result<()> { + todo!() + } async fn delete(&self, _predicate: &str) -> Result<()> { todo!() } diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index a51144dc..20c2ed2a 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -30,7 +30,9 @@ use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner}; pub use lance::dataset::ColumnAlteration; pub use lance::dataset::NewColumnTransform; pub use lance::dataset::ReadParams; -use lance::dataset::{Dataset, UpdateBuilder, WhenMatched, WriteMode, WriteParams}; +use lance::dataset::{ + Dataset, UpdateBuilder as LanceUpdateBuilder, WhenMatched, WriteMode, WriteParams, +}; use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource}; use lance::io::WrappingObjectStore; use lance_index::IndexType; @@ -115,7 +117,8 @@ pub enum AddDataMode { Overwrite, } -/// A builder for configuring a [`Connection::create_table`] operation +/// A builder for configuring a [`crate::connection::Connection::create_table`] or [`Table::add`] +/// operation pub struct AddDataBuilder { parent: Arc, pub(crate) data: Box, @@ -149,6 +152,71 @@ impl AddDataBuilder { } } +/// A builder for configuring an [`Table::update`] operation +#[derive(Debug, Clone)] +pub struct UpdateBuilder { + parent: Arc, + pub(crate) filter: Option, + pub(crate) columns: Vec<(String, String)>, +} + +impl UpdateBuilder { + fn new(parent: Arc) -> Self { + Self { + parent, + filter: None, + columns: Vec::new(), + } + } + + /// Limits the update operation to rows matching the given filter + /// + /// If a row does not match the filter then it will be left unchanged. + pub fn only_if(mut self, filter: impl Into) -> Self { + self.filter = Some(filter.into()); + self + } + + /// Specifies a column to update + /// + /// This method may be called multiple times to update multiple columns + /// + /// The `update_expr` should be an SQL expression explaining how to calculate + /// the new value for the column. The expression will be evaluated against the + /// previous row's value. + /// + /// # Examples + /// + /// ``` + /// # use lancedb::Table; + /// # async fn doctest_helper(tbl: Table) { + /// let mut operation = tbl.update(); + /// // Increments the `bird_count` value by 1 + /// operation = operation.column("bird_count", "bird_count + 1"); + /// operation.execute().await.unwrap(); + /// # } + /// ``` + pub fn column( + mut self, + column_name: impl Into, + update_expr: impl Into, + ) -> Self { + self.columns.push((column_name.into(), update_expr.into())); + self + } + + /// Executes the update operation + pub async fn execute(self) -> Result<()> { + if self.columns.is_empty() { + Err(Error::InvalidInput { + message: "at least one column must be specified in an update operation".to_string(), + }) + } else { + self.parent.clone().update(self).await + } + } +} + #[async_trait] pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Sync { fn as_any(&self) -> &dyn std::any::Any; @@ -163,6 +231,7 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn async fn add(&self, add: AddDataBuilder) -> Result<()>; async fn query(&self, query: &Query) -> Result; async fn delete(&self, predicate: &str) -> Result<()>; + async fn update(&self, update: UpdateBuilder) -> Result<()>; async fn create_index(&self, index: IndexBuilder) -> Result<()>; async fn merge_insert( &self, @@ -244,6 +313,24 @@ impl Table { } } + /// Update existing records in the Table + /// + /// An update operation can be used to adjust existing values. Use the + /// returned builder to specify which columns to update. The new value + /// can be a literal value (e.g. replacing nulls with some default value) + /// or an expression applied to the old value (e.g. incrementing a value) + /// + /// An optional condition can be specified (e.g. "only update if the old + /// value is 0") + /// + /// Note: if your condition is something like "some_id_column == 7" and + /// you are updating many rows (with different ids) then you will get + /// better performance with a single [`merge_insert`] call instead of + /// repeatedly calilng this method. + pub fn update(&self) -> UpdateBuilder { + UpdateBuilder::new(self.inner.clone()) + } + /// Delete the rows from table that match the predicate. /// /// # Arguments @@ -818,23 +905,6 @@ impl NativeTable { Ok(()) } - pub async fn update(&self, predicate: Option<&str>, updates: Vec<(&str, &str)>) -> Result<()> { - let dataset = self.dataset.get().await?.clone(); - let mut builder = UpdateBuilder::new(Arc::new(dataset)); - if let Some(predicate) = predicate { - builder = builder.update_where(predicate)?; - } - - for (column, value) in updates { - builder = builder.set(column, value)?; - } - - let operation = builder.build()?; - let ds = operation.execute().await?; - self.dataset.set_latest(ds.as_ref().clone()).await; - Ok(()) - } - /// Remove old versions of the dataset from disk. /// /// # Arguments @@ -1138,6 +1208,23 @@ impl TableInternal for NativeTable { } } + async fn update(&self, update: UpdateBuilder) -> Result<()> { + let dataset = self.dataset.get().await?.clone(); + let mut builder = LanceUpdateBuilder::new(Arc::new(dataset)); + if let Some(predicate) = update.filter { + builder = builder.update_where(&predicate)?; + } + + for (column, value) in update.columns { + builder = builder.set(column, &value)?; + } + + let operation = builder.build()?; + let ds = operation.execute().await?; + self.dataset.set_latest(ds.as_ref().clone()).await; + Ok(()) + } + async fn query(&self, query: &Query) -> Result { let ds_ref = self.dataset.get().await?; let mut scanner: Scanner = ds_ref.scan(); @@ -1566,9 +1653,10 @@ mod tests { .unwrap(); table - .as_native() - .unwrap() - .update(Some("id > 5"), vec![("name", "'foo'")]) + .update() + .only_if("id > 5") + .column("name", "'foo'") + .execute() .await .unwrap(); @@ -1718,13 +1806,11 @@ mod tests { ("vec_f64", "[1.0, 1.0]"), ]; - // for (column, value) in test_cases { - table - .as_native() - .unwrap() - .update(None, updates) - .await - .unwrap(); + let mut update_op = table.update(); + for (column, value) in updates { + update_op = update_op.column(column, value); + } + update_op.execute().await.unwrap(); let mut batches = table .query() @@ -1808,6 +1894,26 @@ mod tests { } } + #[tokio::test] + async fn test_update_via_expr() { + let tmp_dir = tempdir().unwrap(); + let dataset_path = tmp_dir.path().join("test.lance"); + let uri = dataset_path.to_str().unwrap(); + let conn = connect(uri) + .read_consistency_interval(Duration::from_secs(0)) + .execute() + .await + .unwrap(); + let tbl = conn + .create_table("my_table", Box::new(make_test_batches())) + .execute() + .await + .unwrap(); + assert_eq!(1, tbl.count_rows(Some("i == 0".to_string())).await.unwrap()); + tbl.update().column("i", "i+1").execute().await.unwrap(); + assert_eq!(0, tbl.count_rows(Some("i == 0".to_string())).await.unwrap()); + } + #[derive(Default, Debug)] struct NoOpCacheWrapper { called: AtomicBool,