mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 22:59:57 +00:00
Compare commits
6 Commits
small-doc-
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ce2242e06d | ||
|
|
778339388a | ||
|
|
7f8637a0b4 | ||
|
|
09cd08222d | ||
|
|
a248d7feec | ||
|
|
cc9473a94a |
@@ -11,10 +11,10 @@ license = "Apache-2.0"
|
|||||||
repository = "https://github.com/lancedb/lancedb"
|
repository = "https://github.com/lancedb/lancedb"
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
lance = { "version" = "=0.9.10", "features" = ["dynamodb"] }
|
lance = { "version" = "=0.9.12", "features" = ["dynamodb"] }
|
||||||
lance-index = { "version" = "=0.9.10" }
|
lance-index = { "version" = "=0.9.12" }
|
||||||
lance-linalg = { "version" = "=0.9.10" }
|
lance-linalg = { "version" = "=0.9.12" }
|
||||||
lance-testing = { "version" = "=0.9.10" }
|
lance-testing = { "version" = "=0.9.12" }
|
||||||
# Note that this one does not include pyarrow
|
# Note that this one does not include pyarrow
|
||||||
arrow = { version = "50.0", optional = false }
|
arrow = { version = "50.0", optional = false }
|
||||||
arrow-array = "50.0"
|
arrow-array = "50.0"
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ This guide will show how to create tables, insert data into them, and update the
|
|||||||
const table = await con.createTable(tableName, data, { writeMode: WriteMode.Overwrite })
|
const table = await con.createTable(tableName, data, { writeMode: WriteMode.Overwrite })
|
||||||
```
|
```
|
||||||
|
|
||||||
### From a Pandas DataFrame
|
### From a Pandas DataFrame
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -100,7 +100,9 @@ This guide will show how to create tables, insert data into them, and update the
|
|||||||
db["my_table"].head()
|
db["my_table"].head()
|
||||||
```
|
```
|
||||||
!!! info "Note"
|
!!! info "Note"
|
||||||
Data is converted to Arrow before being written to disk. For maximum control over how data is saved, either provide the PyArrow schema to convert to or else provide a PyArrow Table directly.
|
Data is converted to Arrow before being written to disk. For maximum control over how data is saved, either provide the PyArrow schema to convert to or else provide a PyArrow Table directly.
|
||||||
|
|
||||||
|
The **`vector`** column needs to be a [Vector](../python/pydantic.md#vector-field) (defined as [pyarrow.FixedSizeList](https://arrow.apache.org/docs/python/generated/pyarrow.list_.html)) type.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
custom_schema = pa.schema([
|
custom_schema = pa.schema([
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ const {
|
|||||||
tableCountRows,
|
tableCountRows,
|
||||||
tableDelete,
|
tableDelete,
|
||||||
tableUpdate,
|
tableUpdate,
|
||||||
|
tableMergeInsert,
|
||||||
tableCleanupOldVersions,
|
tableCleanupOldVersions,
|
||||||
tableCompactFiles,
|
tableCompactFiles,
|
||||||
tableListIndices,
|
tableListIndices,
|
||||||
@@ -440,6 +441,38 @@ export interface Table<T = number[]> {
|
|||||||
*/
|
*/
|
||||||
update: (args: UpdateArgs | UpdateSqlArgs) => Promise<void>
|
update: (args: UpdateArgs | UpdateSqlArgs) => Promise<void>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Runs a "merge insert" operation on the table
|
||||||
|
*
|
||||||
|
* This operation can add rows, update rows, and remove rows all in a single
|
||||||
|
* transaction. It is a very generic tool that can be used to create
|
||||||
|
* behaviors like "insert if not exists", "update or insert (i.e. upsert)",
|
||||||
|
* or even replace a portion of existing data with new data (e.g. replace
|
||||||
|
* all data where month="january")
|
||||||
|
*
|
||||||
|
* The merge insert operation works by combining new data from a
|
||||||
|
* **source table** with existing data in a **target table** by using a
|
||||||
|
* join. There are three categories of records.
|
||||||
|
*
|
||||||
|
* "Matched" records are records that exist in both the source table and
|
||||||
|
* the target table. "Not matched" records exist only in the source table
|
||||||
|
* (e.g. these are new data) "Not matched by source" records exist only
|
||||||
|
* in the target table (this is old data)
|
||||||
|
*
|
||||||
|
* The MergeInsertArgs can be used to customize what should happen for
|
||||||
|
* each category of data.
|
||||||
|
*
|
||||||
|
* Please note that the data may appear to be reordered as part of this
|
||||||
|
* operation. This is because updated rows will be deleted from the
|
||||||
|
* dataset and then reinserted at the end with the new values.
|
||||||
|
*
|
||||||
|
* @param on a column to join on. This is how records from the source
|
||||||
|
* table and target table are matched.
|
||||||
|
* @param data the new data to insert
|
||||||
|
* @param args parameters controlling how the operation should behave
|
||||||
|
*/
|
||||||
|
mergeInsert: (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs) => Promise<void>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* List the indicies on this table.
|
* List the indicies on this table.
|
||||||
*/
|
*/
|
||||||
@@ -483,6 +516,36 @@ export interface UpdateSqlArgs {
|
|||||||
valuesSql: Record<string, string>
|
valuesSql: Record<string, string>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface MergeInsertArgs {
|
||||||
|
/**
|
||||||
|
* If true then rows that exist in both the source table (new data) and
|
||||||
|
* the target table (old data) will be updated, replacing the old row
|
||||||
|
* with the corresponding matching row.
|
||||||
|
*
|
||||||
|
* If there are multiple matches then the behavior is undefined.
|
||||||
|
* Currently this causes multiple copies of the row to be created
|
||||||
|
* but that behavior is subject to change.
|
||||||
|
*/
|
||||||
|
whenMatchedUpdateAll?: boolean
|
||||||
|
/**
|
||||||
|
* If true then rows that exist only in the source table (new data)
|
||||||
|
* will be inserted into the target table.
|
||||||
|
*/
|
||||||
|
whenNotMatchedInsertAll?: boolean
|
||||||
|
/**
|
||||||
|
* If true then rows that exist only in the target table (old data)
|
||||||
|
* will be deleted.
|
||||||
|
*
|
||||||
|
* If this is a string then it will be treated as an SQL filter and
|
||||||
|
* only rows that both do not match any row in the source table and
|
||||||
|
* match the given filter will be deleted.
|
||||||
|
*
|
||||||
|
* This can be used to replace a selection of existing data with
|
||||||
|
* new data.
|
||||||
|
*/
|
||||||
|
whenNotMatchedBySourceDelete?: string | boolean
|
||||||
|
}
|
||||||
|
|
||||||
export interface VectorIndex {
|
export interface VectorIndex {
|
||||||
columns: string[]
|
columns: string[]
|
||||||
name: string
|
name: string
|
||||||
@@ -821,6 +884,38 @@ export class LocalTable<T = number[]> implements Table<T> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
|
||||||
|
const whenMatchedUpdateAll = args.whenMatchedUpdateAll ?? false
|
||||||
|
const whenNotMatchedInsertAll = args.whenNotMatchedInsertAll ?? false
|
||||||
|
let whenNotMatchedBySourceDelete = false
|
||||||
|
let whenNotMatchedBySourceDeleteFilt = null
|
||||||
|
if (args.whenNotMatchedBySourceDelete !== undefined && args.whenNotMatchedBySourceDelete !== null) {
|
||||||
|
whenNotMatchedBySourceDelete = true
|
||||||
|
if (args.whenNotMatchedBySourceDelete !== true) {
|
||||||
|
whenNotMatchedBySourceDeleteFilt = args.whenNotMatchedBySourceDelete
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const schema = await this.schema
|
||||||
|
let tbl: ArrowTable
|
||||||
|
if (data instanceof ArrowTable) {
|
||||||
|
tbl = data
|
||||||
|
} else {
|
||||||
|
tbl = makeArrowTable(data, { schema })
|
||||||
|
}
|
||||||
|
const buffer = await fromTableToBuffer(tbl, this._embeddings, schema)
|
||||||
|
|
||||||
|
this._tbl = await tableMergeInsert.call(
|
||||||
|
this._tbl,
|
||||||
|
on,
|
||||||
|
whenMatchedUpdateAll,
|
||||||
|
whenNotMatchedInsertAll,
|
||||||
|
whenNotMatchedBySourceDelete,
|
||||||
|
whenNotMatchedBySourceDeleteFilt,
|
||||||
|
buffer
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Clean up old versions of the table, freeing disk space.
|
* Clean up old versions of the table, freeing disk space.
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -24,7 +24,8 @@ import {
|
|||||||
type IndexStats,
|
type IndexStats,
|
||||||
type UpdateArgs,
|
type UpdateArgs,
|
||||||
type UpdateSqlArgs,
|
type UpdateSqlArgs,
|
||||||
makeArrowTable
|
makeArrowTable,
|
||||||
|
type MergeInsertArgs
|
||||||
} from '../index'
|
} from '../index'
|
||||||
import { Query } from '../query'
|
import { Query } from '../query'
|
||||||
|
|
||||||
@@ -274,6 +275,52 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
|||||||
throw new Error('Not implemented')
|
throw new Error('Not implemented')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
|
||||||
|
let tbl: ArrowTable
|
||||||
|
if (data instanceof ArrowTable) {
|
||||||
|
tbl = data
|
||||||
|
} else {
|
||||||
|
tbl = makeArrowTable(data, await this.schema)
|
||||||
|
}
|
||||||
|
|
||||||
|
const queryParams: any = {
|
||||||
|
on
|
||||||
|
}
|
||||||
|
if (args.whenMatchedUpdateAll ?? false) {
|
||||||
|
queryParams.when_matched_update_all = 'true'
|
||||||
|
} else {
|
||||||
|
queryParams.when_matched_update_all = 'false'
|
||||||
|
}
|
||||||
|
if (args.whenNotMatchedInsertAll ?? false) {
|
||||||
|
queryParams.when_not_matched_insert_all = 'true'
|
||||||
|
} else {
|
||||||
|
queryParams.when_not_matched_insert_all = 'false'
|
||||||
|
}
|
||||||
|
if (args.whenNotMatchedBySourceDelete !== false && args.whenNotMatchedBySourceDelete !== null && args.whenNotMatchedBySourceDelete !== undefined) {
|
||||||
|
queryParams.when_not_matched_by_source_delete = 'true'
|
||||||
|
if (typeof args.whenNotMatchedBySourceDelete === 'string') {
|
||||||
|
queryParams.when_not_matched_by_source_delete_filt = args.whenNotMatchedBySourceDelete
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
queryParams.when_not_matched_by_source_delete = 'false'
|
||||||
|
}
|
||||||
|
|
||||||
|
const buffer = await fromTableToStreamBuffer(tbl, this._embeddings)
|
||||||
|
const res = await this._client.post(
|
||||||
|
`/v1/table/${this._name}/merge_insert/`,
|
||||||
|
buffer,
|
||||||
|
queryParams,
|
||||||
|
'application/vnd.apache.arrow.stream'
|
||||||
|
)
|
||||||
|
if (res.status !== 200) {
|
||||||
|
throw new Error(
|
||||||
|
`Server Error, status: ${res.status}, ` +
|
||||||
|
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||||
|
`message: ${res.statusText}: ${res.data}`
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async add (data: Array<Record<string, unknown>> | ArrowTable): Promise<number> {
|
async add (data: Array<Record<string, unknown>> | ArrowTable): Promise<number> {
|
||||||
let tbl: ArrowTable
|
let tbl: ArrowTable
|
||||||
if (data instanceof ArrowTable) {
|
if (data instanceof ArrowTable) {
|
||||||
|
|||||||
@@ -531,6 +531,44 @@ describe('LanceDB client', function () {
|
|||||||
assert.equal(await table.countRows(), 2)
|
assert.equal(await table.countRows(), 2)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('can merge insert records into the table', async function () {
|
||||||
|
const dir = await track().mkdir('lancejs')
|
||||||
|
const con = await lancedb.connect(dir)
|
||||||
|
|
||||||
|
const data = [{ id: 1, age: 1 }, { id: 2, age: 1 }]
|
||||||
|
const table = await con.createTable('my_table', data)
|
||||||
|
|
||||||
|
let newData = [{ id: 2, age: 2 }, { id: 3, age: 2 }]
|
||||||
|
await table.mergeInsert('id', newData, {
|
||||||
|
whenNotMatchedInsertAll: true
|
||||||
|
})
|
||||||
|
assert.equal(await table.countRows(), 3)
|
||||||
|
assert.equal((await table.filter('age = 2').execute()).length, 1)
|
||||||
|
|
||||||
|
newData = [{ id: 3, age: 3 }, { id: 4, age: 3 }]
|
||||||
|
await table.mergeInsert('id', newData, {
|
||||||
|
whenNotMatchedInsertAll: true,
|
||||||
|
whenMatchedUpdateAll: true
|
||||||
|
})
|
||||||
|
assert.equal(await table.countRows(), 4)
|
||||||
|
assert.equal((await table.filter('age = 3').execute()).length, 2)
|
||||||
|
|
||||||
|
newData = [{ id: 5, age: 4 }]
|
||||||
|
await table.mergeInsert('id', newData, {
|
||||||
|
whenNotMatchedInsertAll: true,
|
||||||
|
whenMatchedUpdateAll: true,
|
||||||
|
whenNotMatchedBySourceDelete: 'age < 3'
|
||||||
|
})
|
||||||
|
assert.equal(await table.countRows(), 3)
|
||||||
|
|
||||||
|
await table.mergeInsert('id', newData, {
|
||||||
|
whenNotMatchedInsertAll: true,
|
||||||
|
whenMatchedUpdateAll: true,
|
||||||
|
whenNotMatchedBySourceDelete: true
|
||||||
|
})
|
||||||
|
assert.equal(await table.countRows(), 1)
|
||||||
|
})
|
||||||
|
|
||||||
it('can update records in the table', async function () {
|
it('can update records in the table', async function () {
|
||||||
const uri = await createTestDB()
|
const uri = await createTestDB()
|
||||||
const con = await lancedb.connect(uri)
|
const con = await lancedb.connect(uri)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.5.1
|
current_version = 0.5.2
|
||||||
commit = True
|
commit = True
|
||||||
message = [python] Bump version: {current_version} → {new_version}
|
message = [python] Bump version: {current_version} → {new_version}
|
||||||
tag = True
|
tag = True
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Iterable, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .common import DATA
|
from .common import DATA
|
||||||
@@ -25,7 +25,7 @@ class LanceMergeInsertBuilder(object):
|
|||||||
more context
|
more context
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, table: "Table", on: Iterable[str]): # noqa: F821
|
def __init__(self, table: "Table", on: List[str]): # noqa: F821
|
||||||
# Do not put a docstring here. This method should be hidden
|
# Do not put a docstring here. This method should be hidden
|
||||||
# from API docs. Users should use merge_insert to create
|
# from API docs. Users should use merge_insert to create
|
||||||
# this object.
|
# this object.
|
||||||
@@ -77,10 +77,27 @@ class LanceMergeInsertBuilder(object):
|
|||||||
self._when_not_matched_by_source_condition = condition
|
self._when_not_matched_by_source_condition = condition
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def execute(self, new_data: DATA):
|
def execute(
|
||||||
|
self,
|
||||||
|
new_data: DATA,
|
||||||
|
on_bad_vectors: str = "error",
|
||||||
|
fill_value: float = 0.0,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Executes the merge insert operation
|
Executes the merge insert operation
|
||||||
|
|
||||||
Nothing is returned but the [`Table`][lancedb.table.Table] is updated
|
Nothing is returned but the [`Table`][lancedb.table.Table] is updated
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
new_data: DATA
|
||||||
|
New records which will be matched against the existing records
|
||||||
|
to potentially insert or update into the table. This parameter
|
||||||
|
can be anything you use for [`add`][lancedb.table.Table.add]
|
||||||
|
on_bad_vectors: str, default "error"
|
||||||
|
What to do if any of the vectors are not the same size or contains NaNs.
|
||||||
|
One of "error", "drop", "fill".
|
||||||
|
fill_value: float, default 0.
|
||||||
|
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||||
"""
|
"""
|
||||||
self._table._do_merge(self, new_data)
|
self._table._do_merge(self, new_data, on_bad_vectors, fill_value)
|
||||||
|
|||||||
@@ -13,6 +13,8 @@
|
|||||||
|
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
@@ -20,6 +22,8 @@ import attrs
|
|||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import requests
|
import requests
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from requests.adapters import HTTPAdapter
|
||||||
|
from urllib3 import Retry
|
||||||
|
|
||||||
from lancedb.common import Credential
|
from lancedb.common import Credential
|
||||||
from lancedb.remote import VectorQuery, VectorQueryResult
|
from lancedb.remote import VectorQuery, VectorQueryResult
|
||||||
@@ -57,6 +61,10 @@ class RestfulLanceDBClient:
|
|||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def session(self) -> requests.Session:
|
def session(self) -> requests.Session:
|
||||||
sess = requests.Session()
|
sess = requests.Session()
|
||||||
|
|
||||||
|
retry_adapter_instance = retry_adapter(retry_adapter_options())
|
||||||
|
sess.mount(urljoin(self.url, "/v1/table/"), retry_adapter_instance)
|
||||||
|
|
||||||
adapter_class = LanceDBClientHTTPAdapterFactory()
|
adapter_class = LanceDBClientHTTPAdapterFactory()
|
||||||
sess.mount("https://", adapter_class())
|
sess.mount("https://", adapter_class())
|
||||||
return sess
|
return sess
|
||||||
@@ -170,3 +178,72 @@ class RestfulLanceDBClient:
|
|||||||
"""Query a table."""
|
"""Query a table."""
|
||||||
tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc)
|
tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc)
|
||||||
return VectorQueryResult(tbl)
|
return VectorQueryResult(tbl)
|
||||||
|
|
||||||
|
def mount_retry_adapter_for_table(self, table_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Adds an http adapter to session that will retry retryable requests to the table.
|
||||||
|
"""
|
||||||
|
retry_options = retry_adapter_options(methods=["GET", "POST"])
|
||||||
|
retry_adapter_instance = retry_adapter(retry_options)
|
||||||
|
session = self.session
|
||||||
|
|
||||||
|
session.mount(
|
||||||
|
urljoin(self.url, f"/v1/table/{table_name}/query/"), retry_adapter_instance
|
||||||
|
)
|
||||||
|
session.mount(
|
||||||
|
urljoin(self.url, f"/v1/table/{table_name}/describe/"),
|
||||||
|
retry_adapter_instance,
|
||||||
|
)
|
||||||
|
session.mount(
|
||||||
|
urljoin(self.url, f"/v1/table/{table_name}/index/list/"),
|
||||||
|
retry_adapter_instance,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def retry_adapter_options(methods=["GET"]) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"retries": int(os.environ.get("LANCE_CLIENT_MAX_RETRIES", "3")),
|
||||||
|
"connect_retries": int(os.environ.get("LANCE_CLIENT_CONNECT_RETRIES", "3")),
|
||||||
|
"read_retries": int(os.environ.get("LANCE_CLIENT_READ_RETRIES", "3")),
|
||||||
|
"backoff_factor": float(
|
||||||
|
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_FACTOR", "0.25")
|
||||||
|
),
|
||||||
|
"backoff_jitter": float(
|
||||||
|
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_JITTER", "0.25")
|
||||||
|
),
|
||||||
|
"statuses": [
|
||||||
|
int(i.strip())
|
||||||
|
for i in os.environ.get(
|
||||||
|
"LANCE_CLIENT_RETRY_STATUSES", "429, 500, 502, 503"
|
||||||
|
).split(",")
|
||||||
|
],
|
||||||
|
"methods": methods,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def retry_adapter(options: Dict[str, Any]) -> HTTPAdapter:
|
||||||
|
total_retries = options["retries"]
|
||||||
|
connect_retries = options["connect_retries"]
|
||||||
|
read_retries = options["read_retries"]
|
||||||
|
backoff_factor = options["backoff_factor"]
|
||||||
|
backoff_jitter = options["backoff_jitter"]
|
||||||
|
statuses = options["statuses"]
|
||||||
|
methods = frozenset(options["methods"])
|
||||||
|
logging.debug(
|
||||||
|
f"Setting up retry adapter with {total_retries} retries," # noqa G003
|
||||||
|
+ f"connect retries {connect_retries}, read retries {read_retries},"
|
||||||
|
+ f"backoff factor {backoff_factor}, statuses {statuses}, "
|
||||||
|
+ f"methods {methods}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return HTTPAdapter(
|
||||||
|
max_retries=Retry(
|
||||||
|
total=total_retries,
|
||||||
|
connect=connect_retries,
|
||||||
|
read=read_retries,
|
||||||
|
backoff_factor=backoff_factor,
|
||||||
|
backoff_jitter=backoff_jitter,
|
||||||
|
status_forcelist=statuses,
|
||||||
|
allowed_methods=methods,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
@@ -95,6 +95,8 @@ class RemoteDBConnection(DBConnection):
|
|||||||
"""
|
"""
|
||||||
from .table import RemoteTable
|
from .table import RemoteTable
|
||||||
|
|
||||||
|
self._client.mount_retry_adapter_for_table(name)
|
||||||
|
|
||||||
# check if table exists
|
# check if table exists
|
||||||
try:
|
try:
|
||||||
self._client.post(f"/v1/table/{name}/describe/")
|
self._client.post(f"/v1/table/{name}/describe/")
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import pyarrow as pa
|
|||||||
from lance import json_to_schema
|
from lance import json_to_schema
|
||||||
|
|
||||||
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||||
|
from lancedb.merge import LanceMergeInsertBuilder
|
||||||
|
|
||||||
from ..query import LanceVectorQueryBuilder
|
from ..query import LanceVectorQueryBuilder
|
||||||
from ..table import Query, Table, _sanitize_data
|
from ..table import Query, Table, _sanitize_data
|
||||||
@@ -244,9 +245,46 @@ class RemoteTable(Table):
|
|||||||
result = self._conn._client.query(self._name, query)
|
result = self._conn._client.query(self._name, query)
|
||||||
return result.to_arrow()
|
return result.to_arrow()
|
||||||
|
|
||||||
def _do_merge(self, *_args):
|
def _do_merge(
|
||||||
"""_do_merge() is not supported on the LanceDB cloud yet"""
|
self,
|
||||||
return NotImplementedError("_do_merge() is not supported on the LanceDB cloud")
|
merge: LanceMergeInsertBuilder,
|
||||||
|
new_data: DATA,
|
||||||
|
on_bad_vectors: str,
|
||||||
|
fill_value: float,
|
||||||
|
):
|
||||||
|
data = _sanitize_data(
|
||||||
|
new_data,
|
||||||
|
self.schema,
|
||||||
|
metadata=None,
|
||||||
|
on_bad_vectors=on_bad_vectors,
|
||||||
|
fill_value=fill_value,
|
||||||
|
)
|
||||||
|
payload = to_ipc_binary(data)
|
||||||
|
|
||||||
|
params = {}
|
||||||
|
if len(merge._on) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"RemoteTable only supports a single on key in merge_insert"
|
||||||
|
)
|
||||||
|
params["on"] = merge._on[0]
|
||||||
|
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
|
||||||
|
params["when_not_matched_insert_all"] = str(
|
||||||
|
merge._when_not_matched_insert_all
|
||||||
|
).lower()
|
||||||
|
params["when_not_matched_by_source_delete"] = str(
|
||||||
|
merge._when_not_matched_by_source_delete
|
||||||
|
).lower()
|
||||||
|
if merge._when_not_matched_by_source_condition is not None:
|
||||||
|
params[
|
||||||
|
"when_not_matched_by_source_delete_filt"
|
||||||
|
] = merge._when_not_matched_by_source_condition
|
||||||
|
|
||||||
|
self._conn._client.post(
|
||||||
|
f"/v1/table/{self._name}/merge_insert/",
|
||||||
|
data=payload,
|
||||||
|
params=params,
|
||||||
|
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||||
|
)
|
||||||
|
|
||||||
def delete(self, predicate: str):
|
def delete(self, predicate: str):
|
||||||
"""Delete rows from the table.
|
"""Delete rows from the table.
|
||||||
@@ -359,6 +397,18 @@ class RemoteTable(Table):
|
|||||||
payload = {"predicate": where, "updates": updates}
|
payload = {"predicate": where, "updates": updates}
|
||||||
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
|
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
|
||||||
|
|
||||||
|
def cleanup_old_versions(self, *_):
|
||||||
|
"""cleanup_old_versions() is not supported on the LanceDB cloud"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"cleanup_old_versions() is not supported on the LanceDB cloud"
|
||||||
|
)
|
||||||
|
|
||||||
|
def compact_files(self, *_):
|
||||||
|
"""compact_files() is not supported on the LanceDB cloud"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"compact_files() is not supported on the LanceDB cloud"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_index(tbl: pa.Table, i: int) -> pa.Table:
|
def add_index(tbl: pa.Table, i: int) -> pa.Table:
|
||||||
return tbl.add_column(
|
return tbl.add_column(
|
||||||
|
|||||||
@@ -391,6 +391,8 @@ class Table(ABC):
|
|||||||
2 3 y
|
2 3 y
|
||||||
3 4 z
|
3 4 z
|
||||||
"""
|
"""
|
||||||
|
on = [on] if isinstance(on, str) else list(on.iter())
|
||||||
|
|
||||||
return LanceMergeInsertBuilder(self, on)
|
return LanceMergeInsertBuilder(self, on)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -438,6 +440,8 @@ class Table(ABC):
|
|||||||
the table
|
the table
|
||||||
vector_column_name: str
|
vector_column_name: str
|
||||||
The name of the vector column to search.
|
The name of the vector column to search.
|
||||||
|
|
||||||
|
The vector column needs to be a pyarrow fixed size list type
|
||||||
*default "vector"*
|
*default "vector"*
|
||||||
query_type: str
|
query_type: str
|
||||||
*default "auto"*.
|
*default "auto"*.
|
||||||
@@ -478,8 +482,8 @@ class Table(ABC):
|
|||||||
self,
|
self,
|
||||||
merge: LanceMergeInsertBuilder,
|
merge: LanceMergeInsertBuilder,
|
||||||
new_data: DATA,
|
new_data: DATA,
|
||||||
*,
|
on_bad_vectors: str,
|
||||||
schema: Optional[pa.Schema] = None,
|
fill_value: float,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -590,6 +594,52 @@ class Table(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def cleanup_old_versions(
|
||||||
|
self,
|
||||||
|
older_than: Optional[timedelta] = None,
|
||||||
|
*,
|
||||||
|
delete_unverified: bool = False,
|
||||||
|
) -> CleanupStats:
|
||||||
|
"""
|
||||||
|
Clean up old versions of the table, freeing disk space.
|
||||||
|
|
||||||
|
Note: This function is not available in LanceDb Cloud (since LanceDb
|
||||||
|
Cloud manages cleanup for you automatically)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
older_than: timedelta, default None
|
||||||
|
The minimum age of the version to delete. If None, then this defaults
|
||||||
|
to two weeks.
|
||||||
|
delete_unverified: bool, default False
|
||||||
|
Because they may be part of an in-progress transaction, files newer
|
||||||
|
than 7 days old are not deleted by default. If you are sure that
|
||||||
|
there are no in-progress transactions, then you can set this to True
|
||||||
|
to delete all files older than `older_than`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
CleanupStats
|
||||||
|
The stats of the cleanup operation, including how many bytes were
|
||||||
|
freed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compact_files(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Run the compaction process on the table.
|
||||||
|
|
||||||
|
Note: This function is not available in LanceDb Cloud (since LanceDb
|
||||||
|
Cloud manages compaction for you automatically)
|
||||||
|
|
||||||
|
This can be run after making several small appends to optimize the table
|
||||||
|
for faster reads.
|
||||||
|
|
||||||
|
Arguments are passed onto :meth:`lance.dataset.DatasetOptimizer.compact_files`.
|
||||||
|
For most cases, the default should be fine.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class LanceTable(Table):
|
class LanceTable(Table):
|
||||||
"""
|
"""
|
||||||
@@ -1265,7 +1315,20 @@ class LanceTable(Table):
|
|||||||
with_row_id=query.with_row_id,
|
with_row_id=query.with_row_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _do_merge(self, merge: LanceMergeInsertBuilder, new_data: DATA, *, schema=None):
|
def _do_merge(
|
||||||
|
self,
|
||||||
|
merge: LanceMergeInsertBuilder,
|
||||||
|
new_data: DATA,
|
||||||
|
on_bad_vectors: str,
|
||||||
|
fill_value: float,
|
||||||
|
):
|
||||||
|
new_data = _sanitize_data(
|
||||||
|
new_data,
|
||||||
|
self.schema,
|
||||||
|
metadata=self.schema.metadata,
|
||||||
|
on_bad_vectors=on_bad_vectors,
|
||||||
|
fill_value=fill_value,
|
||||||
|
)
|
||||||
ds = self.to_lance()
|
ds = self.to_lance()
|
||||||
builder = ds.merge_insert(merge._on)
|
builder = ds.merge_insert(merge._on)
|
||||||
if merge._when_matched_update_all:
|
if merge._when_matched_update_all:
|
||||||
@@ -1275,7 +1338,7 @@ class LanceTable(Table):
|
|||||||
if merge._when_not_matched_by_source_delete:
|
if merge._when_not_matched_by_source_delete:
|
||||||
cond = merge._when_not_matched_by_source_condition
|
cond = merge._when_not_matched_by_source_condition
|
||||||
builder.when_not_matched_by_source_delete(cond)
|
builder.when_not_matched_by_source_delete(cond)
|
||||||
builder.execute(new_data, schema=schema)
|
builder.execute(new_data)
|
||||||
|
|
||||||
def cleanup_old_versions(
|
def cleanup_old_versions(
|
||||||
self,
|
self,
|
||||||
@@ -1314,8 +1377,9 @@ class LanceTable(Table):
|
|||||||
This can be run after making several small appends to optimize the table
|
This can be run after making several small appends to optimize the table
|
||||||
for faster reads.
|
for faster reads.
|
||||||
|
|
||||||
Arguments are passed onto :meth:`lance.dataset.DatasetOptimizer.compact_files`.
|
Arguments are passed onto `lance.dataset.DatasetOptimizer.compact_files`.
|
||||||
For most cases, the default should be fine.
|
(see Lance documentation for more details) For most cases, the default
|
||||||
|
should be fine.
|
||||||
"""
|
"""
|
||||||
return self.to_lance().optimize.compact_files(*args, **kwargs)
|
return self.to_lance().optimize.compact_files(*args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.5.1"
|
version = "0.5.2"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"deprecation",
|
"deprecation",
|
||||||
"pylance==0.9.11",
|
"pylance==0.9.12",
|
||||||
"ratelimiter~=1.0",
|
"ratelimiter~=1.0",
|
||||||
"retry>=0.9.2",
|
"retry>=0.9.2",
|
||||||
"tqdm>=4.27.0",
|
"tqdm>=4.27.0",
|
||||||
|
|||||||
@@ -29,6 +29,9 @@ class FakeLanceDBClient:
|
|||||||
def post(self, path: str):
|
def post(self, path: str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def mount_retry_adapter_for_table(self, table_name: str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def test_remote_db():
|
def test_remote_db():
|
||||||
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
|
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
|
||||||
|
|||||||
@@ -260,6 +260,7 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> {
|
|||||||
cx.export_function("tableCountRows", JsTable::js_count_rows)?;
|
cx.export_function("tableCountRows", JsTable::js_count_rows)?;
|
||||||
cx.export_function("tableDelete", JsTable::js_delete)?;
|
cx.export_function("tableDelete", JsTable::js_delete)?;
|
||||||
cx.export_function("tableUpdate", JsTable::js_update)?;
|
cx.export_function("tableUpdate", JsTable::js_update)?;
|
||||||
|
cx.export_function("tableMergeInsert", JsTable::js_merge_insert)?;
|
||||||
cx.export_function("tableCleanupOldVersions", JsTable::js_cleanup)?;
|
cx.export_function("tableCleanupOldVersions", JsTable::js_cleanup)?;
|
||||||
cx.export_function("tableCompactFiles", JsTable::js_compact)?;
|
cx.export_function("tableCompactFiles", JsTable::js_compact)?;
|
||||||
cx.export_function("tableListIndices", JsTable::js_list_indices)?;
|
cx.export_function("tableListIndices", JsTable::js_list_indices)?;
|
||||||
|
|||||||
@@ -12,6 +12,8 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::ops::Deref;
|
||||||
|
|
||||||
use arrow_array::{RecordBatch, RecordBatchIterator};
|
use arrow_array::{RecordBatch, RecordBatchIterator};
|
||||||
use lance::dataset::optimize::CompactionOptions;
|
use lance::dataset::optimize::CompactionOptions;
|
||||||
use lance::dataset::{WriteMode, WriteParams};
|
use lance::dataset::{WriteMode, WriteParams};
|
||||||
@@ -166,6 +168,53 @@ impl JsTable {
|
|||||||
Ok(promise)
|
Ok(promise)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn js_merge_insert(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
|
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||||
|
let rt = runtime(&mut cx)?;
|
||||||
|
let (deferred, promise) = cx.promise();
|
||||||
|
let channel = cx.channel();
|
||||||
|
let table = js_table.table.clone();
|
||||||
|
|
||||||
|
let key = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||||
|
let mut builder = table.merge_insert(&[&key]);
|
||||||
|
if cx.argument::<JsBoolean>(1)?.value(&mut cx) {
|
||||||
|
builder.when_matched_update_all();
|
||||||
|
}
|
||||||
|
if cx.argument::<JsBoolean>(2)?.value(&mut cx) {
|
||||||
|
builder.when_not_matched_insert_all();
|
||||||
|
}
|
||||||
|
if cx.argument::<JsBoolean>(3)?.value(&mut cx) {
|
||||||
|
if let Some(filter) = cx.argument_opt(4) {
|
||||||
|
if filter.is_a::<JsNull, _>(&mut cx) {
|
||||||
|
builder.when_not_matched_by_source_delete(None);
|
||||||
|
} else {
|
||||||
|
let filter = filter
|
||||||
|
.downcast_or_throw::<JsString, _>(&mut cx)?
|
||||||
|
.deref()
|
||||||
|
.value(&mut cx);
|
||||||
|
builder.when_not_matched_by_source_delete(Some(filter));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
builder.when_not_matched_by_source_delete(None);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let buffer = cx.argument::<JsBuffer>(5)?;
|
||||||
|
let (batches, schema) =
|
||||||
|
arrow_buffer_to_record_batch(buffer.as_slice(&cx)).or_throw(&mut cx)?;
|
||||||
|
|
||||||
|
rt.spawn(async move {
|
||||||
|
let new_data = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
|
||||||
|
let merge_insert_result = builder.execute(Box::new(new_data)).await;
|
||||||
|
|
||||||
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
|
merge_insert_result.or_throw(&mut cx)?;
|
||||||
|
Ok(cx.boxed(JsTable::from(table)))
|
||||||
|
})
|
||||||
|
});
|
||||||
|
Ok(promise)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn js_update(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub(crate) fn js_update(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||||
let table = js_table.table.clone();
|
let table = js_table.table.clone();
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ use std::sync::{Arc, Mutex};
|
|||||||
|
|
||||||
use arrow_array::RecordBatchReader;
|
use arrow_array::RecordBatchReader;
|
||||||
use arrow_schema::{Schema, SchemaRef};
|
use arrow_schema::{Schema, SchemaRef};
|
||||||
|
use async_trait::async_trait;
|
||||||
use chrono::Duration;
|
use chrono::Duration;
|
||||||
use lance::dataset::builder::DatasetBuilder;
|
use lance::dataset::builder::DatasetBuilder;
|
||||||
use lance::dataset::cleanup::RemovalStats;
|
use lance::dataset::cleanup::RemovalStats;
|
||||||
@@ -27,6 +28,7 @@ use lance::dataset::optimize::{
|
|||||||
};
|
};
|
||||||
pub use lance::dataset::ReadParams;
|
pub use lance::dataset::ReadParams;
|
||||||
use lance::dataset::{Dataset, UpdateBuilder, WriteParams};
|
use lance::dataset::{Dataset, UpdateBuilder, WriteParams};
|
||||||
|
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
|
||||||
use lance::io::WrappingObjectStore;
|
use lance::io::WrappingObjectStore;
|
||||||
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
|
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
|
||||||
use log::info;
|
use log::info;
|
||||||
@@ -38,6 +40,10 @@ use crate::query::Query;
|
|||||||
use crate::utils::{PatchReadParam, PatchWriteParam};
|
use crate::utils::{PatchReadParam, PatchWriteParam};
|
||||||
use crate::WriteMode;
|
use crate::WriteMode;
|
||||||
|
|
||||||
|
use self::merge::{MergeInsert, MergeInsertBuilder};
|
||||||
|
|
||||||
|
pub mod merge;
|
||||||
|
|
||||||
/// Optimize the dataset.
|
/// Optimize the dataset.
|
||||||
///
|
///
|
||||||
/// Similar to `VACUUM` in PostgreSQL, it offers different options to
|
/// Similar to `VACUUM` in PostgreSQL, it offers different options to
|
||||||
@@ -170,6 +176,71 @@ pub trait Table: std::fmt::Display + Send + Sync {
|
|||||||
/// ```
|
/// ```
|
||||||
fn create_index(&self, column: &[&str]) -> IndexBuilder;
|
fn create_index(&self, column: &[&str]) -> IndexBuilder;
|
||||||
|
|
||||||
|
/// Create a builder for a merge insert operation
|
||||||
|
///
|
||||||
|
/// This operation can add rows, update rows, and remove rows all in a single
|
||||||
|
/// transaction. It is a very generic tool that can be used to create
|
||||||
|
/// behaviors like "insert if not exists", "update or insert (i.e. upsert)",
|
||||||
|
/// or even replace a portion of existing data with new data (e.g. replace
|
||||||
|
/// all data where month="january")
|
||||||
|
///
|
||||||
|
/// The merge insert operation works by combining new data from a
|
||||||
|
/// **source table** with existing data in a **target table** by using a
|
||||||
|
/// join. There are three categories of records.
|
||||||
|
///
|
||||||
|
/// "Matched" records are records that exist in both the source table and
|
||||||
|
/// the target table. "Not matched" records exist only in the source table
|
||||||
|
/// (e.g. these are new data) "Not matched by source" records exist only
|
||||||
|
/// in the target table (this is old data)
|
||||||
|
///
|
||||||
|
/// The builder returned by this method can be used to customize what
|
||||||
|
/// should happen for each category of data.
|
||||||
|
///
|
||||||
|
/// Please note that the data may appear to be reordered as part of this
|
||||||
|
/// operation. This is because updated rows will be deleted from the
|
||||||
|
/// dataset and then reinserted at the end with the new values.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `on` One or more columns to join on. This is how records from the
|
||||||
|
/// source table and target table are matched. Typically this is some
|
||||||
|
/// kind of key or id column.
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```no_run
|
||||||
|
/// # use std::sync::Arc;
|
||||||
|
/// # use vectordb::connection::{Database, Connection};
|
||||||
|
/// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
|
||||||
|
/// # RecordBatchIterator, Int32Array};
|
||||||
|
/// # use arrow_schema::{Schema, Field, DataType};
|
||||||
|
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||||
|
/// let tmpdir = tempfile::tempdir().unwrap();
|
||||||
|
/// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
|
||||||
|
/// # let tbl = db.open_table("idx_test").await.unwrap();
|
||||||
|
/// # let schema = Arc::new(Schema::new(vec![
|
||||||
|
/// # Field::new("id", DataType::Int32, false),
|
||||||
|
/// # Field::new("vector", DataType::FixedSizeList(
|
||||||
|
/// # Arc::new(Field::new("item", DataType::Float32, true)), 128), true),
|
||||||
|
/// # ]));
|
||||||
|
/// let new_data = RecordBatchIterator::new(vec![
|
||||||
|
/// RecordBatch::try_new(schema.clone(),
|
||||||
|
/// vec![
|
||||||
|
/// Arc::new(Int32Array::from_iter_values(0..10)),
|
||||||
|
/// Arc::new(FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||||
|
/// (0..10).map(|_| Some(vec![Some(1.0); 128])), 128)),
|
||||||
|
/// ]).unwrap()
|
||||||
|
/// ].into_iter().map(Ok),
|
||||||
|
/// schema.clone());
|
||||||
|
/// // Perform an upsert operation
|
||||||
|
/// let mut merge_insert = tbl.merge_insert(&["id"]);
|
||||||
|
/// merge_insert.when_matched_update_all()
|
||||||
|
/// .when_not_matched_insert_all();
|
||||||
|
/// merge_insert.execute(Box::new(new_data)).await.unwrap();
|
||||||
|
/// # });
|
||||||
|
/// ```
|
||||||
|
fn merge_insert(&self, on: &[&str]) -> MergeInsertBuilder;
|
||||||
|
|
||||||
/// Search the table with a given query vector.
|
/// Search the table with a given query vector.
|
||||||
///
|
///
|
||||||
/// This is a convenience method for preparing an ANN query.
|
/// This is a convenience method for preparing an ANN query.
|
||||||
@@ -593,6 +664,42 @@ impl NativeTable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl MergeInsert for NativeTable {
|
||||||
|
async fn do_merge_insert(
|
||||||
|
&self,
|
||||||
|
params: MergeInsertBuilder,
|
||||||
|
new_data: Box<dyn RecordBatchReader + Send>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let dataset = Arc::new(self.clone_inner_dataset());
|
||||||
|
let mut builder = LanceMergeInsertBuilder::try_new(dataset.clone(), params.on)?;
|
||||||
|
if params.when_matched_update_all {
|
||||||
|
builder.when_matched(lance::dataset::WhenMatched::UpdateAll);
|
||||||
|
} else {
|
||||||
|
builder.when_matched(lance::dataset::WhenMatched::DoNothing);
|
||||||
|
}
|
||||||
|
if params.when_not_matched_insert_all {
|
||||||
|
builder.when_not_matched(lance::dataset::WhenNotMatched::InsertAll);
|
||||||
|
} else {
|
||||||
|
builder.when_not_matched(lance::dataset::WhenNotMatched::DoNothing);
|
||||||
|
}
|
||||||
|
if params.when_not_matched_by_source_delete {
|
||||||
|
let behavior = if let Some(filter) = params.when_not_matched_by_source_delete_filt {
|
||||||
|
WhenNotMatchedBySource::delete_if(dataset.as_ref(), &filter)?
|
||||||
|
} else {
|
||||||
|
WhenNotMatchedBySource::Delete
|
||||||
|
};
|
||||||
|
builder.when_not_matched_by_source(behavior);
|
||||||
|
} else {
|
||||||
|
builder.when_not_matched_by_source(WhenNotMatchedBySource::Keep);
|
||||||
|
}
|
||||||
|
let job = builder.try_build()?;
|
||||||
|
let new_dataset = job.execute_reader(new_data).await?;
|
||||||
|
self.reset_dataset((*new_dataset).clone());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl Table for NativeTable {
|
impl Table for NativeTable {
|
||||||
fn as_any(&self) -> &dyn std::any::Any {
|
fn as_any(&self) -> &dyn std::any::Any {
|
||||||
@@ -637,6 +744,11 @@ impl Table for NativeTable {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn merge_insert(&self, on: &[&str]) -> MergeInsertBuilder {
|
||||||
|
let on = Vec::from_iter(on.iter().map(|key| key.to_string()));
|
||||||
|
MergeInsertBuilder::new(Arc::new(self.clone()), on)
|
||||||
|
}
|
||||||
|
|
||||||
fn create_index(&self, columns: &[&str]) -> IndexBuilder {
|
fn create_index(&self, columns: &[&str]) -> IndexBuilder {
|
||||||
IndexBuilder::new(Arc::new(self.clone()), columns)
|
IndexBuilder::new(Arc::new(self.clone()), columns)
|
||||||
}
|
}
|
||||||
@@ -802,6 +914,38 @@ mod tests {
|
|||||||
assert_eq!(table.name, "test");
|
assert_eq!(table.name, "test");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_merge_insert() {
|
||||||
|
let tmp_dir = tempdir().unwrap();
|
||||||
|
let uri = tmp_dir.path().to_str().unwrap();
|
||||||
|
|
||||||
|
// Create a dataset with i=0..10
|
||||||
|
let batches = make_test_batches_with_offset(0);
|
||||||
|
let table = NativeTable::create(&uri, "test", batches, None, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(table.count_rows().await.unwrap(), 10);
|
||||||
|
|
||||||
|
// Create new data with i=5..15
|
||||||
|
let new_batches = Box::new(make_test_batches_with_offset(5));
|
||||||
|
|
||||||
|
// Perform a "insert if not exists"
|
||||||
|
let mut merge_insert_builder = table.merge_insert(&["i"]);
|
||||||
|
merge_insert_builder.when_not_matched_insert_all();
|
||||||
|
merge_insert_builder.execute(new_batches).await.unwrap();
|
||||||
|
// Only 5 rows should actually be inserted
|
||||||
|
assert_eq!(table.count_rows().await.unwrap(), 15);
|
||||||
|
|
||||||
|
// Create new data with i=15..25 (no id matches)
|
||||||
|
let new_batches = Box::new(make_test_batches_with_offset(15));
|
||||||
|
// Perform a "bulk update" (should not affect anything)
|
||||||
|
let mut merge_insert_builder = table.merge_insert(&["i"]);
|
||||||
|
merge_insert_builder.when_matched_update_all();
|
||||||
|
merge_insert_builder.execute(new_batches).await.unwrap();
|
||||||
|
// No new rows should have been inserted
|
||||||
|
assert_eq!(table.count_rows().await.unwrap(), 15);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_add_overwrite() {
|
async fn test_add_overwrite() {
|
||||||
let tmp_dir = tempdir().unwrap();
|
let tmp_dir = tempdir().unwrap();
|
||||||
@@ -1148,17 +1292,25 @@ mod tests {
|
|||||||
assert!(wrapper.called());
|
assert!(wrapper.called());
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
|
fn make_test_batches_with_offset(
|
||||||
|
offset: i32,
|
||||||
|
) -> impl RecordBatchReader + Send + Sync + 'static {
|
||||||
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
|
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
|
||||||
RecordBatchIterator::new(
|
RecordBatchIterator::new(
|
||||||
vec![RecordBatch::try_new(
|
vec![RecordBatch::try_new(
|
||||||
schema.clone(),
|
schema.clone(),
|
||||||
vec![Arc::new(Int32Array::from_iter_values(0..10))],
|
vec![Arc::new(Int32Array::from_iter_values(
|
||||||
|
offset..(offset + 10),
|
||||||
|
))],
|
||||||
)],
|
)],
|
||||||
schema,
|
schema,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
|
||||||
|
make_test_batches_with_offset(0)
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_create_index() {
|
async fn test_create_index() {
|
||||||
use arrow_array::RecordBatch;
|
use arrow_array::RecordBatch;
|
||||||
|
|||||||
95
rust/vectordb/src/table/merge.rs
Normal file
95
rust/vectordb/src/table/merge.rs
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
// Copyright 2024 Lance Developers.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use arrow_array::RecordBatchReader;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
use crate::Result;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub(super) trait MergeInsert: Send + Sync {
|
||||||
|
async fn do_merge_insert(
|
||||||
|
&self,
|
||||||
|
params: MergeInsertBuilder,
|
||||||
|
new_data: Box<dyn RecordBatchReader + Send>,
|
||||||
|
) -> Result<()>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A builder used to create and run a merge insert operation
|
||||||
|
///
|
||||||
|
/// See [`super::Table::merge_insert`] for more context
|
||||||
|
pub struct MergeInsertBuilder {
|
||||||
|
table: Arc<dyn MergeInsert>,
|
||||||
|
pub(super) on: Vec<String>,
|
||||||
|
pub(super) when_matched_update_all: bool,
|
||||||
|
pub(super) when_not_matched_insert_all: bool,
|
||||||
|
pub(super) when_not_matched_by_source_delete: bool,
|
||||||
|
pub(super) when_not_matched_by_source_delete_filt: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MergeInsertBuilder {
|
||||||
|
pub(super) fn new(table: Arc<dyn MergeInsert>, on: Vec<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
table,
|
||||||
|
on,
|
||||||
|
when_matched_update_all: false,
|
||||||
|
when_not_matched_insert_all: false,
|
||||||
|
when_not_matched_by_source_delete: false,
|
||||||
|
when_not_matched_by_source_delete_filt: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Rows that exist in both the source table (new data) and
|
||||||
|
/// the target table (old data) will be updated, replacing
|
||||||
|
/// the old row with the corresponding matching row.
|
||||||
|
///
|
||||||
|
/// If there are multiple matches then the behavior is undefined.
|
||||||
|
/// Currently this causes multiple copies of the row to be created
|
||||||
|
/// but that behavior is subject to change.
|
||||||
|
pub fn when_matched_update_all(&mut self) -> &mut Self {
|
||||||
|
self.when_matched_update_all = true;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Rows that exist only in the source table (new data) should
|
||||||
|
/// be inserted into the target table.
|
||||||
|
pub fn when_not_matched_insert_all(&mut self) -> &mut Self {
|
||||||
|
self.when_not_matched_insert_all = true;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Rows that exist only in the target table (old data) will be
|
||||||
|
/// deleted. An optional condition can be provided to limit what
|
||||||
|
/// data is deleted.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `condition` - If None then all such rows will be deleted.
|
||||||
|
/// Otherwise the condition will be used as an SQL filter to
|
||||||
|
/// limit what rows are deleted.
|
||||||
|
pub fn when_not_matched_by_source_delete(&mut self, filter: Option<String>) -> &mut Self {
|
||||||
|
self.when_not_matched_by_source_delete = true;
|
||||||
|
self.when_not_matched_by_source_delete_filt = filter;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Executes the merge insert operation
|
||||||
|
///
|
||||||
|
/// Nothing is returned but the [`super::Table`] is updated
|
||||||
|
pub async fn execute(self, new_data: Box<dyn RecordBatchReader + Send>) -> Result<()> {
|
||||||
|
self.table.clone().do_merge_insert(self, new_data).await
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user