diff --git a/node/src/remote/index.ts b/node/src/remote/index.ts index b1eee6da..fa31c65d 100644 --- a/node/src/remote/index.ts +++ b/node/src/remote/index.ts @@ -25,6 +25,7 @@ import { Vector, Table as ArrowTable } from 'apache-arrow' import { HttpLancedbClient } from './client' import { isEmbeddingFunction } from '../embedding/embedding_function' import { createEmptyTable, fromRecordsToStreamBuffer, fromTableToStreamBuffer } from '../arrow' +import { toSQL } from '../util' /** * Remote connection. @@ -248,7 +249,23 @@ export class RemoteTable implements Table { } async update (args: UpdateArgs | UpdateSqlArgs): Promise { - throw new Error('Not implemented') + let filter: string | null + let updates: Record + + if ('valuesSql' in args) { + filter = args.where ?? null + updates = args.valuesSql + } else { + filter = args.where ?? null + updates = {} + for (const [key, value] of Object.entries(args.values)) { + updates[key] = toSQL(value) + } + } + await this._client.post(`/v1/table/${this._name}/update/`, { + predicate: filter, + updates: Object.entries(updates).map(([key, value]) => [key, value]) + }) } async listIndices (): Promise { diff --git a/python/lancedb/remote/table.py b/python/lancedb/remote/table.py index 2e5469bc..fbe24460 100644 --- a/python/lancedb/remote/table.py +++ b/python/lancedb/remote/table.py @@ -13,7 +13,7 @@ import uuid from functools import cached_property -from typing import Optional, Union +from typing import Dict, Optional, Union import pyarrow as pa from lance import json_to_schema @@ -22,6 +22,7 @@ from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME from ..query import LanceVectorQueryBuilder from ..table import Query, Table, _sanitize_data +from ..util import value_to_sql from .arrow import to_ipc_binary from .client import ARROW_STREAM_CONTENT_TYPE from .db import RemoteDBConnection @@ -273,3 +274,65 @@ class RemoteTable(Table): self._conn._loop.run_until_complete( self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload) ) + + def update( + self, + where: Optional[str] = None, + values: Optional[dict] = None, + *, + values_sql: Optional[Dict[str, str]] = None, + ): + """ + This can be used to update zero to all rows depending on how many + rows match the where clause. + + Parameters + ---------- + where: str, optional + The SQL where clause to use when updating rows. For example, 'x = 2' + or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error. + values: dict, optional + The values to update. The keys are the column names and the values + are the values to set. + values_sql: dict, optional + The values to update, expressed as SQL expression strings. These can + reference existing columns. For example, {"x": "x + 1"} will increment + the x column by 1. + + Examples + -------- + >>> import lancedb + >>> data = [ + ... {"x": 1, "vector": [1, 2]}, + ... {"x": 2, "vector": [3, 4]}, + ... {"x": 3, "vector": [5, 6]} + ... ] + >>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP + >>> table = db.create_table("my_table", data) # doctest: +SKIP + >>> table.to_pandas() # doctest: +SKIP + x vector # doctest: +SKIP + 0 1 [1.0, 2.0] # doctest: +SKIP + 1 2 [3.0, 4.0] # doctest: +SKIP + 2 3 [5.0, 6.0] # doctest: +SKIP + >>> table.update(where="x = 2", values={"vector": [10, 10]}) # doctest: +SKIP + >>> table.to_pandas() # doctest: +SKIP + x vector # doctest: +SKIP + 0 1 [1.0, 2.0] # doctest: +SKIP + 1 3 [5.0, 6.0] # doctest: +SKIP + 2 2 [10.0, 10.0] # doctest: +SKIP + + """ + if values is not None and values_sql is not None: + raise ValueError("Only one of values or values_sql can be provided") + if values is None and values_sql is None: + raise ValueError("Either values or values_sql must be provided") + + if values is not None: + updates = [[k, value_to_sql(v)] for k, v in values.items()] + else: + updates = [[k, v] for k, v in values_sql.items()] + + payload = {"predicate": where, "updates": updates} + self._conn._loop.run_until_complete( + self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload) + )