Compare commits

..

1 Commits

Author SHA1 Message Date
qzhu
c3be2e3962 small fix for the guides/table page 2024-02-01 14:41:00 -08:00
17 changed files with 24 additions and 716 deletions

View File

@@ -11,10 +11,10 @@ license = "Apache-2.0"
repository = "https://github.com/lancedb/lancedb"
[workspace.dependencies]
lance = { "version" = "=0.9.12", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.9.12" }
lance-linalg = { "version" = "=0.9.12" }
lance-testing = { "version" = "=0.9.12" }
lance = { "version" = "=0.9.10", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.9.10" }
lance-linalg = { "version" = "=0.9.10" }
lance-testing = { "version" = "=0.9.10" }
# Note that this one does not include pyarrow
arrow = { version = "50.0", optional = false }
arrow-array = "50.0"

View File

@@ -100,9 +100,7 @@ This guide will show how to create tables, insert data into them, and update the
db["my_table"].head()
```
!!! info "Note"
Data is converted to Arrow before being written to disk. For maximum control over how data is saved, either provide the PyArrow schema to convert to or else provide a PyArrow Table directly.
The **`vector`** column needs to be a [Vector](../python/pydantic.md#vector-field) (defined as [pyarrow.FixedSizeList](https://arrow.apache.org/docs/python/generated/pyarrow.list_.html)) type.
Data is converted to Arrow before being written to disk. For maximum control over how data is saved, either provide the PyArrow schema to convert to or else provide a PyArrow Table directly.
```python
custom_schema = pa.schema([

View File

@@ -37,7 +37,6 @@ const {
tableCountRows,
tableDelete,
tableUpdate,
tableMergeInsert,
tableCleanupOldVersions,
tableCompactFiles,
tableListIndices,
@@ -441,38 +440,6 @@ export interface Table<T = number[]> {
*/
update: (args: UpdateArgs | UpdateSqlArgs) => Promise<void>
/**
* Runs a "merge insert" operation on the table
*
* This operation can add rows, update rows, and remove rows all in a single
* transaction. It is a very generic tool that can be used to create
* behaviors like "insert if not exists", "update or insert (i.e. upsert)",
* or even replace a portion of existing data with new data (e.g. replace
* all data where month="january")
*
* The merge insert operation works by combining new data from a
* **source table** with existing data in a **target table** by using a
* join. There are three categories of records.
*
* "Matched" records are records that exist in both the source table and
* the target table. "Not matched" records exist only in the source table
* (e.g. these are new data) "Not matched by source" records exist only
* in the target table (this is old data)
*
* The MergeInsertArgs can be used to customize what should happen for
* each category of data.
*
* Please note that the data may appear to be reordered as part of this
* operation. This is because updated rows will be deleted from the
* dataset and then reinserted at the end with the new values.
*
* @param on a column to join on. This is how records from the source
* table and target table are matched.
* @param data the new data to insert
* @param args parameters controlling how the operation should behave
*/
mergeInsert: (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs) => Promise<void>
/**
* List the indicies on this table.
*/
@@ -516,36 +483,6 @@ export interface UpdateSqlArgs {
valuesSql: Record<string, string>
}
export interface MergeInsertArgs {
/**
* If true then rows that exist in both the source table (new data) and
* the target table (old data) will be updated, replacing the old row
* with the corresponding matching row.
*
* If there are multiple matches then the behavior is undefined.
* Currently this causes multiple copies of the row to be created
* but that behavior is subject to change.
*/
whenMatchedUpdateAll?: boolean
/**
* If true then rows that exist only in the source table (new data)
* will be inserted into the target table.
*/
whenNotMatchedInsertAll?: boolean
/**
* If true then rows that exist only in the target table (old data)
* will be deleted.
*
* If this is a string then it will be treated as an SQL filter and
* only rows that both do not match any row in the source table and
* match the given filter will be deleted.
*
* This can be used to replace a selection of existing data with
* new data.
*/
whenNotMatchedBySourceDelete?: string | boolean
}
export interface VectorIndex {
columns: string[]
name: string
@@ -884,38 +821,6 @@ export class LocalTable<T = number[]> implements Table<T> {
})
}
async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
const whenMatchedUpdateAll = args.whenMatchedUpdateAll ?? false
const whenNotMatchedInsertAll = args.whenNotMatchedInsertAll ?? false
let whenNotMatchedBySourceDelete = false
let whenNotMatchedBySourceDeleteFilt = null
if (args.whenNotMatchedBySourceDelete !== undefined && args.whenNotMatchedBySourceDelete !== null) {
whenNotMatchedBySourceDelete = true
if (args.whenNotMatchedBySourceDelete !== true) {
whenNotMatchedBySourceDeleteFilt = args.whenNotMatchedBySourceDelete
}
}
const schema = await this.schema
let tbl: ArrowTable
if (data instanceof ArrowTable) {
tbl = data
} else {
tbl = makeArrowTable(data, { schema })
}
const buffer = await fromTableToBuffer(tbl, this._embeddings, schema)
this._tbl = await tableMergeInsert.call(
this._tbl,
on,
whenMatchedUpdateAll,
whenNotMatchedInsertAll,
whenNotMatchedBySourceDelete,
whenNotMatchedBySourceDeleteFilt,
buffer
)
}
/**
* Clean up old versions of the table, freeing disk space.
*

View File

@@ -24,8 +24,7 @@ import {
type IndexStats,
type UpdateArgs,
type UpdateSqlArgs,
makeArrowTable,
type MergeInsertArgs
makeArrowTable
} from '../index'
import { Query } from '../query'
@@ -275,52 +274,6 @@ export class RemoteTable<T = number[]> implements Table<T> {
throw new Error('Not implemented')
}
async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
let tbl: ArrowTable
if (data instanceof ArrowTable) {
tbl = data
} else {
tbl = makeArrowTable(data, await this.schema)
}
const queryParams: any = {
on
}
if (args.whenMatchedUpdateAll ?? false) {
queryParams.when_matched_update_all = 'true'
} else {
queryParams.when_matched_update_all = 'false'
}
if (args.whenNotMatchedInsertAll ?? false) {
queryParams.when_not_matched_insert_all = 'true'
} else {
queryParams.when_not_matched_insert_all = 'false'
}
if (args.whenNotMatchedBySourceDelete !== false && args.whenNotMatchedBySourceDelete !== null && args.whenNotMatchedBySourceDelete !== undefined) {
queryParams.when_not_matched_by_source_delete = 'true'
if (typeof args.whenNotMatchedBySourceDelete === 'string') {
queryParams.when_not_matched_by_source_delete_filt = args.whenNotMatchedBySourceDelete
}
} else {
queryParams.when_not_matched_by_source_delete = 'false'
}
const buffer = await fromTableToStreamBuffer(tbl, this._embeddings)
const res = await this._client.post(
`/v1/table/${this._name}/merge_insert/`,
buffer,
queryParams,
'application/vnd.apache.arrow.stream'
)
if (res.status !== 200) {
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
)
}
}
async add (data: Array<Record<string, unknown>> | ArrowTable): Promise<number> {
let tbl: ArrowTable
if (data instanceof ArrowTable) {

View File

@@ -531,44 +531,6 @@ describe('LanceDB client', function () {
assert.equal(await table.countRows(), 2)
})
it('can merge insert records into the table', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const data = [{ id: 1, age: 1 }, { id: 2, age: 1 }]
const table = await con.createTable('my_table', data)
let newData = [{ id: 2, age: 2 }, { id: 3, age: 2 }]
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true
})
assert.equal(await table.countRows(), 3)
assert.equal((await table.filter('age = 2').execute()).length, 1)
newData = [{ id: 3, age: 3 }, { id: 4, age: 3 }]
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true,
whenMatchedUpdateAll: true
})
assert.equal(await table.countRows(), 4)
assert.equal((await table.filter('age = 3').execute()).length, 2)
newData = [{ id: 5, age: 4 }]
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true,
whenMatchedUpdateAll: true,
whenNotMatchedBySourceDelete: 'age < 3'
})
assert.equal(await table.countRows(), 3)
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true,
whenMatchedUpdateAll: true,
whenNotMatchedBySourceDelete: true
})
assert.equal(await table.countRows(), 1)
})
it('can update records in the table', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.5.2
current_version = 0.5.1
commit = True
message = [python] Bump version: {current_version} → {new_version}
tag = True

View File

@@ -12,7 +12,7 @@
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Iterable, Optional
if TYPE_CHECKING:
from .common import DATA
@@ -25,7 +25,7 @@ class LanceMergeInsertBuilder(object):
more context
"""
def __init__(self, table: "Table", on: List[str]): # noqa: F821
def __init__(self, table: "Table", on: Iterable[str]): # noqa: F821
# Do not put a docstring here. This method should be hidden
# from API docs. Users should use merge_insert to create
# this object.
@@ -77,27 +77,10 @@ class LanceMergeInsertBuilder(object):
self._when_not_matched_by_source_condition = condition
return self
def execute(
self,
new_data: DATA,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
):
def execute(self, new_data: DATA):
"""
Executes the merge insert operation
Nothing is returned but the [`Table`][lancedb.table.Table] is updated
Parameters
----------
new_data: DATA
New records which will be matched against the existing records
to potentially insert or update into the table. This parameter
can be anything you use for [`add`][lancedb.table.Table.add]
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
self._table._do_merge(self, new_data, on_bad_vectors, fill_value)
self._table._do_merge(self, new_data)

View File

@@ -13,8 +13,6 @@
import functools
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Union
from urllib.parse import urljoin
@@ -22,8 +20,6 @@ import attrs
import pyarrow as pa
import requests
from pydantic import BaseModel
from requests.adapters import HTTPAdapter
from urllib3 import Retry
from lancedb.common import Credential
from lancedb.remote import VectorQuery, VectorQueryResult
@@ -61,10 +57,6 @@ class RestfulLanceDBClient:
@functools.cached_property
def session(self) -> requests.Session:
sess = requests.Session()
retry_adapter_instance = retry_adapter(retry_adapter_options())
sess.mount(urljoin(self.url, "/v1/table/"), retry_adapter_instance)
adapter_class = LanceDBClientHTTPAdapterFactory()
sess.mount("https://", adapter_class())
return sess
@@ -178,72 +170,3 @@ class RestfulLanceDBClient:
"""Query a table."""
tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc)
return VectorQueryResult(tbl)
def mount_retry_adapter_for_table(self, table_name: str) -> None:
"""
Adds an http adapter to session that will retry retryable requests to the table.
"""
retry_options = retry_adapter_options(methods=["GET", "POST"])
retry_adapter_instance = retry_adapter(retry_options)
session = self.session
session.mount(
urljoin(self.url, f"/v1/table/{table_name}/query/"), retry_adapter_instance
)
session.mount(
urljoin(self.url, f"/v1/table/{table_name}/describe/"),
retry_adapter_instance,
)
session.mount(
urljoin(self.url, f"/v1/table/{table_name}/index/list/"),
retry_adapter_instance,
)
def retry_adapter_options(methods=["GET"]) -> Dict[str, Any]:
return {
"retries": int(os.environ.get("LANCE_CLIENT_MAX_RETRIES", "3")),
"connect_retries": int(os.environ.get("LANCE_CLIENT_CONNECT_RETRIES", "3")),
"read_retries": int(os.environ.get("LANCE_CLIENT_READ_RETRIES", "3")),
"backoff_factor": float(
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_FACTOR", "0.25")
),
"backoff_jitter": float(
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_JITTER", "0.25")
),
"statuses": [
int(i.strip())
for i in os.environ.get(
"LANCE_CLIENT_RETRY_STATUSES", "429, 500, 502, 503"
).split(",")
],
"methods": methods,
}
def retry_adapter(options: Dict[str, Any]) -> HTTPAdapter:
total_retries = options["retries"]
connect_retries = options["connect_retries"]
read_retries = options["read_retries"]
backoff_factor = options["backoff_factor"]
backoff_jitter = options["backoff_jitter"]
statuses = options["statuses"]
methods = frozenset(options["methods"])
logging.debug(
f"Setting up retry adapter with {total_retries} retries," # noqa G003
+ f"connect retries {connect_retries}, read retries {read_retries},"
+ f"backoff factor {backoff_factor}, statuses {statuses}, "
+ f"methods {methods}"
)
return HTTPAdapter(
max_retries=Retry(
total=total_retries,
connect=connect_retries,
read=read_retries,
backoff_factor=backoff_factor,
backoff_jitter=backoff_jitter,
status_forcelist=statuses,
allowed_methods=methods,
)
)

View File

@@ -95,8 +95,6 @@ class RemoteDBConnection(DBConnection):
"""
from .table import RemoteTable
self._client.mount_retry_adapter_for_table(name)
# check if table exists
try:
self._client.post(f"/v1/table/{name}/describe/")

View File

@@ -19,7 +19,6 @@ import pyarrow as pa
from lance import json_to_schema
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
from lancedb.merge import LanceMergeInsertBuilder
from ..query import LanceVectorQueryBuilder
from ..table import Query, Table, _sanitize_data
@@ -245,46 +244,9 @@ class RemoteTable(Table):
result = self._conn._client.query(self._name, query)
return result.to_arrow()
def _do_merge(
self,
merge: LanceMergeInsertBuilder,
new_data: DATA,
on_bad_vectors: str,
fill_value: float,
):
data = _sanitize_data(
new_data,
self.schema,
metadata=None,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
payload = to_ipc_binary(data)
params = {}
if len(merge._on) != 1:
raise ValueError(
"RemoteTable only supports a single on key in merge_insert"
)
params["on"] = merge._on[0]
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
params["when_not_matched_insert_all"] = str(
merge._when_not_matched_insert_all
).lower()
params["when_not_matched_by_source_delete"] = str(
merge._when_not_matched_by_source_delete
).lower()
if merge._when_not_matched_by_source_condition is not None:
params[
"when_not_matched_by_source_delete_filt"
] = merge._when_not_matched_by_source_condition
self._conn._client.post(
f"/v1/table/{self._name}/merge_insert/",
data=payload,
params=params,
content_type=ARROW_STREAM_CONTENT_TYPE,
)
def _do_merge(self, *_args):
"""_do_merge() is not supported on the LanceDB cloud yet"""
return NotImplementedError("_do_merge() is not supported on the LanceDB cloud")
def delete(self, predicate: str):
"""Delete rows from the table.
@@ -397,18 +359,6 @@ class RemoteTable(Table):
payload = {"predicate": where, "updates": updates}
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
def cleanup_old_versions(self, *_):
"""cleanup_old_versions() is not supported on the LanceDB cloud"""
raise NotImplementedError(
"cleanup_old_versions() is not supported on the LanceDB cloud"
)
def compact_files(self, *_):
"""compact_files() is not supported on the LanceDB cloud"""
raise NotImplementedError(
"compact_files() is not supported on the LanceDB cloud"
)
def add_index(tbl: pa.Table, i: int) -> pa.Table:
return tbl.add_column(

View File

@@ -391,8 +391,6 @@ class Table(ABC):
2 3 y
3 4 z
"""
on = [on] if isinstance(on, str) else list(on.iter())
return LanceMergeInsertBuilder(self, on)
@abstractmethod
@@ -440,8 +438,6 @@ class Table(ABC):
the table
vector_column_name: str
The name of the vector column to search.
The vector column needs to be a pyarrow fixed size list type
*default "vector"*
query_type: str
*default "auto"*.
@@ -482,8 +478,8 @@ class Table(ABC):
self,
merge: LanceMergeInsertBuilder,
new_data: DATA,
on_bad_vectors: str,
fill_value: float,
*,
schema: Optional[pa.Schema] = None,
):
pass
@@ -594,52 +590,6 @@ class Table(ABC):
"""
raise NotImplementedError
@abstractmethod
def cleanup_old_versions(
self,
older_than: Optional[timedelta] = None,
*,
delete_unverified: bool = False,
) -> CleanupStats:
"""
Clean up old versions of the table, freeing disk space.
Note: This function is not available in LanceDb Cloud (since LanceDb
Cloud manages cleanup for you automatically)
Parameters
----------
older_than: timedelta, default None
The minimum age of the version to delete. If None, then this defaults
to two weeks.
delete_unverified: bool, default False
Because they may be part of an in-progress transaction, files newer
than 7 days old are not deleted by default. If you are sure that
there are no in-progress transactions, then you can set this to True
to delete all files older than `older_than`.
Returns
-------
CleanupStats
The stats of the cleanup operation, including how many bytes were
freed.
"""
@abstractmethod
def compact_files(self, *args, **kwargs):
"""
Run the compaction process on the table.
Note: This function is not available in LanceDb Cloud (since LanceDb
Cloud manages compaction for you automatically)
This can be run after making several small appends to optimize the table
for faster reads.
Arguments are passed onto :meth:`lance.dataset.DatasetOptimizer.compact_files`.
For most cases, the default should be fine.
"""
class LanceTable(Table):
"""
@@ -1315,20 +1265,7 @@ class LanceTable(Table):
with_row_id=query.with_row_id,
)
def _do_merge(
self,
merge: LanceMergeInsertBuilder,
new_data: DATA,
on_bad_vectors: str,
fill_value: float,
):
new_data = _sanitize_data(
new_data,
self.schema,
metadata=self.schema.metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
def _do_merge(self, merge: LanceMergeInsertBuilder, new_data: DATA, *, schema=None):
ds = self.to_lance()
builder = ds.merge_insert(merge._on)
if merge._when_matched_update_all:
@@ -1338,7 +1275,7 @@ class LanceTable(Table):
if merge._when_not_matched_by_source_delete:
cond = merge._when_not_matched_by_source_condition
builder.when_not_matched_by_source_delete(cond)
builder.execute(new_data)
builder.execute(new_data, schema=schema)
def cleanup_old_versions(
self,
@@ -1377,9 +1314,8 @@ class LanceTable(Table):
This can be run after making several small appends to optimize the table
for faster reads.
Arguments are passed onto `lance.dataset.DatasetOptimizer.compact_files`.
(see Lance documentation for more details) For most cases, the default
should be fine.
Arguments are passed onto :meth:`lance.dataset.DatasetOptimizer.compact_files`.
For most cases, the default should be fine.
"""
return self.to_lance().optimize.compact_files(*args, **kwargs)

View File

@@ -1,9 +1,9 @@
[project]
name = "lancedb"
version = "0.5.2"
version = "0.5.1"
dependencies = [
"deprecation",
"pylance==0.9.12",
"pylance==0.9.11",
"ratelimiter~=1.0",
"retry>=0.9.2",
"tqdm>=4.27.0",

View File

@@ -29,9 +29,6 @@ class FakeLanceDBClient:
def post(self, path: str):
pass
def mount_retry_adapter_for_table(self, table_name: str):
pass
def test_remote_db():
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")

View File

@@ -260,7 +260,6 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> {
cx.export_function("tableCountRows", JsTable::js_count_rows)?;
cx.export_function("tableDelete", JsTable::js_delete)?;
cx.export_function("tableUpdate", JsTable::js_update)?;
cx.export_function("tableMergeInsert", JsTable::js_merge_insert)?;
cx.export_function("tableCleanupOldVersions", JsTable::js_cleanup)?;
cx.export_function("tableCompactFiles", JsTable::js_compact)?;
cx.export_function("tableListIndices", JsTable::js_list_indices)?;

View File

@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::ops::Deref;
use arrow_array::{RecordBatch, RecordBatchIterator};
use lance::dataset::optimize::CompactionOptions;
use lance::dataset::{WriteMode, WriteParams};
@@ -168,53 +166,6 @@ impl JsTable {
Ok(promise)
}
pub(crate) fn js_merge_insert(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise();
let channel = cx.channel();
let table = js_table.table.clone();
let key = cx.argument::<JsString>(0)?.value(&mut cx);
let mut builder = table.merge_insert(&[&key]);
if cx.argument::<JsBoolean>(1)?.value(&mut cx) {
builder.when_matched_update_all();
}
if cx.argument::<JsBoolean>(2)?.value(&mut cx) {
builder.when_not_matched_insert_all();
}
if cx.argument::<JsBoolean>(3)?.value(&mut cx) {
if let Some(filter) = cx.argument_opt(4) {
if filter.is_a::<JsNull, _>(&mut cx) {
builder.when_not_matched_by_source_delete(None);
} else {
let filter = filter
.downcast_or_throw::<JsString, _>(&mut cx)?
.deref()
.value(&mut cx);
builder.when_not_matched_by_source_delete(Some(filter));
}
} else {
builder.when_not_matched_by_source_delete(None);
}
}
let buffer = cx.argument::<JsBuffer>(5)?;
let (batches, schema) =
arrow_buffer_to_record_batch(buffer.as_slice(&cx)).or_throw(&mut cx)?;
rt.spawn(async move {
let new_data = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
let merge_insert_result = builder.execute(Box::new(new_data)).await;
deferred.settle_with(&channel, move |mut cx| {
merge_insert_result.or_throw(&mut cx)?;
Ok(cx.boxed(JsTable::from(table)))
})
});
Ok(promise)
}
pub(crate) fn js_update(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let table = js_table.table.clone();

View File

@@ -19,7 +19,6 @@ use std::sync::{Arc, Mutex};
use arrow_array::RecordBatchReader;
use arrow_schema::{Schema, SchemaRef};
use async_trait::async_trait;
use chrono::Duration;
use lance::dataset::builder::DatasetBuilder;
use lance::dataset::cleanup::RemovalStats;
@@ -28,7 +27,6 @@ use lance::dataset::optimize::{
};
pub use lance::dataset::ReadParams;
use lance::dataset::{Dataset, UpdateBuilder, WriteParams};
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
use lance::io::WrappingObjectStore;
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
use log::info;
@@ -40,10 +38,6 @@ use crate::query::Query;
use crate::utils::{PatchReadParam, PatchWriteParam};
use crate::WriteMode;
use self::merge::{MergeInsert, MergeInsertBuilder};
pub mod merge;
/// Optimize the dataset.
///
/// Similar to `VACUUM` in PostgreSQL, it offers different options to
@@ -176,71 +170,6 @@ pub trait Table: std::fmt::Display + Send + Sync {
/// ```
fn create_index(&self, column: &[&str]) -> IndexBuilder;
/// Create a builder for a merge insert operation
///
/// This operation can add rows, update rows, and remove rows all in a single
/// transaction. It is a very generic tool that can be used to create
/// behaviors like "insert if not exists", "update or insert (i.e. upsert)",
/// or even replace a portion of existing data with new data (e.g. replace
/// all data where month="january")
///
/// The merge insert operation works by combining new data from a
/// **source table** with existing data in a **target table** by using a
/// join. There are three categories of records.
///
/// "Matched" records are records that exist in both the source table and
/// the target table. "Not matched" records exist only in the source table
/// (e.g. these are new data) "Not matched by source" records exist only
/// in the target table (this is old data)
///
/// The builder returned by this method can be used to customize what
/// should happen for each category of data.
///
/// Please note that the data may appear to be reordered as part of this
/// operation. This is because updated rows will be deleted from the
/// dataset and then reinserted at the end with the new values.
///
/// # Arguments
///
/// * `on` One or more columns to join on. This is how records from the
/// source table and target table are matched. Typically this is some
/// kind of key or id column.
///
/// # Examples
///
/// ```no_run
/// # use std::sync::Arc;
/// # use vectordb::connection::{Database, Connection};
/// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
/// # RecordBatchIterator, Int32Array};
/// # use arrow_schema::{Schema, Field, DataType};
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let tmpdir = tempfile::tempdir().unwrap();
/// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
/// # let tbl = db.open_table("idx_test").await.unwrap();
/// # let schema = Arc::new(Schema::new(vec![
/// # Field::new("id", DataType::Int32, false),
/// # Field::new("vector", DataType::FixedSizeList(
/// # Arc::new(Field::new("item", DataType::Float32, true)), 128), true),
/// # ]));
/// let new_data = RecordBatchIterator::new(vec![
/// RecordBatch::try_new(schema.clone(),
/// vec![
/// Arc::new(Int32Array::from_iter_values(0..10)),
/// Arc::new(FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
/// (0..10).map(|_| Some(vec![Some(1.0); 128])), 128)),
/// ]).unwrap()
/// ].into_iter().map(Ok),
/// schema.clone());
/// // Perform an upsert operation
/// let mut merge_insert = tbl.merge_insert(&["id"]);
/// merge_insert.when_matched_update_all()
/// .when_not_matched_insert_all();
/// merge_insert.execute(Box::new(new_data)).await.unwrap();
/// # });
/// ```
fn merge_insert(&self, on: &[&str]) -> MergeInsertBuilder;
/// Search the table with a given query vector.
///
/// This is a convenience method for preparing an ANN query.
@@ -664,42 +593,6 @@ impl NativeTable {
}
}
#[async_trait]
impl MergeInsert for NativeTable {
async fn do_merge_insert(
&self,
params: MergeInsertBuilder,
new_data: Box<dyn RecordBatchReader + Send>,
) -> Result<()> {
let dataset = Arc::new(self.clone_inner_dataset());
let mut builder = LanceMergeInsertBuilder::try_new(dataset.clone(), params.on)?;
if params.when_matched_update_all {
builder.when_matched(lance::dataset::WhenMatched::UpdateAll);
} else {
builder.when_matched(lance::dataset::WhenMatched::DoNothing);
}
if params.when_not_matched_insert_all {
builder.when_not_matched(lance::dataset::WhenNotMatched::InsertAll);
} else {
builder.when_not_matched(lance::dataset::WhenNotMatched::DoNothing);
}
if params.when_not_matched_by_source_delete {
let behavior = if let Some(filter) = params.when_not_matched_by_source_delete_filt {
WhenNotMatchedBySource::delete_if(dataset.as_ref(), &filter)?
} else {
WhenNotMatchedBySource::Delete
};
builder.when_not_matched_by_source(behavior);
} else {
builder.when_not_matched_by_source(WhenNotMatchedBySource::Keep);
}
let job = builder.try_build()?;
let new_dataset = job.execute_reader(new_data).await?;
self.reset_dataset((*new_dataset).clone());
Ok(())
}
}
#[async_trait::async_trait]
impl Table for NativeTable {
fn as_any(&self) -> &dyn std::any::Any {
@@ -744,11 +637,6 @@ impl Table for NativeTable {
Ok(())
}
fn merge_insert(&self, on: &[&str]) -> MergeInsertBuilder {
let on = Vec::from_iter(on.iter().map(|key| key.to_string()));
MergeInsertBuilder::new(Arc::new(self.clone()), on)
}
fn create_index(&self, columns: &[&str]) -> IndexBuilder {
IndexBuilder::new(Arc::new(self.clone()), columns)
}
@@ -914,38 +802,6 @@ mod tests {
assert_eq!(table.name, "test");
}
#[tokio::test]
async fn test_merge_insert() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
// Create a dataset with i=0..10
let batches = make_test_batches_with_offset(0);
let table = NativeTable::create(&uri, "test", batches, None, None)
.await
.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
// Create new data with i=5..15
let new_batches = Box::new(make_test_batches_with_offset(5));
// Perform a "insert if not exists"
let mut merge_insert_builder = table.merge_insert(&["i"]);
merge_insert_builder.when_not_matched_insert_all();
merge_insert_builder.execute(new_batches).await.unwrap();
// Only 5 rows should actually be inserted
assert_eq!(table.count_rows().await.unwrap(), 15);
// Create new data with i=15..25 (no id matches)
let new_batches = Box::new(make_test_batches_with_offset(15));
// Perform a "bulk update" (should not affect anything)
let mut merge_insert_builder = table.merge_insert(&["i"]);
merge_insert_builder.when_matched_update_all();
merge_insert_builder.execute(new_batches).await.unwrap();
// No new rows should have been inserted
assert_eq!(table.count_rows().await.unwrap(), 15);
}
#[tokio::test]
async fn test_add_overwrite() {
let tmp_dir = tempdir().unwrap();
@@ -1292,25 +1148,17 @@ mod tests {
assert!(wrapper.called());
}
fn make_test_batches_with_offset(
offset: i32,
) -> impl RecordBatchReader + Send + Sync + 'static {
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(
offset..(offset + 10),
))],
vec![Arc::new(Int32Array::from_iter_values(0..10))],
)],
schema,
)
}
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
make_test_batches_with_offset(0)
}
#[tokio::test]
async fn test_create_index() {
use arrow_array::RecordBatch;

View File

@@ -1,95 +0,0 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use arrow_array::RecordBatchReader;
use async_trait::async_trait;
use crate::Result;
#[async_trait]
pub(super) trait MergeInsert: Send + Sync {
async fn do_merge_insert(
&self,
params: MergeInsertBuilder,
new_data: Box<dyn RecordBatchReader + Send>,
) -> Result<()>;
}
/// A builder used to create and run a merge insert operation
///
/// See [`super::Table::merge_insert`] for more context
pub struct MergeInsertBuilder {
table: Arc<dyn MergeInsert>,
pub(super) on: Vec<String>,
pub(super) when_matched_update_all: bool,
pub(super) when_not_matched_insert_all: bool,
pub(super) when_not_matched_by_source_delete: bool,
pub(super) when_not_matched_by_source_delete_filt: Option<String>,
}
impl MergeInsertBuilder {
pub(super) fn new(table: Arc<dyn MergeInsert>, on: Vec<String>) -> Self {
Self {
table,
on,
when_matched_update_all: false,
when_not_matched_insert_all: false,
when_not_matched_by_source_delete: false,
when_not_matched_by_source_delete_filt: None,
}
}
/// Rows that exist in both the source table (new data) and
/// the target table (old data) will be updated, replacing
/// the old row with the corresponding matching row.
///
/// If there are multiple matches then the behavior is undefined.
/// Currently this causes multiple copies of the row to be created
/// but that behavior is subject to change.
pub fn when_matched_update_all(&mut self) -> &mut Self {
self.when_matched_update_all = true;
self
}
/// Rows that exist only in the source table (new data) should
/// be inserted into the target table.
pub fn when_not_matched_insert_all(&mut self) -> &mut Self {
self.when_not_matched_insert_all = true;
self
}
/// Rows that exist only in the target table (old data) will be
/// deleted. An optional condition can be provided to limit what
/// data is deleted.
///
/// # Arguments
///
/// * `condition` - If None then all such rows will be deleted.
/// Otherwise the condition will be used as an SQL filter to
/// limit what rows are deleted.
pub fn when_not_matched_by_source_delete(&mut self, filter: Option<String>) -> &mut Self {
self.when_not_matched_by_source_delete = true;
self.when_not_matched_by_source_delete_filt = filter;
self
}
/// Executes the merge insert operation
///
/// Nothing is returned but the [`super::Table`] is updated
pub async fn execute(self, new_data: Box<dyn RecordBatchReader + Send>) -> Result<()> {
self.table.clone().do_merge_insert(self, new_data).await
}
}