mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 02:42:57 +00:00
Update in Node & Rust (#696)
Co-authored-by: Will Jones <willjones127@gmail.com>
This commit is contained in:
@@ -5,10 +5,10 @@ exclude = ["python"]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.8.17", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.8.17" }
|
||||
lance-linalg = { "version" = "=0.8.17" }
|
||||
lance-testing = { "version" = "=0.8.17" }
|
||||
lance = { "version" = "=0.8.20", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.8.20" }
|
||||
lance-linalg = { "version" = "=0.8.20" }
|
||||
lance-testing = { "version" = "=0.8.20" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "47.0.0", optional = false }
|
||||
arrow-array = "47.0"
|
||||
|
||||
@@ -21,9 +21,10 @@ import type { EmbeddingFunction } from './embedding/embedding_function'
|
||||
import { RemoteConnection } from './remote'
|
||||
import { Query } from './query'
|
||||
import { isEmbeddingFunction } from './embedding/embedding_function'
|
||||
import { type Literal, toSQL } from './util'
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-var-requires
|
||||
const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, tableCreate, tableAdd, tableCreateVectorIndex, tableCountRows, tableDelete, tableCleanupOldVersions, tableCompactFiles, tableListIndices, tableIndexStats } = require('../native.js')
|
||||
const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, tableCreate, tableAdd, tableCreateVectorIndex, tableCountRows, tableDelete, tableUpdate, tableCleanupOldVersions, tableCompactFiles, tableListIndices, tableIndexStats } = require('../native.js')
|
||||
|
||||
export { Query }
|
||||
export type { EmbeddingFunction }
|
||||
@@ -261,6 +262,39 @@ export interface Table<T = number[]> {
|
||||
*/
|
||||
delete: (filter: string) => Promise<void>
|
||||
|
||||
/**
|
||||
* Update rows in this table.
|
||||
*
|
||||
* This can be used to update a single row, many rows, all rows, or
|
||||
* sometimes no rows (if your predicate matches nothing).
|
||||
*
|
||||
* @param args see {@link UpdateArgs} and {@link UpdateSqlArgs} for more details
|
||||
*
|
||||
* @examples
|
||||
*
|
||||
* ```ts
|
||||
* const con = await lancedb.connect("./.lancedb")
|
||||
* const data = [
|
||||
* {id: 1, vector: [3, 3], name: 'Ye'},
|
||||
* {id: 2, vector: [4, 4], name: 'Mike'},
|
||||
* ];
|
||||
* const tbl = await con.createTable("my_table", data)
|
||||
*
|
||||
* await tbl.update({
|
||||
* filter: "id = 2",
|
||||
* updates: { vector: [2, 2], name: "Michael" },
|
||||
* })
|
||||
*
|
||||
* let results = await tbl.search([1, 1]).execute();
|
||||
* // Returns [
|
||||
* // {id: 2, vector: [2, 2], name: 'Michael'}
|
||||
* // {id: 1, vector: [3, 3], name: 'Ye'}
|
||||
* // ]
|
||||
* ```
|
||||
*
|
||||
*/
|
||||
update: (args: UpdateArgs | UpdateSqlArgs) => Promise<void>
|
||||
|
||||
/**
|
||||
* List the indicies on this table.
|
||||
*/
|
||||
@@ -272,6 +306,34 @@ export interface Table<T = number[]> {
|
||||
indexStats: (indexUuid: string) => Promise<IndexStats>
|
||||
}
|
||||
|
||||
export interface UpdateArgs {
|
||||
/**
|
||||
* A filter in the same format used by a sql WHERE clause. The filter may be empty,
|
||||
* in which case all rows will be updated.
|
||||
*/
|
||||
where?: string
|
||||
|
||||
/**
|
||||
* A key-value map of updates. The keys are the column names, and the values are the
|
||||
* new values to set
|
||||
*/
|
||||
values: Record<string, Literal>
|
||||
}
|
||||
|
||||
export interface UpdateSqlArgs {
|
||||
/**
|
||||
* A filter in the same format used by a sql WHERE clause. The filter may be empty,
|
||||
* in which case all rows will be updated.
|
||||
*/
|
||||
where?: string
|
||||
|
||||
/**
|
||||
* A key-value map of updates. The keys are the column names, and the values are the
|
||||
* new values to set as SQL expressions.
|
||||
*/
|
||||
valuesSql: Record<string, string>
|
||||
}
|
||||
|
||||
export interface VectorIndex {
|
||||
columns: string[]
|
||||
name: string
|
||||
@@ -481,6 +543,31 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
return tableDelete.call(this._tbl, filter).then((newTable: any) => { this._tbl = newTable })
|
||||
}
|
||||
|
||||
/**
|
||||
* Update rows in this table.
|
||||
*
|
||||
* @param args see {@link UpdateArgs} and {@link UpdateSqlArgs} for more details
|
||||
*
|
||||
* @returns
|
||||
*/
|
||||
async update (args: UpdateArgs | UpdateSqlArgs): Promise<void> {
|
||||
let filter: string | null
|
||||
let updates: Record<string, string>
|
||||
|
||||
if ('valuesSql' in args) {
|
||||
filter = args.where ?? null
|
||||
updates = args.valuesSql
|
||||
} else {
|
||||
filter = args.where ?? null
|
||||
updates = {}
|
||||
for (const [key, value] of Object.entries(args.values)) {
|
||||
updates[key] = toSQL(value)
|
||||
}
|
||||
}
|
||||
|
||||
return tableUpdate.call(this._tbl, filter, updates).then((newTable: any) => { this._tbl = newTable })
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up old versions of the table, freeing disk space.
|
||||
*
|
||||
|
||||
@@ -16,7 +16,8 @@ import {
|
||||
type EmbeddingFunction, type Table, type VectorIndexParams, type Connection,
|
||||
type ConnectionOptions, type CreateTableOptions, type VectorIndex,
|
||||
type WriteOptions,
|
||||
type IndexStats
|
||||
type IndexStats,
|
||||
type UpdateArgs, type UpdateSqlArgs
|
||||
} from '../index'
|
||||
import { Query } from '../query'
|
||||
|
||||
@@ -246,6 +247,10 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
await this._client.post(`/v1/table/${this._name}/delete/`, { predicate: filter })
|
||||
}
|
||||
|
||||
async update (args: UpdateArgs | UpdateSqlArgs): Promise<void> {
|
||||
throw new Error('Not implemented')
|
||||
}
|
||||
|
||||
async listIndices (): Promise<VectorIndex[]> {
|
||||
const results = await this._client.post(`/v1/table/${this._name}/index/list/`)
|
||||
return results.data.indexes?.map((index: any) => ({
|
||||
|
||||
@@ -260,6 +260,46 @@ describe('LanceDB client', function () {
|
||||
assert.equal(await table.countRows(), 2)
|
||||
})
|
||||
|
||||
it('can update records in the table', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
|
||||
const table = await con.openTable('vectors')
|
||||
assert.equal(await table.countRows(), 2)
|
||||
|
||||
await table.update({ where: 'price = 10', valuesSql: { price: '100' } })
|
||||
const results = await table.search([0.1, 0.2]).execute()
|
||||
assert.equal(results[0].price, 100)
|
||||
assert.equal(results[1].price, 11)
|
||||
})
|
||||
|
||||
it('can update the records using a literal value', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
|
||||
const table = await con.openTable('vectors')
|
||||
assert.equal(await table.countRows(), 2)
|
||||
|
||||
await table.update({ where: 'price = 10', values: { price: 100 } })
|
||||
const results = await table.search([0.1, 0.2]).execute()
|
||||
assert.equal(results[0].price, 100)
|
||||
assert.equal(results[1].price, 11)
|
||||
})
|
||||
|
||||
it('can update every record in the table', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
|
||||
const table = await con.openTable('vectors')
|
||||
assert.equal(await table.countRows(), 2)
|
||||
|
||||
await table.update({ valuesSql: { price: '100' } })
|
||||
const results = await table.search([0.1, 0.2]).execute()
|
||||
|
||||
assert.equal(results[0].price, 100)
|
||||
assert.equal(results[1].price, 100)
|
||||
})
|
||||
|
||||
it('can delete records from a table', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
@@ -542,7 +582,7 @@ describe('Compact and cleanup', function () {
|
||||
|
||||
// should have no effect, but this validates the arguments are parsed.
|
||||
await table.compactFiles({
|
||||
targetRowsPerFragment: 1024 * 10,
|
||||
targetRowsPerFragment: 102410,
|
||||
maxRowsPerGroup: 1024,
|
||||
materializeDeletions: true,
|
||||
materializeDeletionsThreshold: 0.5,
|
||||
|
||||
45
node/src/test/util.ts
Normal file
45
node/src/test/util.ts
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright 2023 LanceDB Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import { toSQL } from '../util'
|
||||
import * as chai from 'chai'
|
||||
|
||||
const expect = chai.expect
|
||||
|
||||
describe('toSQL', function () {
|
||||
it('should turn string to SQL expression', function () {
|
||||
expect(toSQL('foo')).to.equal("'foo'")
|
||||
})
|
||||
|
||||
it('should turn number to SQL expression', function () {
|
||||
expect(toSQL(123)).to.equal('123')
|
||||
})
|
||||
|
||||
it('should turn boolean to SQL expression', function () {
|
||||
expect(toSQL(true)).to.equal('TRUE')
|
||||
})
|
||||
|
||||
it('should turn null to SQL expression', function () {
|
||||
expect(toSQL(null)).to.equal('NULL')
|
||||
})
|
||||
|
||||
it('should turn Date to SQL expression', function () {
|
||||
const date = new Date('05 October 2011 14:48 UTC')
|
||||
expect(toSQL(date)).to.equal("'2011-10-05T14:48:00.000Z'")
|
||||
})
|
||||
|
||||
it('should turn array to SQL expression', function () {
|
||||
expect(toSQL(['foo', 'bar', true, 1])).to.equal("['foo', 'bar', TRUE, 1]")
|
||||
})
|
||||
})
|
||||
44
node/src/util.ts
Normal file
44
node/src/util.ts
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright 2023 LanceDB Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
export type Literal = string | number | boolean | null | Date | Literal[]
|
||||
|
||||
export function toSQL (value: Literal): string {
|
||||
if (typeof value === 'string') {
|
||||
return `'${value}'`
|
||||
}
|
||||
|
||||
if (typeof value === 'number') {
|
||||
return value.toString()
|
||||
}
|
||||
|
||||
if (typeof value === 'boolean') {
|
||||
return value ? 'TRUE' : 'FALSE'
|
||||
}
|
||||
|
||||
if (value === null) {
|
||||
return 'NULL'
|
||||
}
|
||||
|
||||
if (value instanceof Date) {
|
||||
return `'${value.toISOString()}'`
|
||||
}
|
||||
|
||||
if (Array.isArray(value)) {
|
||||
return `[${value.map(toSQL).join(', ')}]`
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
throw new Error(`Unsupported value type: ${typeof value} value: (${value})`)
|
||||
}
|
||||
@@ -237,6 +237,7 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> {
|
||||
cx.export_function("tableAdd", JsTable::js_add)?;
|
||||
cx.export_function("tableCountRows", JsTable::js_count_rows)?;
|
||||
cx.export_function("tableDelete", JsTable::js_delete)?;
|
||||
cx.export_function("tableUpdate", JsTable::js_update)?;
|
||||
cx.export_function("tableCleanupOldVersions", JsTable::js_cleanup)?;
|
||||
cx.export_function("tableCompactFiles", JsTable::js_compact)?;
|
||||
cx.export_function("tableListIndices", JsTable::js_list_indices)?;
|
||||
|
||||
@@ -48,7 +48,9 @@ impl JsQuery {
|
||||
.map(|s| s.value(&mut cx))
|
||||
.map(|s| MetricType::try_from(s.as_str()).unwrap());
|
||||
|
||||
let prefilter = query_obj.get::<JsBoolean, _, _>(&mut cx, "_prefilter")?.value(&mut cx);
|
||||
let prefilter = query_obj
|
||||
.get::<JsBoolean, _, _>(&mut cx, "_prefilter")?
|
||||
.value(&mut cx);
|
||||
|
||||
let is_electron = cx
|
||||
.argument::<JsBoolean>(1)
|
||||
|
||||
@@ -165,6 +165,69 @@ impl JsTable {
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
pub(crate) fn js_update(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let mut table = js_table.table.clone();
|
||||
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
let channel = cx.channel();
|
||||
|
||||
// create a vector of updates from the passed map
|
||||
let updates_arg = cx.argument::<JsObject>(1)?;
|
||||
let properties = updates_arg.get_own_property_names(&mut cx)?;
|
||||
let mut updates: Vec<(String, String)> =
|
||||
Vec::with_capacity(properties.len(&mut cx) as usize);
|
||||
|
||||
let len_properties = properties.len(&mut cx);
|
||||
for i in 0..len_properties {
|
||||
let property = properties
|
||||
.get_value(&mut cx, i)?
|
||||
.downcast_or_throw::<JsString, _>(&mut cx)?;
|
||||
|
||||
let value = updates_arg
|
||||
.get_value(&mut cx, property.clone())?
|
||||
.downcast_or_throw::<JsString, _>(&mut cx)?;
|
||||
|
||||
let property = property.value(&mut cx);
|
||||
let value = value.value(&mut cx);
|
||||
updates.push((property, value));
|
||||
}
|
||||
|
||||
// get the filter/predicate if the user passed one
|
||||
let predicate = cx.argument_opt(0);
|
||||
let predicate = predicate.unwrap().downcast::<JsString, _>(&mut cx);
|
||||
let predicate = match predicate {
|
||||
Ok(_) => {
|
||||
let val = predicate.map(|s| s.value(&mut cx)).unwrap();
|
||||
Some(val)
|
||||
}
|
||||
Err(_) => {
|
||||
// if the predicate is not string, check it's null otherwise an invalid
|
||||
// type was passed
|
||||
cx.argument::<JsNull>(0)?;
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
rt.spawn(async move {
|
||||
let updates_arg = updates
|
||||
.iter()
|
||||
.map(|(k, v)| (k.as_str(), v.as_str()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let predicate = predicate.as_ref().map(|s| s.as_str());
|
||||
|
||||
let update_result = table.update(predicate, updates_arg).await;
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
update_result.or_throw(&mut cx)?;
|
||||
Ok(cx.boxed(JsTable::from(table)))
|
||||
})
|
||||
});
|
||||
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
pub(crate) fn js_cleanup(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
|
||||
@@ -23,7 +23,7 @@ use lance::dataset::cleanup::RemovalStats;
|
||||
use lance::dataset::optimize::{
|
||||
compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions,
|
||||
};
|
||||
use lance::dataset::{Dataset, WriteParams};
|
||||
use lance::dataset::{Dataset, UpdateBuilder, WriteParams};
|
||||
use lance::index::DatasetIndexExt;
|
||||
use lance::io::object_store::WrappingObjectStore;
|
||||
use std::path::Path;
|
||||
@@ -338,6 +338,27 @@ impl Table {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update(
|
||||
&mut self,
|
||||
predicate: Option<&str>,
|
||||
updates: Vec<(&str, &str)>,
|
||||
) -> Result<()> {
|
||||
let mut builder = UpdateBuilder::new(self.dataset.clone());
|
||||
if let Some(predicate) = predicate {
|
||||
builder = builder.update_where(predicate)?;
|
||||
}
|
||||
|
||||
for (column, value) in updates {
|
||||
builder = builder.set(column, value)?;
|
||||
}
|
||||
|
||||
let operation = builder.build()?;
|
||||
let new_ds = operation.execute().await?;
|
||||
self.dataset = new_ds;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove old versions of the dataset from disk.
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -413,11 +434,14 @@ mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::{
|
||||
Array, FixedSizeListArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
|
||||
RecordBatchReader,
|
||||
Array, BooleanArray, Date32Array, FixedSizeListArray, Float32Array, Float64Array,
|
||||
Int32Array, Int64Array, LargeStringArray, RecordBatch, RecordBatchIterator,
|
||||
RecordBatchReader, StringArray, TimestampMillisecondArray, TimestampNanosecondArray,
|
||||
UInt32Array,
|
||||
};
|
||||
use arrow_data::ArrayDataBuilder;
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use arrow_schema::{DataType, Field, Schema, TimeUnit};
|
||||
use futures::TryStreamExt;
|
||||
use lance::dataset::{Dataset, WriteMode};
|
||||
use lance::index::vector::pq::PQBuildParams;
|
||||
use lance::io::object_store::{ObjectStoreParams, WrappingObjectStore};
|
||||
@@ -540,6 +564,272 @@ mod tests {
|
||||
assert_eq!(table.name, "test");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_with_predicate() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let dataset_path = tmp_dir.path().join("test.lance");
|
||||
let uri = dataset_path.to_str().unwrap();
|
||||
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Int32, false),
|
||||
Field::new("name", DataType::Utf8, false),
|
||||
]));
|
||||
|
||||
let record_batch_iter = RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(Int32Array::from_iter_values(0..10)),
|
||||
Arc::new(StringArray::from_iter_values(vec![
|
||||
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
|
||||
])),
|
||||
],
|
||||
)
|
||||
.unwrap()]
|
||||
.into_iter()
|
||||
.map(Ok),
|
||||
schema.clone(),
|
||||
);
|
||||
|
||||
Dataset::write(record_batch_iter, uri, None).await.unwrap();
|
||||
let mut table = Table::open(uri).await.unwrap();
|
||||
|
||||
table
|
||||
.update(Some("id > 5"), vec![("name", "'foo'")])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let ds_after = Dataset::open(uri).await.unwrap();
|
||||
let mut batches = ds_after
|
||||
.scan()
|
||||
.project(&["id", "name"])
|
||||
.unwrap()
|
||||
.try_into_stream()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
while let Some(batch) = batches.pop() {
|
||||
let ids = batch
|
||||
.column(0)
|
||||
.as_any()
|
||||
.downcast_ref::<Int32Array>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.collect::<Vec<_>>();
|
||||
let names = batch
|
||||
.column(1)
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.collect::<Vec<_>>();
|
||||
for (i, name) in names.iter().enumerate() {
|
||||
let id = ids[i].unwrap();
|
||||
let name = name.unwrap();
|
||||
if id > 5 {
|
||||
assert_eq!(name, "foo");
|
||||
} else {
|
||||
assert_eq!(name, &format!("{}", (b'a' + id as u8) as char));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_all_types() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let dataset_path = tmp_dir.path().join("test.lance");
|
||||
let uri = dataset_path.to_str().unwrap();
|
||||
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("int32", DataType::Int32, false),
|
||||
Field::new("int64", DataType::Int64, false),
|
||||
Field::new("uint32", DataType::UInt32, false),
|
||||
Field::new("string", DataType::Utf8, false),
|
||||
Field::new("large_string", DataType::LargeUtf8, false),
|
||||
Field::new("float32", DataType::Float32, false),
|
||||
Field::new("float64", DataType::Float64, false),
|
||||
Field::new("bool", DataType::Boolean, false),
|
||||
Field::new("date32", DataType::Date32, false),
|
||||
Field::new(
|
||||
"timestamp_ns",
|
||||
DataType::Timestamp(TimeUnit::Nanosecond, None),
|
||||
false,
|
||||
),
|
||||
Field::new(
|
||||
"timestamp_ms",
|
||||
DataType::Timestamp(TimeUnit::Millisecond, None),
|
||||
false,
|
||||
),
|
||||
Field::new(
|
||||
"vec_f32",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
|
||||
false,
|
||||
),
|
||||
Field::new(
|
||||
"vec_f64",
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), 2),
|
||||
false,
|
||||
),
|
||||
]));
|
||||
|
||||
let record_batch_iter = RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(Int32Array::from_iter_values(0..10)),
|
||||
Arc::new(Int64Array::from_iter_values(0..10)),
|
||||
Arc::new(UInt32Array::from_iter_values(0..10)),
|
||||
Arc::new(StringArray::from_iter_values(vec![
|
||||
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
|
||||
])),
|
||||
Arc::new(LargeStringArray::from_iter_values(vec![
|
||||
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
|
||||
])),
|
||||
Arc::new(Float32Array::from_iter_values(
|
||||
(0..10).into_iter().map(|i| i as f32),
|
||||
)),
|
||||
Arc::new(Float64Array::from_iter_values(
|
||||
(0..10).into_iter().map(|i| i as f64),
|
||||
)),
|
||||
Arc::new(Into::<BooleanArray>::into(vec![
|
||||
true, false, true, false, true, false, true, false, true, false,
|
||||
])),
|
||||
Arc::new(Date32Array::from_iter_values(0..10)),
|
||||
Arc::new(TimestampNanosecondArray::from_iter_values(0..10)),
|
||||
Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
|
||||
Arc::new(
|
||||
create_fixed_size_list(
|
||||
Float32Array::from_iter_values((0..20).into_iter().map(|i| i as f32)),
|
||||
2,
|
||||
)
|
||||
.unwrap(),
|
||||
),
|
||||
Arc::new(
|
||||
create_fixed_size_list(
|
||||
Float64Array::from_iter_values((0..20).into_iter().map(|i| i as f64)),
|
||||
2,
|
||||
)
|
||||
.unwrap(),
|
||||
),
|
||||
],
|
||||
)
|
||||
.unwrap()]
|
||||
.into_iter()
|
||||
.map(Ok),
|
||||
schema.clone(),
|
||||
);
|
||||
|
||||
Dataset::write(record_batch_iter, uri, None).await.unwrap();
|
||||
let mut table = Table::open(uri).await.unwrap();
|
||||
|
||||
// check it can do update for each type
|
||||
let updates: Vec<(&str, &str)> = vec![
|
||||
("string", "'foo'"),
|
||||
("large_string", "'large_foo'"),
|
||||
("int32", "1"),
|
||||
("int64", "1"),
|
||||
("uint32", "1"),
|
||||
("float32", "1.0"),
|
||||
("float64", "1.0"),
|
||||
("bool", "true"),
|
||||
("date32", "1"),
|
||||
("timestamp_ns", "1"),
|
||||
("timestamp_ms", "1"),
|
||||
("vec_f32", "[1.0, 1.0]"),
|
||||
("vec_f64", "[1.0, 1.0]"),
|
||||
];
|
||||
|
||||
// for (column, value) in test_cases {
|
||||
table.update(None, updates).await.unwrap();
|
||||
|
||||
let ds_after = Dataset::open(uri).await.unwrap();
|
||||
let mut batches = ds_after
|
||||
.scan()
|
||||
.project(&[
|
||||
"string",
|
||||
"large_string",
|
||||
"int32",
|
||||
"int64",
|
||||
"uint32",
|
||||
"float32",
|
||||
"float64",
|
||||
"bool",
|
||||
"date32",
|
||||
"timestamp_ns",
|
||||
"timestamp_ms",
|
||||
"vec_f32",
|
||||
"vec_f64",
|
||||
])
|
||||
.unwrap()
|
||||
.try_into_stream()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
let batch = batches.pop().unwrap();
|
||||
|
||||
macro_rules! assert_column {
|
||||
($column:expr, $array_type:ty, $expected:expr) => {
|
||||
let array = $column
|
||||
.as_any()
|
||||
.downcast_ref::<$array_type>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.collect::<Vec<_>>();
|
||||
for v in array {
|
||||
assert_eq!(v, Some($expected));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
assert_column!(batch.column(0), StringArray, "foo");
|
||||
assert_column!(batch.column(1), LargeStringArray, "large_foo");
|
||||
assert_column!(batch.column(2), Int32Array, 1);
|
||||
assert_column!(batch.column(3), Int64Array, 1);
|
||||
assert_column!(batch.column(4), UInt32Array, 1);
|
||||
assert_column!(batch.column(5), Float32Array, 1.0);
|
||||
assert_column!(batch.column(6), Float64Array, 1.0);
|
||||
assert_column!(batch.column(7), BooleanArray, true);
|
||||
assert_column!(batch.column(8), Date32Array, 1);
|
||||
assert_column!(batch.column(9), TimestampNanosecondArray, 1);
|
||||
assert_column!(batch.column(10), TimestampMillisecondArray, 1);
|
||||
|
||||
let array = batch
|
||||
.column(11)
|
||||
.as_any()
|
||||
.downcast_ref::<FixedSizeListArray>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.collect::<Vec<_>>();
|
||||
for v in array {
|
||||
let v = v.unwrap();
|
||||
let f32array = v.as_any().downcast_ref::<Float32Array>().unwrap();
|
||||
for v in f32array {
|
||||
assert_eq!(v, Some(1.0));
|
||||
}
|
||||
}
|
||||
|
||||
let array = batch
|
||||
.column(12)
|
||||
.as_any()
|
||||
.downcast_ref::<FixedSizeListArray>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.collect::<Vec<_>>();
|
||||
for v in array {
|
||||
let v = v.unwrap();
|
||||
let f64array = v.as_any().downcast_ref::<Float64Array>().unwrap();
|
||||
for v in f64array {
|
||||
assert_eq!(v, Some(1.0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_search() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
|
||||
Reference in New Issue
Block a user