mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-06 20:02:58 +00:00
@@ -104,4 +104,4 @@ class LanceMergeInsertBuilder(object):
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
"""
|
||||
self._table._do_merge(self, new_data, on_bad_vectors, fill_value)
|
||||
return self._table._do_merge(self, new_data, on_bad_vectors, fill_value)
|
||||
|
||||
@@ -2464,7 +2464,31 @@ class AsyncTable:
|
||||
on_bad_vectors: str,
|
||||
fill_value: float,
|
||||
):
|
||||
pass
|
||||
schema = await self.schema()
|
||||
if on_bad_vectors is None:
|
||||
on_bad_vectors = "error"
|
||||
if fill_value is None:
|
||||
fill_value = 0.0
|
||||
data, _ = _sanitize_data(
|
||||
new_data,
|
||||
schema,
|
||||
metadata=schema.metadata,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
if isinstance(data, pa.Table):
|
||||
data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches())
|
||||
await self._inner.execute_merge_insert(
|
||||
data,
|
||||
dict(
|
||||
on=merge._on,
|
||||
when_matched_update_all=merge._when_matched_update_all,
|
||||
when_matched_update_all_condition=merge._when_matched_update_all_condition,
|
||||
when_not_matched_insert_all=merge._when_not_matched_insert_all,
|
||||
when_not_matched_by_source_delete=merge._when_not_matched_by_source_delete,
|
||||
when_not_matched_by_source_condition=merge._when_not_matched_by_source_condition,
|
||||
),
|
||||
)
|
||||
|
||||
async def delete(self, where: str):
|
||||
"""Delete rows from the table.
|
||||
|
||||
@@ -636,11 +636,13 @@ def test_merge_insert(db):
|
||||
new_data = pa.table({"a": [2, 4], "b": ["x", "z"]})
|
||||
|
||||
# replace-range
|
||||
table.merge_insert(
|
||||
"a"
|
||||
).when_matched_update_all().when_not_matched_insert_all().when_not_matched_by_source_delete(
|
||||
"a > 2"
|
||||
).execute(new_data)
|
||||
(
|
||||
table.merge_insert("a")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.when_not_matched_by_source_delete("a > 2")
|
||||
.execute(new_data)
|
||||
)
|
||||
|
||||
expected = pa.table({"a": [1, 2, 4], "b": ["a", "x", "z"]})
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
@@ -658,6 +660,75 @@ def test_merge_insert(db):
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_insert_async(db_async: AsyncConnection):
|
||||
data = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]})
|
||||
table = await db_async.create_table("some_table", data=data)
|
||||
assert await table.count_rows() == 3
|
||||
version = await table.version()
|
||||
|
||||
new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]})
|
||||
|
||||
# upsert
|
||||
await (
|
||||
table.merge_insert("a")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.execute(new_data)
|
||||
)
|
||||
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "x", "y", "z"]})
|
||||
assert (await table.to_arrow()).sort_by("a") == expected
|
||||
|
||||
await table.checkout(version)
|
||||
await table.restore()
|
||||
|
||||
# conditional update
|
||||
await (
|
||||
table.merge_insert("a")
|
||||
.when_matched_update_all(where="target.b = 'b'")
|
||||
.execute(new_data)
|
||||
)
|
||||
expected = pa.table({"a": [1, 2, 3], "b": ["a", "x", "c"]})
|
||||
assert (await table.to_arrow()).sort_by("a") == expected
|
||||
|
||||
await table.checkout(version)
|
||||
await table.restore()
|
||||
|
||||
# insert-if-not-exists
|
||||
await table.merge_insert("a").when_not_matched_insert_all().execute(new_data)
|
||||
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "b", "c", "z"]})
|
||||
assert (await table.to_arrow()).sort_by("a") == expected
|
||||
|
||||
await table.checkout(version)
|
||||
await table.restore()
|
||||
|
||||
# replace-range
|
||||
new_data = pa.table({"a": [2, 4], "b": ["x", "z"]})
|
||||
await (
|
||||
table.merge_insert("a")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.when_not_matched_by_source_delete("a > 2")
|
||||
.execute(new_data)
|
||||
)
|
||||
expected = pa.table({"a": [1, 2, 4], "b": ["a", "x", "z"]})
|
||||
assert (await table.to_arrow()).sort_by("a") == expected
|
||||
|
||||
await table.checkout(version)
|
||||
await table.restore()
|
||||
|
||||
# replace-range no condition
|
||||
await (
|
||||
table.merge_insert("a")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.when_not_matched_by_source_delete()
|
||||
.execute(new_data)
|
||||
)
|
||||
expected = pa.table({"a": [2, 4], "b": ["x", "z"]})
|
||||
assert (await table.to_arrow()).sort_by("a") == expected
|
||||
|
||||
|
||||
def test_create_with_embedding_function(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
|
||||
@@ -9,7 +9,7 @@ use pyo3::{
|
||||
exceptions::{PyRuntimeError, PyValueError},
|
||||
pyclass, pymethods,
|
||||
types::{PyDict, PyDictMethods, PyString},
|
||||
Bound, PyAny, PyRef, PyResult, Python, ToPyObject,
|
||||
Bound, FromPyObject, PyAny, PyRef, PyResult, Python, ToPyObject,
|
||||
};
|
||||
use pyo3_asyncio_0_21::tokio::future_into_py;
|
||||
|
||||
@@ -331,6 +331,31 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn execute_merge_insert<'a>(
|
||||
self_: PyRef<'a, Self>,
|
||||
data: Bound<'a, PyAny>,
|
||||
parameters: MergeInsertParams,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
let batches: ArrowArrayStreamReader = ArrowArrayStreamReader::from_pyarrow_bound(&data)?;
|
||||
let on = parameters.on.iter().map(|s| s.as_str()).collect::<Vec<_>>();
|
||||
let mut builder = self_.inner_ref()?.merge_insert(&on);
|
||||
if parameters.when_matched_update_all {
|
||||
builder.when_matched_update_all(parameters.when_matched_update_all_condition);
|
||||
}
|
||||
if parameters.when_not_matched_insert_all {
|
||||
builder.when_not_matched_insert_all();
|
||||
}
|
||||
if parameters.when_not_matched_by_source_delete {
|
||||
builder
|
||||
.when_not_matched_by_source_delete(parameters.when_not_matched_by_source_condition);
|
||||
}
|
||||
|
||||
future_into_py(self_.py(), async move {
|
||||
builder.execute(Box::new(batches)).await.infer_error()?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn uses_v2_manifest_paths(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner_ref()?.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
@@ -355,3 +380,14 @@ impl Table {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FromPyObject)]
|
||||
#[pyo3(from_item_all)]
|
||||
pub struct MergeInsertParams {
|
||||
on: Vec<String>,
|
||||
when_matched_update_all: bool,
|
||||
when_matched_update_all_condition: Option<String>,
|
||||
when_not_matched_insert_all: bool,
|
||||
when_not_matched_by_source_delete: bool,
|
||||
when_not_matched_by_source_condition: Option<String>,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user