From 306ada5cb8b70c023f5329b5364bd3a3cbc5373a Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Tue, 30 May 2023 21:32:17 -0700 Subject: [PATCH] Support S3 and GCS from typescript SDK (#106) --- Cargo.lock | 7 +- node/src/index.ts | 7 +- node/src/test/io.ts | 52 ++++++++++++ rust/ffi/node/Cargo.toml | 2 +- rust/ffi/node/src/index/vector.rs | 2 +- rust/ffi/node/src/lib.rs | 49 +++++++++--- rust/vectordb/Cargo.toml | 4 +- rust/vectordb/src/database.rs | 59 +++++++++----- rust/vectordb/src/error.rs | 12 +++ rust/vectordb/src/table.rs | 128 +++++++++++++++--------------- 10 files changed, 216 insertions(+), 106 deletions(-) create mode 100644 node/src/test/io.ts diff --git a/Cargo.lock b/Cargo.lock index 9e198522..3e4ed902 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1052,6 +1052,7 @@ dependencies = [ "paste", "petgraph", "rand", + "regex", "uuid", ] @@ -1645,9 +1646,9 @@ dependencies = [ [[package]] name = "lance" -version = "0.4.12" +version = "0.4.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc96cf89139af6f439a0e28ccd04ddf81be795b79fda3105b7a8952fadeb778e" +checksum = "86dda8185bd1ffae7b910c1f68035af23be9b717c52e9cc4de176cd30b47f772" dependencies = [ "accelerate-src", "arrow", @@ -1684,6 +1685,7 @@ dependencies = [ "rand", "reqwest", "shellexpand", + "snafu", "sqlparser-lance", "tokio", "url", @@ -3362,6 +3364,7 @@ dependencies = [ "arrow-data", "arrow-schema", "lance", + "object_store", "rand", "tempfile", "tokio", diff --git a/node/src/index.ts b/node/src/index.ts index 3462ee88..32889e4f 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -28,7 +28,8 @@ const { databaseNew, databaseTableNames, databaseOpenTable, tableCreate, tableSe * @param uri The uri of the database. */ export async function connect (uri: string): Promise { - return new Connection(uri) + const db = await databaseNew(uri) + return new Connection(db, uri) } /** @@ -38,9 +39,9 @@ export class Connection { private readonly _uri: string private readonly _db: any - constructor (uri: string) { + constructor (db: any, uri: string) { this._uri = uri - this._db = databaseNew(uri) + this._db = db } get uri (): string { diff --git a/node/src/test/io.ts b/node/src/test/io.ts new file mode 100644 index 00000000..fb667ba9 --- /dev/null +++ b/node/src/test/io.ts @@ -0,0 +1,52 @@ +// Copyright 2023 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IO tests + +import { describe } from 'mocha' +import { assert } from 'chai' + +import * as lancedb from '../index' + +describe('LanceDB S3 client', function () { + if (process.env.TEST_S3_BASE_URL != null) { + const baseUri = process.env.TEST_S3_BASE_URL + it('should have a valid url', async function () { + const uri = `${baseUri}/valid_url` + const table = await createTestDB(uri, 2, 20) + const con = await lancedb.connect(uri) + assert.equal(con.uri, uri) + + const results = await table.search([0.1, 0.3]).limit(5).execute() + assert.equal(results.length, 5) + }) + } else { + describe.skip('Skip S3 test', function () {}) + } +}) + +async function createTestDB (uri: string, numDimensions: number = 2, numRows: number = 2): Promise { + const con = await lancedb.connect(uri) + + const data = [] + for (let i = 0; i < numRows; i++) { + const vector = [] + for (let j = 0; j < numDimensions; j++) { + vector.push(i + (j * 0.1)) + } + data.push({ id: i + 1, name: `name_${i}`, price: i + 10, is_active: (i % 2 === 0), vector }) + } + + return await con.createTable('vectors', data) +} diff --git a/rust/ffi/node/Cargo.toml b/rust/ffi/node/Cargo.toml index 72e326f5..e1edae86 100644 --- a/rust/ffi/node/Cargo.toml +++ b/rust/ffi/node/Cargo.toml @@ -15,7 +15,7 @@ arrow-ipc = "37.0" arrow-schema = "37.0" once_cell = "1" futures = "0.3" -lance = "0.4.3" +lance = "0.4.17" vectordb = { path = "../../vectordb" } tokio = { version = "1.23", features = ["rt-multi-thread"] } neon = {version = "0.10.1", default-features = false, features = ["channel-api", "napi-6", "promise-api", "task-api"] } diff --git a/rust/ffi/node/src/index/vector.rs b/rust/ffi/node/src/index/vector.rs index 41dd8dcd..ee55412f 100644 --- a/rust/ffi/node/src/index/vector.rs +++ b/rust/ffi/node/src/index/vector.rs @@ -39,7 +39,7 @@ pub(crate) fn table_create_vector_index(mut cx: FunctionContext) -> JsResult>(cx: &mut C) -> NeonResult<&'static Runtime> { RUNTIME.get_or_try_init(|| Runtime::new().or_else(|err| cx.throw_error(err.to_string()))) } -fn database_new(mut cx: FunctionContext) -> JsResult> { +fn database_new(mut cx: FunctionContext) -> JsResult { let path = cx.argument::(0)?.value(&mut cx); - let db = JsDatabase { - database: Arc::new(Database::connect(path).or_else(|err| cx.throw_error(err.to_string()))?), - }; - Ok(cx.boxed(db)) + + let rt = runtime(&mut cx)?; + let channel = cx.channel(); + let (deferred, promise) = cx.promise(); + + rt.spawn(async move { + let database = Database::connect(&path).await; + + deferred.settle_with(&channel, move |mut cx| { + let db = JsDatabase { + database: Arc::new(database.or_else(|err| cx.throw_error(err.to_string()))?), + }; + Ok(cx.boxed(db)) + }); + }); + Ok(promise) } -fn database_table_names(mut cx: FunctionContext) -> JsResult { +fn database_table_names(mut cx: FunctionContext) -> JsResult { let db = cx .this() .downcast_or_throw::, _>(&mut cx)?; - let tables = db - .database - .table_names() - .or_else(|err| cx.throw_error(err.to_string()))?; - convert::vec_str_to_array(&tables, &mut cx) + + let rt = runtime(&mut cx)?; + let (deferred, promise) = cx.promise(); + let channel = cx.channel(); + let database = db.database.clone(); + + rt.spawn(async move { + let tables_rst = database.table_names().await; + + deferred.settle_with(&channel, move |mut cx| { + let tables = tables_rst.or_else(|err| cx.throw_error(err.to_string()))?; + let table_names = convert::vec_str_to_array(&tables, &mut cx); + table_names + }); + }); + Ok(promise) } fn database_open_table(mut cx: FunctionContext) -> JsResult { @@ -87,7 +110,7 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult { let (deferred, promise) = cx.promise(); rt.spawn(async move { - let table_rst = database.open_table(table_name).await; + let table_rst = database.open_table(&table_name).await; deferred.settle_with(&channel, move |mut cx| { let table = Arc::new(Mutex::new( @@ -186,7 +209,7 @@ fn table_create(mut cx: FunctionContext) -> JsResult { rt.block_on(async move { let batch_reader: Box = Box::new(RecordBatchBuffer::new(batches)); - let table_rst = database.create_table(table_name, batch_reader).await; + let table_rst = database.create_table(&table_name, batch_reader).await; deferred.settle_with(&channel, move |mut cx| { let table = Arc::new(Mutex::new( diff --git a/rust/vectordb/Cargo.toml b/rust/vectordb/Cargo.toml index bc066608..60182093 100644 --- a/rust/vectordb/Cargo.toml +++ b/rust/vectordb/Cargo.toml @@ -12,7 +12,9 @@ repository = "https://github.com/lancedb/lancedb" arrow-array = "37.0" arrow-data = "37.0" arrow-schema = "37.0" -lance = "0.4.3" +object_store = "0.5.6" + +lance = "0.4.17" tokio = { version = "1.23", features = ["rt-multi-thread"] } [dev-dependencies] diff --git a/rust/vectordb/src/database.rs b/rust/vectordb/src/database.rs index de65a991..5d6e2f32 100644 --- a/rust/vectordb/src/database.rs +++ b/rust/vectordb/src/database.rs @@ -12,16 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -use arrow_array::RecordBatchReader; use std::fs::create_dir_all; -use std::path::{Path, PathBuf}; -use std::sync::Arc; +use std::path::Path; + +use arrow_array::RecordBatchReader; +use lance::io::object_store::ObjectStore; use crate::error::Result; use crate::table::Table; pub struct Database { - pub(crate) path: Arc, + object_store: ObjectStore, + + pub(crate) uri: String, } const LANCE_EXTENSION: &str = "lance"; @@ -37,12 +40,17 @@ impl Database { /// # Returns /// /// * A [Database] object. - pub fn connect>(path: P) -> Result { - if !path.as_ref().try_exists()? { - create_dir_all(&path)?; + pub async fn connect(uri: &str) -> Result { + let object_store = ObjectStore::new(uri).await?; + if object_store.is_local() { + let path = Path::new(uri); + if !path.try_exists()? { + create_dir_all(&path)?; + } } Ok(Database { - path: Arc::new(path.as_ref().to_path_buf()), + uri: uri.to_string(), + object_store, }) } @@ -51,12 +59,13 @@ impl Database { /// # Returns /// /// * A [Vec] with all table names. - pub fn table_names(&self) -> Result> { + pub async fn table_names(&self) -> Result> { let f = self - .path - .read_dir()? - .flatten() - .map(|dir_entry| dir_entry.path()) + .object_store + .read_dir("/") + .await? + .iter() + .map(|fname| Path::new(fname)) .filter(|path| { let is_lance = path .extension() @@ -76,10 +85,10 @@ impl Database { pub async fn create_table( &self, - name: String, + name: &str, batches: Box, ) -> Result { - Table::create(self.path.clone(), name, batches).await + Table::create(&self.uri, name, batches).await } /// Open a table in the database. @@ -90,8 +99,8 @@ impl Database { /// # Returns /// /// * A [Table] object. - pub async fn open_table(&self, name: String) -> Result
{ - Table::open(self.path.clone(), name).await + pub async fn open_table(&self, name: &str) -> Result
{ + Table::open(&self.uri, name).await } } @@ -105,10 +114,10 @@ mod tests { #[tokio::test] async fn test_connect() { let tmp_dir = tempdir().unwrap(); - let path_buf = tmp_dir.into_path(); - let db = Database::connect(&path_buf); + let uri = tmp_dir.path().to_str().unwrap(); + let db = Database::connect(uri).await.unwrap(); - assert_eq!(db.unwrap().path.as_path(), path_buf.as_path()) + assert_eq!(db.uri, uri); } #[tokio::test] @@ -118,10 +127,16 @@ mod tests { create_dir_all(tmp_dir.path().join("table2.lance")).unwrap(); create_dir_all(tmp_dir.path().join("invalidlance")).unwrap(); - let db = Database::connect(&tmp_dir.into_path()).unwrap(); - let tables = db.table_names().unwrap(); + let uri = tmp_dir.path().to_str().unwrap(); + let db = Database::connect(uri).await.unwrap(); + let tables = db.table_names().await.unwrap(); assert_eq!(tables.len(), 2); assert!(tables.contains(&String::from("table1"))); assert!(tables.contains(&String::from("table2"))); } + + #[tokio::test] + async fn test_connect_s3() { + // let db = Database::connect("s3://bucket/path/to/database").await.unwrap(); + } } diff --git a/rust/vectordb/src/error.rs b/rust/vectordb/src/error.rs index a3480b5d..3ea71146 100644 --- a/rust/vectordb/src/error.rs +++ b/rust/vectordb/src/error.rs @@ -41,3 +41,15 @@ impl From for Error { Self::Lance(e.to_string()) } } + +impl From for Error { + fn from(e: object_store::Error) -> Self { + Self::IO(e.to_string()) + } +} + +impl From for Error { + fn from(e: object_store::path::Error) -> Self { + Self::IO(e.to_string()) + } +} diff --git a/rust/vectordb/src/table.rs b/rust/vectordb/src/table.rs index d980b804..498a5af2 100644 --- a/rust/vectordb/src/table.rs +++ b/rust/vectordb/src/table.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::path::PathBuf; +use std::path::Path; use std::sync::Arc; use arrow_array::{Float32Array, RecordBatchReader}; @@ -24,16 +24,21 @@ use crate::index::vector::VectorIndexBuilder; use crate::query::Query; pub const VECTOR_COLUMN_NAME: &str = "vector"; - pub const LANCE_FILE_EXTENSION: &str = "lance"; /// A table in a LanceDB database. pub struct Table { name: String, - path: String, + uri: String, dataset: Arc, } +impl std::fmt::Display for Table { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Table({})", self.name) + } +} + impl Table { /// Opens an existing Table /// @@ -45,18 +50,21 @@ impl Table { /// # Returns /// /// * A [Table] object. - pub async fn open(base_path: Arc, name: String) -> Result { - let ds_path = base_path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION)); - let ds_uri = ds_path + pub async fn open(base_uri: &str, name: &str) -> Result { + let path = Path::new(base_uri); + + let table_uri = path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION)); + let uri = table_uri + .as_path() .to_str() - .ok_or(Error::IO(format!("Unable to find table {}", name)))?; - let dataset = Dataset::open(ds_uri).await?; - let table = Table { - name, - path: ds_uri.to_string(), + .ok_or(Error::IO(format!("Invalid table name: {}", name)))?; + + let dataset = Dataset::open(&uri).await?; + Ok(Table { + name: name.to_string(), + uri: uri.to_string(), dataset: Arc::new(dataset), - }; - Ok(table) + }) } /// Creates a new Table @@ -71,25 +79,28 @@ impl Table { /// /// * A [Table] object. pub async fn create( - base_path: Arc, - name: String, + base_uri: &str, + name: &str, mut batches: Box, ) -> Result { - let ds_path = base_path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION)); - let path = ds_path + let base_path = Path::new(base_uri); + let table_uri = base_path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION)); + let uri = table_uri + .as_path() .to_str() - .ok_or(Error::IO(format!("Unable to find table {}", name)))?; - + .ok_or(Error::IO(format!("Invalid table name: {}", name)))? + .to_string(); let dataset = - Arc::new(Dataset::write(&mut batches, path, Some(WriteParams::default())).await?); + Arc::new(Dataset::write(&mut batches, &uri, Some(WriteParams::default())).await?); Ok(Table { - name, - path: path.to_string(), + name: name.to_string(), + uri, dataset, }) } - pub async fn create_idx(&mut self, index_builder: &impl VectorIndexBuilder) -> Result<()> { + /// Create index on the table. + pub async fn create_index(&mut self, index_builder: &impl VectorIndexBuilder) -> Result<()> { use lance::index::DatasetIndexExt; let dataset = self @@ -125,8 +136,7 @@ impl Table { let mut params = WriteParams::default(); params.mode = write_mode.unwrap_or(WriteMode::Append); - self.dataset = - Arc::new(Dataset::write(&mut batches, self.path.as_str(), Some(params)).await?); + self.dataset = Arc::new(Dataset::write(&mut batches, &self.uri, Some(params)).await?); Ok(batches.count()) } @@ -151,6 +161,8 @@ impl Table { #[cfg(test)] mod tests { + use std::sync::Arc; + use arrow_array::{ Array, FixedSizeListArray, Float32Array, Int32Array, RecordBatch, RecordBatchReader, }; @@ -161,53 +173,52 @@ mod tests { use lance::index::vector::ivf::IvfBuildParams; use lance::index::vector::pq::PQBuildParams; use rand::Rng; - use std::sync::Arc; use tempfile::tempdir; - use crate::error::Result; + use super::*; use crate::index::vector::IvfPQIndexBuilder; - use crate::table::Table; #[tokio::test] async fn test_new_table_not_exists() { let tmp_dir = tempdir().unwrap(); - let path_buf = tmp_dir.into_path(); + let uri = tmp_dir.path().to_str().unwrap(); - let table = Table::open(Arc::new(path_buf), "test".to_string()).await; + let table = Table::open(&uri, "test").await; assert!(table.is_err()); } #[tokio::test] async fn test_open() { let tmp_dir = tempdir().unwrap(); - let path_buf = tmp_dir.into_path(); + let dataset_path = tmp_dir.path().join("test.lance"); + let uri = tmp_dir.path().to_str().unwrap(); let mut batches: Box = Box::new(make_test_batches()); - Dataset::write( - &mut batches, - path_buf.join("test.lance").to_str().unwrap(), - None, - ) - .await - .unwrap(); - - let table = Table::open(Arc::new(path_buf), "test".to_string()) + Dataset::write(&mut batches, dataset_path.to_str().unwrap(), None) .await .unwrap(); + let table = Table::open(uri, "test").await.unwrap(); + assert_eq!(table.name, "test") } + #[test] + fn test_object_store_path() { + use std::path::Path as StdPath; + let p = StdPath::new("s3://bucket/path/to/file"); + let c = p.join("subfile"); + assert_eq!(c.to_str().unwrap(), "s3://bucket/path/to/file/subfile"); + } + #[tokio::test] async fn test_add() { let tmp_dir = tempdir().unwrap(); - let path_buf = tmp_dir.into_path(); + let uri = tmp_dir.path().to_str().unwrap(); let batches: Box = Box::new(make_test_batches()); let schema = batches.schema().clone(); - let mut table = Table::create(Arc::new(path_buf), "test".to_string(), batches) - .await - .unwrap(); + let mut table = Table::create(&uri, "test", batches).await.unwrap(); assert_eq!(table.count_rows().await.unwrap(), 10); let new_batches: Box = @@ -225,13 +236,11 @@ mod tests { #[tokio::test] async fn test_add_overwrite() { let tmp_dir = tempdir().unwrap(); - let path_buf = tmp_dir.into_path(); + let uri = tmp_dir.path().to_str().unwrap(); let batches: Box = Box::new(make_test_batches()); let schema = batches.schema().clone(); - let mut table = Table::create(Arc::new(path_buf), "test".to_string(), batches) - .await - .unwrap(); + let mut table = Table::create(uri, "test", batches).await.unwrap(); assert_eq!(table.count_rows().await.unwrap(), 10); let new_batches: Box = @@ -252,21 +261,16 @@ mod tests { #[tokio::test] async fn test_search() { let tmp_dir = tempdir().unwrap(); - let path_buf = tmp_dir.into_path(); + let dataset_path = tmp_dir.path().join("test.lance"); + let uri = tmp_dir.path().to_str().unwrap(); let mut batches: Box = Box::new(make_test_batches()); - Dataset::write( - &mut batches, - path_buf.join("test.lance").to_str().unwrap(), - None, - ) - .await - .unwrap(); - - let table = Table::open(Arc::new(path_buf), "test".to_string()) + Dataset::write(&mut batches, dataset_path.to_str().unwrap(), None) .await .unwrap(); + let table = Table::open(uri, "test").await.unwrap(); + let vector = Float32Array::from_iter_values([0.1, 0.2]); let query = table.search(vector.clone()); assert_eq!(vector, query.query_vector); @@ -291,7 +295,7 @@ mod tests { use arrow_array::Float32Array; let tmp_dir = tempdir().unwrap(); - let path_buf = tmp_dir.into_path(); + let uri = tmp_dir.path().to_str().unwrap(); let dimension = 16; let schema = Arc::new(ArrowSchema::new(vec![Field::new( @@ -318,9 +322,7 @@ mod tests { .unwrap()]); let reader: Box = Box::new(batches); - let mut table = Table::create(Arc::new(path_buf), "test".to_string(), reader) - .await - .unwrap(); + let mut table = Table::create(uri, "test", reader).await.unwrap(); let mut i = IvfPQIndexBuilder::new(); @@ -330,7 +332,7 @@ mod tests { .ivf_params(IvfBuildParams::new(256)) .pq_params(PQBuildParams::default()); - table.create_idx(index_builder).await.unwrap(); + table.create_index(index_builder).await.unwrap(); assert_eq!(table.dataset.load_indices().await.unwrap().len(), 1); assert_eq!(table.count_rows().await.unwrap(), 512);