Compare commits

...

5 Commits

Author SHA1 Message Date
Lance Release
677b7c1fcc [python] Bump version: 0.5.6 → 0.5.7 2024-02-22 20:07:12 +00:00
Lei Xu
8303a7197b chore: bump pylance to 0.9.18 (#1011) 2024-02-22 11:47:36 -08:00
Raghav Dixit
5fa9bfc4a8 python(feat): Imagebind embedding fn support (#1003)
Added imagebind fn support , steps to install mentioned in docstring. 
pytest slow checks done locally

---------

Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
2024-02-22 11:47:08 +05:30
Ayush Chaurasia
bf2e9d0088 Docs: add meta tags (#1006) 2024-02-21 23:22:47 +05:30
Weston Pace
f04590ddad refactor: rust vectordb API stabilization of the Connection trait (#993)
This is the start of a more comprehensive refactor and stabilization of
the Rust API. The `Connection` trait is cleaned up to not require
`lance` and to match the `Connection` trait in other APIs. In addition,
the concrete implementation `Database` is hidden.

BREAKING CHANGE: The struct `crate::connection::Database` is now gone.
Several examples opened a connection using `Database::connect` or
`Database::connect_with_params`. Users should now use
`vectordb::connect`.

BREAKING CHANGE: The `connect`, `create_table`, and `open_table` methods
now all return a builder object. This means that a call like
`conn.open_table(..., opt1, opt2)` will now become
`conn.open_table(...).opt1(opt1).opt2(opt2).execute()` In addition, the
structure of options has changed slightly. However, no options
capability has been removed.

---------

Co-authored-by: Will Jones <willjones127@gmail.com>
2024-02-20 18:35:52 -08:00
20 changed files with 923 additions and 314 deletions

View File

@@ -57,6 +57,16 @@ plugins:
- https://arrow.apache.org/docs/objects.inv
- https://pandas.pydata.org/docs/objects.inv
- mkdocs-jupyter
- ultralytics:
verbose: True
enabled: True
default_image: "assets/lancedb_and_lance.png" # Default image for all pages
add_image: True # Automatically add meta image
add_keywords: True # Add page keywords in the header tag
add_share_buttons: True # Add social share buttons
add_authors: False # Display page authors
add_desc: False
add_dates: False
markdown_extensions:
- admonition
@@ -206,7 +216,6 @@ extra_css:
extra_javascript:
- "extra_js/init_ask_ai_widget.js"
- "extra_js/meta_tag.js"
extra:
analytics:

View File

@@ -2,4 +2,5 @@ mkdocs==1.5.3
mkdocs-jupyter==0.24.1
mkdocs-material==9.5.3
mkdocstrings[python]==0.20.0
pydantic
pydantic
mkdocs-ultralytics-plugin==0.0.44

View File

@@ -1,6 +0,0 @@
window.addEventListener('load', function() {
var meta = document.createElement('meta');
meta.setAttribute('property', 'og:image');
meta.setAttribute('content', '/assets/lancedb_and_lance.png');
document.head.appendChild(meta);
});

View File

@@ -12,18 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use napi::bindgen_prelude::*;
use napi_derive::*;
use crate::table::Table;
use vectordb::connection::{Connection as LanceDBConnection, Database};
use vectordb::connection::Connection as LanceDBConnection;
use vectordb::ipc::ipc_file_to_batches;
#[napi]
pub struct Connection {
conn: Arc<dyn LanceDBConnection>,
conn: LanceDBConnection,
}
#[napi]
@@ -32,9 +30,9 @@ impl Connection {
#[napi(factory)]
pub async fn new(uri: String) -> napi::Result<Self> {
Ok(Self {
conn: Arc::new(Database::connect(&uri).await.map_err(|e| {
conn: vectordb::connect(&uri).execute().await.map_err(|e| {
napi::Error::from_reason(format!("Failed to connect to database: {}", e))
})?),
})?,
})
}
@@ -59,7 +57,8 @@ impl Connection {
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
let tbl = self
.conn
.create_table(&name, Box::new(batches), None)
.create_table(&name, Box::new(batches))
.execute()
.await
.map_err(|e| napi::Error::from_reason(format!("{}", e)))?;
Ok(Table::new(tbl))
@@ -70,6 +69,7 @@ impl Connection {
let tbl = self
.conn
.open_table(&name)
.execute()
.await
.map_err(|e| napi::Error::from_reason(format!("{}", e)))?;
Ok(Table::new(tbl))

View File

@@ -15,6 +15,7 @@
use arrow_ipc::writer::FileWriter;
use napi::bindgen_prelude::*;
use napi_derive::napi;
use vectordb::table::AddDataOptions;
use vectordb::{ipc::ipc_file_to_batches, table::TableRef};
use crate::index::IndexBuilder;
@@ -48,12 +49,15 @@ impl Table {
pub async fn add(&self, buf: Buffer) -> napi::Result<()> {
let batches = ipc_file_to_batches(buf.to_vec())
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
self.table.add(Box::new(batches), None).await.map_err(|e| {
napi::Error::from_reason(format!(
"Failed to add batches to table {}: {}",
self.table, e
))
})
self.table
.add(Box::new(batches), AddDataOptions::default())
.await
.map_err(|e| {
napi::Error::from_reason(format!(
"Failed to add batches to table {}: {}",
self.table, e
))
})
}
#[napi]

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.5.6
current_version = 0.5.7
commit = True
message = [python] Bump version: {current_version} → {new_version}
tag = True

View File

@@ -0,0 +1,172 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import List, Union
import numpy as np
import pyarrow as pa
from ..util import attempt_import_or_raise
from .base import EmbeddingFunction
from .registry import register
from .utils import AUDIO, IMAGES, TEXT
@register("imagebind")
class ImageBindEmbeddings(EmbeddingFunction):
"""
An embedding function that uses the ImageBind API
For generating multi-modal embeddings across
six different modalities: images, text, audio, depth, thermal, and IMU data
to download package, run :
`pip install imagebind@git+https://github.com/raghavdixit99/ImageBind`
"""
name: str = "imagebind_huge"
device: str = "cpu"
normalize: bool = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._ndims = 1024
self._audio_extensions = (".mp3", ".wav", ".flac", ".ogg", ".aac")
self._image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp")
@cached_property
def embedding_model(self):
"""
Get the embedding model. This is cached so that the model is only loaded
once per process.
"""
return self.get_embedding_model()
@cached_property
def _data(self):
"""
Get the data module from imagebind
"""
data = attempt_import_or_raise("imagebind.data", "imagebind")
return data
@cached_property
def _ModalityType(self):
"""
Get the ModalityType from imagebind
"""
imagebind = attempt_import_or_raise("imagebind", "imagebind")
return imagebind.imagebind_model.ModalityType
def ndims(self):
return self._ndims
def compute_query_embeddings(
self, query: Union[str], *args, **kwargs
) -> List[np.ndarray]:
"""
Compute the embeddings for a given user query
Parameters
----------
query : Union[str]
The query to embed. A query can be either text, image paths or audio paths.
"""
query = self.sanitize_input(query)
if query[0].endswith(self._audio_extensions):
return [self.generate_audio_embeddings(query)]
elif query[0].endswith(self._image_extensions):
return [self.generate_image_embeddings(query)]
else:
return [self.generate_text_embeddings(query)]
def generate_image_embeddings(self, image: IMAGES) -> np.ndarray:
torch = attempt_import_or_raise("torch")
inputs = {
self._ModalityType.VISION: self._data.load_and_transform_vision_data(
image, self.device
)
}
with torch.no_grad():
image_features = self.embedding_model(inputs)[self._ModalityType.VISION]
if self.normalize:
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy().squeeze()
def generate_audio_embeddings(self, audio: AUDIO) -> np.ndarray:
torch = attempt_import_or_raise("torch")
inputs = {
self._ModalityType.AUDIO: self._data.load_and_transform_audio_data(
audio, self.device
)
}
with torch.no_grad():
audio_features = self.embedding_model(inputs)[self._ModalityType.AUDIO]
if self.normalize:
audio_features /= audio_features.norm(dim=-1, keepdim=True)
return audio_features.cpu().numpy().squeeze()
def generate_text_embeddings(self, text: TEXT) -> np.ndarray:
torch = attempt_import_or_raise("torch")
inputs = {
self._ModalityType.TEXT: self._data.load_and_transform_text(
text, self.device
)
}
with torch.no_grad():
text_features = self.embedding_model(inputs)[self._ModalityType.TEXT]
if self.normalize:
text_features /= text_features.norm(dim=-1, keepdim=True)
return text_features.cpu().numpy().squeeze()
def compute_source_embeddings(
self, source: Union[IMAGES, AUDIO], *args, **kwargs
) -> List[np.array]:
"""
Get the embeddings for the given sourcefield column in the pydantic model.
"""
source = self.sanitize_input(source)
embeddings = []
if source[0].endswith(self._audio_extensions):
embeddings.extend(self.generate_audio_embeddings(source))
return embeddings
elif source[0].endswith(self._image_extensions):
embeddings.extend(self.generate_image_embeddings(source))
return embeddings
else:
embeddings.extend(self.generate_text_embeddings(source))
return embeddings
def sanitize_input(
self, input: Union[IMAGES, AUDIO]
) -> Union[List[bytes], np.ndarray]:
"""
Sanitize the input to the embedding function.
"""
if isinstance(input, (str, bytes)):
input = [input]
elif isinstance(input, pa.Array):
input = input.to_pylist()
elif isinstance(input, pa.ChunkedArray):
input = input.combine_chunks().to_pylist()
return input
def get_embedding_model(self):
"""
fetches the imagebind embedding model
"""
imagebind = attempt_import_or_raise("imagebind", "imagebind")
model = imagebind.imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(self.device)
return model

View File

@@ -36,6 +36,7 @@ TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
IMAGES = Union[
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
]
AUDIO = Union[str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray]
@deprecated

View File

@@ -1,9 +1,9 @@
[project]
name = "lancedb"
version = "0.5.6"
version = "0.5.7"
dependencies = [
"deprecation",
"pylance==0.9.16",
"pylance==0.9.18",
"ratelimiter~=1.0",
"retry>=0.9.2",
"tqdm>=4.27.0",
@@ -50,7 +50,7 @@ repository = "https://github.com/lancedb/lancedb"
[project.optional-dependencies]
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars>=0.19"]
dev = ["ruff", "pre-commit"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]", "mkdocs-ultralytics-plugin==0.0.44"]
clip = ["torch", "pillow", "open-clip"]
embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "huggingface_hub",
"InstructorEmbedding", "google.generativeai", "boto3>=1.28.57", "awscli>=1.29.57", "botocore>=1.31.57"]

View File

@@ -28,6 +28,23 @@ from lancedb.pydantic import LanceModel, Vector
# or connection to external api
try:
if importlib.util.find_spec("mlx.core") is not None:
_mlx = True
else:
_mlx = None
except Exception:
_mlx = None
try:
if importlib.util.find_spec("imagebind") is not None:
_imagebind = True
else:
_imagebind = None
except Exception:
_imagebind = None
@pytest.mark.slow
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
def test_basic_text_embeddings(alias, tmp_path):
@@ -158,6 +175,89 @@ def test_openclip(tmp_path):
)
@pytest.mark.skipif(
_imagebind is None,
reason="skip if imagebind not installed.",
)
@pytest.mark.slow
def test_imagebind(tmp_path):
import os
import shutil
import tempfile
import pandas as pd
import requests
import lancedb.embeddings.imagebind
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
with tempfile.TemporaryDirectory() as temp_dir:
print(f"Created temporary directory {temp_dir}")
def download_images(image_uris):
downloaded_image_paths = []
for uri in image_uris:
try:
response = requests.get(uri, stream=True)
if response.status_code == 200:
# Extract image name from URI
image_name = os.path.basename(uri)
image_path = os.path.join(temp_dir, image_name)
with open(image_path, "wb") as out_file:
shutil.copyfileobj(response.raw, out_file)
downloaded_image_paths.append(image_path)
except Exception as e: # noqa: PERF203
print(f"Failed to download {uri}. Error: {e}")
return temp_dir, downloaded_image_paths
db = lancedb.connect(tmp_path)
registry = get_registry()
func = registry.get("imagebind").create(max_retries=0)
class Images(LanceModel):
label: str
image_uri: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
table = db.create_table("images", schema=Images)
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
temp_dir, downloaded_images = download_images(uris)
table.add(pd.DataFrame({"label": labels, "image_uri": downloaded_images}))
# text search
actual = (
table.search("man's best friend", vector_column_name="vector")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == "dog"
# image search
query_image_uri = [
"https://live.staticflickr.com/65535/33336453970_491665f66e_h.jpg"
]
temp_dir, downloaded_images = download_images(query_image_uri)
query_image_uri = downloaded_images[0]
actual = (
table.search(query_image_uri, vector_column_name="vector")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == "dog"
if os.path.isdir(temp_dir):
shutil.rmtree(temp_dir)
print(f"Deleted temporary directory {temp_dir}")
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
@@ -217,13 +317,6 @@ def test_gemini_embedding(tmp_path):
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
try:
if importlib.util.find_spec("mlx.core") is not None:
_mlx = True
except ImportError:
_mlx = None
@pytest.mark.skipif(
_mlx is None,
reason="mlx tests only required for apple users.",

View File

@@ -22,9 +22,9 @@ use object_store::CredentialProvider;
use once_cell::sync::OnceCell;
use tokio::runtime::Runtime;
use vectordb::connection::Database;
use vectordb::connect;
use vectordb::connection::Connection;
use vectordb::table::ReadParams;
use vectordb::{ConnectOptions, Connection};
use crate::error::ResultExt;
use crate::query::JsQuery;
@@ -39,7 +39,7 @@ mod query;
mod table;
struct JsDatabase {
database: Arc<dyn Connection + 'static>,
database: Connection,
}
impl Finalize for JsDatabase {}
@@ -89,23 +89,23 @@ fn database_new(mut cx: FunctionContext) -> JsResult<JsPromise> {
let channel = cx.channel();
let (deferred, promise) = cx.promise();
let mut conn_options = ConnectOptions::new(&path);
let mut conn_builder = connect(&path);
if let Some(region) = region {
conn_options = conn_options.region(&region);
conn_builder = conn_builder.region(&region);
}
if let Some(aws_creds) = aws_creds {
conn_options = conn_options.aws_creds(AwsCredential {
conn_builder = conn_builder.aws_creds(AwsCredential {
key_id: aws_creds.key_id,
secret_key: aws_creds.secret_key,
token: aws_creds.token,
});
}
rt.spawn(async move {
let database = Database::connect_with_options(&conn_options).await;
let database = conn_builder.execute().await;
deferred.settle_with(&channel, move |mut cx| {
let db = JsDatabase {
database: Arc::new(database.or_throw(&mut cx)?),
database: database.or_throw(&mut cx)?,
};
Ok(cx.boxed(db))
});
@@ -217,7 +217,11 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
let (deferred, promise) = cx.promise();
rt.spawn(async move {
let table_rst = database.open_table_with_params(&table_name, params).await;
let table_rst = database
.open_table(&table_name)
.lance_read_params(params)
.execute()
.await;
deferred.settle_with(&channel, move |mut cx| {
let js_table = JsTable::from(table_rst.or_throw(&mut cx)?);

View File

@@ -18,7 +18,7 @@ use arrow_array::{RecordBatch, RecordBatchIterator};
use lance::dataset::optimize::CompactionOptions;
use lance::dataset::{WriteMode, WriteParams};
use lance::io::ObjectStoreParams;
use vectordb::table::OptimizeAction;
use vectordb::table::{AddDataOptions, OptimizeAction, WriteOptions};
use crate::arrow::{arrow_buffer_to_record_batch, record_batch_to_buffer};
use neon::prelude::*;
@@ -80,7 +80,11 @@ impl JsTable {
rt.spawn(async move {
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
let table_rst = database
.create_table(&table_name, Box::new(batch_reader), Some(params))
.create_table(&table_name, Box::new(batch_reader))
.write_options(WriteOptions {
lance_write_params: Some(params),
})
.execute()
.await;
deferred.settle_with(&channel, move |mut cx| {
@@ -121,7 +125,13 @@ impl JsTable {
rt.spawn(async move {
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
let add_result = table.add(Box::new(batch_reader), Some(params)).await;
let opts = AddDataOptions {
write_options: WriteOptions {
lance_write_params: Some(params),
},
..Default::default()
};
let add_result = table.add(Box::new(batch_reader), opts).await;
deferred.settle_with(&channel, move |mut cx| {
add_result.or_throw(&mut cx)?;

View File

@@ -19,7 +19,8 @@ use arrow_array::{FixedSizeListArray, Int32Array, RecordBatch, RecordBatchIterat
use arrow_schema::{DataType, Field, Schema};
use futures::TryStreamExt;
use vectordb::Connection;
use vectordb::connection::Connection;
use vectordb::table::AddDataOptions;
use vectordb::{connect, Result, Table, TableRef};
#[tokio::main]
@@ -29,18 +30,18 @@ async fn main() -> Result<()> {
}
// --8<-- [start:connect]
let uri = "data/sample-lancedb";
let db = connect(uri).await?;
let db = connect(uri).execute().await?;
// --8<-- [end:connect]
// --8<-- [start:list_names]
println!("{:?}", db.table_names().await?);
// --8<-- [end:list_names]
let tbl = create_table(db.clone()).await?;
let tbl = create_table(&db).await?;
create_index(tbl.as_ref()).await?;
let batches = search(tbl.as_ref()).await?;
println!("{:?}", batches);
create_empty_table(db.clone()).await.unwrap();
create_empty_table(&db).await.unwrap();
// --8<-- [start:delete]
tbl.delete("id > 24").await.unwrap();
@@ -55,17 +56,14 @@ async fn main() -> Result<()> {
#[allow(dead_code)]
async fn open_with_existing_tbl() -> Result<()> {
let uri = "data/sample-lancedb";
let db = connect(uri).await?;
let db = connect(uri).execute().await?;
// --8<-- [start:open_with_existing_file]
let _ = db
.open_table_with_params("my_table", Default::default())
.await
.unwrap();
let _ = db.open_table("my_table").execute().await.unwrap();
// --8<-- [end:open_with_existing_file]
Ok(())
}
async fn create_table(db: Arc<dyn Connection>) -> Result<TableRef> {
async fn create_table(db: &Connection) -> Result<TableRef> {
// --8<-- [start:create_table]
const TOTAL: usize = 1000;
const DIM: usize = 128;
@@ -102,7 +100,8 @@ async fn create_table(db: Arc<dyn Connection>) -> Result<TableRef> {
schema.clone(),
);
let tbl = db
.create_table("my_table", Box::new(batches), None)
.create_table("my_table", Box::new(batches))
.execute()
.await
.unwrap();
// --8<-- [end:create_table]
@@ -126,21 +125,21 @@ async fn create_table(db: Arc<dyn Connection>) -> Result<TableRef> {
schema.clone(),
);
// --8<-- [start:add]
tbl.add(Box::new(new_batches), None).await.unwrap();
tbl.add(Box::new(new_batches), AddDataOptions::default())
.await
.unwrap();
// --8<-- [end:add]
Ok(tbl)
}
async fn create_empty_table(db: Arc<dyn Connection>) -> Result<TableRef> {
async fn create_empty_table(db: &Connection) -> Result<TableRef> {
// --8<-- [start:create_empty_table]
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("item", DataType::Utf8, true),
]));
let batches = RecordBatchIterator::new(vec![], schema.clone());
db.create_table("empty_table", Box::new(batches), None)
.await
db.create_empty_table("empty_table", schema).execute().await
// --8<-- [end:create_empty_table]
}

View File

@@ -13,14 +13,14 @@
// limitations under the License.
//! LanceDB Database
//!
use std::fs::create_dir_all;
use std::path::Path;
use std::sync::Arc;
use arrow_array::RecordBatchReader;
use lance::dataset::WriteParams;
use arrow_array::{RecordBatchIterator, RecordBatchReader};
use arrow_schema::SchemaRef;
use lance::dataset::{ReadParams, WriteMode};
use lance::io::{ObjectStore, ObjectStoreParams, WrappingObjectStore};
use object_store::{
aws::AwsCredential, local::LocalFileSystem, CredentialProvider, StaticCredentialProvider,
@@ -29,73 +29,283 @@ use snafu::prelude::*;
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
use crate::io::object_store::MirroringObjectStoreWrapper;
use crate::table::{NativeTable, ReadParams, TableRef};
use crate::table::{NativeTable, TableRef, WriteOptions};
pub const LANCE_FILE_EXTENSION: &str = "lance";
/// A connection to LanceDB
#[async_trait::async_trait]
pub trait Connection: Send + Sync {
/// Get the names of all tables in the database.
async fn table_names(&self) -> Result<Vec<String>>;
pub type TableBuilderCallback = Box<dyn FnOnce(OpenTableBuilder) -> OpenTableBuilder + Send>;
/// Create a new table in the database.
/// Describes what happens when creating a table and a table with
/// the same name already exists
pub enum CreateTableMode {
/// If the table already exists, an error is returned
Create,
/// If the table already exists, it is opened. Any provided data is
/// ignored. The function will be passed an OpenTableBuilder to customize
/// how the table is opened
ExistOk(TableBuilderCallback),
/// If the table already exists, it is overwritten
Overwrite,
}
impl CreateTableMode {
pub fn exist_ok(
callback: impl FnOnce(OpenTableBuilder) -> OpenTableBuilder + Send + 'static,
) -> Self {
Self::ExistOk(Box::new(callback))
}
}
impl Default for CreateTableMode {
fn default() -> Self {
Self::Create
}
}
/// Describes what happens when a vector either contains NaN or
/// does not have enough values
#[derive(Clone, Debug, Default)]
enum BadVectorHandling {
/// An error is returned
#[default]
Error,
#[allow(dead_code)] // https://github.com/lancedb/lancedb/issues/992
/// The offending row is droppped
Drop,
#[allow(dead_code)] // https://github.com/lancedb/lancedb/issues/992
/// The invalid/missing items are replaced by fill_value
Fill(f32),
}
/// A builder for configuring a [`Connection::create_table`] operation
pub struct CreateTableBuilder<const HAS_DATA: bool> {
parent: Arc<dyn ConnectionInternal>,
name: String,
data: Option<Box<dyn RecordBatchReader + Send>>,
schema: Option<SchemaRef>,
mode: CreateTableMode,
write_options: WriteOptions,
}
// Builder methods that only apply when we have initial data
impl CreateTableBuilder<true> {
fn new(
parent: Arc<dyn ConnectionInternal>,
name: String,
data: Box<dyn RecordBatchReader + Send>,
) -> Self {
Self {
parent,
name,
data: Some(data),
schema: None,
mode: CreateTableMode::default(),
write_options: WriteOptions::default(),
}
}
/// Apply the given write options when writing the initial data
pub fn write_options(mut self, write_options: WriteOptions) -> Self {
self.write_options = write_options;
self
}
/// Execute the create table operation
pub async fn execute(self) -> Result<TableRef> {
self.parent.clone().do_create_table(self).await
}
}
// Builder methods that only apply when we do not have initial data
impl CreateTableBuilder<false> {
fn new(parent: Arc<dyn ConnectionInternal>, name: String, schema: SchemaRef) -> Self {
Self {
parent,
name,
data: None,
schema: Some(schema),
mode: CreateTableMode::default(),
write_options: WriteOptions::default(),
}
}
/// Execute the create table operation
pub async fn execute(self) -> Result<TableRef> {
self.parent.clone().do_create_empty_table(self).await
}
}
impl<const HAS_DATA: bool> CreateTableBuilder<HAS_DATA> {
/// Set the mode for creating the table
///
/// This controls what happens if a table with the given name already exists
pub fn mode(mut self, mode: CreateTableMode) -> Self {
self.mode = mode;
self
}
}
#[derive(Clone, Debug)]
pub struct OpenTableBuilder {
parent: Arc<dyn ConnectionInternal>,
name: String,
index_cache_size: u32,
lance_read_params: Option<ReadParams>,
}
impl OpenTableBuilder {
fn new(parent: Arc<dyn ConnectionInternal>, name: String) -> Self {
Self {
parent,
name,
index_cache_size: 256,
lance_read_params: None,
}
}
/// Set the size of the index cache, specified as a number of entries
///
/// The default value is 256
///
/// The exact meaning of an "entry" will depend on the type of index:
/// * IVF - there is one entry for each IVF partition
/// * BTREE - there is one entry for the entire index
///
/// This cache applies to the entire opened table, across all indices.
/// Setting this value higher will increase performance on larger datasets
/// at the expense of more RAM
pub fn index_cache_size(mut self, index_cache_size: u32) -> Self {
self.index_cache_size = index_cache_size;
self
}
/// Advanced parameters that can be used to customize table reads
///
/// If set, these will take precedence over any overlapping `OpenTableOptions` options
pub fn lance_read_params(mut self, params: ReadParams) -> Self {
self.lance_read_params = Some(params);
self
}
/// Open the table
pub async fn execute(self) -> Result<TableRef> {
self.parent.clone().do_open_table(self).await
}
}
#[async_trait::async_trait]
trait ConnectionInternal: Send + Sync + std::fmt::Debug + 'static {
async fn table_names(&self) -> Result<Vec<String>>;
async fn do_create_table(&self, options: CreateTableBuilder<true>) -> Result<TableRef>;
async fn do_open_table(&self, options: OpenTableBuilder) -> Result<TableRef>;
async fn drop_table(&self, name: &str) -> Result<()>;
async fn drop_db(&self) -> Result<()>;
async fn do_create_empty_table(&self, options: CreateTableBuilder<false>) -> Result<TableRef> {
let batches = RecordBatchIterator::new(vec![], options.schema.unwrap());
let opts = CreateTableBuilder::<true>::new(options.parent, options.name, Box::new(batches))
.mode(options.mode)
.write_options(options.write_options);
self.do_create_table(opts).await
}
}
/// A connection to LanceDB
#[derive(Clone)]
pub struct Connection {
uri: String,
internal: Arc<dyn ConnectionInternal>,
}
impl Connection {
/// Get the URI of the connection
pub fn uri(&self) -> &str {
self.uri.as_str()
}
/// Get the names of all tables in the database.
pub async fn table_names(&self) -> Result<Vec<String>> {
self.internal.table_names().await
}
/// Create a new table from data
///
/// # Parameters
///
/// * `name` - The name of the table.
/// * `batches` - The initial data to write to the table.
/// * `params` - Optional [`WriteParams`] to create the table.
///
/// # Returns
/// Created [`TableRef`], or [`Err(Error::TableAlreadyExists)`] if the table already exists.
async fn create_table(
/// * `name` - The name of the table
/// * `initial_data` - The initial data to write to the table
pub fn create_table(
&self,
name: &str,
batches: Box<dyn RecordBatchReader + Send>,
params: Option<WriteParams>,
) -> Result<TableRef>;
async fn open_table(&self, name: &str) -> Result<TableRef> {
self.open_table_with_params(name, ReadParams::default())
.await
name: impl Into<String>,
initial_data: Box<dyn RecordBatchReader + Send>,
) -> CreateTableBuilder<true> {
CreateTableBuilder::<true>::new(self.internal.clone(), name.into(), initial_data)
}
async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result<TableRef>;
/// Create an empty table with a given schema
///
/// # Parameters
///
/// * `name` - The name of the table
/// * `schema` - The schema of the table
pub fn create_empty_table(
&self,
name: impl Into<String>,
schema: SchemaRef,
) -> CreateTableBuilder<false> {
CreateTableBuilder::<false>::new(self.internal.clone(), name.into(), schema)
}
/// Open an existing table in the database
///
/// # Arguments
/// * `name` - The name of the table
///
/// # Returns
/// Created [`TableRef`], or [`Error::TableNotFound`] if the table does not exist.
pub fn open_table(&self, name: impl Into<String>) -> OpenTableBuilder {
OpenTableBuilder::new(self.internal.clone(), name.into())
}
/// Drop a table in the database.
///
/// # Arguments
/// * `name` - The name of the table.
async fn drop_table(&self, name: &str) -> Result<()>;
/// * `name` - The name of the table to drop
pub async fn drop_table(&self, name: impl AsRef<str>) -> Result<()> {
self.internal.drop_table(name.as_ref()).await
}
/// Drop the database
///
/// This is the same as dropping all of the tables
pub async fn drop_db(&self) -> Result<()> {
self.internal.drop_db().await
}
}
#[derive(Debug)]
pub struct ConnectOptions {
pub struct ConnectBuilder {
/// Database URI
///
/// # Accpeted URI formats
/// ### Accpeted URI formats
///
/// - `/path/to/database` - local database on file system.
/// - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud object store
/// - `db://dbname` - Lance Cloud
pub uri: String,
/// - `db://dbname` - LanceDB Cloud
uri: String,
/// Lance Cloud API key
pub api_key: Option<String>,
/// Lance Cloud region
pub region: Option<String>,
/// Lance Cloud host override
pub host_override: Option<String>,
/// LanceDB Cloud API key, required if using Lance Cloud
api_key: Option<String>,
/// LanceDB Cloud region, required if using Lance Cloud
region: Option<String>,
/// LanceDB Cloud host override, only required if using an on-premises Lance Cloud instance
host_override: Option<String>,
/// User provided AWS credentials
pub aws_creds: Option<AwsCredential>,
/// The maximum number of indices to cache in memory. Defaults to 256.
pub index_cache_size: u32,
aws_creds: Option<AwsCredential>,
}
impl ConnectOptions {
impl ConnectBuilder {
/// Create a new [`ConnectOptions`] with the given database URI.
pub fn new(uri: &str) -> Self {
Self {
@@ -104,7 +314,6 @@ impl ConnectOptions {
region: None,
host_override: None,
aws_creds: None,
index_cache_size: 256,
}
}
@@ -124,15 +333,18 @@ impl ConnectOptions {
}
/// [`AwsCredential`] to use when connecting to S3.
///
pub fn aws_creds(mut self, aws_creds: AwsCredential) -> Self {
self.aws_creds = Some(aws_creds);
self
}
pub fn index_cache_size(mut self, index_cache_size: u32) -> Self {
self.index_cache_size = index_cache_size;
self
/// Establishes a connection to the database
pub async fn execute(self) -> Result<Connection> {
let internal = Arc::new(Database::connect_with_options(&self).await?);
Ok(Connection {
internal,
uri: self.uri,
})
}
}
@@ -140,29 +352,14 @@ impl ConnectOptions {
///
/// # Arguments
///
/// - `uri` - URI where the database is located, can be a local file or a supported remote cloud storage
///
/// ## Accepted URI formats
///
/// - `/path/to/database` - local database on file system.
/// - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud object store
/// - `db://dbname` - Lance Cloud
///
pub async fn connect(uri: &str) -> Result<Arc<dyn Connection>> {
let options = ConnectOptions::new(uri);
connect_with_options(&options).await
/// * `uri` - URI where the database is located, can be a local directory, supported remote cloud storage,
/// or a LanceDB Cloud database. See [ConnectOptions::uri] for a list of accepted formats
pub fn connect(uri: &str) -> ConnectBuilder {
ConnectBuilder::new(uri)
}
/// Connect with [`ConnectOptions`].
///
/// # Arguments
/// - `options` - [`ConnectOptions`] to connect to the database.
pub async fn connect_with_options(options: &ConnectOptions) -> Result<Arc<dyn Connection>> {
let db = Database::connect(&options.uri).await?;
Ok(Arc::new(db))
}
pub struct Database {
#[derive(Debug)]
struct Database {
object_store: ObjectStore,
query_string: Option<String>,
@@ -179,21 +376,7 @@ const MIRRORED_STORE: &str = "mirroredStore";
/// A connection to LanceDB
impl Database {
/// Connects to LanceDB
///
/// # Arguments
///
/// * `uri` - URI where the database is located, can be a local file or a supported remote cloud storage
///
/// # Returns
///
/// * A [Database] object.
pub async fn connect(uri: &str) -> Result<Self> {
let options = ConnectOptions::new(uri);
Self::connect_with_options(&options).await
}
pub async fn connect_with_options(options: &ConnectOptions) -> Result<Self> {
async fn connect_with_options(options: &ConnectBuilder) -> Result<Self> {
let uri = &options.uri;
let parse_res = url::Url::parse(uri);
@@ -333,7 +516,7 @@ impl Database {
}
#[async_trait::async_trait]
impl Connection for Database {
impl ConnectionInternal for Database {
async fn table_names(&self) -> Result<Vec<String>> {
let mut f = self
.object_store
@@ -354,40 +537,47 @@ impl Connection for Database {
Ok(f)
}
async fn create_table(
&self,
name: &str,
batches: Box<dyn RecordBatchReader + Send>,
params: Option<WriteParams>,
) -> Result<TableRef> {
let table_uri = self.table_uri(name)?;
async fn do_create_table(&self, options: CreateTableBuilder<true>) -> Result<TableRef> {
let table_uri = self.table_uri(&options.name)?;
Ok(Arc::new(
NativeTable::create(
&table_uri,
name,
batches,
self.store_wrapper.clone(),
params,
)
.await?,
))
let mut write_params = options.write_options.lance_write_params.unwrap_or_default();
if matches!(&options.mode, CreateTableMode::Overwrite) {
write_params.mode = WriteMode::Overwrite;
}
match NativeTable::create(
&table_uri,
&options.name,
options.data.unwrap(),
self.store_wrapper.clone(),
Some(write_params),
)
.await
{
Ok(table) => Ok(Arc::new(table)),
Err(Error::TableAlreadyExists { name }) => match options.mode {
CreateTableMode::Create => Err(Error::TableAlreadyExists { name }),
CreateTableMode::ExistOk(callback) => {
let builder = OpenTableBuilder::new(options.parent, options.name);
let builder = (callback)(builder);
builder.execute().await
}
CreateTableMode::Overwrite => unreachable!(),
},
Err(err) => Err(err),
}
}
/// Open a table in the database.
///
/// # Arguments
/// * `name` - The name of the table.
/// * `params` - The parameters to open the table.
///
/// # Returns
///
/// * A [TableRef] object.
async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result<TableRef> {
let table_uri = self.table_uri(name)?;
async fn do_open_table(&self, options: OpenTableBuilder) -> Result<TableRef> {
let table_uri = self.table_uri(&options.name)?;
Ok(Arc::new(
NativeTable::open_with_params(&table_uri, name, self.store_wrapper.clone(), params)
.await?,
NativeTable::open_with_params(
&table_uri,
&options.name,
self.store_wrapper.clone(),
options.lance_read_params,
)
.await?,
))
}
@@ -397,12 +587,17 @@ impl Connection for Database {
self.object_store.remove_dir_all(full_path).await?;
Ok(())
}
async fn drop_db(&self) -> Result<()> {
todo!()
}
}
#[cfg(test)]
mod tests {
use std::fs::create_dir_all;
use arrow_schema::{DataType, Field, Schema};
use tempfile::tempdir;
use super::*;
@@ -411,7 +606,7 @@ mod tests {
async fn test_connect() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = Database::connect(uri).await.unwrap();
let db = connect(uri).execute().await.unwrap();
assert_eq!(db.uri, uri);
}
@@ -429,7 +624,8 @@ mod tests {
let relative_root = std::path::PathBuf::from(relative_ancestors.join("/"));
let relative_uri = relative_root.join(&uri);
let db = Database::connect(relative_uri.to_str().unwrap())
let db = connect(relative_uri.to_str().unwrap())
.execute()
.await
.unwrap();
@@ -444,7 +640,7 @@ mod tests {
create_dir_all(tmp_dir.path().join("invalidlance")).unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = Database::connect(uri).await.unwrap();
let db = connect(uri).execute().await.unwrap();
let tables = db.table_names().await.unwrap();
assert_eq!(tables.len(), 2);
assert!(tables[0].eq(&String::from("table1")));
@@ -462,10 +658,44 @@ mod tests {
create_dir_all(tmp_dir.path().join("table1.lance")).unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = Database::connect(uri).await.unwrap();
let db = connect(uri).execute().await.unwrap();
db.drop_table("table1").await.unwrap();
let tables = db.table_names().await.unwrap();
assert_eq!(tables.len(), 0);
}
#[tokio::test]
async fn test_create_table_already_exists() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)]));
db.create_empty_table("test", schema.clone())
.execute()
.await
.unwrap();
// TODO: None of the open table options are "inspectable" right now but once one is we
// should assert we are passing these options in correctly
db.create_empty_table("test", schema)
.mode(CreateTableMode::exist_ok(|builder| {
builder.index_cache_size(16)
}))
.execute()
.await
.unwrap();
let other_schema = Arc::new(Schema::new(vec![Field::new("y", DataType::Int32, false)]));
assert!(db
.create_empty_table("test", other_schema.clone())
.execute()
.await
.is_err());
let overwritten = db
.create_empty_table("test", other_schema.clone())
.mode(CreateTableMode::Overwrite)
.execute()
.await
.unwrap();
assert_eq!(other_schema, overwritten.schema());
}
}

View File

@@ -174,7 +174,6 @@ fn coerce_schema_batch(
}
/// Coerce the reader (input data) to match the given [Schema].
///
pub fn coerce_schema(
reader: impl RecordBatchReader + Send + 'static,
schema: Arc<Schema>,

View File

@@ -342,7 +342,7 @@ mod test {
use object_store::local::LocalFileSystem;
use tempfile;
use crate::connection::{Connection, Database};
use crate::{connect, table::WriteOptions};
#[tokio::test]
async fn test_e2e() {
@@ -354,7 +354,7 @@ mod test {
secondary: Arc::new(secondary_store),
});
let db = Database::connect(dir1.to_str().unwrap()).await.unwrap();
let db = connect(dir1.to_str().unwrap()).execute().await.unwrap();
let mut param = WriteParams::default();
let store_params = ObjectStoreParams {
@@ -368,7 +368,11 @@ mod test {
datagen = datagen.col(Box::new(RandomVector::default().named("vector".into())));
let res = db
.create_table("test", Box::new(datagen.batch(100)), Some(param.clone()))
.create_table("test", Box::new(datagen.batch(100)))
.write_options(WriteOptions {
lance_write_params: Some(param),
})
.execute()
.await;
// leave this here for easy debugging

View File

@@ -43,10 +43,9 @@
//! #### Connect to a database.
//!
//! ```rust
//! use vectordb::connect;
//! # use arrow_schema::{Field, Schema};
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! let db = connect("data/sample-lancedb").await.unwrap();
//! let db = vectordb::connect("data/sample-lancedb").execute().await.unwrap();
//! # });
//! ```
//!
@@ -56,14 +55,20 @@
//! - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud object store
//! - `db://dbname` - Lance Cloud
//!
//! You can also use [`ConnectOptions`] to configure the connectoin to the database.
//! You can also use [`ConnectOptions`] to configure the connection to the database.
//!
//! ```rust
//! use vectordb::{connect_with_options, ConnectOptions};
//! use object_store::aws::AwsCredential;
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! let options = ConnectOptions::new("data/sample-lancedb")
//! .index_cache_size(1024);
//! let db = connect_with_options(&options).await.unwrap();
//! let db = vectordb::connect("data/sample-lancedb")
//! .aws_creds(AwsCredential {
//! key_id: "some_key".to_string(),
//! secret_key: "some_secret".to_string(),
//! token: None,
//! })
//! .execute()
//! .await
//! .unwrap();
//! # });
//! ```
//!
@@ -79,31 +84,44 @@
//!
//! ```rust
//! # use std::sync::Arc;
//! use arrow_schema::{DataType, Schema, Field};
//! use arrow_array::{RecordBatch, RecordBatchIterator};
//! use arrow_schema::{DataType, Field, Schema};
//! # use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, types::Float32Type};
//! # use vectordb::connection::{Database, Connection};
//! # use vectordb::connect;
//!
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! # let tmpdir = tempfile::tempdir().unwrap();
//! # let db = connect(tmpdir.path().to_str().unwrap()).await.unwrap();
//! # let db = vectordb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap();
//! let schema = Arc::new(Schema::new(vec![
//! Field::new("id", DataType::Int32, false),
//! Field::new("vector", DataType::FixedSizeList(
//! Arc::new(Field::new("item", DataType::Float32, true)), 128), true),
//! Field::new("id", DataType::Int32, false),
//! Field::new(
//! "vector",
//! DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 128),
//! true,
//! ),
//! ]));
//! // Create a RecordBatch stream.
//! let batches = RecordBatchIterator::new(vec![
//! RecordBatch::try_new(schema.clone(),
//! let batches = RecordBatchIterator::new(
//! vec![RecordBatch::try_new(
//! schema.clone(),
//! vec![
//! Arc::new(Int32Array::from_iter_values(0..1000)),
//! Arc::new(FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
//! (0..1000).map(|_| Some(vec![Some(1.0); 128])), 128)),
//! ]).unwrap()
//! ].into_iter().map(Ok),
//! schema.clone());
//! db.create_table("my_table", Box::new(batches), None).await.unwrap();
//! Arc::new(Int32Array::from_iter_values(0..256)),
//! Arc::new(
//! FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
//! (0..256).map(|_| Some(vec![Some(1.0); 128])),
//! 128,
//! ),
//! ),
//! ],
//! )
//! .unwrap()]
//! .into_iter()
//! .map(Ok),
//! schema.clone(),
//! );
//! db.create_table("my_table", Box::new(batches))
//! .execute()
//! .await
//! .unwrap();
//! # });
//! ```
//!
@@ -111,14 +129,13 @@
//!
//! ```no_run
//! # use std::sync::Arc;
//! # use vectordb::connect;
//! # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
//! # RecordBatchIterator, Int32Array};
//! # use arrow_schema::{Schema, Field, DataType};
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! # let tmpdir = tempfile::tempdir().unwrap();
//! # let db = connect(tmpdir.path().to_str().unwrap()).await.unwrap();
//! # let tbl = db.open_table("idx_test").await.unwrap();
//! # let db = vectordb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap();
//! # let tbl = db.open_table("idx_test").execute().await.unwrap();
//! tbl.create_index(&["vector"])
//! .ivf_pq()
//! .num_partitions(256)
@@ -136,10 +153,9 @@
//! # use arrow_schema::{DataType, Schema, Field};
//! # use arrow_array::{RecordBatch, RecordBatchIterator};
//! # use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, types::Float32Type};
//! # use vectordb::connection::{Database, Connection};
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! # let tmpdir = tempfile::tempdir().unwrap();
//! # let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
//! # let db = vectordb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap();
//! # let schema = Arc::new(Schema::new(vec![
//! # Field::new("id", DataType::Int32, false),
//! # Field::new("vector", DataType::FixedSizeList(
@@ -154,8 +170,8 @@
//! # ]).unwrap()
//! # ].into_iter().map(Ok),
//! # schema.clone());
//! # db.create_table("my_table", Box::new(batches), None).await.unwrap();
//! # let table = db.open_table("my_table").await.unwrap();
//! # db.create_table("my_table", Box::new(batches)).execute().await.unwrap();
//! # let table = db.open_table("my_table").execute().await.unwrap();
//! let results = table
//! .search(&[1.0; 128])
//! .execute_stream()
@@ -165,8 +181,6 @@
//! .await
//! .unwrap();
//! # });
//!
//!
//! ```
pub mod connection;
@@ -179,10 +193,8 @@ pub mod query;
pub mod table;
pub mod utils;
pub use connection::{Connection, Database};
pub use error::{Error, Result};
pub use table::{Table, TableRef};
/// Connect to a database
pub use connection::{connect, connect_with_options, ConnectOptions};
pub use lance::dataset::WriteMode;
pub use connection::connect;

View File

@@ -60,7 +60,6 @@ impl Query {
/// # Arguments
///
/// * `dataset` - Lance dataset.
///
pub(crate) fn new(dataset: Arc<Dataset>) -> Self {
Self {
dataset,

View File

@@ -17,7 +17,7 @@
use std::path::Path;
use std::sync::{Arc, Mutex};
use arrow_array::RecordBatchReader;
use arrow_array::{RecordBatchIterator, RecordBatchReader};
use arrow_schema::{Schema, SchemaRef};
use async_trait::async_trait;
use chrono::Duration;
@@ -27,7 +27,7 @@ use lance::dataset::optimize::{
compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions,
};
pub use lance::dataset::ReadParams;
use lance::dataset::{Dataset, UpdateBuilder, WhenMatched, WriteParams};
use lance::dataset::{Dataset, UpdateBuilder, WhenMatched, WriteMode, WriteParams};
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
use lance::io::WrappingObjectStore;
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
@@ -38,7 +38,6 @@ use crate::index::vector::{VectorIndex, VectorIndexStatistics};
use crate::index::IndexBuilder;
use crate::query::Query;
use crate::utils::{PatchReadParam, PatchWriteParam};
use crate::WriteMode;
use self::merge::{MergeInsert, MergeInsertBuilder};
@@ -85,6 +84,35 @@ pub struct OptimizeStats {
pub prune: Option<RemovalStats>,
}
/// Options to use when writing data
#[derive(Clone, Debug, Default)]
pub struct WriteOptions {
// Coming soon: https://github.com/lancedb/lancedb/issues/992
// /// What behavior to take if the data contains invalid vectors
// pub on_bad_vectors: BadVectorHandling,
/// Advanced parameters that can be used to customize table creation
///
/// If set, these will take precedence over any overlapping `OpenTableOptions` options
pub lance_write_params: Option<WriteParams>,
}
#[derive(Debug, Clone, Default)]
pub enum AddDataMode {
/// Rows will be appended to the table (the default)
#[default]
Append,
/// The existing table will be overwritten with the new data
Overwrite,
}
#[derive(Debug, Default, Clone)]
pub struct AddDataOptions {
/// Whether to add new rows (the default) or replace the existing data
pub mode: AddDataMode,
/// Options to use when writing the data
pub write_options: WriteOptions,
}
/// A Table is a collection of strong typed Rows.
///
/// The type of the each row is defined in Apache Arrow [Schema].
@@ -112,12 +140,12 @@ pub trait Table: std::fmt::Display + Send + Sync {
///
/// # Arguments
///
/// * `batches` RecordBatch to be saved in the Table
/// * `params` Append / Overwrite existing records. Default: Append
/// * `batches` data to be added to the Table
/// * `options` options to control how data is added
async fn add(
&self,
batches: Box<dyn RecordBatchReader + Send>,
params: Option<WriteParams>,
options: AddDataOptions,
) -> Result<()>;
/// Delete the rows from table that match the predicate.
@@ -129,28 +157,43 @@ pub trait Table: std::fmt::Display + Send + Sync {
///
/// ```no_run
/// # use std::sync::Arc;
/// # use vectordb::connection::{Database, Connection};
/// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
/// # RecordBatchIterator, Int32Array};
/// # use arrow_schema::{Schema, Field, DataType};
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let tmpdir = tempfile::tempdir().unwrap();
/// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
/// let db = vectordb::connect(tmpdir.path().to_str().unwrap())
/// .execute()
/// .await
/// .unwrap();
/// # let schema = Arc::new(Schema::new(vec![
/// # Field::new("id", DataType::Int32, false),
/// # Field::new("vector", DataType::FixedSizeList(
/// # Arc::new(Field::new("item", DataType::Float32, true)), 128), true),
/// # ]));
/// let batches = RecordBatchIterator::new(vec![
/// RecordBatch::try_new(schema.clone(),
/// vec![
/// Arc::new(Int32Array::from_iter_values(0..10)),
/// Arc::new(FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
/// (0..10).map(|_| Some(vec![Some(1.0); 128])), 128)),
/// ]).unwrap()
/// ].into_iter().map(Ok),
/// schema.clone());
/// let tbl = db.create_table("delete_test", Box::new(batches), None).await.unwrap();
/// let batches = RecordBatchIterator::new(
/// vec![RecordBatch::try_new(
/// schema.clone(),
/// vec![
/// Arc::new(Int32Array::from_iter_values(0..10)),
/// Arc::new(
/// FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
/// (0..10).map(|_| Some(vec![Some(1.0); 128])),
/// 128,
/// ),
/// ),
/// ],
/// )
/// .unwrap()]
/// .into_iter()
/// .map(Ok),
/// schema.clone(),
/// );
/// let tbl = db
/// .create_table("delete_test", Box::new(batches))
/// .execute()
/// .await
/// .unwrap();
/// tbl.delete("id > 5").await.unwrap();
/// # });
/// ```
@@ -162,14 +205,16 @@ pub trait Table: std::fmt::Display + Send + Sync {
///
/// ```no_run
/// # use std::sync::Arc;
/// # use vectordb::connection::{Database, Connection};
/// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
/// # RecordBatchIterator, Int32Array};
/// # use arrow_schema::{Schema, Field, DataType};
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let tmpdir = tempfile::tempdir().unwrap();
/// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
/// # let tbl = db.open_table("idx_test").await.unwrap();
/// let db = vectordb::connect(tmpdir.path().to_str().unwrap())
/// .execute()
/// .await
/// .unwrap();
/// # let tbl = db.open_table("idx_test").execute().await.unwrap();
/// tbl.create_index(&["vector"])
/// .ivf_pq()
/// .num_partitions(256)
@@ -214,32 +259,44 @@ pub trait Table: std::fmt::Display + Send + Sync {
///
/// ```no_run
/// # use std::sync::Arc;
/// # use vectordb::connection::{Database, Connection};
/// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
/// # RecordBatchIterator, Int32Array};
/// # use arrow_schema::{Schema, Field, DataType};
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let tmpdir = tempfile::tempdir().unwrap();
/// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
/// # let tbl = db.open_table("idx_test").await.unwrap();
/// let db = vectordb::connect(tmpdir.path().to_str().unwrap())
/// .execute()
/// .await
/// .unwrap();
/// # let tbl = db.open_table("idx_test").execute().await.unwrap();
/// # let schema = Arc::new(Schema::new(vec![
/// # Field::new("id", DataType::Int32, false),
/// # Field::new("vector", DataType::FixedSizeList(
/// # Arc::new(Field::new("item", DataType::Float32, true)), 128), true),
/// # ]));
/// let new_data = RecordBatchIterator::new(vec![
/// RecordBatch::try_new(schema.clone(),
/// vec![
/// Arc::new(Int32Array::from_iter_values(0..10)),
/// Arc::new(FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
/// (0..10).map(|_| Some(vec![Some(1.0); 128])), 128)),
/// ]).unwrap()
/// ].into_iter().map(Ok),
/// schema.clone());
/// let new_data = RecordBatchIterator::new(
/// vec![RecordBatch::try_new(
/// schema.clone(),
/// vec![
/// Arc::new(Int32Array::from_iter_values(0..10)),
/// Arc::new(
/// FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
/// (0..10).map(|_| Some(vec![Some(1.0); 128])),
/// 128,
/// ),
/// ),
/// ],
/// )
/// .unwrap()]
/// .into_iter()
/// .map(Ok),
/// schema.clone(),
/// );
/// // Perform an upsert operation
/// let mut merge_insert = tbl.merge_insert(&["id"]);
/// merge_insert.when_matched_update_all(None)
/// .when_not_matched_insert_all();
/// merge_insert
/// .when_matched_update_all(None)
/// .when_not_matched_insert_all();
/// merge_insert.execute(Box::new(new_data)).await.unwrap();
/// # });
/// ```
@@ -266,7 +323,9 @@ pub trait Table: std::fmt::Display + Send + Sync {
/// # use futures::TryStreamExt;
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// # let tbl = vectordb::table::NativeTable::open("/tmp/tbl").await.unwrap();
/// let stream = tbl.query().nearest_to(&[1.0, 2.0, 3.0])
/// let stream = tbl
/// .query()
/// .nearest_to(&[1.0, 2.0, 3.0])
/// .refine_factor(5)
/// .nprobes(10)
/// .execute_stream()
@@ -299,11 +358,7 @@ pub trait Table: std::fmt::Display + Send + Sync {
/// # use futures::TryStreamExt;
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// # let tbl = vectordb::table::NativeTable::open("/tmp/tbl").await.unwrap();
/// let stream = tbl
/// .query()
/// .execute_stream()
/// .await
/// .unwrap();
/// let stream = tbl.query().execute_stream().await.unwrap();
/// let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
/// # });
/// ```
@@ -351,7 +406,7 @@ impl NativeTable {
/// * A [NativeTable] object.
pub async fn open(uri: &str) -> Result<Self> {
let name = Self::get_table_name(uri)?;
Self::open_with_params(uri, &name, None, ReadParams::default()).await
Self::open_with_params(uri, &name, None, None).await
}
/// Opens an existing Table
@@ -369,8 +424,9 @@ impl NativeTable {
uri: &str,
name: &str,
write_store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
params: ReadParams,
params: Option<ReadParams>,
) -> Result<Self> {
let params = params.unwrap_or_default();
// patch the params if we have a write store wrapper
let params = match write_store_wrapper.clone() {
Some(wrapper) => params.patch_with_store_wrapper(wrapper)?,
@@ -403,7 +459,6 @@ impl NativeTable {
}
/// Checkout a specific version of this [NativeTable]
///
pub async fn checkout(uri: &str, version: u64) -> Result<Self> {
let name = Self::get_table_name(uri)?;
Self::checkout_with_params(uri, &name, version, None, ReadParams::default()).await
@@ -489,13 +544,14 @@ impl NativeTable {
write_store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
params: Option<WriteParams>,
) -> Result<Self> {
let params = params.unwrap_or_default();
// patch the params if we have a write store wrapper
let params = match write_store_wrapper.clone() {
Some(wrapper) => params.patch_with_store_wrapper(wrapper)?,
None => params,
};
let dataset = Dataset::write(batches, uri, params)
let dataset = Dataset::write(batches, uri, Some(params))
.await
.map_err(|e| match e {
lance::Error::DatasetAlreadyExists { .. } => Error::TableAlreadyExists {
@@ -513,6 +569,17 @@ impl NativeTable {
})
}
pub async fn create_empty(
uri: &str,
name: &str,
schema: SchemaRef,
write_store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
params: Option<WriteParams>,
) -> Result<Self> {
let batches = RecordBatchIterator::new(vec![], schema);
Self::create(uri, name, batches, write_store_wrapper, params).await
}
/// Version of this Table
pub fn version(&self) -> u64 {
self.dataset.lock().expect("lock poison").version().version
@@ -740,20 +807,26 @@ impl Table for NativeTable {
async fn add(
&self,
batches: Box<dyn RecordBatchReader + Send>,
params: Option<WriteParams>,
params: AddDataOptions,
) -> Result<()> {
let params = Some(params.unwrap_or(WriteParams {
mode: WriteMode::Append,
..WriteParams::default()
}));
let lance_params = params
.write_options
.lance_write_params
.unwrap_or(WriteParams {
mode: match params.mode {
AddDataMode::Append => WriteMode::Append,
AddDataMode::Overwrite => WriteMode::Overwrite,
},
..Default::default()
});
// patch the params if we have a write store wrapper
let params = match self.store_wrapper.clone() {
Some(wrapper) => params.patch_with_store_wrapper(wrapper)?,
None => params,
let lance_params = match self.store_wrapper.clone() {
Some(wrapper) => lance_params.patch_with_store_wrapper(wrapper)?,
None => lance_params,
};
self.reset_dataset(Dataset::write(batches, &self.uri, params).await?);
self.reset_dataset(Dataset::write(batches, &self.uri, Some(lance_params)).await?);
Ok(())
}
@@ -881,25 +954,6 @@ mod tests {
assert_eq!(c.to_str().unwrap(), "s3://bucket/path/to/file/subfile");
}
#[tokio::test]
async fn test_create_already_exists() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let batches = make_test_batches();
let _ = batches.schema().clone();
NativeTable::create(uri, "test", batches, None, None)
.await
.unwrap();
let batches = make_test_batches();
let result = NativeTable::create(uri, "test", batches, None, None).await;
assert!(matches!(
result.unwrap_err(),
Error::TableAlreadyExists { .. }
));
}
#[tokio::test]
async fn test_count_rows() {
let tmp_dir = tempdir().unwrap();
@@ -940,7 +994,10 @@ mod tests {
schema.clone(),
);
table.add(Box::new(new_batches), None).await.unwrap();
table
.add(Box::new(new_batches), AddDataOptions::default())
.await
.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 20);
assert_eq!(table.name, "test");
}
@@ -1003,23 +1060,47 @@ mod tests {
.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 10);
let new_batches = RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(100..110))],
)
.unwrap()]
.into_iter()
.map(Ok),
let batches = vec![RecordBatch::try_new(
schema.clone(),
);
vec![Arc::new(Int32Array::from_iter_values(100..110))],
)
.unwrap()]
.into_iter()
.map(Ok);
let new_batches = RecordBatchIterator::new(batches.clone(), schema.clone());
// Can overwrite using AddDataOptions::mode
table
.add(
Box::new(new_batches),
AddDataOptions {
mode: AddDataMode::Overwrite,
..Default::default()
},
)
.await
.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 10);
assert_eq!(table.name, "test");
// Can overwrite using underlying WriteParams (which
// take precedence over AddDataOptions::mode)
let param: WriteParams = WriteParams {
mode: WriteMode::Overwrite,
..Default::default()
};
table.add(Box::new(new_batches), Some(param)).await.unwrap();
let opts = AddDataOptions {
write_options: WriteOptions {
lance_write_params: Some(param),
},
mode: AddDataMode::Append,
};
let new_batches = RecordBatchIterator::new(batches.clone(), schema.clone());
table.add(Box::new(new_batches), opts).await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 10);
assert_eq!(table.name, "test");
}
@@ -1329,7 +1410,7 @@ mod tests {
..Default::default()
};
assert!(!wrapper.called());
let _ = NativeTable::open_with_params(uri, "test", None, param)
let _ = NativeTable::open_with_params(uri, "test", None, Some(param))
.await
.unwrap();
assert!(wrapper.called());

View File

@@ -32,20 +32,17 @@ impl PatchStoreParam for Option<ObjectStoreParams> {
}
pub trait PatchWriteParam {
fn patch_with_store_wrapper(
self,
wrapper: Arc<dyn WrappingObjectStore>,
) -> Result<Option<WriteParams>>;
fn patch_with_store_wrapper(self, wrapper: Arc<dyn WrappingObjectStore>)
-> Result<WriteParams>;
}
impl PatchWriteParam for Option<WriteParams> {
impl PatchWriteParam for WriteParams {
fn patch_with_store_wrapper(
self,
mut self,
wrapper: Arc<dyn WrappingObjectStore>,
) -> Result<Option<WriteParams>> {
let mut params = self.unwrap_or_default();
params.store_params = params.store_params.patch_with_store_wrapper(wrapper)?;
Ok(Some(params))
) -> Result<WriteParams> {
self.store_params = self.store_params.patch_with_store_wrapper(wrapper)?;
Ok(self)
}
}