mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 15:12:53 +00:00
feat: add merge_insert to the node and rust APIs (#915)
This commit is contained in:
@@ -12,7 +12,7 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Iterable, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .common import DATA
|
||||
@@ -25,7 +25,7 @@ class LanceMergeInsertBuilder(object):
|
||||
more context
|
||||
"""
|
||||
|
||||
def __init__(self, table: "Table", on: Iterable[str]): # noqa: F821
|
||||
def __init__(self, table: "Table", on: List[str]): # noqa: F821
|
||||
# Do not put a docstring here. This method should be hidden
|
||||
# from API docs. Users should use merge_insert to create
|
||||
# this object.
|
||||
@@ -77,10 +77,27 @@ class LanceMergeInsertBuilder(object):
|
||||
self._when_not_matched_by_source_condition = condition
|
||||
return self
|
||||
|
||||
def execute(self, new_data: DATA):
|
||||
def execute(
|
||||
self,
|
||||
new_data: DATA,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
):
|
||||
"""
|
||||
Executes the merge insert operation
|
||||
|
||||
Nothing is returned but the [`Table`][lancedb.table.Table] is updated
|
||||
|
||||
Parameters
|
||||
----------
|
||||
new_data: DATA
|
||||
New records which will be matched against the existing records
|
||||
to potentially insert or update into the table. This parameter
|
||||
can be anything you use for [`add`][lancedb.table.Table.add]
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
"""
|
||||
self._table._do_merge(self, new_data)
|
||||
self._table._do_merge(self, new_data, on_bad_vectors, fill_value)
|
||||
|
||||
@@ -19,6 +19,7 @@ import pyarrow as pa
|
||||
from lance import json_to_schema
|
||||
|
||||
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from lancedb.merge import LanceMergeInsertBuilder
|
||||
|
||||
from ..query import LanceVectorQueryBuilder
|
||||
from ..table import Query, Table, _sanitize_data
|
||||
@@ -244,9 +245,46 @@ class RemoteTable(Table):
|
||||
result = self._conn._client.query(self._name, query)
|
||||
return result.to_arrow()
|
||||
|
||||
def _do_merge(self, *_args):
|
||||
"""_do_merge() is not supported on the LanceDB cloud yet"""
|
||||
return NotImplementedError("_do_merge() is not supported on the LanceDB cloud")
|
||||
def _do_merge(
|
||||
self,
|
||||
merge: LanceMergeInsertBuilder,
|
||||
new_data: DATA,
|
||||
on_bad_vectors: str,
|
||||
fill_value: float,
|
||||
):
|
||||
data = _sanitize_data(
|
||||
new_data,
|
||||
self.schema,
|
||||
metadata=None,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
payload = to_ipc_binary(data)
|
||||
|
||||
params = {}
|
||||
if len(merge._on) != 1:
|
||||
raise ValueError(
|
||||
"RemoteTable only supports a single on key in merge_insert"
|
||||
)
|
||||
params["on"] = merge._on[0]
|
||||
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
|
||||
params["when_not_matched_insert_all"] = str(
|
||||
merge._when_not_matched_insert_all
|
||||
).lower()
|
||||
params["when_not_matched_by_source_delete"] = str(
|
||||
merge._when_not_matched_by_source_delete
|
||||
).lower()
|
||||
if merge._when_not_matched_by_source_condition is not None:
|
||||
params[
|
||||
"when_not_matched_by_source_delete_filt"
|
||||
] = merge._when_not_matched_by_source_condition
|
||||
|
||||
self._conn._client.post(
|
||||
f"/v1/table/{self._name}/merge_insert/",
|
||||
data=payload,
|
||||
params=params,
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
)
|
||||
|
||||
def delete(self, predicate: str):
|
||||
"""Delete rows from the table.
|
||||
|
||||
@@ -390,6 +390,8 @@ class Table(ABC):
|
||||
2 3 y
|
||||
3 4 z
|
||||
"""
|
||||
on = [on] if isinstance(on, str) else list(on.iter())
|
||||
|
||||
return LanceMergeInsertBuilder(self, on)
|
||||
|
||||
@abstractmethod
|
||||
@@ -479,8 +481,8 @@ class Table(ABC):
|
||||
self,
|
||||
merge: LanceMergeInsertBuilder,
|
||||
new_data: DATA,
|
||||
*,
|
||||
schema: Optional[pa.Schema] = None,
|
||||
on_bad_vectors: str,
|
||||
fill_value: float,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -1305,7 +1307,20 @@ class LanceTable(Table):
|
||||
with_row_id=query.with_row_id,
|
||||
)
|
||||
|
||||
def _do_merge(self, merge: LanceMergeInsertBuilder, new_data: DATA, *, schema=None):
|
||||
def _do_merge(
|
||||
self,
|
||||
merge: LanceMergeInsertBuilder,
|
||||
new_data: DATA,
|
||||
on_bad_vectors: str,
|
||||
fill_value: float,
|
||||
):
|
||||
new_data = _sanitize_data(
|
||||
new_data,
|
||||
self.schema,
|
||||
metadata=self.schema.metadata,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
ds = self.to_lance()
|
||||
builder = ds.merge_insert(merge._on)
|
||||
if merge._when_matched_update_all:
|
||||
@@ -1315,7 +1330,7 @@ class LanceTable(Table):
|
||||
if merge._when_not_matched_by_source_delete:
|
||||
cond = merge._when_not_matched_by_source_condition
|
||||
builder.when_not_matched_by_source_delete(cond)
|
||||
builder.execute(new_data, schema=schema)
|
||||
builder.execute(new_data)
|
||||
|
||||
def cleanup_old_versions(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user