mirror of
https://github.com/lancedb/lancedb.git
synced 2026-06-03 04:10:41 +00:00
JavaScript / Node.js library for LanceDB
- Core rust library - ffi bridge that exposes rust functionality to javascript - npm package that provides a TypeScript / JavaScript library - limitations: it only supports reading for now
This commit is contained in:
19
rust/ffi/node/Cargo.toml
Normal file
19
rust/ffi/node/Cargo.toml
Normal file
@@ -0,0 +1,19 @@
|
||||
[package]
|
||||
name = "vectordb-node"
|
||||
version = "0.1.0"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
license = "Apache-2.0"
|
||||
edition = "2018"
|
||||
exclude = ["index.node"]
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
arrow-array = "37.0"
|
||||
arrow-ipc = "37.0"
|
||||
once_cell = "1"
|
||||
futures = "0.3"
|
||||
vectordb = { path = "../../vectordb" }
|
||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
||||
neon = {version = "0.10.1", default-features = false, features = ["channel-api", "napi-6", "promise-api", "task-api"] }
|
||||
3
rust/ffi/node/README.md
Normal file
3
rust/ffi/node/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
The LanceDB node bridge (vectordb-node) allows javascript applications to access LanceDB datasets.
|
||||
|
||||
It is build using [Neon](https://neon-bindings.com). See the node project for an example of how it is used / tests
|
||||
36
rust/ffi/node/src/convert.rs
Normal file
36
rust/ffi/node/src/convert.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
// Copyright 2023 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use neon::prelude::*;
|
||||
|
||||
pub(crate) fn vec_str_to_array<'a, C: Context<'a>>(
|
||||
vec: &Vec<String>,
|
||||
cx: &mut C,
|
||||
) -> JsResult<'a, JsArray> {
|
||||
let a = JsArray::new(cx, vec.len() as u32);
|
||||
for (i, s) in vec.iter().enumerate() {
|
||||
let v = cx.string(s);
|
||||
a.set(cx, i as u32, v)?;
|
||||
}
|
||||
Ok(a)
|
||||
}
|
||||
|
||||
pub(crate) fn js_array_to_vec(array: &JsArray, cx: &mut FunctionContext) -> Vec<f32> {
|
||||
let mut query_vec: Vec<f32> = Vec::new();
|
||||
for i in 0..array.len(cx) {
|
||||
let entry: Handle<JsNumber> = array.get(cx, i).unwrap();
|
||||
query_vec.push(entry.value(cx) as f32);
|
||||
}
|
||||
query_vec
|
||||
}
|
||||
145
rust/ffi/node/src/lib.rs
Normal file
145
rust/ffi/node/src/lib.rs
Normal file
@@ -0,0 +1,145 @@
|
||||
// Copyright 2023 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod convert;
|
||||
|
||||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::Float32Array;
|
||||
use arrow_ipc::writer::FileWriter;
|
||||
use futures::TryStreamExt;
|
||||
use neon::prelude::*;
|
||||
use once_cell::sync::OnceCell;
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
use vectordb::database::Database;
|
||||
use vectordb::table::Table;
|
||||
|
||||
struct JsDatabase {
|
||||
database: Arc<Database>,
|
||||
}
|
||||
|
||||
struct JsTable {
|
||||
table: Arc<Table>,
|
||||
}
|
||||
|
||||
impl Finalize for JsDatabase {}
|
||||
|
||||
impl Finalize for JsTable {}
|
||||
|
||||
fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> {
|
||||
static RUNTIME: OnceCell<Runtime> = OnceCell::new();
|
||||
|
||||
RUNTIME.get_or_try_init(|| Runtime::new().or_else(|err| cx.throw_error(err.to_string())))
|
||||
}
|
||||
|
||||
fn database_new(mut cx: FunctionContext) -> JsResult<JsBox<JsDatabase>> {
|
||||
let path = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
let db = JsDatabase {
|
||||
database: Arc::new(Database::connect(path).or_else(|err| cx.throw_error(err.to_string()))?),
|
||||
};
|
||||
Ok(cx.boxed(db))
|
||||
}
|
||||
|
||||
fn database_table_names(mut cx: FunctionContext) -> JsResult<JsArray> {
|
||||
let db = cx
|
||||
.this()
|
||||
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
|
||||
let tables = db
|
||||
.database
|
||||
.table_names()
|
||||
.or_else(|err| cx.throw_error(err.to_string()))?;
|
||||
convert::vec_str_to_array(&tables, &mut cx)
|
||||
}
|
||||
|
||||
fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let db = cx
|
||||
.this()
|
||||
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
|
||||
let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
|
||||
let rt = runtime(&mut cx)?;
|
||||
let channel = cx.channel();
|
||||
let database = db.database.clone();
|
||||
|
||||
let (deferred, promise) = cx.promise();
|
||||
rt.spawn(async move {
|
||||
let table_rst = database.open_table(table_name).await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let table = Arc::new(table_rst.or_else(|err| cx.throw_error(err.to_string()))?);
|
||||
Ok(cx.boxed(JsTable { table }))
|
||||
});
|
||||
});
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
fn table_search(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let query_vector = cx.argument::<JsArray>(0)?; //. .as_value(&mut cx);
|
||||
let limit = cx.argument::<JsNumber>(1)?.value(&mut cx);
|
||||
|
||||
let rt = runtime(&mut cx)?;
|
||||
let channel = cx.channel();
|
||||
|
||||
let (deferred, promise) = cx.promise();
|
||||
let table = js_table.table.clone();
|
||||
let query = convert::js_array_to_vec(query_vector.deref(), &mut cx);
|
||||
|
||||
rt.spawn(async move {
|
||||
let builder = table
|
||||
.search(Float32Array::from(query))
|
||||
.limit(limit as usize);
|
||||
let results = builder
|
||||
.execute()
|
||||
.await
|
||||
.unwrap() // FIXME unwrap
|
||||
.try_collect::<Vec<_>>()
|
||||
.await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let results = results.or_else(|err| cx.throw_error(err.to_string()))?;
|
||||
let vector: Vec<u8> = Vec::new();
|
||||
|
||||
if results.is_empty() {
|
||||
return cx.buffer(0);
|
||||
}
|
||||
|
||||
let schema = results.get(0).unwrap().schema();
|
||||
let mut fr = FileWriter::try_new(vector, schema.deref())
|
||||
.or_else(|err| cx.throw_error(err.to_string()))?;
|
||||
|
||||
for batch in results.iter() {
|
||||
fr.write(batch)
|
||||
.or_else(|err| cx.throw_error(err.to_string()))?;
|
||||
}
|
||||
fr.finish().or_else(|err| cx.throw_error(err.to_string()))?;
|
||||
let buf = fr
|
||||
.into_inner()
|
||||
.or_else(|err| cx.throw_error(err.to_string()))?;
|
||||
Ok(JsBuffer::external(&mut cx, buf))
|
||||
});
|
||||
});
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
#[neon::main]
|
||||
fn main(mut cx: ModuleContext) -> NeonResult<()> {
|
||||
cx.export_function("databaseNew", database_new)?;
|
||||
cx.export_function("databaseTableNames", database_table_names)?;
|
||||
cx.export_function("databaseOpenTable", database_open_table)?;
|
||||
cx.export_function("tableSearch", table_search)?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
pub fn add(left: usize, right: usize) -> usize {
|
||||
left + right
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn it_works() {
|
||||
let result = add(2, 2);
|
||||
assert_eq!(result, 4);
|
||||
}
|
||||
}
|
||||
@@ -9,4 +9,10 @@ repository = "https://github.com/lancedb/lancedb"
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
arrow-array = "37.0"
|
||||
arrow-schema = "37.0"
|
||||
lance = "0.4.3"
|
||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.5.0"
|
||||
118
rust/vectordb/src/database.rs
Normal file
118
rust/vectordb/src/database.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
// Copyright 2023 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::fs::create_dir_all;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::table::Table;
|
||||
|
||||
pub struct Database {
|
||||
pub(crate) path: Arc<PathBuf>,
|
||||
}
|
||||
|
||||
const LANCE_EXTENSION: &str = "lance";
|
||||
|
||||
/// A connection to LanceDB
|
||||
impl Database {
|
||||
/// Connects to LanceDB
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - URI where the database is located, can be a local file or a supported remote cloud storage
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [Database] object.
|
||||
pub fn connect<P: AsRef<Path>>(path: P) -> Result<Database> {
|
||||
if !path.as_ref().try_exists()? {
|
||||
create_dir_all(&path)?;
|
||||
}
|
||||
Ok(Database {
|
||||
path: Arc::new(path.as_ref().to_path_buf()),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the names of all tables in the database.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [Vec<String>] with all table names.
|
||||
pub fn table_names(&self) -> Result<Vec<String>> {
|
||||
let f = self
|
||||
.path
|
||||
.read_dir()?
|
||||
.flatten()
|
||||
.map(|dir_entry| dir_entry.path())
|
||||
.filter(|path| {
|
||||
let is_lance = path
|
||||
.extension()
|
||||
.map(|e| e.to_str().map(|e| e == LANCE_EXTENSION))
|
||||
.flatten();
|
||||
is_lance.unwrap_or(false)
|
||||
})
|
||||
.map(|p| {
|
||||
p.file_stem()
|
||||
.map(|s| s.to_str().map(|s| String::from(s)))
|
||||
.flatten()
|
||||
})
|
||||
.flatten()
|
||||
.collect();
|
||||
Ok(f)
|
||||
}
|
||||
|
||||
/// Open a table in the database.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `name` - The name of the table.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [Table] object.
|
||||
pub async fn open_table(&self, name: String) -> Result<Table> {
|
||||
Table::new(self.path.clone(), name).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::fs::create_dir_all;
|
||||
use tempfile::tempdir;
|
||||
|
||||
use crate::database::Database;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_connect() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let path_buf = tmp_dir.into_path();
|
||||
let db = Database::connect(&path_buf);
|
||||
|
||||
assert_eq!(db.unwrap().path.as_path(), path_buf.as_path())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_table_names() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
create_dir_all(tmp_dir.path().join("table1.lance")).unwrap();
|
||||
create_dir_all(tmp_dir.path().join("table2.lance")).unwrap();
|
||||
create_dir_all(tmp_dir.path().join("invalidlance")).unwrap();
|
||||
|
||||
let db = Database::connect(&tmp_dir.into_path()).unwrap();
|
||||
let tables = db.table_names().unwrap();
|
||||
assert_eq!(tables.len(), 2);
|
||||
assert!(tables.contains(&String::from("table1")));
|
||||
assert!(tables.contains(&String::from("table2")));
|
||||
}
|
||||
}
|
||||
43
rust/vectordb/src/error.rs
Normal file
43
rust/vectordb/src/error.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
// Copyright 2023 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
IO(String),
|
||||
Lance(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let (catalog, message) = match self {
|
||||
Self::IO(s) => ("I/O", s.as_str()),
|
||||
Self::Lance(s) => ("Lance", s.as_str()),
|
||||
};
|
||||
write!(f, "LanceDBError({catalog}): {message}")
|
||||
}
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
impl From<std::io::Error> for Error {
|
||||
fn from(e: std::io::Error) -> Self {
|
||||
Self::IO(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<lance::Error> for Error {
|
||||
fn from(e: lance::Error) -> Self {
|
||||
Self::Lance(e.to_string())
|
||||
}
|
||||
}
|
||||
18
rust/vectordb/src/lib.rs
Normal file
18
rust/vectordb/src/lib.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
// Copyright 2023 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
pub mod database;
|
||||
pub mod error;
|
||||
pub mod query;
|
||||
pub mod table;
|
||||
210
rust/vectordb/src/query.rs
Normal file
210
rust/vectordb/src/query.rs
Normal file
@@ -0,0 +1,210 @@
|
||||
// Copyright 2023 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::Float32Array;
|
||||
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
|
||||
use lance::dataset::Dataset;
|
||||
use lance::index::vector::MetricType;
|
||||
|
||||
use crate::error::Result;
|
||||
|
||||
/// A builder for nearest neighbor queries for LanceDB.
|
||||
pub struct Query {
|
||||
pub dataset: Arc<Dataset>,
|
||||
pub query_vector: Float32Array,
|
||||
pub limit: usize,
|
||||
pub nprobes: usize,
|
||||
pub refine_factor: Option<u32>,
|
||||
pub metric_type: MetricType,
|
||||
pub use_index: bool,
|
||||
}
|
||||
|
||||
impl Query {
|
||||
/// Creates a new Query object
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dataset` - The table / dataset the query will be run against.
|
||||
/// * `vector` The vector used for this query.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [Query] object.
|
||||
pub(crate) fn new(dataset: Arc<Dataset>, vector: Float32Array) -> Self {
|
||||
Query {
|
||||
dataset,
|
||||
query_vector: vector,
|
||||
limit: 10,
|
||||
nprobes: 20,
|
||||
refine_factor: None,
|
||||
metric_type: MetricType::L2,
|
||||
use_index: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute the queries and return its results.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [DatasetRecordBatchStream] with the query's results.
|
||||
pub async fn execute(&self) -> Result<DatasetRecordBatchStream> {
|
||||
let mut scanner: Scanner = self.dataset.scan();
|
||||
|
||||
scanner.nearest(
|
||||
crate::table::VECTOR_COLUMN_NAME,
|
||||
&self.query_vector,
|
||||
self.limit,
|
||||
)?;
|
||||
scanner.nprobs(self.nprobes);
|
||||
scanner.distance_metric(self.metric_type);
|
||||
scanner.use_index(self.use_index);
|
||||
self.refine_factor.map(|rf| scanner.refine(rf));
|
||||
Ok(scanner.try_into_stream().await?)
|
||||
}
|
||||
|
||||
/// Set the maximum number of results to return.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `limit` - The maximum number of results to return.
|
||||
pub fn limit(mut self, limit: usize) -> Query {
|
||||
self.limit = limit;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the vector used for this query.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vector` - The vector that will be used for search.
|
||||
pub fn query_vector(mut self, query_vector: Float32Array) -> Query {
|
||||
self.query_vector = query_vector;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the number of probes to use.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `nprobes` - The number of probes to use.
|
||||
pub fn nprobes(mut self, nprobes: usize) -> Query {
|
||||
self.nprobes = nprobes;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the refine factor to use.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `refine_factor` - The refine factor to use.
|
||||
pub fn refine_factor(mut self, refine_factor: Option<u32>) -> Query {
|
||||
self.refine_factor = refine_factor;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the distance metric to use.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `metric_type` - The distance metric to use. By default [MetricType::L2] is used.
|
||||
pub fn metric_type(mut self, metric_type: MetricType) -> Query {
|
||||
self.metric_type = metric_type;
|
||||
self
|
||||
}
|
||||
|
||||
/// Whether to use an ANN index if available
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `use_index` - Sets Whether to use an ANN index if available
|
||||
pub fn use_index(mut self, use_index: bool) -> Query {
|
||||
self.use_index = use_index;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::{Float32Array, RecordBatch, RecordBatchReader};
|
||||
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
|
||||
use lance::arrow::RecordBatchBuffer;
|
||||
use lance::dataset::Dataset;
|
||||
use lance::index::vector::MetricType;
|
||||
|
||||
use crate::query::Query;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_setters_getters() {
|
||||
let mut batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
|
||||
let ds = Dataset::write(&mut batches, ":memory:", None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let vector = Float32Array::from_iter_values([0.1, 0.2]);
|
||||
let query = Query::new(Arc::new(ds), vector.clone());
|
||||
assert_eq!(query.query_vector, vector);
|
||||
|
||||
let new_vector = Float32Array::from_iter_values([9.8, 8.7]);
|
||||
|
||||
let query = query
|
||||
.query_vector(new_vector.clone())
|
||||
.limit(100)
|
||||
.nprobes(1000)
|
||||
.use_index(true)
|
||||
.metric_type(MetricType::Cosine)
|
||||
.refine_factor(Some(999));
|
||||
|
||||
assert_eq!(query.query_vector, new_vector);
|
||||
assert_eq!(query.limit, 100);
|
||||
assert_eq!(query.nprobes, 1000);
|
||||
assert_eq!(query.use_index, true);
|
||||
assert_eq!(query.metric_type, MetricType::Cosine);
|
||||
assert_eq!(query.refine_factor, Some(999));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute() {
|
||||
let mut batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
|
||||
let ds = Dataset::write(&mut batches, ":memory:", None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let vector = Float32Array::from_iter_values([0.1; 128]);
|
||||
let query = Query::new(Arc::new(ds), vector.clone());
|
||||
let result = query.execute().await;
|
||||
assert_eq!(result.is_ok(), true);
|
||||
}
|
||||
|
||||
fn make_test_batches() -> RecordBatchBuffer {
|
||||
let dim: usize = 128;
|
||||
let schema = Arc::new(ArrowSchema::new(vec![
|
||||
ArrowField::new("key", DataType::Int32, false),
|
||||
ArrowField::new(
|
||||
"vector",
|
||||
DataType::FixedSizeList(
|
||||
Arc::new(ArrowField::new("item", DataType::Float32, true)),
|
||||
dim as i32,
|
||||
),
|
||||
true,
|
||||
),
|
||||
ArrowField::new("uri", DataType::Utf8, true),
|
||||
]));
|
||||
|
||||
RecordBatchBuffer::new(vec![RecordBatch::new_empty(schema.clone())])
|
||||
}
|
||||
}
|
||||
144
rust/vectordb/src/table.rs
Normal file
144
rust/vectordb/src/table.rs
Normal file
@@ -0,0 +1,144 @@
|
||||
// Copyright 2023 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::Float32Array;
|
||||
use lance::dataset::Dataset;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::query::Query;
|
||||
|
||||
pub const VECTOR_COLUMN_NAME: &str = "vector";
|
||||
|
||||
pub const LANCE_FILE_EXTENSION: &str = "lance";
|
||||
|
||||
/// A table in a LanceDB database.
|
||||
pub struct Table {
|
||||
name: String,
|
||||
dataset: Arc<Dataset>,
|
||||
}
|
||||
|
||||
impl Table {
|
||||
/// Creates a new Table object
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `base_path` - The base path where the table is located
|
||||
/// * `name` The Table name
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [Table] object.
|
||||
pub async fn new(base_path: Arc<PathBuf>, name: String) -> Result<Self> {
|
||||
let ds_path = base_path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION));
|
||||
let ds_uri = ds_path
|
||||
.to_str()
|
||||
.ok_or(Error::IO(format!("Unable to find table {}", name)))?;
|
||||
let dataset = Dataset::open(ds_uri).await?;
|
||||
let table = Table {
|
||||
name,
|
||||
dataset: Arc::new(dataset),
|
||||
};
|
||||
Ok(table)
|
||||
}
|
||||
|
||||
/// Creates a new Query object that can be executed.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vector` The vector used for this query.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [Query] object.
|
||||
pub fn search(&self, query_vector: Float32Array) -> Query {
|
||||
Query::new(self.dataset.clone(), query_vector)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use arrow_array::{Float32Array, Int32Array, RecordBatch, RecordBatchReader};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use lance::arrow::RecordBatchBuffer;
|
||||
use lance::dataset::Dataset;
|
||||
use std::sync::Arc;
|
||||
use tempfile::tempdir;
|
||||
|
||||
use crate::table::Table;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_new_table_not_exists() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let path_buf = tmp_dir.into_path();
|
||||
|
||||
let table = Table::new(Arc::new(path_buf), "test".to_string()).await;
|
||||
assert!(table.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_new() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let path_buf = tmp_dir.into_path();
|
||||
|
||||
let mut batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
|
||||
Dataset::write(
|
||||
&mut batches,
|
||||
path_buf.join("test.lance").to_str().unwrap(),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let table = Table::new(Arc::new(path_buf), "test".to_string())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(table.name, "test")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_search() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let path_buf = tmp_dir.into_path();
|
||||
|
||||
let mut batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
|
||||
Dataset::write(
|
||||
&mut batches,
|
||||
path_buf.join("test.lance").to_str().unwrap(),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let table = Table::new(Arc::new(path_buf), "test".to_string())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let vector = Float32Array::from_iter_values([0.1, 0.2]);
|
||||
let query = table.search(vector.clone());
|
||||
assert_eq!(vector, query.query_vector);
|
||||
}
|
||||
|
||||
fn make_test_batches() -> RecordBatchBuffer {
|
||||
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
|
||||
RecordBatchBuffer::new(vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(Int32Array::from_iter_values(0..20))],
|
||||
)
|
||||
.unwrap()])
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user