feat: add update to the async API (#1093)

This commit is contained in:
Weston Pace
2024-03-12 14:11:37 -07:00
committed by GitHub
parent 9bc320874a
commit d744972f2f
11 changed files with 315 additions and 75 deletions

View File

@@ -66,6 +66,23 @@ describe("Given a table", () => {
expect(table.isOpen()).toBe(false);
expect(table.countRows()).rejects.toThrow("Table some_table is closed");
});
it("should let me update values", async () => {
await table.add([{ id: 1 }]);
expect(await table.countRows("id == 1")).toBe(1);
expect(await table.countRows("id == 7")).toBe(0);
await table.update({ id: "7" });
expect(await table.countRows("id == 1")).toBe(0);
expect(await table.countRows("id == 7")).toBe(1);
await table.add([{ id: 2 }]);
// Test Map as input
await table.update(new Map(Object.entries({ id: "10" })), {
where: "id % 2 == 0",
});
expect(await table.countRows("id == 2")).toBe(0);
expect(await table.countRows("id == 7")).toBe(1);
expect(await table.countRows("id == 10")).toBe(1);
});
});
describe("When creating an index", () => {

View File

@@ -113,6 +113,7 @@ export class Table {
countRows(filter?: string | undefined | null): Promise<number>
delete(predicate: string): Promise<void>
createIndex(index: Index | undefined | null, column: string, replace?: boolean | undefined | null): Promise<void>
update(onlyIf: string | undefined | null, columns: Array<[string, string]>): Promise<void>
query(): Query
addColumns(transforms: Array<AddColumnsSql>): Promise<void>
alterColumns(alterations: Array<ColumnAlteration>): Promise<void>

View File

@@ -33,6 +33,20 @@ export interface AddDataOptions {
mode: "append" | "overwrite";
}
export interface UpdateOptions {
/**
* A filter that limits the scope of the update.
*
* This should be an SQL filter expression.
*
* Only rows that satisfy the expression will be updated.
*
* For example, this could be 'my_col == 0' to replace all instances
* of 0 in a column with some other default value.
*/
where: string;
}
/**
* A Table is a collection of Records in a LanceDB Database.
*
@@ -93,6 +107,45 @@ export class Table {
await this.inner.add(buffer, mode);
}
/**
* Update existing records in the Table
*
* An update operation can be used to adjust existing values. Use the
* returned builder to specify which columns to update. The new value
* can be a literal value (e.g. replacing nulls with some default value)
* or an expression applied to the old value (e.g. incrementing a value)
*
* An optional condition can be specified (e.g. "only update if the old
* value is 0")
*
* Note: if your condition is something like "some_id_column == 7" and
* you are updating many rows (with different ids) then you will get
* better performance with a single [`merge_insert`] call instead of
* repeatedly calilng this method.
*
* @param updates the columns to update
*
* Keys in the map should specify the name of the column to update.
* Values in the map provide the new value of the column. These can
* be SQL literal strings (e.g. "7" or "'foo'") or they can be expressions
* based on the row being updated (e.g. "my_col + 1")
*
* @param options additional options to control the update behavior
*/
async update(
updates: Map<string, string> | Record<string, string>,
options?: Partial<UpdateOptions>,
) {
const onlyIf = options?.where;
let columns: [string, string][];
if (updates instanceof Map) {
columns = Array.from(updates.entries());
} else {
columns = Object.entries(updates);
}
await this.inner.update(onlyIf, columns);
}
/** Count the total number of rows in the dataset. */
async countRows(filter?: string): Promise<number> {
return await this.inner.countRows(filter);

View File

@@ -150,6 +150,22 @@ impl Table {
builder.execute().await.default_error()
}
#[napi]
pub async fn update(
&self,
only_if: Option<String>,
columns: Vec<(String, String)>,
) -> napi::Result<()> {
let mut op = self.inner_ref()?.update();
if let Some(only_if) = only_if {
op = op.only_if(only_if);
}
for (column_name, value) in columns {
op = op.column(column_name, value);
}
op.execute().await.default_error()
}
#[napi]
pub fn query(&self) -> napi::Result<Query> {
Ok(Query::new(self.inner_ref()?.query()))

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Dict, Optional
import pyarrow as pa
@@ -30,6 +30,7 @@ class Table:
def __repr__(self) -> str: ...
async def schema(self) -> pa.Schema: ...
async def add(self, data: pa.RecordBatchReader, mode: str) -> None: ...
async def update(self, updates: Dict[str, str], where: Optional[str]) -> None: ...
async def count_rows(self, filter: Optional[str]) -> int: ...
async def create_index(
self, column: str, config: Optional[Index], replace: Optional[bool]

View File

@@ -2214,58 +2214,57 @@ class AsyncTable:
async def update(
self,
where: Optional[str] = None,
values: Optional[dict] = None,
updates: Optional[Dict[str, Any]] = None,
*,
values_sql: Optional[Dict[str, str]] = None,
where: Optional[str] = None,
updates_sql: Optional[Dict[str, str]] = None,
):
"""
This can be used to update zero to all rows depending on how many
rows match the where clause. If no where clause is provided, then
all rows will be updated.
This can be used to update zero to all rows in the table.
Either `values` or `values_sql` must be provided. You cannot provide
both.
If a filter is provided with `where` then only rows matching the
filter will be updated. Otherwise all rows will be updated.
Parameters
----------
updates: dict, optional
The updates to apply. The keys should be the name of the column to
update. The values should be the new values to assign. This is
required unless updates_sql is supplied.
where: str, optional
The SQL where clause to use when updating rows. For example, 'x = 2'
or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
values: dict, optional
The values to update. The keys are the column names and the values
are the values to set.
values_sql: dict, optional
The values to update, expressed as SQL expression strings. These can
reference existing columns. For example, {"x": "x + 1"} will increment
the x column by 1.
An SQL filter that controls which rows are updated. For example, 'x = 2'
or 'x IN (1, 2, 3)'. Only rows that satisfy this filter will be udpated.
updates_sql: dict, optional
The updates to apply, expressed as SQL expression strings. The keys should
be column names. The values should be SQL expressions. These can be SQL
literals (e.g. "7" or "'foo'") or they can be expressions based on the
previous value of the row (e.g. "x + 1" to increment the x column by 1)
Examples
--------
>>> import asyncio
>>> import lancedb
>>> import pandas as pd
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data)
>>> table.to_pandas()
x vector
0 1 [1.0, 2.0]
1 2 [3.0, 4.0]
2 3 [5.0, 6.0]
>>> table.update(where="x = 2", values={"vector": [10, 10]})
>>> table.to_pandas()
x vector
0 1 [1.0, 2.0]
1 3 [5.0, 6.0]
2 2 [10.0, 10.0]
>>> table.update(values_sql={"x": "x + 1"})
>>> table.to_pandas()
x vector
0 2 [1.0, 2.0]
1 4 [5.0, 6.0]
2 3 [10.0, 10.0]
>>> async def demo_update():
... data = pd.DataFrame({"x": [1, 2], "vector": [[1, 2], [3, 4]]})
... db = await lancedb.connect_async("./.lancedb")
... table = await db.create_table("my_table", data)
... # x is [1, 2], vector is [[1, 2], [3, 4]]
... await table.update({"vector": [10, 10]}, where="x = 2")
... # x is [1, 2], vector is [[1, 2], [10, 10]]
... await table.update(updates_sql={"x": "x + 1"})
... # x is [2, 3], vector is [[1, 2], [10, 10]]
>>> asyncio.run(demo_update())
"""
raise NotImplementedError
if updates is not None and updates_sql is not None:
raise ValueError("Only one of updates or updates_sql can be provided")
if updates is None and updates_sql is None:
raise ValueError("Either updates or updates_sql must be provided")
if updates is not None:
updates_sql = {k: value_to_sql(v) for k, v in updates.items()}
return await self._inner.update(updates_sql, where)
async def cleanup_old_versions(
self,

View File

@@ -85,6 +85,23 @@ async def test_close(db_async: AsyncConnection):
assert str(table) == "ClosedTable(some_table)"
@pytest.mark.asyncio
async def test_update_async(db_async: AsyncConnection):
table = await db_async.create_table("some_table", data=[{"id": 0}])
assert await table.count_rows("id == 0") == 1
assert await table.count_rows("id == 7") == 0
await table.update({"id": 7})
assert await table.count_rows("id == 7") == 1
assert await table.count_rows("id == 0") == 0
await table.add([{"id": 2}])
await table.update(where="id % 2 == 0", updates_sql={"id": "5"})
assert await table.count_rows("id == 7") == 1
assert await table.count_rows("id == 2") == 0
assert await table.count_rows("id == 5") == 1
await table.update({"id": 10}, where="id == 5")
assert await table.count_rows("id == 10") == 1
def test_create_table(db):
schema = pa.schema(
[

View File

@@ -5,7 +5,9 @@ use arrow::{
use lancedb::table::{AddDataMode, Table as LanceDbTable};
use pyo3::{
exceptions::{PyRuntimeError, PyValueError},
pyclass, pymethods, PyAny, PyRef, PyResult, Python,
pyclass, pymethods,
types::{PyDict, PyString},
PyAny, PyRef, PyResult, Python,
};
use pyo3_asyncio::tokio::future_into_py;
@@ -74,6 +76,28 @@ impl Table {
})
}
pub fn update<'a>(
self_: PyRef<'a, Self>,
updates: &PyDict,
r#where: Option<String>,
) -> PyResult<&'a PyAny> {
let mut op = self_.inner_ref()?.update();
if let Some(only_if) = r#where {
op = op.only_if(only_if);
}
for (column_name, value) in updates.into_iter() {
let column_name: &PyString = column_name.downcast()?;
let column_name = column_name.to_str()?.to_string();
let value: &PyString = value.downcast()?;
let value = value.to_str()?.to_string();
op = op.column(column_name, value);
}
future_into_py(self_.py(), async move {
op.execute().await.infer_error()?;
Ok(())
})
}
pub fn count_rows(self_: PyRef<'_, Self>, filter: Option<String>) -> PyResult<&PyAny> {
let inner = self_.inner_ref()?.clone();
future_into_py(self_.py(), async move {

View File

@@ -297,11 +297,14 @@ impl JsTable {
let predicate = predicate.as_deref();
let update_result = table
.as_native()
.unwrap()
.update(predicate, updates_arg)
.await;
let mut update_op = table.update();
if let Some(predicate) = predicate {
update_op = update_op.only_if(predicate);
}
for (column, value) in updates_arg {
update_op = update_op.column(column, value);
}
let update_result = update_op.execute().await;
deferred.settle_with(&channel, move |mut cx| {
update_result.or_throw(&mut cx)?;
Ok(cx.boxed(Self::from(table)))

View File

@@ -9,7 +9,7 @@ use crate::{
query::Query,
table::{
merge::MergeInsertBuilder, AddDataBuilder, NativeTable, OptimizeAction, OptimizeStats,
TableInternal,
TableInternal, UpdateBuilder,
},
};
@@ -69,6 +69,9 @@ impl TableInternal for RemoteTable {
async fn query(&self, _query: &Query) -> Result<DatasetRecordBatchStream> {
todo!()
}
async fn update(&self, _update: UpdateBuilder) -> Result<()> {
todo!()
}
async fn delete(&self, _predicate: &str) -> Result<()> {
todo!()
}

View File

@@ -30,7 +30,9 @@ use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
pub use lance::dataset::ColumnAlteration;
pub use lance::dataset::NewColumnTransform;
pub use lance::dataset::ReadParams;
use lance::dataset::{Dataset, UpdateBuilder, WhenMatched, WriteMode, WriteParams};
use lance::dataset::{
Dataset, UpdateBuilder as LanceUpdateBuilder, WhenMatched, WriteMode, WriteParams,
};
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
use lance::io::WrappingObjectStore;
use lance_index::IndexType;
@@ -115,7 +117,8 @@ pub enum AddDataMode {
Overwrite,
}
/// A builder for configuring a [`Connection::create_table`] operation
/// A builder for configuring a [`crate::connection::Connection::create_table`] or [`Table::add`]
/// operation
pub struct AddDataBuilder {
parent: Arc<dyn TableInternal>,
pub(crate) data: Box<dyn RecordBatchReader + Send>,
@@ -149,6 +152,71 @@ impl AddDataBuilder {
}
}
/// A builder for configuring an [`Table::update`] operation
#[derive(Debug, Clone)]
pub struct UpdateBuilder {
parent: Arc<dyn TableInternal>,
pub(crate) filter: Option<String>,
pub(crate) columns: Vec<(String, String)>,
}
impl UpdateBuilder {
fn new(parent: Arc<dyn TableInternal>) -> Self {
Self {
parent,
filter: None,
columns: Vec::new(),
}
}
/// Limits the update operation to rows matching the given filter
///
/// If a row does not match the filter then it will be left unchanged.
pub fn only_if(mut self, filter: impl Into<String>) -> Self {
self.filter = Some(filter.into());
self
}
/// Specifies a column to update
///
/// This method may be called multiple times to update multiple columns
///
/// The `update_expr` should be an SQL expression explaining how to calculate
/// the new value for the column. The expression will be evaluated against the
/// previous row's value.
///
/// # Examples
///
/// ```
/// # use lancedb::Table;
/// # async fn doctest_helper(tbl: Table) {
/// let mut operation = tbl.update();
/// // Increments the `bird_count` value by 1
/// operation = operation.column("bird_count", "bird_count + 1");
/// operation.execute().await.unwrap();
/// # }
/// ```
pub fn column(
mut self,
column_name: impl Into<String>,
update_expr: impl Into<String>,
) -> Self {
self.columns.push((column_name.into(), update_expr.into()));
self
}
/// Executes the update operation
pub async fn execute(self) -> Result<()> {
if self.columns.is_empty() {
Err(Error::InvalidInput {
message: "at least one column must be specified in an update operation".to_string(),
})
} else {
self.parent.clone().update(self).await
}
}
}
#[async_trait]
pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Sync {
fn as_any(&self) -> &dyn std::any::Any;
@@ -163,6 +231,7 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
async fn add(&self, add: AddDataBuilder) -> Result<()>;
async fn query(&self, query: &Query) -> Result<DatasetRecordBatchStream>;
async fn delete(&self, predicate: &str) -> Result<()>;
async fn update(&self, update: UpdateBuilder) -> Result<()>;
async fn create_index(&self, index: IndexBuilder) -> Result<()>;
async fn merge_insert(
&self,
@@ -244,6 +313,24 @@ impl Table {
}
}
/// Update existing records in the Table
///
/// An update operation can be used to adjust existing values. Use the
/// returned builder to specify which columns to update. The new value
/// can be a literal value (e.g. replacing nulls with some default value)
/// or an expression applied to the old value (e.g. incrementing a value)
///
/// An optional condition can be specified (e.g. "only update if the old
/// value is 0")
///
/// Note: if your condition is something like "some_id_column == 7" and
/// you are updating many rows (with different ids) then you will get
/// better performance with a single [`merge_insert`] call instead of
/// repeatedly calilng this method.
pub fn update(&self) -> UpdateBuilder {
UpdateBuilder::new(self.inner.clone())
}
/// Delete the rows from table that match the predicate.
///
/// # Arguments
@@ -818,23 +905,6 @@ impl NativeTable {
Ok(())
}
pub async fn update(&self, predicate: Option<&str>, updates: Vec<(&str, &str)>) -> Result<()> {
let dataset = self.dataset.get().await?.clone();
let mut builder = UpdateBuilder::new(Arc::new(dataset));
if let Some(predicate) = predicate {
builder = builder.update_where(predicate)?;
}
for (column, value) in updates {
builder = builder.set(column, value)?;
}
let operation = builder.build()?;
let ds = operation.execute().await?;
self.dataset.set_latest(ds.as_ref().clone()).await;
Ok(())
}
/// Remove old versions of the dataset from disk.
///
/// # Arguments
@@ -1138,6 +1208,23 @@ impl TableInternal for NativeTable {
}
}
async fn update(&self, update: UpdateBuilder) -> Result<()> {
let dataset = self.dataset.get().await?.clone();
let mut builder = LanceUpdateBuilder::new(Arc::new(dataset));
if let Some(predicate) = update.filter {
builder = builder.update_where(&predicate)?;
}
for (column, value) in update.columns {
builder = builder.set(column, &value)?;
}
let operation = builder.build()?;
let ds = operation.execute().await?;
self.dataset.set_latest(ds.as_ref().clone()).await;
Ok(())
}
async fn query(&self, query: &Query) -> Result<DatasetRecordBatchStream> {
let ds_ref = self.dataset.get().await?;
let mut scanner: Scanner = ds_ref.scan();
@@ -1566,9 +1653,10 @@ mod tests {
.unwrap();
table
.as_native()
.unwrap()
.update(Some("id > 5"), vec![("name", "'foo'")])
.update()
.only_if("id > 5")
.column("name", "'foo'")
.execute()
.await
.unwrap();
@@ -1718,13 +1806,11 @@ mod tests {
("vec_f64", "[1.0, 1.0]"),
];
// for (column, value) in test_cases {
table
.as_native()
.unwrap()
.update(None, updates)
.await
.unwrap();
let mut update_op = table.update();
for (column, value) in updates {
update_op = update_op.column(column, value);
}
update_op.execute().await.unwrap();
let mut batches = table
.query()
@@ -1808,6 +1894,26 @@ mod tests {
}
}
#[tokio::test]
async fn test_update_via_expr() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path().join("test.lance");
let uri = dataset_path.to_str().unwrap();
let conn = connect(uri)
.read_consistency_interval(Duration::from_secs(0))
.execute()
.await
.unwrap();
let tbl = conn
.create_table("my_table", Box::new(make_test_batches()))
.execute()
.await
.unwrap();
assert_eq!(1, tbl.count_rows(Some("i == 0".to_string())).await.unwrap());
tbl.update().column("i", "i+1").execute().await.unwrap();
assert_eq!(0, tbl.count_rows(Some("i == 0".to_string())).await.unwrap());
}
#[derive(Default, Debug)]
struct NoOpCacheWrapper {
called: AtomicBool,