Compare commits

...

3 Commits

Author SHA1 Message Date
Chang She
2b26775ed1 python v0.1.4 2023-05-31 20:11:25 -07:00
Lei Xu
306ada5cb8 Support S3 and GCS from typescript SDK (#106) 2023-05-30 21:32:17 -07:00
gsilvestrin
d3aa8bfbc5 add embedding functions to the nodejs client (#95) 2023-05-26 18:09:20 -07:00
13 changed files with 359 additions and 137 deletions

7
Cargo.lock generated
View File

@@ -1052,6 +1052,7 @@ dependencies = [
"paste", "paste",
"petgraph", "petgraph",
"rand", "rand",
"regex",
"uuid", "uuid",
] ]
@@ -1645,9 +1646,9 @@ dependencies = [
[[package]] [[package]]
name = "lance" name = "lance"
version = "0.4.12" version = "0.4.17"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc96cf89139af6f439a0e28ccd04ddf81be795b79fda3105b7a8952fadeb778e" checksum = "86dda8185bd1ffae7b910c1f68035af23be9b717c52e9cc4de176cd30b47f772"
dependencies = [ dependencies = [
"accelerate-src", "accelerate-src",
"arrow", "arrow",
@@ -1684,6 +1685,7 @@ dependencies = [
"rand", "rand",
"reqwest", "reqwest",
"shellexpand", "shellexpand",
"snafu",
"sqlparser-lance", "sqlparser-lance",
"tokio", "tokio",
"url", "url",
@@ -3362,6 +3364,7 @@ dependencies = [
"arrow-data", "arrow-data",
"arrow-schema", "arrow-schema",
"lance", "lance",
"object_store",
"rand", "rand",
"tempfile", "tempfile",
"tokio", "tokio",

View File

@@ -15,15 +15,16 @@
import { import {
Field, Field,
Float32, Float32,
List, List, type ListBuilder,
makeBuilder, makeBuilder,
RecordBatchFileWriter, RecordBatchFileWriter,
Table, Utf8, Table, Utf8,
type Vector, type Vector,
vectorFromArray vectorFromArray
} from 'apache-arrow' } from 'apache-arrow'
import { type EmbeddingFunction } from './index'
export function convertToTable (data: Array<Record<string, unknown>>): Table { export function convertToTable<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Table {
if (data.length === 0) { if (data.length === 0) {
throw new Error('At least one record needs to be provided') throw new Error('At least one record needs to be provided')
} }
@@ -33,11 +34,7 @@ export function convertToTable (data: Array<Record<string, unknown>>): Table {
for (const columnsKey of columns) { for (const columnsKey of columns) {
if (columnsKey === 'vector') { if (columnsKey === 'vector') {
const children = new Field<Float32>('item', new Float32()) const listBuilder = newVectorListBuilder()
const list = new List(children)
const listBuilder = makeBuilder({
type: list
})
const vectorSize = (data[0].vector as any[]).length const vectorSize = (data[0].vector as any[]).length
for (const datum of data) { for (const datum of data) {
if ((datum[columnsKey] as any[]).length !== vectorSize) { if ((datum[columnsKey] as any[]).length !== vectorSize) {
@@ -52,6 +49,14 @@ export function convertToTable (data: Array<Record<string, unknown>>): Table {
for (const datum of data) { for (const datum of data) {
values.push(datum[columnsKey]) values.push(datum[columnsKey])
} }
if (columnsKey === embeddings?.sourceColumn) {
const vectors = embeddings.embed(values as T[])
const listBuilder = newVectorListBuilder()
vectors.map(v => listBuilder.append(v))
records.vector = listBuilder.finish().toVector()
}
if (typeof values[0] === 'string') { if (typeof values[0] === 'string') {
// `vectorFromArray` converts strings into dictionary vectors, forcing it back to a string column // `vectorFromArray` converts strings into dictionary vectors, forcing it back to a string column
records[columnsKey] = vectorFromArray(values, new Utf8()) records[columnsKey] = vectorFromArray(values, new Utf8())
@@ -64,8 +69,17 @@ export function convertToTable (data: Array<Record<string, unknown>>): Table {
return new Table(records) return new Table(records)
} }
export async function fromRecordsToBuffer (data: Array<Record<string, unknown>>): Promise<Buffer> { // Creates a new Arrow ListBuilder that stores a Vector column
const table = convertToTable(data) function newVectorListBuilder (): ListBuilder<Float32, any> {
const children = new Field<Float32>('item', new Float32())
const list = new List(children)
return makeBuilder({
type: list
})
}
export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
const table = convertToTable(data, embeddings)
const writer = RecordBatchFileWriter.writeAll(table) const writer = RecordBatchFileWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array()) return Buffer.from(await writer.toUint8Array())
} }

View File

@@ -28,7 +28,8 @@ const { databaseNew, databaseTableNames, databaseOpenTable, tableCreate, tableSe
* @param uri The uri of the database. * @param uri The uri of the database.
*/ */
export async function connect (uri: string): Promise<Connection> { export async function connect (uri: string): Promise<Connection> {
return new Connection(uri) const db = await databaseNew(uri)
return new Connection(db, uri)
} }
/** /**
@@ -38,9 +39,9 @@ export class Connection {
private readonly _uri: string private readonly _uri: string
private readonly _db: any private readonly _db: any
constructor (uri: string) { constructor (db: any, uri: string) {
this._uri = uri this._uri = uri
this._db = databaseNew(uri) this._db = db
} }
get uri (): string { get uri (): string {
@@ -56,16 +57,49 @@ export class Connection {
/** /**
* Open a table in the database. * Open a table in the database.
*
* @param name The name of the table. * @param name The name of the table.
*/ */
async openTable (name: string): Promise<Table> { async openTable (name: string): Promise<Table>
/**
* Open a table in the database.
*
* @param name The name of the table.
* @param embeddings An embedding function to use on this Table
*/
async openTable<T> (name: string, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
const tbl = await databaseOpenTable.call(this._db, name) const tbl = await databaseOpenTable.call(this._db, name)
if (embeddings !== undefined) {
return new Table(tbl, name, embeddings)
} else {
return new Table(tbl, name) return new Table(tbl, name)
} }
}
async createTable (name: string, data: Array<Record<string, unknown>>): Promise<Table> { /**
await tableCreate.call(this._db, name, await fromRecordsToBuffer(data)) * Creates a new Table and initialize it with new data.
return await this.openTable(name) *
* @param name The name of the table.
* @param data Non-empty Array of Records to be inserted into the Table
*/
async createTable (name: string, data: Array<Record<string, unknown>>): Promise<Table>
/**
* Creates a new Table and initialize it with new data.
*
* @param name The name of the table.
* @param data Non-empty Array of Records to be inserted into the Table
* @param embeddings An embedding function to use on this Table
*/
async createTable<T> (name: string, data: Array<Record<string, unknown>>, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
async createTable<T> (name: string, data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
const tbl = await tableCreate.call(this._db, name, await fromRecordsToBuffer(data, embeddings))
if (embeddings !== undefined) {
return new Table(tbl, name, embeddings)
} else {
return new Table(tbl, name)
}
} }
async createTableArrow (name: string, table: ArrowTable): Promise<Table> { async createTableArrow (name: string, table: ArrowTable): Promise<Table> {
@@ -75,16 +109,22 @@ export class Connection {
} }
} }
/** export class Table<T = number[]> {
* A table in a LanceDB database.
*/
export class Table {
private readonly _tbl: any private readonly _tbl: any
private readonly _name: string private readonly _name: string
private readonly _embeddings?: EmbeddingFunction<T>
constructor (tbl: any, name: string) { constructor (tbl: any, name: string)
/**
* @param tbl
* @param name
* @param embeddings An embedding function to use when interacting with this table
*/
constructor (tbl: any, name: string, embeddings: EmbeddingFunction<T>)
constructor (tbl: any, name: string, embeddings?: EmbeddingFunction<T>) {
this._tbl = tbl this._tbl = tbl
this._name = name this._name = name
this._embeddings = embeddings
} }
get name (): string { get name (): string {
@@ -92,10 +132,16 @@ export class Table {
} }
/** /**
* Create a search query to find the nearest neighbors of the given query vector. * Creates a search query to find the nearest neighbors of the given search term
* @param queryVector The query vector. * @param query The query search term
*/ */
search (queryVector: number[]): Query { search (query: T): Query {
let queryVector: number[]
if (this._embeddings !== undefined) {
queryVector = this._embeddings.embed([query])[0]
} else {
queryVector = query as number[]
}
return new Query(this._tbl, queryVector) return new Query(this._tbl, queryVector)
} }
@@ -106,7 +152,7 @@ export class Table {
* @return The number of rows added to the table * @return The number of rows added to the table
*/ */
async add (data: Array<Record<string, unknown>>): Promise<number> { async add (data: Array<Record<string, unknown>>): Promise<number> {
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data), WriteMode.Append.toString()) return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Append.toString())
} }
/** /**
@@ -116,9 +162,14 @@ export class Table {
* @return The number of rows added to the table * @return The number of rows added to the table
*/ */
async overwrite (data: Array<Record<string, unknown>>): Promise<number> { async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data), WriteMode.Overwrite.toString()) return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString())
} }
/**
* Create an ANN index on this Table vector index.
*
* @param indexParams The parameters of this Index, @see VectorIndexParams.
*/
async create_index (indexParams: VectorIndexParams): Promise<any> { async create_index (indexParams: VectorIndexParams): Promise<any> {
return tableCreateVectorIndex.call(this._tbl, indexParams) return tableCreateVectorIndex.call(this._tbl, indexParams)
} }
@@ -268,6 +319,21 @@ export enum WriteMode {
Append = 'append' Append = 'append'
} }
/**
* An embedding function that automatically creates vector representation for a given column.
*/
export interface EmbeddingFunction<T> {
/**
* The name of the column that will be used as input for the Embedding Function.
*/
sourceColumn: string
/**
* Creates a vector representation for the given values.
*/
embed: (data: T[]) => number[][]
}
/** /**
* Distance metrics type. * Distance metrics type.
*/ */

52
node/src/test/io.ts Normal file
View File

@@ -0,0 +1,52 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// IO tests
import { describe } from 'mocha'
import { assert } from 'chai'
import * as lancedb from '../index'
describe('LanceDB S3 client', function () {
if (process.env.TEST_S3_BASE_URL != null) {
const baseUri = process.env.TEST_S3_BASE_URL
it('should have a valid url', async function () {
const uri = `${baseUri}/valid_url`
const table = await createTestDB(uri, 2, 20)
const con = await lancedb.connect(uri)
assert.equal(con.uri, uri)
const results = await table.search([0.1, 0.3]).limit(5).execute()
assert.equal(results.length, 5)
})
} else {
describe.skip('Skip S3 test', function () {})
}
})
async function createTestDB (uri: string, numDimensions: number = 2, numRows: number = 2): Promise<lancedb.Table> {
const con = await lancedb.connect(uri)
const data = []
for (let i = 0; i < numRows; i++) {
const vector = []
for (let j = 0; j < numDimensions; j++) {
vector.push(i + (j * 0.1))
}
data.push({ id: i + 1, name: `name_${i}`, price: i + 10, is_active: (i % 2 === 0), vector })
}
return await con.createTable('vectors', data)
}

View File

@@ -17,7 +17,7 @@ import { assert } from 'chai'
import { track } from 'temp' import { track } from 'temp'
import * as lancedb from '../index' import * as lancedb from '../index'
import { MetricType, Query } from '../index' import { type EmbeddingFunction, MetricType, Query } from '../index'
describe('LanceDB client', function () { describe('LanceDB client', function () {
describe('when creating a connection to lancedb', function () { describe('when creating a connection to lancedb', function () {
@@ -140,6 +140,39 @@ describe('LanceDB client', function () {
await table.create_index({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2 }) await table.create_index({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2 })
}).timeout(10_000) // Timeout is high partially because GH macos runner is pretty slow }).timeout(10_000) // Timeout is high partially because GH macos runner is pretty slow
}) })
describe('when using a custom embedding function', function () {
class TextEmbedding implements EmbeddingFunction<string> {
sourceColumn: string
constructor (targetColumn: string) {
this.sourceColumn = targetColumn
}
_embedding_map = new Map<string, number[]>([
['foo', [2.1, 2.2]],
['bar', [3.1, 3.2]]
])
embed (data: string[]): number[][] {
return data.map(datum => this._embedding_map.get(datum) ?? [0.0, 0.0])
}
}
it('should encode the original data into embeddings', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const embeddings = new TextEmbedding('name')
const data = [
{ price: 10, name: 'foo' },
{ price: 50, name: 'bar' }
]
const table = await con.createTable('vectors', data, embeddings)
const results = await table.search('foo').execute()
assert.equal(results.length, 2)
})
})
}) })
describe('Query object', function () { describe('Query object', function () {

View File

@@ -1,7 +1,7 @@
[project] [project]
name = "lancedb" name = "lancedb"
version = "0.1.3" version = "0.1.4"
dependencies = ["pylance>=0.4.15", "ratelimiter", "retry", "tqdm"] dependencies = ["pylance>=0.4.17", "ratelimiter", "retry", "tqdm"]
description = "lancedb" description = "lancedb"
authors = [ authors = [
{ name = "LanceDB Devs", email = "dev@lancedb.com" }, { name = "LanceDB Devs", email = "dev@lancedb.com" },

View File

@@ -15,7 +15,7 @@ arrow-ipc = "37.0"
arrow-schema = "37.0" arrow-schema = "37.0"
once_cell = "1" once_cell = "1"
futures = "0.3" futures = "0.3"
lance = "0.4.3" lance = "0.4.17"
vectordb = { path = "../../vectordb" } vectordb = { path = "../../vectordb" }
tokio = { version = "1.23", features = ["rt-multi-thread"] } tokio = { version = "1.23", features = ["rt-multi-thread"] }
neon = {version = "0.10.1", default-features = false, features = ["channel-api", "napi-6", "promise-api", "task-api"] } neon = {version = "0.10.1", default-features = false, features = ["channel-api", "napi-6", "promise-api", "task-api"] }

View File

@@ -39,7 +39,7 @@ pub(crate) fn table_create_vector_index(mut cx: FunctionContext) -> JsResult<JsP
let add_result = table let add_result = table
.lock() .lock()
.unwrap() .unwrap()
.create_idx(&index_params_builder) .create_index(&index_params_builder)
.await; .await;
deferred.settle_with(&channel, move |mut cx| { deferred.settle_with(&channel, move |mut cx| {

View File

@@ -56,23 +56,46 @@ fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> {
RUNTIME.get_or_try_init(|| Runtime::new().or_else(|err| cx.throw_error(err.to_string()))) RUNTIME.get_or_try_init(|| Runtime::new().or_else(|err| cx.throw_error(err.to_string())))
} }
fn database_new(mut cx: FunctionContext) -> JsResult<JsBox<JsDatabase>> { fn database_new(mut cx: FunctionContext) -> JsResult<JsPromise> {
let path = cx.argument::<JsString>(0)?.value(&mut cx); let path = cx.argument::<JsString>(0)?.value(&mut cx);
let rt = runtime(&mut cx)?;
let channel = cx.channel();
let (deferred, promise) = cx.promise();
rt.spawn(async move {
let database = Database::connect(&path).await;
deferred.settle_with(&channel, move |mut cx| {
let db = JsDatabase { let db = JsDatabase {
database: Arc::new(Database::connect(path).or_else(|err| cx.throw_error(err.to_string()))?), database: Arc::new(database.or_else(|err| cx.throw_error(err.to_string()))?),
}; };
Ok(cx.boxed(db)) Ok(cx.boxed(db))
});
});
Ok(promise)
} }
fn database_table_names(mut cx: FunctionContext) -> JsResult<JsArray> { fn database_table_names(mut cx: FunctionContext) -> JsResult<JsPromise> {
let db = cx let db = cx
.this() .this()
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?; .downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
let tables = db
.database let rt = runtime(&mut cx)?;
.table_names() let (deferred, promise) = cx.promise();
.or_else(|err| cx.throw_error(err.to_string()))?; let channel = cx.channel();
convert::vec_str_to_array(&tables, &mut cx) let database = db.database.clone();
rt.spawn(async move {
let tables_rst = database.table_names().await;
deferred.settle_with(&channel, move |mut cx| {
let tables = tables_rst.or_else(|err| cx.throw_error(err.to_string()))?;
let table_names = convert::vec_str_to_array(&tables, &mut cx);
table_names
});
});
Ok(promise)
} }
fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> { fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
@@ -87,7 +110,7 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
let (deferred, promise) = cx.promise(); let (deferred, promise) = cx.promise();
rt.spawn(async move { rt.spawn(async move {
let table_rst = database.open_table(table_name).await; let table_rst = database.open_table(&table_name).await;
deferred.settle_with(&channel, move |mut cx| { deferred.settle_with(&channel, move |mut cx| {
let table = Arc::new(Mutex::new( let table = Arc::new(Mutex::new(
@@ -186,7 +209,7 @@ fn table_create(mut cx: FunctionContext) -> JsResult<JsPromise> {
rt.block_on(async move { rt.block_on(async move {
let batch_reader: Box<dyn RecordBatchReader> = Box::new(RecordBatchBuffer::new(batches)); let batch_reader: Box<dyn RecordBatchReader> = Box::new(RecordBatchBuffer::new(batches));
let table_rst = database.create_table(table_name, batch_reader).await; let table_rst = database.create_table(&table_name, batch_reader).await;
deferred.settle_with(&channel, move |mut cx| { deferred.settle_with(&channel, move |mut cx| {
let table = Arc::new(Mutex::new( let table = Arc::new(Mutex::new(

View File

@@ -12,7 +12,9 @@ repository = "https://github.com/lancedb/lancedb"
arrow-array = "37.0" arrow-array = "37.0"
arrow-data = "37.0" arrow-data = "37.0"
arrow-schema = "37.0" arrow-schema = "37.0"
lance = "0.4.3" object_store = "0.5.6"
lance = "0.4.17"
tokio = { version = "1.23", features = ["rt-multi-thread"] } tokio = { version = "1.23", features = ["rt-multi-thread"] }
[dev-dependencies] [dev-dependencies]

View File

@@ -12,16 +12,19 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use arrow_array::RecordBatchReader;
use std::fs::create_dir_all; use std::fs::create_dir_all;
use std::path::{Path, PathBuf}; use std::path::Path;
use std::sync::Arc;
use arrow_array::RecordBatchReader;
use lance::io::object_store::ObjectStore;
use crate::error::Result; use crate::error::Result;
use crate::table::Table; use crate::table::Table;
pub struct Database { pub struct Database {
pub(crate) path: Arc<PathBuf>, object_store: ObjectStore,
pub(crate) uri: String,
} }
const LANCE_EXTENSION: &str = "lance"; const LANCE_EXTENSION: &str = "lance";
@@ -37,12 +40,17 @@ impl Database {
/// # Returns /// # Returns
/// ///
/// * A [Database] object. /// * A [Database] object.
pub fn connect<P: AsRef<Path>>(path: P) -> Result<Database> { pub async fn connect(uri: &str) -> Result<Database> {
if !path.as_ref().try_exists()? { let object_store = ObjectStore::new(uri).await?;
if object_store.is_local() {
let path = Path::new(uri);
if !path.try_exists()? {
create_dir_all(&path)?; create_dir_all(&path)?;
} }
}
Ok(Database { Ok(Database {
path: Arc::new(path.as_ref().to_path_buf()), uri: uri.to_string(),
object_store,
}) })
} }
@@ -51,12 +59,13 @@ impl Database {
/// # Returns /// # Returns
/// ///
/// * A [Vec<String>] with all table names. /// * A [Vec<String>] with all table names.
pub fn table_names(&self) -> Result<Vec<String>> { pub async fn table_names(&self) -> Result<Vec<String>> {
let f = self let f = self
.path .object_store
.read_dir()? .read_dir("/")
.flatten() .await?
.map(|dir_entry| dir_entry.path()) .iter()
.map(|fname| Path::new(fname))
.filter(|path| { .filter(|path| {
let is_lance = path let is_lance = path
.extension() .extension()
@@ -76,10 +85,10 @@ impl Database {
pub async fn create_table( pub async fn create_table(
&self, &self,
name: String, name: &str,
batches: Box<dyn RecordBatchReader>, batches: Box<dyn RecordBatchReader>,
) -> Result<Table> { ) -> Result<Table> {
Table::create(self.path.clone(), name, batches).await Table::create(&self.uri, name, batches).await
} }
/// Open a table in the database. /// Open a table in the database.
@@ -90,8 +99,8 @@ impl Database {
/// # Returns /// # Returns
/// ///
/// * A [Table] object. /// * A [Table] object.
pub async fn open_table(&self, name: String) -> Result<Table> { pub async fn open_table(&self, name: &str) -> Result<Table> {
Table::open(self.path.clone(), name).await Table::open(&self.uri, name).await
} }
} }
@@ -105,10 +114,10 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_connect() { async fn test_connect() {
let tmp_dir = tempdir().unwrap(); let tmp_dir = tempdir().unwrap();
let path_buf = tmp_dir.into_path(); let uri = tmp_dir.path().to_str().unwrap();
let db = Database::connect(&path_buf); let db = Database::connect(uri).await.unwrap();
assert_eq!(db.unwrap().path.as_path(), path_buf.as_path()) assert_eq!(db.uri, uri);
} }
#[tokio::test] #[tokio::test]
@@ -118,10 +127,16 @@ mod tests {
create_dir_all(tmp_dir.path().join("table2.lance")).unwrap(); create_dir_all(tmp_dir.path().join("table2.lance")).unwrap();
create_dir_all(tmp_dir.path().join("invalidlance")).unwrap(); create_dir_all(tmp_dir.path().join("invalidlance")).unwrap();
let db = Database::connect(&tmp_dir.into_path()).unwrap(); let uri = tmp_dir.path().to_str().unwrap();
let tables = db.table_names().unwrap(); let db = Database::connect(uri).await.unwrap();
let tables = db.table_names().await.unwrap();
assert_eq!(tables.len(), 2); assert_eq!(tables.len(), 2);
assert!(tables.contains(&String::from("table1"))); assert!(tables.contains(&String::from("table1")));
assert!(tables.contains(&String::from("table2"))); assert!(tables.contains(&String::from("table2")));
} }
#[tokio::test]
async fn test_connect_s3() {
// let db = Database::connect("s3://bucket/path/to/database").await.unwrap();
}
} }

View File

@@ -41,3 +41,15 @@ impl From<lance::Error> for Error {
Self::Lance(e.to_string()) Self::Lance(e.to_string())
} }
} }
impl From<object_store::Error> for Error {
fn from(e: object_store::Error) -> Self {
Self::IO(e.to_string())
}
}
impl From<object_store::path::Error> for Error {
fn from(e: object_store::path::Error) -> Self {
Self::IO(e.to_string())
}
}

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::path::PathBuf; use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use arrow_array::{Float32Array, RecordBatchReader}; use arrow_array::{Float32Array, RecordBatchReader};
@@ -24,16 +24,21 @@ use crate::index::vector::VectorIndexBuilder;
use crate::query::Query; use crate::query::Query;
pub const VECTOR_COLUMN_NAME: &str = "vector"; pub const VECTOR_COLUMN_NAME: &str = "vector";
pub const LANCE_FILE_EXTENSION: &str = "lance"; pub const LANCE_FILE_EXTENSION: &str = "lance";
/// A table in a LanceDB database. /// A table in a LanceDB database.
pub struct Table { pub struct Table {
name: String, name: String,
path: String, uri: String,
dataset: Arc<Dataset>, dataset: Arc<Dataset>,
} }
impl std::fmt::Display for Table {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Table({})", self.name)
}
}
impl Table { impl Table {
/// Opens an existing Table /// Opens an existing Table
/// ///
@@ -45,18 +50,21 @@ impl Table {
/// # Returns /// # Returns
/// ///
/// * A [Table] object. /// * A [Table] object.
pub async fn open(base_path: Arc<PathBuf>, name: String) -> Result<Self> { pub async fn open(base_uri: &str, name: &str) -> Result<Self> {
let ds_path = base_path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION)); let path = Path::new(base_uri);
let ds_uri = ds_path
let table_uri = path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION));
let uri = table_uri
.as_path()
.to_str() .to_str()
.ok_or(Error::IO(format!("Unable to find table {}", name)))?; .ok_or(Error::IO(format!("Invalid table name: {}", name)))?;
let dataset = Dataset::open(ds_uri).await?;
let table = Table { let dataset = Dataset::open(&uri).await?;
name, Ok(Table {
path: ds_uri.to_string(), name: name.to_string(),
uri: uri.to_string(),
dataset: Arc::new(dataset), dataset: Arc::new(dataset),
}; })
Ok(table)
} }
/// Creates a new Table /// Creates a new Table
@@ -71,25 +79,28 @@ impl Table {
/// ///
/// * A [Table] object. /// * A [Table] object.
pub async fn create( pub async fn create(
base_path: Arc<PathBuf>, base_uri: &str,
name: String, name: &str,
mut batches: Box<dyn RecordBatchReader>, mut batches: Box<dyn RecordBatchReader>,
) -> Result<Self> { ) -> Result<Self> {
let ds_path = base_path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION)); let base_path = Path::new(base_uri);
let path = ds_path let table_uri = base_path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION));
let uri = table_uri
.as_path()
.to_str() .to_str()
.ok_or(Error::IO(format!("Unable to find table {}", name)))?; .ok_or(Error::IO(format!("Invalid table name: {}", name)))?
.to_string();
let dataset = let dataset =
Arc::new(Dataset::write(&mut batches, path, Some(WriteParams::default())).await?); Arc::new(Dataset::write(&mut batches, &uri, Some(WriteParams::default())).await?);
Ok(Table { Ok(Table {
name, name: name.to_string(),
path: path.to_string(), uri,
dataset, dataset,
}) })
} }
pub async fn create_idx(&mut self, index_builder: &impl VectorIndexBuilder) -> Result<()> { /// Create index on the table.
pub async fn create_index(&mut self, index_builder: &impl VectorIndexBuilder) -> Result<()> {
use lance::index::DatasetIndexExt; use lance::index::DatasetIndexExt;
let dataset = self let dataset = self
@@ -125,8 +136,7 @@ impl Table {
let mut params = WriteParams::default(); let mut params = WriteParams::default();
params.mode = write_mode.unwrap_or(WriteMode::Append); params.mode = write_mode.unwrap_or(WriteMode::Append);
self.dataset = self.dataset = Arc::new(Dataset::write(&mut batches, &self.uri, Some(params)).await?);
Arc::new(Dataset::write(&mut batches, self.path.as_str(), Some(params)).await?);
Ok(batches.count()) Ok(batches.count())
} }
@@ -151,6 +161,8 @@ impl Table {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::sync::Arc;
use arrow_array::{ use arrow_array::{
Array, FixedSizeListArray, Float32Array, Int32Array, RecordBatch, RecordBatchReader, Array, FixedSizeListArray, Float32Array, Int32Array, RecordBatch, RecordBatchReader,
}; };
@@ -161,53 +173,52 @@ mod tests {
use lance::index::vector::ivf::IvfBuildParams; use lance::index::vector::ivf::IvfBuildParams;
use lance::index::vector::pq::PQBuildParams; use lance::index::vector::pq::PQBuildParams;
use rand::Rng; use rand::Rng;
use std::sync::Arc;
use tempfile::tempdir; use tempfile::tempdir;
use crate::error::Result; use super::*;
use crate::index::vector::IvfPQIndexBuilder; use crate::index::vector::IvfPQIndexBuilder;
use crate::table::Table;
#[tokio::test] #[tokio::test]
async fn test_new_table_not_exists() { async fn test_new_table_not_exists() {
let tmp_dir = tempdir().unwrap(); let tmp_dir = tempdir().unwrap();
let path_buf = tmp_dir.into_path(); let uri = tmp_dir.path().to_str().unwrap();
let table = Table::open(Arc::new(path_buf), "test".to_string()).await; let table = Table::open(&uri, "test").await;
assert!(table.is_err()); assert!(table.is_err());
} }
#[tokio::test] #[tokio::test]
async fn test_open() { async fn test_open() {
let tmp_dir = tempdir().unwrap(); let tmp_dir = tempdir().unwrap();
let path_buf = tmp_dir.into_path(); let dataset_path = tmp_dir.path().join("test.lance");
let uri = tmp_dir.path().to_str().unwrap();
let mut batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches()); let mut batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
Dataset::write( Dataset::write(&mut batches, dataset_path.to_str().unwrap(), None)
&mut batches,
path_buf.join("test.lance").to_str().unwrap(),
None,
)
.await .await
.unwrap(); .unwrap();
let table = Table::open(Arc::new(path_buf), "test".to_string()) let table = Table::open(uri, "test").await.unwrap();
.await
.unwrap();
assert_eq!(table.name, "test") assert_eq!(table.name, "test")
} }
#[test]
fn test_object_store_path() {
use std::path::Path as StdPath;
let p = StdPath::new("s3://bucket/path/to/file");
let c = p.join("subfile");
assert_eq!(c.to_str().unwrap(), "s3://bucket/path/to/file/subfile");
}
#[tokio::test] #[tokio::test]
async fn test_add() { async fn test_add() {
let tmp_dir = tempdir().unwrap(); let tmp_dir = tempdir().unwrap();
let path_buf = tmp_dir.into_path(); let uri = tmp_dir.path().to_str().unwrap();
let batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches()); let batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
let schema = batches.schema().clone(); let schema = batches.schema().clone();
let mut table = Table::create(Arc::new(path_buf), "test".to_string(), batches) let mut table = Table::create(&uri, "test", batches).await.unwrap();
.await
.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10); assert_eq!(table.count_rows().await.unwrap(), 10);
let new_batches: Box<dyn RecordBatchReader> = let new_batches: Box<dyn RecordBatchReader> =
@@ -225,13 +236,11 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_add_overwrite() { async fn test_add_overwrite() {
let tmp_dir = tempdir().unwrap(); let tmp_dir = tempdir().unwrap();
let path_buf = tmp_dir.into_path(); let uri = tmp_dir.path().to_str().unwrap();
let batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches()); let batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
let schema = batches.schema().clone(); let schema = batches.schema().clone();
let mut table = Table::create(Arc::new(path_buf), "test".to_string(), batches) let mut table = Table::create(uri, "test", batches).await.unwrap();
.await
.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10); assert_eq!(table.count_rows().await.unwrap(), 10);
let new_batches: Box<dyn RecordBatchReader> = let new_batches: Box<dyn RecordBatchReader> =
@@ -252,20 +261,15 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_search() { async fn test_search() {
let tmp_dir = tempdir().unwrap(); let tmp_dir = tempdir().unwrap();
let path_buf = tmp_dir.into_path(); let dataset_path = tmp_dir.path().join("test.lance");
let uri = tmp_dir.path().to_str().unwrap();
let mut batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches()); let mut batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
Dataset::write( Dataset::write(&mut batches, dataset_path.to_str().unwrap(), None)
&mut batches,
path_buf.join("test.lance").to_str().unwrap(),
None,
)
.await .await
.unwrap(); .unwrap();
let table = Table::open(Arc::new(path_buf), "test".to_string()) let table = Table::open(uri, "test").await.unwrap();
.await
.unwrap();
let vector = Float32Array::from_iter_values([0.1, 0.2]); let vector = Float32Array::from_iter_values([0.1, 0.2]);
let query = table.search(vector.clone()); let query = table.search(vector.clone());
@@ -291,7 +295,7 @@ mod tests {
use arrow_array::Float32Array; use arrow_array::Float32Array;
let tmp_dir = tempdir().unwrap(); let tmp_dir = tempdir().unwrap();
let path_buf = tmp_dir.into_path(); let uri = tmp_dir.path().to_str().unwrap();
let dimension = 16; let dimension = 16;
let schema = Arc::new(ArrowSchema::new(vec![Field::new( let schema = Arc::new(ArrowSchema::new(vec![Field::new(
@@ -318,9 +322,7 @@ mod tests {
.unwrap()]); .unwrap()]);
let reader: Box<dyn RecordBatchReader + Send> = Box::new(batches); let reader: Box<dyn RecordBatchReader + Send> = Box::new(batches);
let mut table = Table::create(Arc::new(path_buf), "test".to_string(), reader) let mut table = Table::create(uri, "test", reader).await.unwrap();
.await
.unwrap();
let mut i = IvfPQIndexBuilder::new(); let mut i = IvfPQIndexBuilder::new();
@@ -330,7 +332,7 @@ mod tests {
.ivf_params(IvfBuildParams::new(256)) .ivf_params(IvfBuildParams::new(256))
.pq_params(PQBuildParams::default()); .pq_params(PQBuildParams::default());
table.create_idx(index_builder).await.unwrap(); table.create_index(index_builder).await.unwrap();
assert_eq!(table.dataset.load_indices().await.unwrap().len(), 1); assert_eq!(table.dataset.load_indices().await.unwrap().len(), 1);
assert_eq!(table.count_rows().await.unwrap(), 512); assert_eq!(table.count_rows().await.unwrap(), 512);