feat: add merge_insert to the node and rust APIs (#915)

This commit is contained in:
Weston Pace
2024-02-02 13:16:51 -08:00
parent 2e75b16403
commit 18f7bad3dd
11 changed files with 565 additions and 18 deletions

View File

@@ -12,7 +12,7 @@
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING, Iterable, Optional
from typing import TYPE_CHECKING, List, Optional
if TYPE_CHECKING:
from .common import DATA
@@ -25,7 +25,7 @@ class LanceMergeInsertBuilder(object):
more context
"""
def __init__(self, table: "Table", on: Iterable[str]): # noqa: F821
def __init__(self, table: "Table", on: List[str]): # noqa: F821
# Do not put a docstring here. This method should be hidden
# from API docs. Users should use merge_insert to create
# this object.
@@ -77,10 +77,27 @@ class LanceMergeInsertBuilder(object):
self._when_not_matched_by_source_condition = condition
return self
def execute(self, new_data: DATA):
def execute(
self,
new_data: DATA,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
):
"""
Executes the merge insert operation
Nothing is returned but the [`Table`][lancedb.table.Table] is updated
Parameters
----------
new_data: DATA
New records which will be matched against the existing records
to potentially insert or update into the table. This parameter
can be anything you use for [`add`][lancedb.table.Table.add]
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
self._table._do_merge(self, new_data)
self._table._do_merge(self, new_data, on_bad_vectors, fill_value)

View File

@@ -19,6 +19,7 @@ import pyarrow as pa
from lance import json_to_schema
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
from lancedb.merge import LanceMergeInsertBuilder
from ..query import LanceVectorQueryBuilder
from ..table import Query, Table, _sanitize_data
@@ -244,9 +245,46 @@ class RemoteTable(Table):
result = self._conn._client.query(self._name, query)
return result.to_arrow()
def _do_merge(self, *_args):
"""_do_merge() is not supported on the LanceDB cloud yet"""
return NotImplementedError("_do_merge() is not supported on the LanceDB cloud")
def _do_merge(
self,
merge: LanceMergeInsertBuilder,
new_data: DATA,
on_bad_vectors: str,
fill_value: float,
):
data = _sanitize_data(
new_data,
self.schema,
metadata=None,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
payload = to_ipc_binary(data)
params = {}
if len(merge._on) != 1:
raise ValueError(
"RemoteTable only supports a single on key in merge_insert"
)
params["on"] = merge._on[0]
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
params["when_not_matched_insert_all"] = str(
merge._when_not_matched_insert_all
).lower()
params["when_not_matched_by_source_delete"] = str(
merge._when_not_matched_by_source_delete
).lower()
if merge._when_not_matched_by_source_condition is not None:
params[
"when_not_matched_by_source_delete_filt"
] = merge._when_not_matched_by_source_condition
self._conn._client.post(
f"/v1/table/{self._name}/merge_insert/",
data=payload,
params=params,
content_type=ARROW_STREAM_CONTENT_TYPE,
)
def delete(self, predicate: str):
"""Delete rows from the table.

View File

@@ -390,6 +390,8 @@ class Table(ABC):
2 3 y
3 4 z
"""
on = [on] if isinstance(on, str) else list(on.iter())
return LanceMergeInsertBuilder(self, on)
@abstractmethod
@@ -479,8 +481,8 @@ class Table(ABC):
self,
merge: LanceMergeInsertBuilder,
new_data: DATA,
*,
schema: Optional[pa.Schema] = None,
on_bad_vectors: str,
fill_value: float,
):
pass
@@ -1305,7 +1307,20 @@ class LanceTable(Table):
with_row_id=query.with_row_id,
)
def _do_merge(self, merge: LanceMergeInsertBuilder, new_data: DATA, *, schema=None):
def _do_merge(
self,
merge: LanceMergeInsertBuilder,
new_data: DATA,
on_bad_vectors: str,
fill_value: float,
):
new_data = _sanitize_data(
new_data,
self.schema,
metadata=self.schema.metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
ds = self.to_lance()
builder = ds.merge_insert(merge._on)
if merge._when_matched_update_all:
@@ -1315,7 +1330,7 @@ class LanceTable(Table):
if merge._when_not_matched_by_source_delete:
cond = merge._when_not_matched_by_source_condition
builder.when_not_matched_by_source_delete(cond)
builder.execute(new_data, schema=schema)
builder.execute(new_data)
def cleanup_old_versions(
self,