diff --git a/Cargo.lock b/Cargo.lock index 59dce54a..be084e10 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1645,9 +1645,9 @@ dependencies = [ [[package]] name = "lance" -version = "0.4.6" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0be48532dda07b0cd5a5b087710bbbba6972a882f7c771a316940a4868d66a6" +checksum = "fc96cf89139af6f439a0e28ccd04ddf81be795b79fda3105b7a8952fadeb778e" dependencies = [ "accelerate-src", "arrow", @@ -3371,7 +3371,9 @@ version = "0.1.0" dependencies = [ "arrow-array", "arrow-ipc", + "arrow-schema", "futures", + "lance", "neon", "once_cell", "tokio", diff --git a/node/README.md b/node/README.md index 02ff3449..5e0a4838 100644 --- a/node/README.md +++ b/node/README.md @@ -1,13 +1,26 @@ +# LanceDB A JavaScript / Node.js library for [LanceDB](https://github.com/lancedb/lancedb). -## Quick Start +## Installation ```bash -npm i vectordb +npm install vectordb ``` -See the examples folder for usage. +## Usage + +### Basic Example + +```javascript +const lancedb = require('vectordb'); +const db = lancedb.connect(''); +const table = await db.openTable('my_table'); +const query = await table.search([0.1, 0.3]).setLimit(20).execute(); +console.log(results); +``` + +The [examples](./examples) folder contains complete examples. ## Development diff --git a/node/examples/simple/index.js b/node/examples/simple/index.js index 6051ada3..d323ef8d 100644 --- a/node/examples/simple/index.js +++ b/node/examples/simple/index.js @@ -16,12 +16,17 @@ async function example() { const lancedb = require('vectordb'); - const db = lancedb.connect('../../sample-lancedb'); + const db = await lancedb.connect('data/sample-lancedb'); - console.log(db.tableNames()); + const data = [ + { id: 1, vector: [0.1, 0.2], price: 10 }, + { id: 2, vector: [1.1, 1.2], price: 50 } + ] - const tbl = await db.openTable('my_table'); - const query = tbl.search([0.1, 0.3]); + const table = await db.createTable('vectors', data) + console.log(await db.tableNames()); + + const query = table.search([0.1, 0.3]); query.limit = 20; const results = await query.execute(); console.log(results); diff --git a/node/examples/simple/package.json b/node/examples/simple/package.json index 422f7b1d..d59dc0a6 100644 --- a/node/examples/simple/package.json +++ b/node/examples/simple/package.json @@ -9,6 +9,6 @@ "author": "", "license": "Apache-2.0", "dependencies": { - "vectordb": "^0.0.6" + "vectordb": "^0.1.0" } } diff --git a/node/package-lock.json b/node/package-lock.json index 64a3207f..fc9c0bbe 100644 --- a/node/package-lock.json +++ b/node/package-lock.json @@ -1,12 +1,12 @@ { "name": "vectordb", - "version": "0.0.6", + "version": "0.1.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "vectordb", - "version": "0.0.6", + "version": "0.1.0", "license": "Apache-2.0", "dependencies": { "@apache-arrow/ts": "^12.0.0", @@ -16,6 +16,7 @@ "@types/chai": "^4.3.4", "@types/mocha": "^10.0.1", "@types/node": "^18.16.2", + "@types/temp": "^0.9.1", "@typescript-eslint/eslint-plugin": "^5.59.1", "cargo-cp-artifact": "^0.1", "chai": "^4.3.7", @@ -25,6 +26,7 @@ "eslint-plugin-n": "^15.7.0", "eslint-plugin-promise": "^6.1.1", "mocha": "^10.2.0", + "temp": "^0.9.4", "ts-node": "^10.9.1", "ts-node-dev": "^2.0.0", "typescript": "*" @@ -317,6 +319,15 @@ "integrity": "sha512-7NQmHra/JILCd1QqpSzl8+mJRc8ZHz3uDm8YV1Ks9IhK0epEiTw8aIErbvH9PI+6XbqhyIQy3462nEsn7UVzjQ==", "dev": true }, + "node_modules/@types/temp": { + "version": "0.9.1", + "resolved": "https://registry.npmjs.org/@types/temp/-/temp-0.9.1.tgz", + "integrity": "sha512-yDQ8Y+oQi9V7VkexwE6NBSVyNuyNFeGI275yWXASc2DjmxNicMi9O50KxDpNlST1kBbV9jKYBHGXhgNYFMPqtA==", + "dev": true, + "dependencies": { + "@types/node": "*" + } + }, "node_modules/@typescript-eslint/eslint-plugin": { "version": "5.59.1", "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.59.1.tgz", @@ -3569,6 +3580,43 @@ "node": ">=8" } }, + "node_modules/temp": { + "version": "0.9.4", + "resolved": "https://registry.npmjs.org/temp/-/temp-0.9.4.tgz", + "integrity": "sha512-yYrrsWnrXMcdsnu/7YMYAofM1ktpL5By7vZhf15CrXijWWrEYZks5AXBudalfSWJLlnen/QUJUB5aoB0kqZUGA==", + "dev": true, + "dependencies": { + "mkdirp": "^0.5.1", + "rimraf": "~2.6.2" + }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/temp/node_modules/mkdirp": { + "version": "0.5.6", + "resolved": "https://registry.npmjs.org/mkdirp/-/mkdirp-0.5.6.tgz", + "integrity": "sha512-FP+p8RB8OWpF3YZBCrP5gtADmtXApB5AMLn+vdyA+PyxCjrCs00mjyUozssO33cwDeT3wNGdLxJ5M//YqtHAJw==", + "dev": true, + "dependencies": { + "minimist": "^1.2.6" + }, + "bin": { + "mkdirp": "bin/cmd.js" + } + }, + "node_modules/temp/node_modules/rimraf": { + "version": "2.6.3", + "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-2.6.3.tgz", + "integrity": "sha512-mwqeW5XsA2qAejG46gYdENaxXjx9onRNCfn7L0duuP4hCuTIi/QO7PDK07KJfp1d+izWPrzEJDcSqBa0OZQriA==", + "dev": true, + "dependencies": { + "glob": "^7.1.3" + }, + "bin": { + "rimraf": "bin.js" + } + }, "node_modules/text-table": { "version": "0.2.0", "resolved": "https://registry.npmjs.org/text-table/-/text-table-0.2.0.tgz", @@ -4256,6 +4304,15 @@ "integrity": "sha512-7NQmHra/JILCd1QqpSzl8+mJRc8ZHz3uDm8YV1Ks9IhK0epEiTw8aIErbvH9PI+6XbqhyIQy3462nEsn7UVzjQ==", "dev": true }, + "@types/temp": { + "version": "0.9.1", + "resolved": "https://registry.npmjs.org/@types/temp/-/temp-0.9.1.tgz", + "integrity": "sha512-yDQ8Y+oQi9V7VkexwE6NBSVyNuyNFeGI275yWXASc2DjmxNicMi9O50KxDpNlST1kBbV9jKYBHGXhgNYFMPqtA==", + "dev": true, + "requires": { + "@types/node": "*" + } + }, "@typescript-eslint/eslint-plugin": { "version": "5.59.1", "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.59.1.tgz", @@ -6552,6 +6609,36 @@ } } }, + "temp": { + "version": "0.9.4", + "resolved": "https://registry.npmjs.org/temp/-/temp-0.9.4.tgz", + "integrity": "sha512-yYrrsWnrXMcdsnu/7YMYAofM1ktpL5By7vZhf15CrXijWWrEYZks5AXBudalfSWJLlnen/QUJUB5aoB0kqZUGA==", + "dev": true, + "requires": { + "mkdirp": "^0.5.1", + "rimraf": "~2.6.2" + }, + "dependencies": { + "mkdirp": { + "version": "0.5.6", + "resolved": "https://registry.npmjs.org/mkdirp/-/mkdirp-0.5.6.tgz", + "integrity": "sha512-FP+p8RB8OWpF3YZBCrP5gtADmtXApB5AMLn+vdyA+PyxCjrCs00mjyUozssO33cwDeT3wNGdLxJ5M//YqtHAJw==", + "dev": true, + "requires": { + "minimist": "^1.2.6" + } + }, + "rimraf": { + "version": "2.6.3", + "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-2.6.3.tgz", + "integrity": "sha512-mwqeW5XsA2qAejG46gYdENaxXjx9onRNCfn7L0duuP4hCuTIi/QO7PDK07KJfp1d+izWPrzEJDcSqBa0OZQriA==", + "dev": true, + "requires": { + "glob": "^7.1.3" + } + } + } + }, "text-table": { "version": "0.2.0", "resolved": "https://registry.npmjs.org/text-table/-/text-table-0.2.0.tgz", diff --git a/node/package.json b/node/package.json index bf542ceb..e63644d0 100644 --- a/node/package.json +++ b/node/package.json @@ -1,6 +1,6 @@ { "name": "vectordb", - "version": "0.0.6", + "version": "0.1.0", "description": " Serverless, low-latency vector database for AI applications", "main": "dist/index.js", "types": "dist/index.d.ts", @@ -27,6 +27,7 @@ "@types/chai": "^4.3.4", "@types/mocha": "^10.0.1", "@types/node": "^18.16.2", + "@types/temp": "^0.9.1", "@typescript-eslint/eslint-plugin": "^5.59.1", "cargo-cp-artifact": "^0.1", "chai": "^4.3.7", @@ -36,6 +37,7 @@ "eslint-plugin-n": "^15.7.0", "eslint-plugin-promise": "^6.1.1", "mocha": "^10.2.0", + "temp": "^0.9.4", "ts-node": "^10.9.1", "ts-node-dev": "^2.0.0", "typescript": "*" diff --git a/node/src/index.ts b/node/src/index.ts index b1644cb7..c1a172b0 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -12,16 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -import { tableFromIPC, Vector } from 'apache-arrow' +import { + Field, + Float32, + List, + makeBuilder, + RecordBatchFileWriter, + Table as ArrowTable, + tableFromIPC, + Vector, + vectorFromArray +} from 'apache-arrow' // eslint-disable-next-line @typescript-eslint/no-var-requires -const { databaseNew, databaseTableNames, databaseOpenTable, tableSearch } = require('../index.node') +const { databaseNew, databaseTableNames, databaseOpenTable, tableCreate, tableSearch } = require('../index.node') /** * Connect to a LanceDB instance at the given URI * @param uri The uri of the database. */ -export function connect (uri: string): Connection { +export async function connect (uri: string): Promise { return new Connection(uri) } @@ -44,7 +54,7 @@ export class Connection { /** * Get the names of all tables in the database. */ - tableNames (): string[] { + async tableNames (): Promise { return databaseTableNames.call(this._db) } @@ -56,6 +66,50 @@ export class Connection { const tbl = await databaseOpenTable.call(this._db, name) return new Table(tbl, name) } + + async createTable (name: string, data: Array>): Promise { + if (data.length === 0) { + throw new Error('At least one record needs to be provided') + } + + const columns = Object.keys(data[0]) + const records: Record = {} + + for (const columnsKey of columns) { + if (columnsKey === 'vector') { + const children = new Field('item', new Float32()) + const list = new List(children) + const listBuilder = makeBuilder({ + type: list + }) + const vectorSize = (data[0].vector as any[]).length + for (const datum of data) { + if ((datum[columnsKey] as any[]).length !== vectorSize) { + throw new Error(`Invalid vector size, expected ${vectorSize}`) + } + + listBuilder.append(datum[columnsKey]) + } + records[columnsKey] = listBuilder.finish().toVector() + } else { + const values = [] + for (const datum of data) { + values.push(datum[columnsKey]) + } + records[columnsKey] = vectorFromArray(values) + } + } + + const table = new ArrowTable(records) + await this.createTableArrow(name, table) + return await this.openTable(name) + } + + async createTableArrow (name: string, table: ArrowTable): Promise
{ + const writer = RecordBatchFileWriter.writeAll(table) + await tableCreate.call(this._db, name, Buffer.from(await writer.toUint8Array())) + return await this.openTable(name) + } } /** @@ -93,7 +147,7 @@ export class Query { private readonly _refine_factor?: number private readonly _nprobes: number private readonly _columns?: string[] - private readonly _where?: string + private _filter?: string private readonly _metric = 'L2' constructor (tbl: any, queryVector: number[]) { @@ -103,22 +157,29 @@ export class Query { this._nprobes = 20 this._refine_factor = undefined this._columns = undefined - this._where = undefined + this._filter = undefined } - set limit (value: number) { + limit (value: number): Query { this._limit = value + return this } - get limit (): number { - return this._limit + filter (value: string): Query { + this._filter = value + return this } /** * Execute the query and return the results as an Array of Objects */ - async execute (): Promise { - const buffer = await tableSearch.call(this._tbl, this._query_vector, this._limit) + async execute> (): Promise { + let buffer; + if (this._filter != null) { + buffer = await tableSearch.call(this._tbl, this._query_vector, this._limit, this._filter) + } else { + buffer = await tableSearch.call(this._tbl, this._query_vector, this._limit) + } const data = tableFromIPC(buffer) return data.toArray().map((entry: Record) => { const newObject: Record = {} @@ -129,14 +190,7 @@ export class Query { newObject[key] = entry[key] } }) - return newObject + return newObject as unknown as T }) } - - /** - * Execute the query and return the results as an Array of the generic type provided - */ - async execute_cast(): Promise { - return await this.execute() as T[] - } } diff --git a/node/src/test/test.ts b/node/src/test/test.ts index 80153aa6..c7e61e5d 100644 --- a/node/src/test/test.ts +++ b/node/src/test/test.ts @@ -14,67 +14,94 @@ import { describe } from 'mocha' import { assert } from 'chai' +import { track } from 'temp' import * as lancedb from '../index' describe('LanceDB client', function () { - describe('open a connection to lancedb', function () { - const con = lancedb.connect('.../../sample-lancedb') - - it('should have a valid url', function () { - assert.equal(con.uri, '.../../sample-lancedb') + describe('when creating a connection to lancedb', function () { + it('should have a valid url', async function () { + const uri = await createTestDB() + const con = await lancedb.connect(uri) + assert.equal(con.uri, uri) }) - it('should return the existing table names', function () { - assert.deepEqual(con.tableNames(), ['my_table']) + it('should return the existing table names', async function () { + const uri = await createTestDB() + const con = await lancedb.connect(uri) + assert.deepEqual(await con.tableNames(), ['vectors']) + }) + }) + + describe('when querying an existing dataset', function () { + it('should open a table', async function () { + const uri = await createTestDB() + const con = await lancedb.connect(uri) + const table = await con.openTable('vectors') + assert.equal(table.name, 'vectors') }) - describe('open a table from a connection', function () { - const tablePromise = con.openTable('my_table') + it('execute a query', async function () { + const uri = await createTestDB() + const con = await lancedb.connect(uri) + const table = await con.openTable('vectors') + const results = await table.search([0.1, 0.3]).execute() - it('should have a valid name', async function () { - const table = await tablePromise - assert.equal(table.name, 'my_table') - }) + assert.equal(results.length, 2) + assert.equal(results[0].price, 10) + const vector = results[0].vector as Float32Array + assert.approximately(vector[0], 0.0, 0.2) + assert.approximately(vector[0], 0.1, 0.3) + }) - class MyResult { - vector: Float32Array = new Float32Array(0) - price: number = 0 - item: string = '' - } + it('limits # of results', async function () { + const uri = await createTestDB() + const con = await lancedb.connect(uri) + const table = await con.openTable('vectors') + const results = await table.search([0.1, 0.3]).limit(1).execute() + assert.equal(results.length, 1) + assert.equal(results[0].id, 1) + }) - it('execute a query', async function () { - const table = await tablePromise - const builder = table.search([0.1, 0.3]) - const results = await builder.execute() as MyResult[] + it('uses a filter', async function () { + const uri = await createTestDB() + const con = await lancedb.connect(uri) + const table = await con.openTable('vectors') + const results = await table.search([0.1, 0.3]).filter('id == 2').execute() + assert.equal(results.length, 1) + assert.equal(results[0].id, 2) + }) + }) - assert.equal(results.length, 2) - assert.equal(results[0].item, 'foo') - assert.equal(results[0].price, 10) - assert.approximately(results[0].vector[0], 3.1, 0.1) - assert.approximately(results[0].vector[1], 4.1, 0.1) - }) + describe('when creating a new dataset', function () { + it('creates a new table from javascript objects', async function () { + const dir = await track().mkdir('lancejs') + const con = await lancedb.connect(dir) - it('execute a query and type cast the result', async function () { - const table = await tablePromise + const data = [ + { id: 1, vector: [0.1, 0.2], price: 10 }, + { id: 2, vector: [1.1, 1.2], price: 50 } + ] - const builder = table.search([0.1, 0.3]) - const results = await builder.execute_cast() - assert.equal(results.length, 2) - assert.equal(results[0].item, 'foo') - assert.equal(results[0].price, 10) - assert.approximately(results[0].vector[0], 3.1, 0.1) - assert.approximately(results[0].vector[1], 4.1, 0.1) - }) + const tableName = `vectors_${Math.floor(Math.random() * 100)}` + const table = await con.createTable(tableName, data) + assert.equal(table.name, tableName) - it('limits # of results', async function () { - const table = await tablePromise - const builder = table.search([0.1, 0.3]) - builder.limit = 1 - const results = await builder.execute() as MyResult[] - - assert.equal(results.length, 1) - }) + const results = await table.search([0.1, 0.3]).execute() + assert.equal(results.length, 2) }) }) }) + +async function createTestDB (): Promise { + const dir = await track().mkdir('lancejs') + const con = await lancedb.connect(dir) + + const data = [ + { id: 1, vector: [0.1, 0.2], name: 'foo', price: 10, is_active: true }, + { id: 2, vector: [1.1, 1.2], name: 'bar', price: 50, is_active: false } + ] + + await con.createTable('vectors', data) + return dir +} diff --git a/rust/ffi/node/Cargo.toml b/rust/ffi/node/Cargo.toml index cef23ae2..72e326f5 100644 --- a/rust/ffi/node/Cargo.toml +++ b/rust/ffi/node/Cargo.toml @@ -12,8 +12,10 @@ crate-type = ["cdylib"] [dependencies] arrow-array = "37.0" arrow-ipc = "37.0" +arrow-schema = "37.0" once_cell = "1" futures = "0.3" +lance = "0.4.3" vectordb = { path = "../../vectordb" } tokio = { version = "1.23", features = ["rt-multi-thread"] } neon = {version = "0.10.1", default-features = false, features = ["channel-api", "napi-6", "promise-api", "task-api"] } diff --git a/rust/ffi/node/src/arrow.rs b/rust/ffi/node/src/arrow.rs new file mode 100644 index 00000000..599a354f --- /dev/null +++ b/rust/ffi/node/src/arrow.rs @@ -0,0 +1,47 @@ +// Copyright 2023 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::ops::Deref; +use std::sync::Arc; + +use arrow_array::cast::as_list_array; +use arrow_array::{Array, FixedSizeListArray, RecordBatch}; +use arrow_schema::{DataType, Field, Schema}; +use lance::arrow::{FixedSizeListArrayExt, RecordBatchExt}; + +pub(crate) fn convert_record_batch(record_batch: RecordBatch) -> RecordBatch { + let column = record_batch + .column_by_name("vector") + .expect("vector column is missing"); + let arr = as_list_array(column.deref()); + let list_size = arr.values().len() / record_batch.num_rows(); + let r = FixedSizeListArray::try_new(arr.values(), list_size as i32).unwrap(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "vector", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + list_size as i32, + ), + true, + )])); + + let mut new_batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(r)]).unwrap(); + + if record_batch.num_columns() > 1 { + let rb = record_batch.drop_column("vector").unwrap(); + new_batch = new_batch.merge(&rb).unwrap(); + } + new_batch +} diff --git a/rust/ffi/node/src/lib.rs b/rust/ffi/node/src/lib.rs index 1c2a8103..2fbeeb1f 100644 --- a/rust/ffi/node/src/lib.rs +++ b/rust/ffi/node/src/lib.rs @@ -12,21 +12,29 @@ // See the License for the specific language governing permissions and // limitations under the License. -mod convert; - +use std::io::Cursor; use std::ops::Deref; use std::sync::Arc; -use arrow_array::Float32Array; +use arrow_array::{Float32Array, RecordBatch, RecordBatchReader}; +use arrow_ipc::reader::FileReader; use arrow_ipc::writer::FileWriter; -use futures::TryStreamExt; +use futures::{TryFutureExt, TryStreamExt}; +use lance::arrow::RecordBatchBuffer; use neon::prelude::*; +use neon::types::buffer::TypedArray; use once_cell::sync::OnceCell; use tokio::runtime::Runtime; use vectordb::database::Database; +use vectordb::error::Error; use vectordb::table::Table; +use crate::arrow::convert_record_batch; + +mod arrow; +mod convert; + struct JsDatabase { database: Arc, } @@ -90,6 +98,7 @@ fn table_search(mut cx: FunctionContext) -> JsResult { let js_table = cx.this().downcast_or_throw::, _>(&mut cx)?; let query_vector = cx.argument::(0)?; //. .as_value(&mut cx); let limit = cx.argument::(1)?.value(&mut cx); + let filter = cx.argument_opt(2).map(|f| f.downcast_or_throw::(&mut cx).unwrap().value(&mut cx)); let rt = runtime(&mut cx)?; let channel = cx.channel(); @@ -101,12 +110,11 @@ fn table_search(mut cx: FunctionContext) -> JsResult { rt.spawn(async move { let builder = table .search(Float32Array::from(query)) - .limit(limit as usize); - let results = builder - .execute() - .await - .unwrap() // FIXME unwrap - .try_collect::>() + .limit(limit as usize) + .filter(filter); + let record_batch_stream = builder.execute(); + let results = record_batch_stream + .and_then(|stream| stream.try_collect::>().map_err(Error::from)) .await; deferred.settle_with(&channel, move |mut cx| { @@ -135,11 +143,46 @@ fn table_search(mut cx: FunctionContext) -> JsResult { Ok(promise) } +fn table_create(mut cx: FunctionContext) -> JsResult { + let db = cx + .this() + .downcast_or_throw::, _>(&mut cx)?; + let table_name = cx.argument::(0)?.value(&mut cx); + let buffer = cx.argument::(1)?; + let slice = buffer.as_slice(&mut cx); + + let mut batches: Vec = Vec::new(); + let fr = FileReader::try_new(Cursor::new(slice), None); + let file_reader = fr.unwrap(); + for b in file_reader { + let record_batch = convert_record_batch(b.unwrap()); + batches.push(record_batch); + } + + let rt = runtime(&mut cx)?; + let channel = cx.channel(); + + let (deferred, promise) = cx.promise(); + let database = db.database.clone(); + + rt.block_on(async move { + let batch_reader: Box = Box::new(RecordBatchBuffer::new(batches)); + let table_rst = database.create_table(table_name, batch_reader).await; + + deferred.settle_with(&channel, move |mut cx| { + let table = Arc::new(table_rst.or_else(|err| cx.throw_error(err.to_string()))?); + Ok(cx.boxed(JsTable { table })) + }); + }); + Ok(promise) +} + #[neon::main] fn main(mut cx: ModuleContext) -> NeonResult<()> { cx.export_function("databaseNew", database_new)?; cx.export_function("databaseTableNames", database_table_names)?; cx.export_function("databaseOpenTable", database_open_table)?; cx.export_function("tableSearch", table_search)?; + cx.export_function("tableCreate", table_create)?; Ok(()) } diff --git a/rust/vectordb/src/database.rs b/rust/vectordb/src/database.rs index 854e10c6..4dd238fa 100644 --- a/rust/vectordb/src/database.rs +++ b/rust/vectordb/src/database.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use arrow_array::RecordBatchReader; use std::fs::create_dir_all; use std::path::{Path, PathBuf}; use std::sync::Arc; @@ -73,6 +74,14 @@ impl Database { Ok(f) } + pub async fn create_table( + &self, + name: String, + batches: Box, + ) -> Result
{ + Table::create(self.path.clone(), name, batches).await + } + /// Open a table in the database. /// /// # Arguments diff --git a/rust/vectordb/src/query.rs b/rust/vectordb/src/query.rs index 655fe03b..fcbda05f 100644 --- a/rust/vectordb/src/query.rs +++ b/rust/vectordb/src/query.rs @@ -26,6 +26,7 @@ pub struct Query { pub dataset: Arc, pub query_vector: Float32Array, pub limit: usize, + pub filter: Option, pub nprobes: usize, pub refine_factor: Option, pub metric_type: MetricType, @@ -52,6 +53,7 @@ impl Query { refine_factor: None, metric_type: MetricType::L2, use_index: false, + filter: None } } @@ -71,6 +73,7 @@ impl Query { scanner.nprobs(self.nprobes); scanner.distance_metric(self.metric_type); scanner.use_index(self.use_index); + self.filter.as_ref().map(|f| scanner.filter(f)); self.refine_factor.map(|rf| scanner.refine(rf)); Ok(scanner.try_into_stream().await?) } @@ -134,6 +137,11 @@ impl Query { self.use_index = use_index; self } + + pub fn filter(mut self, filter: Option) -> Query { + self.filter = filter; + self + } } #[cfg(test)] diff --git a/rust/vectordb/src/table.rs b/rust/vectordb/src/table.rs index f8fd4ae0..a8a9d3e6 100644 --- a/rust/vectordb/src/table.rs +++ b/rust/vectordb/src/table.rs @@ -15,8 +15,8 @@ use std::path::PathBuf; use std::sync::Arc; -use arrow_array::Float32Array; -use lance::dataset::Dataset; +use arrow_array::{Float32Array, RecordBatchReader}; +use lance::dataset::{Dataset, WriteParams}; use crate::error::{Error, Result}; use crate::query::Query; @@ -55,6 +55,21 @@ impl Table { Ok(table) } + pub async fn create( + base_path: Arc, + name: String, + mut batches: Box, + ) -> Result { + let ds_path = base_path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION)); + let ds_uri = ds_path + .to_str() + .ok_or(Error::IO(format!("Unable to find table {}", name)))?; + + let dataset = + Arc::new(Dataset::write(&mut batches, ds_uri, Some(WriteParams::default())).await?); + Ok(Table { name, dataset }) + } + /// Creates a new Query object that can be executed. /// /// # Arguments