Compare commits

...

3 Commits

Author SHA1 Message Date
albertlockett
2bbc56b9f9 more code 2024-01-31 18:48:12 -05:00
albertlockett
65c9c0ba9b it compiles 2024-01-31 18:28:51 -05:00
albertlockett
e2e45dd5a6 merge insert 2024-01-31 18:09:10 -05:00
8 changed files with 217 additions and 5 deletions

View File

@@ -11,10 +11,11 @@ license = "Apache-2.0"
repository = "https://github.com/lancedb/lancedb"
[workspace.dependencies]
lance = { "version" = "=0.9.10", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.9.10" }
lance-linalg = { "version" = "=0.9.10" }
lance-testing = { "version" = "=0.9.10" }
lance = { "version" = "=0.9.11", "features" = ["dynamodb"] }
lance-datafusion = { "version" = "=0.9.11" }
lance-index = { "version" = "=0.9.11" }
lance-linalg = { "version" = "=0.9.11" }
lance-testing = { "version" = "=0.9.11" }
# Note that this one does not include pyarrow
arrow = { version = "50.0", optional = false }
arrow-array = "50.0"

View File

@@ -451,6 +451,11 @@ export interface Table<T = number[]> {
indexStats: (indexUuid: string) => Promise<IndexStats>
filter(value: string): Query<T>
/**
* TODO comment
*/
mergeInsert: () => MergeInsertBuilder
schema: Promise<Schema>
}
@@ -900,6 +905,15 @@ export class LocalTable<T = number[]> implements Table<T> {
return false
}
}
mergeInsert: () => MergeInsertBuilder = () => {
return new MergeInsertBuilder(async (args: {
params: MergeInsertParams
data: Array<Record<string, unknown>> | ArrowTable
}) => {
throw new Error('Not implemented')
})
}
}
export interface CleanupStats {
@@ -1076,3 +1090,56 @@ export enum MetricType {
*/
Dot = 'dot',
}
export interface MergeInsertParams {
whenMatchedUpdateAll: boolean
whenNotMatchedInsertAll: boolean
whenNotMatchedBySourceDelete: boolean
whenNotMatchedBySourceCondition: boolean
}
type MergeInsertCallback = (args: {
params: MergeInsertParams
data: Array<Record<string, unknown>> | ArrowTable
}) => Promise<void>
export class MergeInsertBuilder {
readonly #callback: MergeInsertCallback
readonly #params: MergeInsertParams
constructor (callback: MergeInsertCallback) {
this.#callback = callback
this.#params = {
whenMatchedUpdateAll: false,
whenNotMatchedInsertAll: false,
whenNotMatchedBySourceDelete: false,
whenNotMatchedBySourceCondition: false
}
}
whenMatchedUpdateAll (): MergeInsertBuilder {
this.#params.whenMatchedUpdateAll = true
return this
}
whenNotMatchedInsertAll (): MergeInsertBuilder {
this.#params.whenNotMatchedInsertAll = true
return this
}
whenNotMatchedBySourceDelete (): MergeInsertBuilder {
this.#params.whenNotMatchedBySourceDelete = true
return this
}
whenNotMatchedBySourceCondition (): MergeInsertBuilder {
this.#params.whenNotMatchedBySourceCondition = true
return this
}
async execute ({ data }: {
data: Array<Record<string, unknown>> | ArrowTable
}): Promise<void> {
await this.#callback({ params: this.#params, data })
}
}

View File

@@ -120,7 +120,7 @@ export class HttpLancedbClient {
public async post (
path: string,
data?: any,
params?: Record<string, string | number>,
params?: Record<string, string | number | boolean>,
content?: string | undefined
): Promise<AxiosResponse> {
const response = await axios.post(

View File

@@ -24,6 +24,8 @@ import {
type IndexStats,
type UpdateArgs,
type UpdateSqlArgs,
type MergeInsertParams,
MergeInsertBuilder,
makeArrowTable
} from '../index'
import { Query } from '../query'
@@ -424,4 +426,36 @@ export class RemoteTable<T = number[]> implements Table<T> {
numUnindexedRows: results.data.num_unindexed_rows
}
}
mergeInsert: () => MergeInsertBuilder = () => {
return new MergeInsertBuilder(async ({ data, params }: {
params: MergeInsertParams
data: Array<Record<string, unknown>> | ArrowTable
}) => {
// TODO -- uncomment this this
// let tbl: ArrowTable
// if (data instanceof ArrowTable) {
// tbl = data
// } else {
// tbl = makeArrowTable(data, await this.schema)
// }
const tbl = data as ArrowTable
const buffer = await fromTableToStreamBuffer(tbl, this._embeddings)
console.log({ buffer })
await this._client.post(
`/v1/table/${this._name}/merge_insert/`,
buffer,
{
when_matched_update_all: params.whenMatchedUpdateAll,
when_not_matched_insert_all: params.whenNotMatchedInsertAll,
when_not_matched_by_source_delete: params.whenNotMatchedBySourceDelete,
when_not_matched_by_source_condition: params.whenNotMatchedBySourceCondition
},
'application/vnd.apache.arrow.stream'
)
})
}
}

View File

@@ -22,6 +22,7 @@ object_store = { workspace = true }
snafu = { workspace = true }
half = { workspace = true }
lance = { workspace = true }
lance-datafusion = { workspace = true }
lance-index = { workspace = true }
lance-linalg = { workspace = true }
lance-testing = { workspace = true }

View File

@@ -175,6 +175,7 @@ pub mod error;
pub mod index;
pub mod io;
pub mod ipc;
pub mod merge_insert;
pub mod query;
pub mod table;
pub mod utils;

View File

@@ -0,0 +1,101 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::pin::Pin;
use std::sync::Arc;
use arrow_array::RecordBatchReader;
use lance::dataset::{self, WhenMatched, WhenNotMatched, WhenNotMatchedBySource};
use lance_datafusion::utils::reader_to_stream;
use crate::TableRef;
use crate::error::{Error, Result};
pub struct MergeInsertBuilder {
table: TableRef,
when_matched_update_all: bool,
when_not_matched_insert_all: bool,
when_not_matched_by_source_delete: bool,
when_not_matched_by_source_condition: bool,
}
impl MergeInsertBuilder {
pub(crate) fn new(table: TableRef) -> Self {
Self {
table,
when_matched_update_all: false,
when_not_matched_insert_all: false,
when_not_matched_by_source_delete: false,
when_not_matched_by_source_condition: false,
}
}
pub fn when_matched_update_all(mut self) -> Self {
self.when_matched_update_all = true;
self
}
pub fn when_not_matched_insert_all(mut self) -> Self {
self.when_not_matched_insert_all = true;
self
}
pub fn when_not_matched_by_source_delete(mut self) -> Self {
self.when_not_matched_by_source_delete = true;
self
}
pub fn when_not_matched_by_source_condition(mut self) -> Self {
self.when_not_matched_by_source_condition = true;
self
}
pub async fn execute(
mut self,
batches: Box<dyn RecordBatchReader + Send>,
) -> Result<()> {
let native_table = self.table.as_native().unwrap(); // TODO no unwrap
let ds = native_table.clone_inner_dataset();
let mut builder = dataset::MergeInsertBuilder::try_new(
Arc::new(ds),
vec!["vectors".to_string()],
)
.unwrap(); // TODO no unwrap
if self.when_matched_update_all {
builder.when_matched(WhenMatched::UpdateAll);
}
if self.when_not_matched_insert_all {
builder.when_not_matched(WhenNotMatched::InsertAll);
}
if self.when_not_matched_by_source_delete {
builder.when_not_matched_by_source(WhenNotMatchedBySource::Delete);
}
// TODO
// if self.when_not_matched_by_source_condition {
// builder.when_not_matched_by_source(WhenNotMatchedBySource::DeleteIf(()));
// }
let job = builder.try_build().unwrap(); // TODO no unwrap
let batches = reader_to_stream(batches).await.unwrap().0; // TODO no unwrap
let ds2 = job.execute(batches).await.unwrap(); // TODO no unwrap
native_table.reset_dataset(ds2.as_ref().clone());
Ok(())
}
}

View File

@@ -34,6 +34,7 @@ use log::info;
use crate::error::{Error, Result};
use crate::index::vector::{VectorIndex, VectorIndexStatistics};
use crate::index::IndexBuilder;
use crate::merge_insert::MergeInsertBuilder;
use crate::query::Query;
use crate::utils::{PatchReadParam, PatchWriteParam};
use crate::WriteMode;
@@ -241,6 +242,8 @@ pub trait Table: std::fmt::Display + Send + Sync {
/// Modeled after ``VACCUM`` in PostgreSQL.
/// Not all implementations support explicit optimization.
async fn optimize(&self, action: OptimizeAction) -> Result<OptimizeStats>;
fn merge_insert(&self) -> MergeInsertBuilder;
}
/// Reference to a Table pointer.
@@ -698,6 +701,10 @@ impl Table for NativeTable {
}
Ok(stats)
}
fn merge_insert(&self) -> MergeInsertBuilder {
MergeInsertBuilder::new(Arc::new(self.clone()))
}
}
#[cfg(test)]