mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
Compare commits
5 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
677b7c1fcc | ||
|
|
8303a7197b | ||
|
|
5fa9bfc4a8 | ||
|
|
bf2e9d0088 | ||
|
|
f04590ddad |
@@ -57,6 +57,16 @@ plugins:
|
||||
- https://arrow.apache.org/docs/objects.inv
|
||||
- https://pandas.pydata.org/docs/objects.inv
|
||||
- mkdocs-jupyter
|
||||
- ultralytics:
|
||||
verbose: True
|
||||
enabled: True
|
||||
default_image: "assets/lancedb_and_lance.png" # Default image for all pages
|
||||
add_image: True # Automatically add meta image
|
||||
add_keywords: True # Add page keywords in the header tag
|
||||
add_share_buttons: True # Add social share buttons
|
||||
add_authors: False # Display page authors
|
||||
add_desc: False
|
||||
add_dates: False
|
||||
|
||||
markdown_extensions:
|
||||
- admonition
|
||||
@@ -206,7 +216,6 @@ extra_css:
|
||||
|
||||
extra_javascript:
|
||||
- "extra_js/init_ask_ai_widget.js"
|
||||
- "extra_js/meta_tag.js"
|
||||
|
||||
extra:
|
||||
analytics:
|
||||
|
||||
@@ -2,4 +2,5 @@ mkdocs==1.5.3
|
||||
mkdocs-jupyter==0.24.1
|
||||
mkdocs-material==9.5.3
|
||||
mkdocstrings[python]==0.20.0
|
||||
pydantic
|
||||
pydantic
|
||||
mkdocs-ultralytics-plugin==0.0.44
|
||||
@@ -1,6 +0,0 @@
|
||||
window.addEventListener('load', function() {
|
||||
var meta = document.createElement('meta');
|
||||
meta.setAttribute('property', 'og:image');
|
||||
meta.setAttribute('content', '/assets/lancedb_and_lance.png');
|
||||
document.head.appendChild(meta);
|
||||
});
|
||||
@@ -12,18 +12,16 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use napi::bindgen_prelude::*;
|
||||
use napi_derive::*;
|
||||
|
||||
use crate::table::Table;
|
||||
use vectordb::connection::{Connection as LanceDBConnection, Database};
|
||||
use vectordb::connection::Connection as LanceDBConnection;
|
||||
use vectordb::ipc::ipc_file_to_batches;
|
||||
|
||||
#[napi]
|
||||
pub struct Connection {
|
||||
conn: Arc<dyn LanceDBConnection>,
|
||||
conn: LanceDBConnection,
|
||||
}
|
||||
|
||||
#[napi]
|
||||
@@ -32,9 +30,9 @@ impl Connection {
|
||||
#[napi(factory)]
|
||||
pub async fn new(uri: String) -> napi::Result<Self> {
|
||||
Ok(Self {
|
||||
conn: Arc::new(Database::connect(&uri).await.map_err(|e| {
|
||||
conn: vectordb::connect(&uri).execute().await.map_err(|e| {
|
||||
napi::Error::from_reason(format!("Failed to connect to database: {}", e))
|
||||
})?),
|
||||
})?,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -59,7 +57,8 @@ impl Connection {
|
||||
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
|
||||
let tbl = self
|
||||
.conn
|
||||
.create_table(&name, Box::new(batches), None)
|
||||
.create_table(&name, Box::new(batches))
|
||||
.execute()
|
||||
.await
|
||||
.map_err(|e| napi::Error::from_reason(format!("{}", e)))?;
|
||||
Ok(Table::new(tbl))
|
||||
@@ -70,6 +69,7 @@ impl Connection {
|
||||
let tbl = self
|
||||
.conn
|
||||
.open_table(&name)
|
||||
.execute()
|
||||
.await
|
||||
.map_err(|e| napi::Error::from_reason(format!("{}", e)))?;
|
||||
Ok(Table::new(tbl))
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
use arrow_ipc::writer::FileWriter;
|
||||
use napi::bindgen_prelude::*;
|
||||
use napi_derive::napi;
|
||||
use vectordb::table::AddDataOptions;
|
||||
use vectordb::{ipc::ipc_file_to_batches, table::TableRef};
|
||||
|
||||
use crate::index::IndexBuilder;
|
||||
@@ -48,12 +49,15 @@ impl Table {
|
||||
pub async fn add(&self, buf: Buffer) -> napi::Result<()> {
|
||||
let batches = ipc_file_to_batches(buf.to_vec())
|
||||
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
|
||||
self.table.add(Box::new(batches), None).await.map_err(|e| {
|
||||
napi::Error::from_reason(format!(
|
||||
"Failed to add batches to table {}: {}",
|
||||
self.table, e
|
||||
))
|
||||
})
|
||||
self.table
|
||||
.add(Box::new(batches), AddDataOptions::default())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
napi::Error::from_reason(format!(
|
||||
"Failed to add batches to table {}: {}",
|
||||
self.table, e
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
#[napi]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.5.6
|
||||
current_version = 0.5.7
|
||||
commit = True
|
||||
message = [python] Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
172
python/lancedb/embeddings/imagebind.py
Normal file
172
python/lancedb/embeddings/imagebind.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import cached_property
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import EmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import AUDIO, IMAGES, TEXT
|
||||
|
||||
|
||||
@register("imagebind")
|
||||
class ImageBindEmbeddings(EmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the ImageBind API
|
||||
For generating multi-modal embeddings across
|
||||
six different modalities: images, text, audio, depth, thermal, and IMU data
|
||||
|
||||
to download package, run :
|
||||
`pip install imagebind@git+https://github.com/raghavdixit99/ImageBind`
|
||||
"""
|
||||
|
||||
name: str = "imagebind_huge"
|
||||
device: str = "cpu"
|
||||
normalize: bool = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._ndims = 1024
|
||||
self._audio_extensions = (".mp3", ".wav", ".flac", ".ogg", ".aac")
|
||||
self._image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp")
|
||||
|
||||
@cached_property
|
||||
def embedding_model(self):
|
||||
"""
|
||||
Get the embedding model. This is cached so that the model is only loaded
|
||||
once per process.
|
||||
"""
|
||||
return self.get_embedding_model()
|
||||
|
||||
@cached_property
|
||||
def _data(self):
|
||||
"""
|
||||
Get the data module from imagebind
|
||||
"""
|
||||
data = attempt_import_or_raise("imagebind.data", "imagebind")
|
||||
return data
|
||||
|
||||
@cached_property
|
||||
def _ModalityType(self):
|
||||
"""
|
||||
Get the ModalityType from imagebind
|
||||
"""
|
||||
imagebind = attempt_import_or_raise("imagebind", "imagebind")
|
||||
return imagebind.imagebind_model.ModalityType
|
||||
|
||||
def ndims(self):
|
||||
return self._ndims
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: Union[str], *args, **kwargs
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : Union[str]
|
||||
The query to embed. A query can be either text, image paths or audio paths.
|
||||
"""
|
||||
query = self.sanitize_input(query)
|
||||
if query[0].endswith(self._audio_extensions):
|
||||
return [self.generate_audio_embeddings(query)]
|
||||
elif query[0].endswith(self._image_extensions):
|
||||
return [self.generate_image_embeddings(query)]
|
||||
else:
|
||||
return [self.generate_text_embeddings(query)]
|
||||
|
||||
def generate_image_embeddings(self, image: IMAGES) -> np.ndarray:
|
||||
torch = attempt_import_or_raise("torch")
|
||||
inputs = {
|
||||
self._ModalityType.VISION: self._data.load_and_transform_vision_data(
|
||||
image, self.device
|
||||
)
|
||||
}
|
||||
with torch.no_grad():
|
||||
image_features = self.embedding_model(inputs)[self._ModalityType.VISION]
|
||||
if self.normalize:
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
return image_features.cpu().numpy().squeeze()
|
||||
|
||||
def generate_audio_embeddings(self, audio: AUDIO) -> np.ndarray:
|
||||
torch = attempt_import_or_raise("torch")
|
||||
inputs = {
|
||||
self._ModalityType.AUDIO: self._data.load_and_transform_audio_data(
|
||||
audio, self.device
|
||||
)
|
||||
}
|
||||
with torch.no_grad():
|
||||
audio_features = self.embedding_model(inputs)[self._ModalityType.AUDIO]
|
||||
if self.normalize:
|
||||
audio_features /= audio_features.norm(dim=-1, keepdim=True)
|
||||
return audio_features.cpu().numpy().squeeze()
|
||||
|
||||
def generate_text_embeddings(self, text: TEXT) -> np.ndarray:
|
||||
torch = attempt_import_or_raise("torch")
|
||||
inputs = {
|
||||
self._ModalityType.TEXT: self._data.load_and_transform_text(
|
||||
text, self.device
|
||||
)
|
||||
}
|
||||
with torch.no_grad():
|
||||
text_features = self.embedding_model(inputs)[self._ModalityType.TEXT]
|
||||
if self.normalize:
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
return text_features.cpu().numpy().squeeze()
|
||||
|
||||
def compute_source_embeddings(
|
||||
self, source: Union[IMAGES, AUDIO], *args, **kwargs
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given sourcefield column in the pydantic model.
|
||||
"""
|
||||
source = self.sanitize_input(source)
|
||||
embeddings = []
|
||||
if source[0].endswith(self._audio_extensions):
|
||||
embeddings.extend(self.generate_audio_embeddings(source))
|
||||
return embeddings
|
||||
elif source[0].endswith(self._image_extensions):
|
||||
embeddings.extend(self.generate_image_embeddings(source))
|
||||
return embeddings
|
||||
else:
|
||||
embeddings.extend(self.generate_text_embeddings(source))
|
||||
return embeddings
|
||||
|
||||
def sanitize_input(
|
||||
self, input: Union[IMAGES, AUDIO]
|
||||
) -> Union[List[bytes], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(input, (str, bytes)):
|
||||
input = [input]
|
||||
elif isinstance(input, pa.Array):
|
||||
input = input.to_pylist()
|
||||
elif isinstance(input, pa.ChunkedArray):
|
||||
input = input.combine_chunks().to_pylist()
|
||||
return input
|
||||
|
||||
def get_embedding_model(self):
|
||||
"""
|
||||
fetches the imagebind embedding model
|
||||
"""
|
||||
imagebind = attempt_import_or_raise("imagebind", "imagebind")
|
||||
model = imagebind.imagebind_model.imagebind_huge(pretrained=True)
|
||||
model.eval()
|
||||
model.to(self.device)
|
||||
return model
|
||||
@@ -36,6 +36,7 @@ TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
IMAGES = Union[
|
||||
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
|
||||
]
|
||||
AUDIO = Union[str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
|
||||
|
||||
@deprecated
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
[project]
|
||||
name = "lancedb"
|
||||
version = "0.5.6"
|
||||
version = "0.5.7"
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.9.16",
|
||||
"pylance==0.9.18",
|
||||
"ratelimiter~=1.0",
|
||||
"retry>=0.9.2",
|
||||
"tqdm>=4.27.0",
|
||||
@@ -50,7 +50,7 @@ repository = "https://github.com/lancedb/lancedb"
|
||||
[project.optional-dependencies]
|
||||
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars>=0.19"]
|
||||
dev = ["ruff", "pre-commit"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]", "mkdocs-ultralytics-plugin==0.0.44"]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "huggingface_hub",
|
||||
"InstructorEmbedding", "google.generativeai", "boto3>=1.28.57", "awscli>=1.29.57", "botocore>=1.31.57"]
|
||||
|
||||
@@ -28,6 +28,23 @@ from lancedb.pydantic import LanceModel, Vector
|
||||
# or connection to external api
|
||||
|
||||
|
||||
try:
|
||||
if importlib.util.find_spec("mlx.core") is not None:
|
||||
_mlx = True
|
||||
else:
|
||||
_mlx = None
|
||||
except Exception:
|
||||
_mlx = None
|
||||
|
||||
try:
|
||||
if importlib.util.find_spec("imagebind") is not None:
|
||||
_imagebind = True
|
||||
else:
|
||||
_imagebind = None
|
||||
except Exception:
|
||||
_imagebind = None
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
|
||||
def test_basic_text_embeddings(alias, tmp_path):
|
||||
@@ -158,6 +175,89 @@ def test_openclip(tmp_path):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
_imagebind is None,
|
||||
reason="skip if imagebind not installed.",
|
||||
)
|
||||
@pytest.mark.slow
|
||||
def test_imagebind(tmp_path):
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import pandas as pd
|
||||
import requests
|
||||
|
||||
import lancedb.embeddings.imagebind
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
print(f"Created temporary directory {temp_dir}")
|
||||
|
||||
def download_images(image_uris):
|
||||
downloaded_image_paths = []
|
||||
for uri in image_uris:
|
||||
try:
|
||||
response = requests.get(uri, stream=True)
|
||||
if response.status_code == 200:
|
||||
# Extract image name from URI
|
||||
image_name = os.path.basename(uri)
|
||||
image_path = os.path.join(temp_dir, image_name)
|
||||
with open(image_path, "wb") as out_file:
|
||||
shutil.copyfileobj(response.raw, out_file)
|
||||
downloaded_image_paths.append(image_path)
|
||||
except Exception as e: # noqa: PERF203
|
||||
print(f"Failed to download {uri}. Error: {e}")
|
||||
return temp_dir, downloaded_image_paths
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = get_registry()
|
||||
func = registry.get("imagebind").create(max_retries=0)
|
||||
|
||||
class Images(LanceModel):
|
||||
label: str
|
||||
image_uri: str = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = db.create_table("images", schema=Images)
|
||||
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
||||
uris = [
|
||||
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||
]
|
||||
temp_dir, downloaded_images = download_images(uris)
|
||||
table.add(pd.DataFrame({"label": labels, "image_uri": downloaded_images}))
|
||||
# text search
|
||||
actual = (
|
||||
table.search("man's best friend", vector_column_name="vector")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == "dog"
|
||||
|
||||
# image search
|
||||
query_image_uri = [
|
||||
"https://live.staticflickr.com/65535/33336453970_491665f66e_h.jpg"
|
||||
]
|
||||
temp_dir, downloaded_images = download_images(query_image_uri)
|
||||
query_image_uri = downloaded_images[0]
|
||||
actual = (
|
||||
table.search(query_image_uri, vector_column_name="vector")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == "dog"
|
||||
|
||||
if os.path.isdir(temp_dir):
|
||||
shutil.rmtree(temp_dir)
|
||||
print(f"Deleted temporary directory {temp_dir}")
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||
@@ -217,13 +317,6 @@ def test_gemini_embedding(tmp_path):
|
||||
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||
|
||||
|
||||
try:
|
||||
if importlib.util.find_spec("mlx.core") is not None:
|
||||
_mlx = True
|
||||
except ImportError:
|
||||
_mlx = None
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
_mlx is None,
|
||||
reason="mlx tests only required for apple users.",
|
||||
|
||||
@@ -22,9 +22,9 @@ use object_store::CredentialProvider;
|
||||
use once_cell::sync::OnceCell;
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
use vectordb::connection::Database;
|
||||
use vectordb::connect;
|
||||
use vectordb::connection::Connection;
|
||||
use vectordb::table::ReadParams;
|
||||
use vectordb::{ConnectOptions, Connection};
|
||||
|
||||
use crate::error::ResultExt;
|
||||
use crate::query::JsQuery;
|
||||
@@ -39,7 +39,7 @@ mod query;
|
||||
mod table;
|
||||
|
||||
struct JsDatabase {
|
||||
database: Arc<dyn Connection + 'static>,
|
||||
database: Connection,
|
||||
}
|
||||
|
||||
impl Finalize for JsDatabase {}
|
||||
@@ -89,23 +89,23 @@ fn database_new(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let channel = cx.channel();
|
||||
let (deferred, promise) = cx.promise();
|
||||
|
||||
let mut conn_options = ConnectOptions::new(&path);
|
||||
let mut conn_builder = connect(&path);
|
||||
if let Some(region) = region {
|
||||
conn_options = conn_options.region(®ion);
|
||||
conn_builder = conn_builder.region(®ion);
|
||||
}
|
||||
if let Some(aws_creds) = aws_creds {
|
||||
conn_options = conn_options.aws_creds(AwsCredential {
|
||||
conn_builder = conn_builder.aws_creds(AwsCredential {
|
||||
key_id: aws_creds.key_id,
|
||||
secret_key: aws_creds.secret_key,
|
||||
token: aws_creds.token,
|
||||
});
|
||||
}
|
||||
rt.spawn(async move {
|
||||
let database = Database::connect_with_options(&conn_options).await;
|
||||
let database = conn_builder.execute().await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let db = JsDatabase {
|
||||
database: Arc::new(database.or_throw(&mut cx)?),
|
||||
database: database.or_throw(&mut cx)?,
|
||||
};
|
||||
Ok(cx.boxed(db))
|
||||
});
|
||||
@@ -217,7 +217,11 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
|
||||
let (deferred, promise) = cx.promise();
|
||||
rt.spawn(async move {
|
||||
let table_rst = database.open_table_with_params(&table_name, params).await;
|
||||
let table_rst = database
|
||||
.open_table(&table_name)
|
||||
.lance_read_params(params)
|
||||
.execute()
|
||||
.await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let js_table = JsTable::from(table_rst.or_throw(&mut cx)?);
|
||||
|
||||
@@ -18,7 +18,7 @@ use arrow_array::{RecordBatch, RecordBatchIterator};
|
||||
use lance::dataset::optimize::CompactionOptions;
|
||||
use lance::dataset::{WriteMode, WriteParams};
|
||||
use lance::io::ObjectStoreParams;
|
||||
use vectordb::table::OptimizeAction;
|
||||
use vectordb::table::{AddDataOptions, OptimizeAction, WriteOptions};
|
||||
|
||||
use crate::arrow::{arrow_buffer_to_record_batch, record_batch_to_buffer};
|
||||
use neon::prelude::*;
|
||||
@@ -80,7 +80,11 @@ impl JsTable {
|
||||
rt.spawn(async move {
|
||||
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
|
||||
let table_rst = database
|
||||
.create_table(&table_name, Box::new(batch_reader), Some(params))
|
||||
.create_table(&table_name, Box::new(batch_reader))
|
||||
.write_options(WriteOptions {
|
||||
lance_write_params: Some(params),
|
||||
})
|
||||
.execute()
|
||||
.await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
@@ -121,7 +125,13 @@ impl JsTable {
|
||||
|
||||
rt.spawn(async move {
|
||||
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
|
||||
let add_result = table.add(Box::new(batch_reader), Some(params)).await;
|
||||
let opts = AddDataOptions {
|
||||
write_options: WriteOptions {
|
||||
lance_write_params: Some(params),
|
||||
},
|
||||
..Default::default()
|
||||
};
|
||||
let add_result = table.add(Box::new(batch_reader), opts).await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
add_result.or_throw(&mut cx)?;
|
||||
|
||||
@@ -19,7 +19,8 @@ use arrow_array::{FixedSizeListArray, Int32Array, RecordBatch, RecordBatchIterat
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use futures::TryStreamExt;
|
||||
|
||||
use vectordb::Connection;
|
||||
use vectordb::connection::Connection;
|
||||
use vectordb::table::AddDataOptions;
|
||||
use vectordb::{connect, Result, Table, TableRef};
|
||||
|
||||
#[tokio::main]
|
||||
@@ -29,18 +30,18 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
// --8<-- [start:connect]
|
||||
let uri = "data/sample-lancedb";
|
||||
let db = connect(uri).await?;
|
||||
let db = connect(uri).execute().await?;
|
||||
// --8<-- [end:connect]
|
||||
|
||||
// --8<-- [start:list_names]
|
||||
println!("{:?}", db.table_names().await?);
|
||||
// --8<-- [end:list_names]
|
||||
let tbl = create_table(db.clone()).await?;
|
||||
let tbl = create_table(&db).await?;
|
||||
create_index(tbl.as_ref()).await?;
|
||||
let batches = search(tbl.as_ref()).await?;
|
||||
println!("{:?}", batches);
|
||||
|
||||
create_empty_table(db.clone()).await.unwrap();
|
||||
create_empty_table(&db).await.unwrap();
|
||||
|
||||
// --8<-- [start:delete]
|
||||
tbl.delete("id > 24").await.unwrap();
|
||||
@@ -55,17 +56,14 @@ async fn main() -> Result<()> {
|
||||
#[allow(dead_code)]
|
||||
async fn open_with_existing_tbl() -> Result<()> {
|
||||
let uri = "data/sample-lancedb";
|
||||
let db = connect(uri).await?;
|
||||
let db = connect(uri).execute().await?;
|
||||
// --8<-- [start:open_with_existing_file]
|
||||
let _ = db
|
||||
.open_table_with_params("my_table", Default::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let _ = db.open_table("my_table").execute().await.unwrap();
|
||||
// --8<-- [end:open_with_existing_file]
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_table(db: Arc<dyn Connection>) -> Result<TableRef> {
|
||||
async fn create_table(db: &Connection) -> Result<TableRef> {
|
||||
// --8<-- [start:create_table]
|
||||
const TOTAL: usize = 1000;
|
||||
const DIM: usize = 128;
|
||||
@@ -102,7 +100,8 @@ async fn create_table(db: Arc<dyn Connection>) -> Result<TableRef> {
|
||||
schema.clone(),
|
||||
);
|
||||
let tbl = db
|
||||
.create_table("my_table", Box::new(batches), None)
|
||||
.create_table("my_table", Box::new(batches))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
// --8<-- [end:create_table]
|
||||
@@ -126,21 +125,21 @@ async fn create_table(db: Arc<dyn Connection>) -> Result<TableRef> {
|
||||
schema.clone(),
|
||||
);
|
||||
// --8<-- [start:add]
|
||||
tbl.add(Box::new(new_batches), None).await.unwrap();
|
||||
tbl.add(Box::new(new_batches), AddDataOptions::default())
|
||||
.await
|
||||
.unwrap();
|
||||
// --8<-- [end:add]
|
||||
|
||||
Ok(tbl)
|
||||
}
|
||||
|
||||
async fn create_empty_table(db: Arc<dyn Connection>) -> Result<TableRef> {
|
||||
async fn create_empty_table(db: &Connection) -> Result<TableRef> {
|
||||
// --8<-- [start:create_empty_table]
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Int32, false),
|
||||
Field::new("item", DataType::Utf8, true),
|
||||
]));
|
||||
let batches = RecordBatchIterator::new(vec![], schema.clone());
|
||||
db.create_table("empty_table", Box::new(batches), None)
|
||||
.await
|
||||
db.create_empty_table("empty_table", schema).execute().await
|
||||
// --8<-- [end:create_empty_table]
|
||||
}
|
||||
|
||||
|
||||
@@ -13,14 +13,14 @@
|
||||
// limitations under the License.
|
||||
|
||||
//! LanceDB Database
|
||||
//!
|
||||
|
||||
use std::fs::create_dir_all;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::RecordBatchReader;
|
||||
use lance::dataset::WriteParams;
|
||||
use arrow_array::{RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_schema::SchemaRef;
|
||||
use lance::dataset::{ReadParams, WriteMode};
|
||||
use lance::io::{ObjectStore, ObjectStoreParams, WrappingObjectStore};
|
||||
use object_store::{
|
||||
aws::AwsCredential, local::LocalFileSystem, CredentialProvider, StaticCredentialProvider,
|
||||
@@ -29,73 +29,283 @@ use snafu::prelude::*;
|
||||
|
||||
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
|
||||
use crate::io::object_store::MirroringObjectStoreWrapper;
|
||||
use crate::table::{NativeTable, ReadParams, TableRef};
|
||||
use crate::table::{NativeTable, TableRef, WriteOptions};
|
||||
|
||||
pub const LANCE_FILE_EXTENSION: &str = "lance";
|
||||
|
||||
/// A connection to LanceDB
|
||||
#[async_trait::async_trait]
|
||||
pub trait Connection: Send + Sync {
|
||||
/// Get the names of all tables in the database.
|
||||
async fn table_names(&self) -> Result<Vec<String>>;
|
||||
pub type TableBuilderCallback = Box<dyn FnOnce(OpenTableBuilder) -> OpenTableBuilder + Send>;
|
||||
|
||||
/// Create a new table in the database.
|
||||
/// Describes what happens when creating a table and a table with
|
||||
/// the same name already exists
|
||||
pub enum CreateTableMode {
|
||||
/// If the table already exists, an error is returned
|
||||
Create,
|
||||
/// If the table already exists, it is opened. Any provided data is
|
||||
/// ignored. The function will be passed an OpenTableBuilder to customize
|
||||
/// how the table is opened
|
||||
ExistOk(TableBuilderCallback),
|
||||
/// If the table already exists, it is overwritten
|
||||
Overwrite,
|
||||
}
|
||||
|
||||
impl CreateTableMode {
|
||||
pub fn exist_ok(
|
||||
callback: impl FnOnce(OpenTableBuilder) -> OpenTableBuilder + Send + 'static,
|
||||
) -> Self {
|
||||
Self::ExistOk(Box::new(callback))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CreateTableMode {
|
||||
fn default() -> Self {
|
||||
Self::Create
|
||||
}
|
||||
}
|
||||
|
||||
/// Describes what happens when a vector either contains NaN or
|
||||
/// does not have enough values
|
||||
#[derive(Clone, Debug, Default)]
|
||||
enum BadVectorHandling {
|
||||
/// An error is returned
|
||||
#[default]
|
||||
Error,
|
||||
#[allow(dead_code)] // https://github.com/lancedb/lancedb/issues/992
|
||||
/// The offending row is droppped
|
||||
Drop,
|
||||
#[allow(dead_code)] // https://github.com/lancedb/lancedb/issues/992
|
||||
/// The invalid/missing items are replaced by fill_value
|
||||
Fill(f32),
|
||||
}
|
||||
|
||||
/// A builder for configuring a [`Connection::create_table`] operation
|
||||
pub struct CreateTableBuilder<const HAS_DATA: bool> {
|
||||
parent: Arc<dyn ConnectionInternal>,
|
||||
name: String,
|
||||
data: Option<Box<dyn RecordBatchReader + Send>>,
|
||||
schema: Option<SchemaRef>,
|
||||
mode: CreateTableMode,
|
||||
write_options: WriteOptions,
|
||||
}
|
||||
|
||||
// Builder methods that only apply when we have initial data
|
||||
impl CreateTableBuilder<true> {
|
||||
fn new(
|
||||
parent: Arc<dyn ConnectionInternal>,
|
||||
name: String,
|
||||
data: Box<dyn RecordBatchReader + Send>,
|
||||
) -> Self {
|
||||
Self {
|
||||
parent,
|
||||
name,
|
||||
data: Some(data),
|
||||
schema: None,
|
||||
mode: CreateTableMode::default(),
|
||||
write_options: WriteOptions::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply the given write options when writing the initial data
|
||||
pub fn write_options(mut self, write_options: WriteOptions) -> Self {
|
||||
self.write_options = write_options;
|
||||
self
|
||||
}
|
||||
|
||||
/// Execute the create table operation
|
||||
pub async fn execute(self) -> Result<TableRef> {
|
||||
self.parent.clone().do_create_table(self).await
|
||||
}
|
||||
}
|
||||
|
||||
// Builder methods that only apply when we do not have initial data
|
||||
impl CreateTableBuilder<false> {
|
||||
fn new(parent: Arc<dyn ConnectionInternal>, name: String, schema: SchemaRef) -> Self {
|
||||
Self {
|
||||
parent,
|
||||
name,
|
||||
data: None,
|
||||
schema: Some(schema),
|
||||
mode: CreateTableMode::default(),
|
||||
write_options: WriteOptions::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute the create table operation
|
||||
pub async fn execute(self) -> Result<TableRef> {
|
||||
self.parent.clone().do_create_empty_table(self).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<const HAS_DATA: bool> CreateTableBuilder<HAS_DATA> {
|
||||
/// Set the mode for creating the table
|
||||
///
|
||||
/// This controls what happens if a table with the given name already exists
|
||||
pub fn mode(mut self, mode: CreateTableMode) -> Self {
|
||||
self.mode = mode;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OpenTableBuilder {
|
||||
parent: Arc<dyn ConnectionInternal>,
|
||||
name: String,
|
||||
index_cache_size: u32,
|
||||
lance_read_params: Option<ReadParams>,
|
||||
}
|
||||
|
||||
impl OpenTableBuilder {
|
||||
fn new(parent: Arc<dyn ConnectionInternal>, name: String) -> Self {
|
||||
Self {
|
||||
parent,
|
||||
name,
|
||||
index_cache_size: 256,
|
||||
lance_read_params: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the size of the index cache, specified as a number of entries
|
||||
///
|
||||
/// The default value is 256
|
||||
///
|
||||
/// The exact meaning of an "entry" will depend on the type of index:
|
||||
/// * IVF - there is one entry for each IVF partition
|
||||
/// * BTREE - there is one entry for the entire index
|
||||
///
|
||||
/// This cache applies to the entire opened table, across all indices.
|
||||
/// Setting this value higher will increase performance on larger datasets
|
||||
/// at the expense of more RAM
|
||||
pub fn index_cache_size(mut self, index_cache_size: u32) -> Self {
|
||||
self.index_cache_size = index_cache_size;
|
||||
self
|
||||
}
|
||||
|
||||
/// Advanced parameters that can be used to customize table reads
|
||||
///
|
||||
/// If set, these will take precedence over any overlapping `OpenTableOptions` options
|
||||
pub fn lance_read_params(mut self, params: ReadParams) -> Self {
|
||||
self.lance_read_params = Some(params);
|
||||
self
|
||||
}
|
||||
|
||||
/// Open the table
|
||||
pub async fn execute(self) -> Result<TableRef> {
|
||||
self.parent.clone().do_open_table(self).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
trait ConnectionInternal: Send + Sync + std::fmt::Debug + 'static {
|
||||
async fn table_names(&self) -> Result<Vec<String>>;
|
||||
async fn do_create_table(&self, options: CreateTableBuilder<true>) -> Result<TableRef>;
|
||||
async fn do_open_table(&self, options: OpenTableBuilder) -> Result<TableRef>;
|
||||
async fn drop_table(&self, name: &str) -> Result<()>;
|
||||
async fn drop_db(&self) -> Result<()>;
|
||||
|
||||
async fn do_create_empty_table(&self, options: CreateTableBuilder<false>) -> Result<TableRef> {
|
||||
let batches = RecordBatchIterator::new(vec![], options.schema.unwrap());
|
||||
let opts = CreateTableBuilder::<true>::new(options.parent, options.name, Box::new(batches))
|
||||
.mode(options.mode)
|
||||
.write_options(options.write_options);
|
||||
self.do_create_table(opts).await
|
||||
}
|
||||
}
|
||||
|
||||
/// A connection to LanceDB
|
||||
#[derive(Clone)]
|
||||
pub struct Connection {
|
||||
uri: String,
|
||||
internal: Arc<dyn ConnectionInternal>,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
/// Get the URI of the connection
|
||||
pub fn uri(&self) -> &str {
|
||||
self.uri.as_str()
|
||||
}
|
||||
|
||||
/// Get the names of all tables in the database.
|
||||
pub async fn table_names(&self) -> Result<Vec<String>> {
|
||||
self.internal.table_names().await
|
||||
}
|
||||
|
||||
/// Create a new table from data
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `name` - The name of the table.
|
||||
/// * `batches` - The initial data to write to the table.
|
||||
/// * `params` - Optional [`WriteParams`] to create the table.
|
||||
///
|
||||
/// # Returns
|
||||
/// Created [`TableRef`], or [`Err(Error::TableAlreadyExists)`] if the table already exists.
|
||||
async fn create_table(
|
||||
/// * `name` - The name of the table
|
||||
/// * `initial_data` - The initial data to write to the table
|
||||
pub fn create_table(
|
||||
&self,
|
||||
name: &str,
|
||||
batches: Box<dyn RecordBatchReader + Send>,
|
||||
params: Option<WriteParams>,
|
||||
) -> Result<TableRef>;
|
||||
|
||||
async fn open_table(&self, name: &str) -> Result<TableRef> {
|
||||
self.open_table_with_params(name, ReadParams::default())
|
||||
.await
|
||||
name: impl Into<String>,
|
||||
initial_data: Box<dyn RecordBatchReader + Send>,
|
||||
) -> CreateTableBuilder<true> {
|
||||
CreateTableBuilder::<true>::new(self.internal.clone(), name.into(), initial_data)
|
||||
}
|
||||
|
||||
async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result<TableRef>;
|
||||
/// Create an empty table with a given schema
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `name` - The name of the table
|
||||
/// * `schema` - The schema of the table
|
||||
pub fn create_empty_table(
|
||||
&self,
|
||||
name: impl Into<String>,
|
||||
schema: SchemaRef,
|
||||
) -> CreateTableBuilder<false> {
|
||||
CreateTableBuilder::<false>::new(self.internal.clone(), name.into(), schema)
|
||||
}
|
||||
|
||||
/// Open an existing table in the database
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `name` - The name of the table
|
||||
///
|
||||
/// # Returns
|
||||
/// Created [`TableRef`], or [`Error::TableNotFound`] if the table does not exist.
|
||||
pub fn open_table(&self, name: impl Into<String>) -> OpenTableBuilder {
|
||||
OpenTableBuilder::new(self.internal.clone(), name.into())
|
||||
}
|
||||
|
||||
/// Drop a table in the database.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `name` - The name of the table.
|
||||
async fn drop_table(&self, name: &str) -> Result<()>;
|
||||
/// * `name` - The name of the table to drop
|
||||
pub async fn drop_table(&self, name: impl AsRef<str>) -> Result<()> {
|
||||
self.internal.drop_table(name.as_ref()).await
|
||||
}
|
||||
|
||||
/// Drop the database
|
||||
///
|
||||
/// This is the same as dropping all of the tables
|
||||
pub async fn drop_db(&self) -> Result<()> {
|
||||
self.internal.drop_db().await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ConnectOptions {
|
||||
pub struct ConnectBuilder {
|
||||
/// Database URI
|
||||
///
|
||||
/// # Accpeted URI formats
|
||||
/// ### Accpeted URI formats
|
||||
///
|
||||
/// - `/path/to/database` - local database on file system.
|
||||
/// - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud object store
|
||||
/// - `db://dbname` - Lance Cloud
|
||||
pub uri: String,
|
||||
/// - `db://dbname` - LanceDB Cloud
|
||||
uri: String,
|
||||
|
||||
/// Lance Cloud API key
|
||||
pub api_key: Option<String>,
|
||||
/// Lance Cloud region
|
||||
pub region: Option<String>,
|
||||
/// Lance Cloud host override
|
||||
pub host_override: Option<String>,
|
||||
/// LanceDB Cloud API key, required if using Lance Cloud
|
||||
api_key: Option<String>,
|
||||
/// LanceDB Cloud region, required if using Lance Cloud
|
||||
region: Option<String>,
|
||||
/// LanceDB Cloud host override, only required if using an on-premises Lance Cloud instance
|
||||
host_override: Option<String>,
|
||||
|
||||
/// User provided AWS credentials
|
||||
pub aws_creds: Option<AwsCredential>,
|
||||
|
||||
/// The maximum number of indices to cache in memory. Defaults to 256.
|
||||
pub index_cache_size: u32,
|
||||
aws_creds: Option<AwsCredential>,
|
||||
}
|
||||
|
||||
impl ConnectOptions {
|
||||
impl ConnectBuilder {
|
||||
/// Create a new [`ConnectOptions`] with the given database URI.
|
||||
pub fn new(uri: &str) -> Self {
|
||||
Self {
|
||||
@@ -104,7 +314,6 @@ impl ConnectOptions {
|
||||
region: None,
|
||||
host_override: None,
|
||||
aws_creds: None,
|
||||
index_cache_size: 256,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,15 +333,18 @@ impl ConnectOptions {
|
||||
}
|
||||
|
||||
/// [`AwsCredential`] to use when connecting to S3.
|
||||
///
|
||||
pub fn aws_creds(mut self, aws_creds: AwsCredential) -> Self {
|
||||
self.aws_creds = Some(aws_creds);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn index_cache_size(mut self, index_cache_size: u32) -> Self {
|
||||
self.index_cache_size = index_cache_size;
|
||||
self
|
||||
/// Establishes a connection to the database
|
||||
pub async fn execute(self) -> Result<Connection> {
|
||||
let internal = Arc::new(Database::connect_with_options(&self).await?);
|
||||
Ok(Connection {
|
||||
internal,
|
||||
uri: self.uri,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,29 +352,14 @@ impl ConnectOptions {
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `uri` - URI where the database is located, can be a local file or a supported remote cloud storage
|
||||
///
|
||||
/// ## Accepted URI formats
|
||||
///
|
||||
/// - `/path/to/database` - local database on file system.
|
||||
/// - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud object store
|
||||
/// - `db://dbname` - Lance Cloud
|
||||
///
|
||||
pub async fn connect(uri: &str) -> Result<Arc<dyn Connection>> {
|
||||
let options = ConnectOptions::new(uri);
|
||||
connect_with_options(&options).await
|
||||
/// * `uri` - URI where the database is located, can be a local directory, supported remote cloud storage,
|
||||
/// or a LanceDB Cloud database. See [ConnectOptions::uri] for a list of accepted formats
|
||||
pub fn connect(uri: &str) -> ConnectBuilder {
|
||||
ConnectBuilder::new(uri)
|
||||
}
|
||||
|
||||
/// Connect with [`ConnectOptions`].
|
||||
///
|
||||
/// # Arguments
|
||||
/// - `options` - [`ConnectOptions`] to connect to the database.
|
||||
pub async fn connect_with_options(options: &ConnectOptions) -> Result<Arc<dyn Connection>> {
|
||||
let db = Database::connect(&options.uri).await?;
|
||||
Ok(Arc::new(db))
|
||||
}
|
||||
|
||||
pub struct Database {
|
||||
#[derive(Debug)]
|
||||
struct Database {
|
||||
object_store: ObjectStore,
|
||||
query_string: Option<String>,
|
||||
|
||||
@@ -179,21 +376,7 @@ const MIRRORED_STORE: &str = "mirroredStore";
|
||||
|
||||
/// A connection to LanceDB
|
||||
impl Database {
|
||||
/// Connects to LanceDB
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `uri` - URI where the database is located, can be a local file or a supported remote cloud storage
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [Database] object.
|
||||
pub async fn connect(uri: &str) -> Result<Self> {
|
||||
let options = ConnectOptions::new(uri);
|
||||
Self::connect_with_options(&options).await
|
||||
}
|
||||
|
||||
pub async fn connect_with_options(options: &ConnectOptions) -> Result<Self> {
|
||||
async fn connect_with_options(options: &ConnectBuilder) -> Result<Self> {
|
||||
let uri = &options.uri;
|
||||
let parse_res = url::Url::parse(uri);
|
||||
|
||||
@@ -333,7 +516,7 @@ impl Database {
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Connection for Database {
|
||||
impl ConnectionInternal for Database {
|
||||
async fn table_names(&self) -> Result<Vec<String>> {
|
||||
let mut f = self
|
||||
.object_store
|
||||
@@ -354,40 +537,47 @@ impl Connection for Database {
|
||||
Ok(f)
|
||||
}
|
||||
|
||||
async fn create_table(
|
||||
&self,
|
||||
name: &str,
|
||||
batches: Box<dyn RecordBatchReader + Send>,
|
||||
params: Option<WriteParams>,
|
||||
) -> Result<TableRef> {
|
||||
let table_uri = self.table_uri(name)?;
|
||||
async fn do_create_table(&self, options: CreateTableBuilder<true>) -> Result<TableRef> {
|
||||
let table_uri = self.table_uri(&options.name)?;
|
||||
|
||||
Ok(Arc::new(
|
||||
NativeTable::create(
|
||||
&table_uri,
|
||||
name,
|
||||
batches,
|
||||
self.store_wrapper.clone(),
|
||||
params,
|
||||
)
|
||||
.await?,
|
||||
))
|
||||
let mut write_params = options.write_options.lance_write_params.unwrap_or_default();
|
||||
if matches!(&options.mode, CreateTableMode::Overwrite) {
|
||||
write_params.mode = WriteMode::Overwrite;
|
||||
}
|
||||
|
||||
match NativeTable::create(
|
||||
&table_uri,
|
||||
&options.name,
|
||||
options.data.unwrap(),
|
||||
self.store_wrapper.clone(),
|
||||
Some(write_params),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(table) => Ok(Arc::new(table)),
|
||||
Err(Error::TableAlreadyExists { name }) => match options.mode {
|
||||
CreateTableMode::Create => Err(Error::TableAlreadyExists { name }),
|
||||
CreateTableMode::ExistOk(callback) => {
|
||||
let builder = OpenTableBuilder::new(options.parent, options.name);
|
||||
let builder = (callback)(builder);
|
||||
builder.execute().await
|
||||
}
|
||||
CreateTableMode::Overwrite => unreachable!(),
|
||||
},
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
/// Open a table in the database.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `name` - The name of the table.
|
||||
/// * `params` - The parameters to open the table.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [TableRef] object.
|
||||
async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result<TableRef> {
|
||||
let table_uri = self.table_uri(name)?;
|
||||
async fn do_open_table(&self, options: OpenTableBuilder) -> Result<TableRef> {
|
||||
let table_uri = self.table_uri(&options.name)?;
|
||||
Ok(Arc::new(
|
||||
NativeTable::open_with_params(&table_uri, name, self.store_wrapper.clone(), params)
|
||||
.await?,
|
||||
NativeTable::open_with_params(
|
||||
&table_uri,
|
||||
&options.name,
|
||||
self.store_wrapper.clone(),
|
||||
options.lance_read_params,
|
||||
)
|
||||
.await?,
|
||||
))
|
||||
}
|
||||
|
||||
@@ -397,12 +587,17 @@ impl Connection for Database {
|
||||
self.object_store.remove_dir_all(full_path).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn drop_db(&self) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::fs::create_dir_all;
|
||||
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use tempfile::tempdir;
|
||||
|
||||
use super::*;
|
||||
@@ -411,7 +606,7 @@ mod tests {
|
||||
async fn test_connect() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
let db = Database::connect(uri).await.unwrap();
|
||||
let db = connect(uri).execute().await.unwrap();
|
||||
|
||||
assert_eq!(db.uri, uri);
|
||||
}
|
||||
@@ -429,7 +624,8 @@ mod tests {
|
||||
let relative_root = std::path::PathBuf::from(relative_ancestors.join("/"));
|
||||
let relative_uri = relative_root.join(&uri);
|
||||
|
||||
let db = Database::connect(relative_uri.to_str().unwrap())
|
||||
let db = connect(relative_uri.to_str().unwrap())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -444,7 +640,7 @@ mod tests {
|
||||
create_dir_all(tmp_dir.path().join("invalidlance")).unwrap();
|
||||
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
let db = Database::connect(uri).await.unwrap();
|
||||
let db = connect(uri).execute().await.unwrap();
|
||||
let tables = db.table_names().await.unwrap();
|
||||
assert_eq!(tables.len(), 2);
|
||||
assert!(tables[0].eq(&String::from("table1")));
|
||||
@@ -462,10 +658,44 @@ mod tests {
|
||||
create_dir_all(tmp_dir.path().join("table1.lance")).unwrap();
|
||||
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
let db = Database::connect(uri).await.unwrap();
|
||||
let db = connect(uri).execute().await.unwrap();
|
||||
db.drop_table("table1").await.unwrap();
|
||||
|
||||
let tables = db.table_names().await.unwrap();
|
||||
assert_eq!(tables.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_table_already_exists() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
let db = connect(uri).execute().await.unwrap();
|
||||
let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)]));
|
||||
db.create_empty_table("test", schema.clone())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
// TODO: None of the open table options are "inspectable" right now but once one is we
|
||||
// should assert we are passing these options in correctly
|
||||
db.create_empty_table("test", schema)
|
||||
.mode(CreateTableMode::exist_ok(|builder| {
|
||||
builder.index_cache_size(16)
|
||||
}))
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
let other_schema = Arc::new(Schema::new(vec![Field::new("y", DataType::Int32, false)]));
|
||||
assert!(db
|
||||
.create_empty_table("test", other_schema.clone())
|
||||
.execute()
|
||||
.await
|
||||
.is_err());
|
||||
let overwritten = db
|
||||
.create_empty_table("test", other_schema.clone())
|
||||
.mode(CreateTableMode::Overwrite)
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(other_schema, overwritten.schema());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,7 +174,6 @@ fn coerce_schema_batch(
|
||||
}
|
||||
|
||||
/// Coerce the reader (input data) to match the given [Schema].
|
||||
///
|
||||
pub fn coerce_schema(
|
||||
reader: impl RecordBatchReader + Send + 'static,
|
||||
schema: Arc<Schema>,
|
||||
|
||||
@@ -342,7 +342,7 @@ mod test {
|
||||
use object_store::local::LocalFileSystem;
|
||||
use tempfile;
|
||||
|
||||
use crate::connection::{Connection, Database};
|
||||
use crate::{connect, table::WriteOptions};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_e2e() {
|
||||
@@ -354,7 +354,7 @@ mod test {
|
||||
secondary: Arc::new(secondary_store),
|
||||
});
|
||||
|
||||
let db = Database::connect(dir1.to_str().unwrap()).await.unwrap();
|
||||
let db = connect(dir1.to_str().unwrap()).execute().await.unwrap();
|
||||
|
||||
let mut param = WriteParams::default();
|
||||
let store_params = ObjectStoreParams {
|
||||
@@ -368,7 +368,11 @@ mod test {
|
||||
datagen = datagen.col(Box::new(RandomVector::default().named("vector".into())));
|
||||
|
||||
let res = db
|
||||
.create_table("test", Box::new(datagen.batch(100)), Some(param.clone()))
|
||||
.create_table("test", Box::new(datagen.batch(100)))
|
||||
.write_options(WriteOptions {
|
||||
lance_write_params: Some(param),
|
||||
})
|
||||
.execute()
|
||||
.await;
|
||||
|
||||
// leave this here for easy debugging
|
||||
|
||||
@@ -43,10 +43,9 @@
|
||||
//! #### Connect to a database.
|
||||
//!
|
||||
//! ```rust
|
||||
//! use vectordb::connect;
|
||||
//! # use arrow_schema::{Field, Schema};
|
||||
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
//! let db = connect("data/sample-lancedb").await.unwrap();
|
||||
//! let db = vectordb::connect("data/sample-lancedb").execute().await.unwrap();
|
||||
//! # });
|
||||
//! ```
|
||||
//!
|
||||
@@ -56,14 +55,20 @@
|
||||
//! - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud object store
|
||||
//! - `db://dbname` - Lance Cloud
|
||||
//!
|
||||
//! You can also use [`ConnectOptions`] to configure the connectoin to the database.
|
||||
//! You can also use [`ConnectOptions`] to configure the connection to the database.
|
||||
//!
|
||||
//! ```rust
|
||||
//! use vectordb::{connect_with_options, ConnectOptions};
|
||||
//! use object_store::aws::AwsCredential;
|
||||
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
//! let options = ConnectOptions::new("data/sample-lancedb")
|
||||
//! .index_cache_size(1024);
|
||||
//! let db = connect_with_options(&options).await.unwrap();
|
||||
//! let db = vectordb::connect("data/sample-lancedb")
|
||||
//! .aws_creds(AwsCredential {
|
||||
//! key_id: "some_key".to_string(),
|
||||
//! secret_key: "some_secret".to_string(),
|
||||
//! token: None,
|
||||
//! })
|
||||
//! .execute()
|
||||
//! .await
|
||||
//! .unwrap();
|
||||
//! # });
|
||||
//! ```
|
||||
//!
|
||||
@@ -79,31 +84,44 @@
|
||||
//!
|
||||
//! ```rust
|
||||
//! # use std::sync::Arc;
|
||||
//! use arrow_schema::{DataType, Schema, Field};
|
||||
//! use arrow_array::{RecordBatch, RecordBatchIterator};
|
||||
//! use arrow_schema::{DataType, Field, Schema};
|
||||
//! # use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, types::Float32Type};
|
||||
//! # use vectordb::connection::{Database, Connection};
|
||||
//! # use vectordb::connect;
|
||||
//!
|
||||
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
//! # let tmpdir = tempfile::tempdir().unwrap();
|
||||
//! # let db = connect(tmpdir.path().to_str().unwrap()).await.unwrap();
|
||||
//! # let db = vectordb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap();
|
||||
//! let schema = Arc::new(Schema::new(vec![
|
||||
//! Field::new("id", DataType::Int32, false),
|
||||
//! Field::new("vector", DataType::FixedSizeList(
|
||||
//! Arc::new(Field::new("item", DataType::Float32, true)), 128), true),
|
||||
//! Field::new("id", DataType::Int32, false),
|
||||
//! Field::new(
|
||||
//! "vector",
|
||||
//! DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 128),
|
||||
//! true,
|
||||
//! ),
|
||||
//! ]));
|
||||
//! // Create a RecordBatch stream.
|
||||
//! let batches = RecordBatchIterator::new(vec![
|
||||
//! RecordBatch::try_new(schema.clone(),
|
||||
//! let batches = RecordBatchIterator::new(
|
||||
//! vec![RecordBatch::try_new(
|
||||
//! schema.clone(),
|
||||
//! vec![
|
||||
//! Arc::new(Int32Array::from_iter_values(0..1000)),
|
||||
//! Arc::new(FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
//! (0..1000).map(|_| Some(vec![Some(1.0); 128])), 128)),
|
||||
//! ]).unwrap()
|
||||
//! ].into_iter().map(Ok),
|
||||
//! schema.clone());
|
||||
//! db.create_table("my_table", Box::new(batches), None).await.unwrap();
|
||||
//! Arc::new(Int32Array::from_iter_values(0..256)),
|
||||
//! Arc::new(
|
||||
//! FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
//! (0..256).map(|_| Some(vec![Some(1.0); 128])),
|
||||
//! 128,
|
||||
//! ),
|
||||
//! ),
|
||||
//! ],
|
||||
//! )
|
||||
//! .unwrap()]
|
||||
//! .into_iter()
|
||||
//! .map(Ok),
|
||||
//! schema.clone(),
|
||||
//! );
|
||||
//! db.create_table("my_table", Box::new(batches))
|
||||
//! .execute()
|
||||
//! .await
|
||||
//! .unwrap();
|
||||
//! # });
|
||||
//! ```
|
||||
//!
|
||||
@@ -111,14 +129,13 @@
|
||||
//!
|
||||
//! ```no_run
|
||||
//! # use std::sync::Arc;
|
||||
//! # use vectordb::connect;
|
||||
//! # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
|
||||
//! # RecordBatchIterator, Int32Array};
|
||||
//! # use arrow_schema::{Schema, Field, DataType};
|
||||
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
//! # let tmpdir = tempfile::tempdir().unwrap();
|
||||
//! # let db = connect(tmpdir.path().to_str().unwrap()).await.unwrap();
|
||||
//! # let tbl = db.open_table("idx_test").await.unwrap();
|
||||
//! # let db = vectordb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap();
|
||||
//! # let tbl = db.open_table("idx_test").execute().await.unwrap();
|
||||
//! tbl.create_index(&["vector"])
|
||||
//! .ivf_pq()
|
||||
//! .num_partitions(256)
|
||||
@@ -136,10 +153,9 @@
|
||||
//! # use arrow_schema::{DataType, Schema, Field};
|
||||
//! # use arrow_array::{RecordBatch, RecordBatchIterator};
|
||||
//! # use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, types::Float32Type};
|
||||
//! # use vectordb::connection::{Database, Connection};
|
||||
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
//! # let tmpdir = tempfile::tempdir().unwrap();
|
||||
//! # let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
|
||||
//! # let db = vectordb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap();
|
||||
//! # let schema = Arc::new(Schema::new(vec![
|
||||
//! # Field::new("id", DataType::Int32, false),
|
||||
//! # Field::new("vector", DataType::FixedSizeList(
|
||||
@@ -154,8 +170,8 @@
|
||||
//! # ]).unwrap()
|
||||
//! # ].into_iter().map(Ok),
|
||||
//! # schema.clone());
|
||||
//! # db.create_table("my_table", Box::new(batches), None).await.unwrap();
|
||||
//! # let table = db.open_table("my_table").await.unwrap();
|
||||
//! # db.create_table("my_table", Box::new(batches)).execute().await.unwrap();
|
||||
//! # let table = db.open_table("my_table").execute().await.unwrap();
|
||||
//! let results = table
|
||||
//! .search(&[1.0; 128])
|
||||
//! .execute_stream()
|
||||
@@ -165,8 +181,6 @@
|
||||
//! .await
|
||||
//! .unwrap();
|
||||
//! # });
|
||||
//!
|
||||
//!
|
||||
//! ```
|
||||
|
||||
pub mod connection;
|
||||
@@ -179,10 +193,8 @@ pub mod query;
|
||||
pub mod table;
|
||||
pub mod utils;
|
||||
|
||||
pub use connection::{Connection, Database};
|
||||
pub use error::{Error, Result};
|
||||
pub use table::{Table, TableRef};
|
||||
|
||||
/// Connect to a database
|
||||
pub use connection::{connect, connect_with_options, ConnectOptions};
|
||||
pub use lance::dataset::WriteMode;
|
||||
pub use connection::connect;
|
||||
|
||||
@@ -60,7 +60,6 @@ impl Query {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dataset` - Lance dataset.
|
||||
///
|
||||
pub(crate) fn new(dataset: Arc<Dataset>) -> Self {
|
||||
Self {
|
||||
dataset,
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use arrow_array::RecordBatchReader;
|
||||
use arrow_array::{RecordBatchIterator, RecordBatchReader};
|
||||
use arrow_schema::{Schema, SchemaRef};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Duration;
|
||||
@@ -27,7 +27,7 @@ use lance::dataset::optimize::{
|
||||
compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions,
|
||||
};
|
||||
pub use lance::dataset::ReadParams;
|
||||
use lance::dataset::{Dataset, UpdateBuilder, WhenMatched, WriteParams};
|
||||
use lance::dataset::{Dataset, UpdateBuilder, WhenMatched, WriteMode, WriteParams};
|
||||
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
|
||||
use lance::io::WrappingObjectStore;
|
||||
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
|
||||
@@ -38,7 +38,6 @@ use crate::index::vector::{VectorIndex, VectorIndexStatistics};
|
||||
use crate::index::IndexBuilder;
|
||||
use crate::query::Query;
|
||||
use crate::utils::{PatchReadParam, PatchWriteParam};
|
||||
use crate::WriteMode;
|
||||
|
||||
use self::merge::{MergeInsert, MergeInsertBuilder};
|
||||
|
||||
@@ -85,6 +84,35 @@ pub struct OptimizeStats {
|
||||
pub prune: Option<RemovalStats>,
|
||||
}
|
||||
|
||||
/// Options to use when writing data
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct WriteOptions {
|
||||
// Coming soon: https://github.com/lancedb/lancedb/issues/992
|
||||
// /// What behavior to take if the data contains invalid vectors
|
||||
// pub on_bad_vectors: BadVectorHandling,
|
||||
/// Advanced parameters that can be used to customize table creation
|
||||
///
|
||||
/// If set, these will take precedence over any overlapping `OpenTableOptions` options
|
||||
pub lance_write_params: Option<WriteParams>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub enum AddDataMode {
|
||||
/// Rows will be appended to the table (the default)
|
||||
#[default]
|
||||
Append,
|
||||
/// The existing table will be overwritten with the new data
|
||||
Overwrite,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct AddDataOptions {
|
||||
/// Whether to add new rows (the default) or replace the existing data
|
||||
pub mode: AddDataMode,
|
||||
/// Options to use when writing the data
|
||||
pub write_options: WriteOptions,
|
||||
}
|
||||
|
||||
/// A Table is a collection of strong typed Rows.
|
||||
///
|
||||
/// The type of the each row is defined in Apache Arrow [Schema].
|
||||
@@ -112,12 +140,12 @@ pub trait Table: std::fmt::Display + Send + Sync {
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `batches` RecordBatch to be saved in the Table
|
||||
/// * `params` Append / Overwrite existing records. Default: Append
|
||||
/// * `batches` data to be added to the Table
|
||||
/// * `options` options to control how data is added
|
||||
async fn add(
|
||||
&self,
|
||||
batches: Box<dyn RecordBatchReader + Send>,
|
||||
params: Option<WriteParams>,
|
||||
options: AddDataOptions,
|
||||
) -> Result<()>;
|
||||
|
||||
/// Delete the rows from table that match the predicate.
|
||||
@@ -129,28 +157,43 @@ pub trait Table: std::fmt::Display + Send + Sync {
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use std::sync::Arc;
|
||||
/// # use vectordb::connection::{Database, Connection};
|
||||
/// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
|
||||
/// # RecordBatchIterator, Int32Array};
|
||||
/// # use arrow_schema::{Schema, Field, DataType};
|
||||
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
/// let tmpdir = tempfile::tempdir().unwrap();
|
||||
/// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
|
||||
/// let db = vectordb::connect(tmpdir.path().to_str().unwrap())
|
||||
/// .execute()
|
||||
/// .await
|
||||
/// .unwrap();
|
||||
/// # let schema = Arc::new(Schema::new(vec![
|
||||
/// # Field::new("id", DataType::Int32, false),
|
||||
/// # Field::new("vector", DataType::FixedSizeList(
|
||||
/// # Arc::new(Field::new("item", DataType::Float32, true)), 128), true),
|
||||
/// # ]));
|
||||
/// let batches = RecordBatchIterator::new(vec![
|
||||
/// RecordBatch::try_new(schema.clone(),
|
||||
/// vec![
|
||||
/// Arc::new(Int32Array::from_iter_values(0..10)),
|
||||
/// Arc::new(FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
/// (0..10).map(|_| Some(vec![Some(1.0); 128])), 128)),
|
||||
/// ]).unwrap()
|
||||
/// ].into_iter().map(Ok),
|
||||
/// schema.clone());
|
||||
/// let tbl = db.create_table("delete_test", Box::new(batches), None).await.unwrap();
|
||||
/// let batches = RecordBatchIterator::new(
|
||||
/// vec![RecordBatch::try_new(
|
||||
/// schema.clone(),
|
||||
/// vec![
|
||||
/// Arc::new(Int32Array::from_iter_values(0..10)),
|
||||
/// Arc::new(
|
||||
/// FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
/// (0..10).map(|_| Some(vec![Some(1.0); 128])),
|
||||
/// 128,
|
||||
/// ),
|
||||
/// ),
|
||||
/// ],
|
||||
/// )
|
||||
/// .unwrap()]
|
||||
/// .into_iter()
|
||||
/// .map(Ok),
|
||||
/// schema.clone(),
|
||||
/// );
|
||||
/// let tbl = db
|
||||
/// .create_table("delete_test", Box::new(batches))
|
||||
/// .execute()
|
||||
/// .await
|
||||
/// .unwrap();
|
||||
/// tbl.delete("id > 5").await.unwrap();
|
||||
/// # });
|
||||
/// ```
|
||||
@@ -162,14 +205,16 @@ pub trait Table: std::fmt::Display + Send + Sync {
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use std::sync::Arc;
|
||||
/// # use vectordb::connection::{Database, Connection};
|
||||
/// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
|
||||
/// # RecordBatchIterator, Int32Array};
|
||||
/// # use arrow_schema::{Schema, Field, DataType};
|
||||
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
/// let tmpdir = tempfile::tempdir().unwrap();
|
||||
/// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
|
||||
/// # let tbl = db.open_table("idx_test").await.unwrap();
|
||||
/// let db = vectordb::connect(tmpdir.path().to_str().unwrap())
|
||||
/// .execute()
|
||||
/// .await
|
||||
/// .unwrap();
|
||||
/// # let tbl = db.open_table("idx_test").execute().await.unwrap();
|
||||
/// tbl.create_index(&["vector"])
|
||||
/// .ivf_pq()
|
||||
/// .num_partitions(256)
|
||||
@@ -214,32 +259,44 @@ pub trait Table: std::fmt::Display + Send + Sync {
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use std::sync::Arc;
|
||||
/// # use vectordb::connection::{Database, Connection};
|
||||
/// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
|
||||
/// # RecordBatchIterator, Int32Array};
|
||||
/// # use arrow_schema::{Schema, Field, DataType};
|
||||
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
/// let tmpdir = tempfile::tempdir().unwrap();
|
||||
/// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
|
||||
/// # let tbl = db.open_table("idx_test").await.unwrap();
|
||||
/// let db = vectordb::connect(tmpdir.path().to_str().unwrap())
|
||||
/// .execute()
|
||||
/// .await
|
||||
/// .unwrap();
|
||||
/// # let tbl = db.open_table("idx_test").execute().await.unwrap();
|
||||
/// # let schema = Arc::new(Schema::new(vec![
|
||||
/// # Field::new("id", DataType::Int32, false),
|
||||
/// # Field::new("vector", DataType::FixedSizeList(
|
||||
/// # Arc::new(Field::new("item", DataType::Float32, true)), 128), true),
|
||||
/// # ]));
|
||||
/// let new_data = RecordBatchIterator::new(vec![
|
||||
/// RecordBatch::try_new(schema.clone(),
|
||||
/// vec![
|
||||
/// Arc::new(Int32Array::from_iter_values(0..10)),
|
||||
/// Arc::new(FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
/// (0..10).map(|_| Some(vec![Some(1.0); 128])), 128)),
|
||||
/// ]).unwrap()
|
||||
/// ].into_iter().map(Ok),
|
||||
/// schema.clone());
|
||||
/// let new_data = RecordBatchIterator::new(
|
||||
/// vec![RecordBatch::try_new(
|
||||
/// schema.clone(),
|
||||
/// vec![
|
||||
/// Arc::new(Int32Array::from_iter_values(0..10)),
|
||||
/// Arc::new(
|
||||
/// FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||
/// (0..10).map(|_| Some(vec![Some(1.0); 128])),
|
||||
/// 128,
|
||||
/// ),
|
||||
/// ),
|
||||
/// ],
|
||||
/// )
|
||||
/// .unwrap()]
|
||||
/// .into_iter()
|
||||
/// .map(Ok),
|
||||
/// schema.clone(),
|
||||
/// );
|
||||
/// // Perform an upsert operation
|
||||
/// let mut merge_insert = tbl.merge_insert(&["id"]);
|
||||
/// merge_insert.when_matched_update_all(None)
|
||||
/// .when_not_matched_insert_all();
|
||||
/// merge_insert
|
||||
/// .when_matched_update_all(None)
|
||||
/// .when_not_matched_insert_all();
|
||||
/// merge_insert.execute(Box::new(new_data)).await.unwrap();
|
||||
/// # });
|
||||
/// ```
|
||||
@@ -266,7 +323,9 @@ pub trait Table: std::fmt::Display + Send + Sync {
|
||||
/// # use futures::TryStreamExt;
|
||||
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
/// # let tbl = vectordb::table::NativeTable::open("/tmp/tbl").await.unwrap();
|
||||
/// let stream = tbl.query().nearest_to(&[1.0, 2.0, 3.0])
|
||||
/// let stream = tbl
|
||||
/// .query()
|
||||
/// .nearest_to(&[1.0, 2.0, 3.0])
|
||||
/// .refine_factor(5)
|
||||
/// .nprobes(10)
|
||||
/// .execute_stream()
|
||||
@@ -299,11 +358,7 @@ pub trait Table: std::fmt::Display + Send + Sync {
|
||||
/// # use futures::TryStreamExt;
|
||||
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
/// # let tbl = vectordb::table::NativeTable::open("/tmp/tbl").await.unwrap();
|
||||
/// let stream = tbl
|
||||
/// .query()
|
||||
/// .execute_stream()
|
||||
/// .await
|
||||
/// .unwrap();
|
||||
/// let stream = tbl.query().execute_stream().await.unwrap();
|
||||
/// let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
|
||||
/// # });
|
||||
/// ```
|
||||
@@ -351,7 +406,7 @@ impl NativeTable {
|
||||
/// * A [NativeTable] object.
|
||||
pub async fn open(uri: &str) -> Result<Self> {
|
||||
let name = Self::get_table_name(uri)?;
|
||||
Self::open_with_params(uri, &name, None, ReadParams::default()).await
|
||||
Self::open_with_params(uri, &name, None, None).await
|
||||
}
|
||||
|
||||
/// Opens an existing Table
|
||||
@@ -369,8 +424,9 @@ impl NativeTable {
|
||||
uri: &str,
|
||||
name: &str,
|
||||
write_store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
|
||||
params: ReadParams,
|
||||
params: Option<ReadParams>,
|
||||
) -> Result<Self> {
|
||||
let params = params.unwrap_or_default();
|
||||
// patch the params if we have a write store wrapper
|
||||
let params = match write_store_wrapper.clone() {
|
||||
Some(wrapper) => params.patch_with_store_wrapper(wrapper)?,
|
||||
@@ -403,7 +459,6 @@ impl NativeTable {
|
||||
}
|
||||
|
||||
/// Checkout a specific version of this [NativeTable]
|
||||
///
|
||||
pub async fn checkout(uri: &str, version: u64) -> Result<Self> {
|
||||
let name = Self::get_table_name(uri)?;
|
||||
Self::checkout_with_params(uri, &name, version, None, ReadParams::default()).await
|
||||
@@ -489,13 +544,14 @@ impl NativeTable {
|
||||
write_store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
|
||||
params: Option<WriteParams>,
|
||||
) -> Result<Self> {
|
||||
let params = params.unwrap_or_default();
|
||||
// patch the params if we have a write store wrapper
|
||||
let params = match write_store_wrapper.clone() {
|
||||
Some(wrapper) => params.patch_with_store_wrapper(wrapper)?,
|
||||
None => params,
|
||||
};
|
||||
|
||||
let dataset = Dataset::write(batches, uri, params)
|
||||
let dataset = Dataset::write(batches, uri, Some(params))
|
||||
.await
|
||||
.map_err(|e| match e {
|
||||
lance::Error::DatasetAlreadyExists { .. } => Error::TableAlreadyExists {
|
||||
@@ -513,6 +569,17 @@ impl NativeTable {
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn create_empty(
|
||||
uri: &str,
|
||||
name: &str,
|
||||
schema: SchemaRef,
|
||||
write_store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
|
||||
params: Option<WriteParams>,
|
||||
) -> Result<Self> {
|
||||
let batches = RecordBatchIterator::new(vec![], schema);
|
||||
Self::create(uri, name, batches, write_store_wrapper, params).await
|
||||
}
|
||||
|
||||
/// Version of this Table
|
||||
pub fn version(&self) -> u64 {
|
||||
self.dataset.lock().expect("lock poison").version().version
|
||||
@@ -740,20 +807,26 @@ impl Table for NativeTable {
|
||||
async fn add(
|
||||
&self,
|
||||
batches: Box<dyn RecordBatchReader + Send>,
|
||||
params: Option<WriteParams>,
|
||||
params: AddDataOptions,
|
||||
) -> Result<()> {
|
||||
let params = Some(params.unwrap_or(WriteParams {
|
||||
mode: WriteMode::Append,
|
||||
..WriteParams::default()
|
||||
}));
|
||||
let lance_params = params
|
||||
.write_options
|
||||
.lance_write_params
|
||||
.unwrap_or(WriteParams {
|
||||
mode: match params.mode {
|
||||
AddDataMode::Append => WriteMode::Append,
|
||||
AddDataMode::Overwrite => WriteMode::Overwrite,
|
||||
},
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
// patch the params if we have a write store wrapper
|
||||
let params = match self.store_wrapper.clone() {
|
||||
Some(wrapper) => params.patch_with_store_wrapper(wrapper)?,
|
||||
None => params,
|
||||
let lance_params = match self.store_wrapper.clone() {
|
||||
Some(wrapper) => lance_params.patch_with_store_wrapper(wrapper)?,
|
||||
None => lance_params,
|
||||
};
|
||||
|
||||
self.reset_dataset(Dataset::write(batches, &self.uri, params).await?);
|
||||
self.reset_dataset(Dataset::write(batches, &self.uri, Some(lance_params)).await?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -881,25 +954,6 @@ mod tests {
|
||||
assert_eq!(c.to_str().unwrap(), "s3://bucket/path/to/file/subfile");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_already_exists() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
|
||||
let batches = make_test_batches();
|
||||
let _ = batches.schema().clone();
|
||||
NativeTable::create(uri, "test", batches, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let batches = make_test_batches();
|
||||
let result = NativeTable::create(uri, "test", batches, None, None).await;
|
||||
assert!(matches!(
|
||||
result.unwrap_err(),
|
||||
Error::TableAlreadyExists { .. }
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_count_rows() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
@@ -940,7 +994,10 @@ mod tests {
|
||||
schema.clone(),
|
||||
);
|
||||
|
||||
table.add(Box::new(new_batches), None).await.unwrap();
|
||||
table
|
||||
.add(Box::new(new_batches), AddDataOptions::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 20);
|
||||
assert_eq!(table.name, "test");
|
||||
}
|
||||
@@ -1003,23 +1060,47 @@ mod tests {
|
||||
.unwrap();
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||
|
||||
let new_batches = RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(Int32Array::from_iter_values(100..110))],
|
||||
)
|
||||
.unwrap()]
|
||||
.into_iter()
|
||||
.map(Ok),
|
||||
let batches = vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
);
|
||||
vec![Arc::new(Int32Array::from_iter_values(100..110))],
|
||||
)
|
||||
.unwrap()]
|
||||
.into_iter()
|
||||
.map(Ok);
|
||||
|
||||
let new_batches = RecordBatchIterator::new(batches.clone(), schema.clone());
|
||||
|
||||
// Can overwrite using AddDataOptions::mode
|
||||
table
|
||||
.add(
|
||||
Box::new(new_batches),
|
||||
AddDataOptions {
|
||||
mode: AddDataMode::Overwrite,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||
assert_eq!(table.name, "test");
|
||||
|
||||
// Can overwrite using underlying WriteParams (which
|
||||
// take precedence over AddDataOptions::mode)
|
||||
|
||||
let param: WriteParams = WriteParams {
|
||||
mode: WriteMode::Overwrite,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
table.add(Box::new(new_batches), Some(param)).await.unwrap();
|
||||
let opts = AddDataOptions {
|
||||
write_options: WriteOptions {
|
||||
lance_write_params: Some(param),
|
||||
},
|
||||
mode: AddDataMode::Append,
|
||||
};
|
||||
|
||||
let new_batches = RecordBatchIterator::new(batches.clone(), schema.clone());
|
||||
table.add(Box::new(new_batches), opts).await.unwrap();
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||
assert_eq!(table.name, "test");
|
||||
}
|
||||
@@ -1329,7 +1410,7 @@ mod tests {
|
||||
..Default::default()
|
||||
};
|
||||
assert!(!wrapper.called());
|
||||
let _ = NativeTable::open_with_params(uri, "test", None, param)
|
||||
let _ = NativeTable::open_with_params(uri, "test", None, Some(param))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(wrapper.called());
|
||||
|
||||
@@ -32,20 +32,17 @@ impl PatchStoreParam for Option<ObjectStoreParams> {
|
||||
}
|
||||
|
||||
pub trait PatchWriteParam {
|
||||
fn patch_with_store_wrapper(
|
||||
self,
|
||||
wrapper: Arc<dyn WrappingObjectStore>,
|
||||
) -> Result<Option<WriteParams>>;
|
||||
fn patch_with_store_wrapper(self, wrapper: Arc<dyn WrappingObjectStore>)
|
||||
-> Result<WriteParams>;
|
||||
}
|
||||
|
||||
impl PatchWriteParam for Option<WriteParams> {
|
||||
impl PatchWriteParam for WriteParams {
|
||||
fn patch_with_store_wrapper(
|
||||
self,
|
||||
mut self,
|
||||
wrapper: Arc<dyn WrappingObjectStore>,
|
||||
) -> Result<Option<WriteParams>> {
|
||||
let mut params = self.unwrap_or_default();
|
||||
params.store_params = params.store_params.patch_with_store_wrapper(wrapper)?;
|
||||
Ok(Some(params))
|
||||
) -> Result<WriteParams> {
|
||||
self.store_params = self.store_params.patch_with_store_wrapper(wrapper)?;
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user