Compare commits

..

2 Commits

Author SHA1 Message Date
ayush chaurasia
8debf26b81 update 2024-02-15 21:46:34 +05:30
ayush chaurasia
d2af9fd81d update 2024-02-15 21:40:16 +05:30
35 changed files with 519 additions and 1222 deletions

View File

@@ -33,8 +33,3 @@ rustflags = ["-C", "target-cpu=haswell", "-C", "target-feature=+avx2,+fma,+f16c"
[target.aarch64-apple-darwin]
rustflags = ["-C", "target-cpu=apple-m1", "-C", "target-feature=+neon,+fp16,+fhm,+dotprod"]
# Not all Windows systems have the C runtime installed, so this avoids library
# not found errors on systems that are missing it.
[target.x86_64-pc-windows-msvc]
rustflags = ["-Ctarget-feature=+crt-static"]

View File

@@ -14,10 +14,10 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
categories = ["database-implementations"]
[workspace.dependencies]
lance = { "version" = "=0.9.18", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.9.18" }
lance-linalg = { "version" = "=0.9.18" }
lance-testing = { "version" = "=0.9.18" }
lance = { "version" = "=0.9.16", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.9.16" }
lance-linalg = { "version" = "=0.9.16" }
lance-testing = { "version" = "=0.9.16" }
# Note that this one does not include pyarrow
arrow = { version = "50.0", optional = false }
arrow-array = "50.0"

View File

@@ -57,16 +57,6 @@ plugins:
- https://arrow.apache.org/docs/objects.inv
- https://pandas.pydata.org/docs/objects.inv
- mkdocs-jupyter
- ultralytics:
verbose: True
enabled: True
default_image: "assets/lancedb_and_lance.png" # Default image for all pages
add_image: True # Automatically add meta image
add_keywords: True # Add page keywords in the header tag
add_share_buttons: True # Add social share buttons
add_authors: False # Display page authors
add_desc: False
add_dates: False
markdown_extensions:
- admonition
@@ -109,9 +99,10 @@ nav:
- Configuring Storage: guides/storage.md
- 🧬 Managing embeddings:
- Overview: embeddings/index.md
- Embedding functions: embeddings/embedding_functions.md
- Available models: embeddings/default_embedding_functions.md
- User-defined embedding functions: embeddings/custom_embedding_function.md
- Explicit management: embeddings/embedding_explicit.md
- Implicit management: embeddings/embedding_functions.md
- Available Functions: embeddings/default_embedding_functions.md
- Custom Embedding Functions: embeddings/api.md
- "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb
- "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb
- 🔌 Integrations:
@@ -173,9 +164,10 @@ nav:
- Configuring Storage: guides/storage.md
- Managing Embeddings:
- Overview: embeddings/index.md
- Embedding functions: embeddings/embedding_functions.md
- Available models: embeddings/default_embedding_functions.md
- User-defined embedding functions: embeddings/custom_embedding_function.md
- Explicit management: embeddings/embedding_explicit.md
- Implicit management: embeddings/embedding_functions.md
- Available Functions: embeddings/default_embedding_functions.md
- Custom Embedding Functions: embeddings/api.md
- "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb
- "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb
- Integrations:

View File

@@ -2,5 +2,4 @@ mkdocs==1.5.3
mkdocs-jupyter==0.24.1
mkdocs-material==9.5.3
mkdocstrings[python]==0.20.0
pydantic
mkdocs-ultralytics-plugin==0.0.44
pydantic

View File

@@ -0,0 +1,141 @@
In this workflow, you define your own embedding function and pass it as a callable to LanceDB, invoking it in your code to generate the embeddings. Let's look at some examples.
### Hugging Face
!!! note
Currently, the Hugging Face method is only supported in the Python SDK.
=== "Python"
The most popular open source option is to use the [sentence-transformers](https://www.sbert.net/)
library, which can be installed via pip.
```bash
pip install sentence-transformers
```
The example below shows how to use the `paraphrase-albert-small-v2` model to generate embeddings
for a given document.
```python
from sentence_transformers import SentenceTransformer
name="paraphrase-albert-small-v2"
model = SentenceTransformer(name)
# used for both training and querying
def embed_func(batch):
return [model.encode(sentence) for sentence in batch]
```
### OpenAI
Another popular alternative is to use an external API like OpenAI's [embeddings API](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings).
=== "Python"
```python
import openai
import os
# Configuring the environment variable OPENAI_API_KEY
if "OPENAI_API_KEY" not in os.environ:
# OR set the key here as a variable
openai.api_key = "sk-..."
# verify that the API key is working
assert len(openai.Model.list()["data"]) > 0
def embed_func(c):
rs = openai.Embedding.create(input=c, engine="text-embedding-ada-002")
return [record["embedding"] for record in rs["data"]]
```
=== "JavaScript"
```javascript
const lancedb = require("vectordb");
// You need to provide an OpenAI API key
const apiKey = "sk-..."
// The embedding function will create embeddings for the 'text' column
const embedding = new lancedb.OpenAIEmbeddingFunction('text', apiKey)
```
## Applying an embedding function to data
=== "Python"
Using an embedding function, you can apply it to raw data
to generate embeddings for each record.
Say you have a pandas DataFrame with a `text` column that you want embedded,
you can use the `with_embeddings` function to generate embeddings and add them to
an existing table.
```python
import pandas as pd
from lancedb.embeddings import with_embeddings
df = pd.DataFrame(
[
{"text": "pepperoni"},
{"text": "pineapple"}
]
)
data = with_embeddings(embed_func, df)
# The output is used to create / append to a table
# db.create_table("my_table", data=data)
```
If your data is in a different column, you can specify the `column` kwarg to `with_embeddings`.
By default, LanceDB calls the function with batches of 1000 rows. This can be configured
using the `batch_size` parameter to `with_embeddings`.
LanceDB automatically wraps the function with retry and rate-limit logic to ensure the OpenAI
API call is reliable.
=== "JavaScript"
Using an embedding function, you can apply it to raw data
to generate embeddings for each record.
Simply pass the embedding function created above and LanceDB will use it to generate
embeddings for your data.
```javascript
const db = await lancedb.connect("data/sample-lancedb");
const data = [
{ text: "pepperoni"},
{ text: "pineapple"}
]
const table = await db.createTable("vectors", data, embedding)
```
## Querying using an embedding function
!!! warning
At query time, you **must** use the same embedding function you used to vectorize your data.
If you use a different embedding function, the embeddings will not reside in the same vector
space and the results will be nonsensical.
=== "Python"
```python
query = "What's the best pizza topping?"
query_vector = embed_func([query])[0]
results = (
tbl.search(query_vector)
.limit(10)
.to_pandas()
)
```
The above snippet returns a pandas DataFrame with the 10 closest vectors to the query.
=== "JavaScript"
```javascript
const results = await table
.search("What's the best pizza topping?")
.limit(10)
.execute()
```
The above snippet returns an array of records with the top 10 nearest neighbors to the query.

View File

@@ -3,126 +3,61 @@ Representing multi-modal data as vector embeddings is becoming a standard practi
For this purpose, LanceDB introduces an **embedding functions API**, that allow you simply set up once, during the configuration stage of your project. After this, the table remembers it, effectively making the embedding functions *disappear in the background* so you don't have to worry about manually passing callables, and instead, simply focus on the rest of your data engineering pipeline.
!!! warning
Using the embedding function registry means that you don't have to explicitly generate the embeddings yourself.
However, if your embedding function changes, you'll have to re-configure your table with the new embedding function
and regenerate the embeddings. In the future, we plan to support the ability to change the embedding function via
table metadata and have LanceDB automatically take care of regenerating the embeddings.
Using the implicit embeddings management approach means that you can forget about the manually passing around embedding
functions in your code, as long as you don't intend to change it at a later time. If your embedding function changes,
you'll have to re-configure your table with the new embedding function and regenerate the embeddings.
## 1. Define the embedding function
We have some pre-defined embedding functions in the global registry, with more coming soon. Here's let's an implementation of CLIP as example.
```
registry = EmbeddingFunctionRegistry.get_instance()
clip = registry.get("open-clip").create()
=== "Python"
In the LanceDB python SDK, we define a global embedding function registry with
many different embedding models and even more coming soon.
Here's let's an implementation of CLIP as example.
```python
from lancedb.embeddings import get_registry
registry = get_registry()
clip = registry.get("open-clip").create()
```
You can also define your own embedding function by implementing the `EmbeddingFunction`
abstract base interface. It subclasses Pydantic Model which can be utilized to write complex schemas simply as we'll see next!
=== "JavaScript""
In the TypeScript SDK, the choices are more limited. For now, only the OpenAI
embedding function is available.
```javascript
const lancedb = require("vectordb");
// You need to provide an OpenAI API key
const apiKey = "sk-..."
// The embedding function will create embeddings for the 'text' column
const embedding = new lancedb.OpenAIEmbeddingFunction('text', apiKey)
```
```
You can also define your own embedding function by implementing the `EmbeddingFunction` abstract base interface. It subclasses Pydantic Model which can be utilized to write complex schemas simply as we'll see next!
## 2. Define the data model or schema
The embedding function defined above abstracts away all the details about the models and dimensions required to define the schema. You can simply set a field as **source** or **vector** column. Here's how:
=== "Python"
The embedding function defined above abstracts away all the details about the models and dimensions required to define the schema. You can simply set a field as **source** or **vector** column. Here's how:
```python
class Pets(LanceModel):
vector: Vector(clip.ndims) = clip.VectorField()
image_uri: str = clip.SourceField()
```
```python
class Pets(LanceModel):
vector: Vector(clip.ndims) = clip.VectorField()
image_uri: str = clip.SourceField()
```
`VectorField` tells LanceDB to use the clip embedding function to generate query embeddings for the `vector` column and `SourceField` ensures that when adding data, we automatically use the specified embedding function to encode `image_uri`.
`VectorField` tells LanceDB to use the clip embedding function to generate query embeddings for the `vector` column and `SourceField` ensures that when adding data, we automatically use the specified embedding function to encode `image_uri`.
## 3. Create LanceDB table
Now that we have chosen/defined our embedding function and the schema, we can create the table:
=== "JavaScript"
```python
db = lancedb.connect("~/lancedb")
table = db.create_table("pets", schema=Pets)
For the TypeScript SDK, a schema can be inferred from input data, or an explicit
Arrow schema can be provided.
```
## 3. Create table and add data
That's it! We've provided all the information needed to embed the source and query inputs. We can now forget about the model and dimension details and start to build our VectorDB pipeline.
Now that we have chosen/defined our embedding function and the schema,
we can create the table and ingest data without needing to explicitly generate
the embeddings at all:
## 4. Ingest lots of data and query your table
Any new or incoming data can just be added and it'll be vectorized automatically.
=== "Python"
```python
db = lancedb.connect("~/lancedb")
table = db.create_table("pets", schema=Pets)
```python
table.add([{"image_uri": u} for u in uris])
```
table.add([{"image_uri": u} for u in uris])
```
Our OpenCLIP query embedding function supports querying via both text and images:
=== "JavaScript"
```python
result = table.search("dog")
```
```javascript
const db = await lancedb.connect("data/sample-lancedb");
const data = [
{ text: "pepperoni"},
{ text: "pineapple"}
]
Let's query an image:
const table = await db.createTable("vectors", data, embedding)
```
## 4. Querying your table
Not only can you forget about the embeddings during ingestion, you also don't
need to worry about it when you query the table:
=== "Python"
Our OpenCLIP query embedding function supports querying via both text and images:
```python
results = (
table.search("dog")
.limit(10)
.to_pandas()
)
```
Or we can search using an image:
```python
p = Path("path/to/images/samoyed_100.jpg")
query_image = Image.open(p)
results = (
table.search(query_image)
.limit(10)
.to_pandas()
)
```
Both of the above snippet returns a pandas DataFrame with the 10 closest vectors to the query.
=== "JavaScript"
```javascript
const results = await table
.search("What's the best pizza topping?")
.limit(10)
.execute()
```
The above snippet returns an array of records with the top 10 nearest neighbors to the query.
```python
p = Path("path/to/images/samoyed_100.jpg")
query_image = Image.open(p)
table.search(query_image)
```
---
@@ -165,5 +100,4 @@ rs[2].image
![](../assets/dog_clip_output.png)
Now that you have the basic idea about LanceDB embedding functions and the embedding function registry,
let's dive deeper into defining your own [custom functions](./custom_embedding_function.md).
Now that you have the basic idea about implicit management via embedding functions, let's dive deeper into a [custom API](./api.md) that you can use to implement your own embedding functions.

View File

@@ -1,14 +1,8 @@
Due to the nature of vector embeddings, they can be used to represent any kind of data, from text to images to audio.
This makes them a very powerful tool for machine learning practitioners.
However, there's no one-size-fits-all solution for generating embeddings - there are many different libraries and APIs
(both commercial and open source) that can be used to generate embeddings from structured/unstructured data.
Due to the nature of vector embeddings, they can be used to represent any kind of data, from text to images to audio. This makes them a very powerful tool for machine learning practitioners. However, there's no one-size-fits-all solution for generating embeddings - there are many different libraries and APIs (both commercial and open source) that can be used to generate embeddings from structured/unstructured data.
LanceDB supports 3 methods of working with embeddings.
LanceDB supports 2 methods of vectorizing your raw data into embeddings.
1. You can manually generate embeddings for the data and queries. This is done outside of LanceDB.
2. You can use the built-in [embedding functions](./embedding_functions.md) to embed the data and queries in the background.
3. For python users, you can define your own [custom embedding function](./custom_embedding_function.md)
that extends the default embedding functions.
1. **Explicit**: By manually calling LanceDB's `with_embedding` function to vectorize your data via an `embed_func` of your choice
2. **Implicit**: Allow LanceDB to embed the data and queries in the background as they come in, by using the table's `EmbeddingRegistry` information
For python users, there is also a legacy [with_embeddings API](./legacy.md).
It is retained for compatibility and will be removed in a future version.
See the [explicit](embedding_explicit.md) and [implicit](embedding_functions.md) embedding sections for more details.

View File

@@ -1,99 +0,0 @@
The legacy `with_embeddings` API is for Python only and is deprecated.
### Hugging Face
The most popular open source option is to use the [sentence-transformers](https://www.sbert.net/)
library, which can be installed via pip.
```bash
pip install sentence-transformers
```
The example below shows how to use the `paraphrase-albert-small-v2` model to generate embeddings
for a given document.
```python
from sentence_transformers import SentenceTransformer
name="paraphrase-albert-small-v2"
model = SentenceTransformer(name)
# used for both training and querying
def embed_func(batch):
return [model.encode(sentence) for sentence in batch]
```
### OpenAI
Another popular alternative is to use an external API like OpenAI's [embeddings API](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings).
```python
import openai
import os
# Configuring the environment variable OPENAI_API_KEY
if "OPENAI_API_KEY" not in os.environ:
# OR set the key here as a variable
openai.api_key = "sk-..."
client = openai.OpenAI()
def embed_func(c):
rs = client.embeddings.create(input=c, model="text-embedding-ada-002")
return [record.embedding for record in rs["data"]]
```
## Applying an embedding function to data
Using an embedding function, you can apply it to raw data
to generate embeddings for each record.
Say you have a pandas DataFrame with a `text` column that you want embedded,
you can use the `with_embeddings` function to generate embeddings and add them to
an existing table.
```python
import pandas as pd
from lancedb.embeddings import with_embeddings
df = pd.DataFrame(
[
{"text": "pepperoni"},
{"text": "pineapple"}
]
)
data = with_embeddings(embed_func, df)
# The output is used to create / append to a table
tbl = db.create_table("my_table", data=data)
```
If your data is in a different column, you can specify the `column` kwarg to `with_embeddings`.
By default, LanceDB calls the function with batches of 1000 rows. This can be configured
using the `batch_size` parameter to `with_embeddings`.
LanceDB automatically wraps the function with retry and rate-limit logic to ensure the OpenAI
API call is reliable.
## Querying using an embedding function
!!! warning
At query time, you **must** use the same embedding function you used to vectorize your data.
If you use a different embedding function, the embeddings will not reside in the same vector
space and the results will be nonsensical.
=== "Python"
```python
query = "What's the best pizza topping?"
query_vector = embed_func([query])[0]
results = (
tbl.search(query_vector)
.limit(10)
.to_pandas()
)
```
The above snippet returns a pandas DataFrame with the 10 closest vectors to the query.

View File

@@ -1,5 +1,6 @@
import pickle
import re
import sys
import zipfile
from pathlib import Path

View File

@@ -290,7 +290,7 @@
"from lancedb.pydantic import LanceModel, Vector\n",
"\n",
"class Pets(LanceModel):\n",
" vector: Vector(clip.ndims()) = clip.VectorField()\n",
" vector: Vector(clip.ndims) = clip.VectorField()\n",
" image_uri: str = clip.SourceField()\n",
"\n",
" @property\n",
@@ -360,7 +360,7 @@
" table = db.create_table(\"pets\", schema=Pets)\n",
" # use a sampling of 1000 images\n",
" p = Path(\"~/Downloads/images\").expanduser()\n",
" uris = [str(f) for f in p.glob(\"*.jpg\")]\n",
" uris = [str(f) for f in p.iterdir()]\n",
" uris = sample(uris, 1000)\n",
" table.add(pd.DataFrame({\"image_uri\": uris}))"
]
@@ -543,7 +543,7 @@
],
"source": [
"from PIL import Image\n",
"p = Path(\"~/Downloads/images/samoyed_100.jpg\").expanduser()\n",
"p = Path(\"/Users/changshe/Downloads/images/samoyed_100.jpg\")\n",
"query_image = Image.open(p)\n",
"query_image"
]

View File

@@ -23,8 +23,10 @@ from multiprocessing import Pool
import lance
import pyarrow as pa
from datasets import load_dataset
from PIL import Image
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast
import lancedb
MODEL_ID = "openai/clip-vit-base-patch32"

View File

@@ -1,9 +1,6 @@
# DuckDB
In Python, LanceDB tables can also be queried with [DuckDB](https://duckdb.org/), an in-process SQL OLAP database. This means you can write complex SQL queries to analyze your data in LanceDB.
This integration is done via [Apache Arrow](https://duckdb.org/docs/guides/python/sql_on_arrow), which provides zero-copy data sharing between LanceDB and DuckDB. DuckDB is capable of passing down column selections and basic filters to LanceDB, reducing the amount of data that needs to be scanned to perform your query. Finally, the integration allows streaming data from LanceDB tables, allowing you to aggregate tables that won't fit into memory. All of this uses the same mechanism described in DuckDB's blog post *[DuckDB quacks Arrow](https://duckdb.org/2021/12/03/duck-arrow.html)*.
LanceDB is very well-integrated with [DuckDB](https://duckdb.org/), an in-process SQL OLAP database. This integration is done via [Arrow](https://duckdb.org/docs/guides/python/sql_on_arrow) .
We can demonstrate this by first installing `duckdb` and `lancedb`.
@@ -22,15 +19,14 @@ data = [
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}
]
table = db.create_table("pd_table", data=data)
arrow_table = table.to_arrow()
```
To query the table, first call `to_lance` to convert the table to a "dataset", which is an object that can be queried by DuckDB. Then all you need to do is reference that dataset by the same name in your SQL query.
DuckDB can directly query the `pyarrow.Table` object:
```python
import duckdb
arrow_table = table.to_lance()
duckdb.query("SELECT * FROM arrow_table")
```

View File

@@ -12,16 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use napi::bindgen_prelude::*;
use napi_derive::*;
use crate::table::Table;
use vectordb::connection::Connection as LanceDBConnection;
use vectordb::connection::{Connection as LanceDBConnection, Database};
use vectordb::ipc::ipc_file_to_batches;
#[napi]
pub struct Connection {
conn: LanceDBConnection,
conn: Arc<dyn LanceDBConnection>,
}
#[napi]
@@ -30,9 +32,9 @@ impl Connection {
#[napi(factory)]
pub async fn new(uri: String) -> napi::Result<Self> {
Ok(Self {
conn: vectordb::connect(&uri).execute().await.map_err(|e| {
conn: Arc::new(Database::connect(&uri).await.map_err(|e| {
napi::Error::from_reason(format!("Failed to connect to database: {}", e))
})?,
})?),
})
}
@@ -57,8 +59,7 @@ impl Connection {
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
let tbl = self
.conn
.create_table(&name, Box::new(batches))
.execute()
.create_table(&name, Box::new(batches), None)
.await
.map_err(|e| napi::Error::from_reason(format!("{}", e)))?;
Ok(Table::new(tbl))
@@ -69,7 +70,6 @@ impl Connection {
let tbl = self
.conn
.open_table(&name)
.execute()
.await
.map_err(|e| napi::Error::from_reason(format!("{}", e)))?;
Ok(Table::new(tbl))

View File

@@ -15,7 +15,6 @@
use arrow_ipc::writer::FileWriter;
use napi::bindgen_prelude::*;
use napi_derive::napi;
use vectordb::table::AddDataOptions;
use vectordb::{ipc::ipc_file_to_batches, table::TableRef};
use crate::index::IndexBuilder;
@@ -49,15 +48,12 @@ impl Table {
pub async fn add(&self, buf: Buffer) -> napi::Result<()> {
let batches = ipc_file_to_batches(buf.to_vec())
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
self.table
.add(Box::new(batches), AddDataOptions::default())
.await
.map_err(|e| {
napi::Error::from_reason(format!(
"Failed to add batches to table {}: {}",
self.table, e
))
})
self.table.add(Box::new(batches), None).await.map_err(|e| {
napi::Error::from_reason(format!(
"Failed to add batches to table {}: {}",
self.table, e
))
})
}
#[napi]

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.5.7
current_version = 0.5.5
commit = True
message = [python] Bump version: {current_version} → {new_version}
tag = True

View File

@@ -13,9 +13,8 @@
import importlib.metadata
import os
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import Optional, Union
from typing import Optional
__version__ = importlib.metadata.version("lancedb")
@@ -33,7 +32,6 @@ def connect(
region: str = "us-east-1",
host_override: Optional[str] = None,
read_consistency_interval: Optional[timedelta] = None,
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
) -> DBConnection:
"""Connect to a LanceDB database.
@@ -60,14 +58,7 @@ def connect(
the last check, then the table will be checked for updates. Note: this
consistency only applies to read operations. Write operations are
always consistent.
request_thread_pool: int or ThreadPoolExecutor, optional
The thread pool to use for making batch requests to the LanceDB Cloud API.
If an integer, then a ThreadPoolExecutor will be created with that
number of threads. If None, then a ThreadPoolExecutor will be created
with the default number of threads. If a ThreadPoolExecutor, then that
executor will be used for making requests. This is for LanceDB Cloud
only and is only used when making batch requests (i.e., passing in
multiple queries to the search method at once).
Examples
--------
@@ -95,9 +86,5 @@ def connect(
api_key = os.environ.get("LANCEDB_API_KEY")
if api_key is None:
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
if isinstance(request_thread_pool, int):
request_thread_pool = ThreadPoolExecutor(request_thread_pool)
return RemoteDBConnection(
uri, api_key, region, host_override, request_thread_pool=request_thread_pool
)
return RemoteDBConnection(uri, api_key, region, host_override)
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)

View File

@@ -1,172 +0,0 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import List, Union
import numpy as np
import pyarrow as pa
from ..util import attempt_import_or_raise
from .base import EmbeddingFunction
from .registry import register
from .utils import AUDIO, IMAGES, TEXT
@register("imagebind")
class ImageBindEmbeddings(EmbeddingFunction):
"""
An embedding function that uses the ImageBind API
For generating multi-modal embeddings across
six different modalities: images, text, audio, depth, thermal, and IMU data
to download package, run :
`pip install imagebind@git+https://github.com/raghavdixit99/ImageBind`
"""
name: str = "imagebind_huge"
device: str = "cpu"
normalize: bool = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._ndims = 1024
self._audio_extensions = (".mp3", ".wav", ".flac", ".ogg", ".aac")
self._image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp")
@cached_property
def embedding_model(self):
"""
Get the embedding model. This is cached so that the model is only loaded
once per process.
"""
return self.get_embedding_model()
@cached_property
def _data(self):
"""
Get the data module from imagebind
"""
data = attempt_import_or_raise("imagebind.data", "imagebind")
return data
@cached_property
def _ModalityType(self):
"""
Get the ModalityType from imagebind
"""
imagebind = attempt_import_or_raise("imagebind", "imagebind")
return imagebind.imagebind_model.ModalityType
def ndims(self):
return self._ndims
def compute_query_embeddings(
self, query: Union[str], *args, **kwargs
) -> List[np.ndarray]:
"""
Compute the embeddings for a given user query
Parameters
----------
query : Union[str]
The query to embed. A query can be either text, image paths or audio paths.
"""
query = self.sanitize_input(query)
if query[0].endswith(self._audio_extensions):
return [self.generate_audio_embeddings(query)]
elif query[0].endswith(self._image_extensions):
return [self.generate_image_embeddings(query)]
else:
return [self.generate_text_embeddings(query)]
def generate_image_embeddings(self, image: IMAGES) -> np.ndarray:
torch = attempt_import_or_raise("torch")
inputs = {
self._ModalityType.VISION: self._data.load_and_transform_vision_data(
image, self.device
)
}
with torch.no_grad():
image_features = self.embedding_model(inputs)[self._ModalityType.VISION]
if self.normalize:
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy().squeeze()
def generate_audio_embeddings(self, audio: AUDIO) -> np.ndarray:
torch = attempt_import_or_raise("torch")
inputs = {
self._ModalityType.AUDIO: self._data.load_and_transform_audio_data(
audio, self.device
)
}
with torch.no_grad():
audio_features = self.embedding_model(inputs)[self._ModalityType.AUDIO]
if self.normalize:
audio_features /= audio_features.norm(dim=-1, keepdim=True)
return audio_features.cpu().numpy().squeeze()
def generate_text_embeddings(self, text: TEXT) -> np.ndarray:
torch = attempt_import_or_raise("torch")
inputs = {
self._ModalityType.TEXT: self._data.load_and_transform_text(
text, self.device
)
}
with torch.no_grad():
text_features = self.embedding_model(inputs)[self._ModalityType.TEXT]
if self.normalize:
text_features /= text_features.norm(dim=-1, keepdim=True)
return text_features.cpu().numpy().squeeze()
def compute_source_embeddings(
self, source: Union[IMAGES, AUDIO], *args, **kwargs
) -> List[np.array]:
"""
Get the embeddings for the given sourcefield column in the pydantic model.
"""
source = self.sanitize_input(source)
embeddings = []
if source[0].endswith(self._audio_extensions):
embeddings.extend(self.generate_audio_embeddings(source))
return embeddings
elif source[0].endswith(self._image_extensions):
embeddings.extend(self.generate_image_embeddings(source))
return embeddings
else:
embeddings.extend(self.generate_text_embeddings(source))
return embeddings
def sanitize_input(
self, input: Union[IMAGES, AUDIO]
) -> Union[List[bytes], np.ndarray]:
"""
Sanitize the input to the embedding function.
"""
if isinstance(input, (str, bytes)):
input = [input]
elif isinstance(input, pa.Array):
input = input.to_pylist()
elif isinstance(input, pa.ChunkedArray):
input = input.combine_chunks().to_pylist()
return input
def get_embedding_model(self):
"""
fetches the imagebind embedding model
"""
imagebind = attempt_import_or_raise("imagebind", "imagebind")
model = imagebind.imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(self.device)
return model

View File

@@ -26,7 +26,7 @@ import pyarrow as pa
from lance.vector import vec_to_table
from retry import retry
from ..util import deprecated, safe_import_pandas
from ..util import safe_import_pandas
from ..utils.general import LOGGER
pd = safe_import_pandas()
@@ -36,10 +36,8 @@ TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
IMAGES = Union[
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
]
AUDIO = Union[str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray]
@deprecated
def with_embeddings(
func: Callable,
data: DATA,

View File

@@ -14,7 +14,6 @@
import inspect
import logging
import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Iterable, List, Optional, Union
from urllib.parse import urlparse
@@ -40,7 +39,6 @@ class RemoteDBConnection(DBConnection):
api_key: str,
region: str,
host_override: Optional[str] = None,
request_thread_pool: Optional[ThreadPoolExecutor] = None,
):
"""Connect to a remote LanceDB database."""
parsed = urlparse(db_url)
@@ -51,7 +49,6 @@ class RemoteDBConnection(DBConnection):
self._client = RestfulLanceDBClient(
self.db_name, region, api_key, host_override
)
self._request_thread_pool = request_thread_pool
def __repr__(self) -> str:
return f"RemoteConnect(name={self.db_name})"

View File

@@ -13,7 +13,6 @@
import logging
import uuid
from concurrent.futures import Future
from functools import cached_property
from typing import Dict, Optional, Union
@@ -271,28 +270,15 @@ class RemoteTable(Table):
and len(query.vector) > 0
and not isinstance(query.vector[0], float)
):
if self._conn._request_thread_pool is None:
def submit(name, q):
f = Future()
f.set_result(self._conn._client.query(name, q))
return f
else:
def submit(name, q):
return self._conn._request_thread_pool.submit(
self._conn._client.query, name, q
)
results = []
for v in query.vector:
v = list(v)
q = query.copy()
q.vector = v
results.append(submit(self._name, q))
results.append(self._conn._client.query(self._name, q))
return pa.concat_tables(
[add_index(r.result().to_arrow(), i) for i, r in enumerate(results)]
[add_index(r.to_arrow(), i) for i, r in enumerate(results)]
)
else:
result = self._conn._client.query(self._name, query)

View File

@@ -11,11 +11,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import importlib
import os
import pathlib
import warnings
from datetime import date, datetime
from functools import singledispatch
from typing import Tuple, Union
@@ -241,25 +239,3 @@ def _(value: list):
@value_to_sql.register(np.ndarray)
def _(value: np.ndarray):
return value_to_sql(value.tolist())
def deprecated(func):
"""This is a decorator which can be used to mark functions
as deprecated. It will result in a warning being emitted
when the function is used."""
@functools.wraps(func)
def new_func(*args, **kwargs):
warnings.simplefilter("always", DeprecationWarning) # turn off filter
warnings.warn(
(
f"Function {func.__name__} is deprecated and will be "
"removed in a future version"
),
category=DeprecationWarning,
stacklevel=2,
)
warnings.simplefilter("default", DeprecationWarning) # reset filter
return func(*args, **kwargs)
return new_func

View File

@@ -1,9 +1,9 @@
[project]
name = "lancedb"
version = "0.5.7"
version = "0.5.5"
dependencies = [
"deprecation",
"pylance==0.9.18",
"pylance==0.9.16",
"ratelimiter~=1.0",
"retry>=0.9.2",
"tqdm>=4.27.0",
@@ -50,7 +50,7 @@ repository = "https://github.com/lancedb/lancedb"
[project.optional-dependencies]
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars>=0.19"]
dev = ["ruff", "pre-commit"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]", "mkdocs-ultralytics-plugin==0.0.44"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"]
embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "huggingface_hub",
"InstructorEmbedding", "google.generativeai", "boto3>=1.28.57", "awscli>=1.29.57", "botocore>=1.31.57"]

View File

@@ -28,23 +28,6 @@ from lancedb.pydantic import LanceModel, Vector
# or connection to external api
try:
if importlib.util.find_spec("mlx.core") is not None:
_mlx = True
else:
_mlx = None
except Exception:
_mlx = None
try:
if importlib.util.find_spec("imagebind") is not None:
_imagebind = True
else:
_imagebind = None
except Exception:
_imagebind = None
@pytest.mark.slow
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
def test_basic_text_embeddings(alias, tmp_path):
@@ -175,89 +158,6 @@ def test_openclip(tmp_path):
)
@pytest.mark.skipif(
_imagebind is None,
reason="skip if imagebind not installed.",
)
@pytest.mark.slow
def test_imagebind(tmp_path):
import os
import shutil
import tempfile
import pandas as pd
import requests
import lancedb.embeddings.imagebind
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
with tempfile.TemporaryDirectory() as temp_dir:
print(f"Created temporary directory {temp_dir}")
def download_images(image_uris):
downloaded_image_paths = []
for uri in image_uris:
try:
response = requests.get(uri, stream=True)
if response.status_code == 200:
# Extract image name from URI
image_name = os.path.basename(uri)
image_path = os.path.join(temp_dir, image_name)
with open(image_path, "wb") as out_file:
shutil.copyfileobj(response.raw, out_file)
downloaded_image_paths.append(image_path)
except Exception as e: # noqa: PERF203
print(f"Failed to download {uri}. Error: {e}")
return temp_dir, downloaded_image_paths
db = lancedb.connect(tmp_path)
registry = get_registry()
func = registry.get("imagebind").create(max_retries=0)
class Images(LanceModel):
label: str
image_uri: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
table = db.create_table("images", schema=Images)
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
temp_dir, downloaded_images = download_images(uris)
table.add(pd.DataFrame({"label": labels, "image_uri": downloaded_images}))
# text search
actual = (
table.search("man's best friend", vector_column_name="vector")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == "dog"
# image search
query_image_uri = [
"https://live.staticflickr.com/65535/33336453970_491665f66e_h.jpg"
]
temp_dir, downloaded_images = download_images(query_image_uri)
query_image_uri = downloaded_images[0]
actual = (
table.search(query_image_uri, vector_column_name="vector")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == "dog"
if os.path.isdir(temp_dir):
shutil.rmtree(temp_dir)
print(f"Deleted temporary directory {temp_dir}")
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
@@ -317,6 +217,13 @@ def test_gemini_embedding(tmp_path):
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
try:
if importlib.util.find_spec("mlx.core") is not None:
_mlx = True
except ImportError:
_mlx = None
@pytest.mark.skipif(
_mlx is None,
reason="mlx tests only required for apple users.",

View File

@@ -803,8 +803,10 @@ def test_count_rows(db):
assert table.count_rows(filter="text='bar'") == 1
def test_hybrid_search(db, tmp_path):
db = MockDB(str(tmp_path))
def test_hybrid_search(db):
# hardcoding temporarily.. this test is failing with tmp_path mockdb.
# Probably not being parsed right by the fts
db = MockDB("~/lancedb_")
# Create a LanceDB table schema with a vector and a text column
emb = EmbeddingFunctionRegistry.get_instance().get("test")()

View File

@@ -22,9 +22,9 @@ use object_store::CredentialProvider;
use once_cell::sync::OnceCell;
use tokio::runtime::Runtime;
use vectordb::connect;
use vectordb::connection::Connection;
use vectordb::connection::Database;
use vectordb::table::ReadParams;
use vectordb::{ConnectOptions, Connection};
use crate::error::ResultExt;
use crate::query::JsQuery;
@@ -39,7 +39,7 @@ mod query;
mod table;
struct JsDatabase {
database: Connection,
database: Arc<dyn Connection + 'static>,
}
impl Finalize for JsDatabase {}
@@ -89,23 +89,23 @@ fn database_new(mut cx: FunctionContext) -> JsResult<JsPromise> {
let channel = cx.channel();
let (deferred, promise) = cx.promise();
let mut conn_builder = connect(&path);
let mut conn_options = ConnectOptions::new(&path);
if let Some(region) = region {
conn_builder = conn_builder.region(&region);
conn_options = conn_options.region(&region);
}
if let Some(aws_creds) = aws_creds {
conn_builder = conn_builder.aws_creds(AwsCredential {
conn_options = conn_options.aws_creds(AwsCredential {
key_id: aws_creds.key_id,
secret_key: aws_creds.secret_key,
token: aws_creds.token,
});
}
rt.spawn(async move {
let database = conn_builder.execute().await;
let database = Database::connect_with_options(&conn_options).await;
deferred.settle_with(&channel, move |mut cx| {
let db = JsDatabase {
database: database.or_throw(&mut cx)?,
database: Arc::new(database.or_throw(&mut cx)?),
};
Ok(cx.boxed(db))
});
@@ -217,11 +217,7 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
let (deferred, promise) = cx.promise();
rt.spawn(async move {
let table_rst = database
.open_table(&table_name)
.lance_read_params(params)
.execute()
.await;
let table_rst = database.open_table_with_params(&table_name, params).await;
deferred.settle_with(&channel, move |mut cx| {
let js_table = JsTable::from(table_rst.or_throw(&mut cx)?);

View File

@@ -18,7 +18,7 @@ use arrow_array::{RecordBatch, RecordBatchIterator};
use lance::dataset::optimize::CompactionOptions;
use lance::dataset::{WriteMode, WriteParams};
use lance::io::ObjectStoreParams;
use vectordb::table::{AddDataOptions, OptimizeAction, WriteOptions};
use vectordb::table::OptimizeAction;
use crate::arrow::{arrow_buffer_to_record_batch, record_batch_to_buffer};
use neon::prelude::*;
@@ -80,11 +80,7 @@ impl JsTable {
rt.spawn(async move {
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
let table_rst = database
.create_table(&table_name, Box::new(batch_reader))
.write_options(WriteOptions {
lance_write_params: Some(params),
})
.execute()
.create_table(&table_name, Box::new(batch_reader), Some(params))
.await;
deferred.settle_with(&channel, move |mut cx| {
@@ -125,13 +121,7 @@ impl JsTable {
rt.spawn(async move {
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
let opts = AddDataOptions {
write_options: WriteOptions {
lance_write_params: Some(params),
},
..Default::default()
};
let add_result = table.add(Box::new(batch_reader), opts).await;
let add_result = table.add(Box::new(batch_reader), Some(params)).await;
deferred.settle_with(&channel, move |mut cx| {
add_result.or_throw(&mut cx)?;

View File

@@ -19,8 +19,7 @@ use arrow_array::{FixedSizeListArray, Int32Array, RecordBatch, RecordBatchIterat
use arrow_schema::{DataType, Field, Schema};
use futures::TryStreamExt;
use vectordb::connection::Connection;
use vectordb::table::AddDataOptions;
use vectordb::Connection;
use vectordb::{connect, Result, Table, TableRef};
#[tokio::main]
@@ -30,18 +29,18 @@ async fn main() -> Result<()> {
}
// --8<-- [start:connect]
let uri = "data/sample-lancedb";
let db = connect(uri).execute().await?;
let db = connect(uri).await?;
// --8<-- [end:connect]
// --8<-- [start:list_names]
println!("{:?}", db.table_names().await?);
// --8<-- [end:list_names]
let tbl = create_table(&db).await?;
let tbl = create_table(db.clone()).await?;
create_index(tbl.as_ref()).await?;
let batches = search(tbl.as_ref()).await?;
println!("{:?}", batches);
create_empty_table(&db).await.unwrap();
create_empty_table(db.clone()).await.unwrap();
// --8<-- [start:delete]
tbl.delete("id > 24").await.unwrap();
@@ -56,14 +55,17 @@ async fn main() -> Result<()> {
#[allow(dead_code)]
async fn open_with_existing_tbl() -> Result<()> {
let uri = "data/sample-lancedb";
let db = connect(uri).execute().await?;
let db = connect(uri).await?;
// --8<-- [start:open_with_existing_file]
let _ = db.open_table("my_table").execute().await.unwrap();
let _ = db
.open_table_with_params("my_table", Default::default())
.await
.unwrap();
// --8<-- [end:open_with_existing_file]
Ok(())
}
async fn create_table(db: &Connection) -> Result<TableRef> {
async fn create_table(db: Arc<dyn Connection>) -> Result<TableRef> {
// --8<-- [start:create_table]
const TOTAL: usize = 1000;
const DIM: usize = 128;
@@ -100,8 +102,7 @@ async fn create_table(db: &Connection) -> Result<TableRef> {
schema.clone(),
);
let tbl = db
.create_table("my_table", Box::new(batches))
.execute()
.create_table("my_table", Box::new(batches), None)
.await
.unwrap();
// --8<-- [end:create_table]
@@ -125,21 +126,21 @@ async fn create_table(db: &Connection) -> Result<TableRef> {
schema.clone(),
);
// --8<-- [start:add]
tbl.add(Box::new(new_batches), AddDataOptions::default())
.await
.unwrap();
tbl.add(Box::new(new_batches), None).await.unwrap();
// --8<-- [end:add]
Ok(tbl)
}
async fn create_empty_table(db: &Connection) -> Result<TableRef> {
async fn create_empty_table(db: Arc<dyn Connection>) -> Result<TableRef> {
// --8<-- [start:create_empty_table]
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("item", DataType::Utf8, true),
]));
db.create_empty_table("empty_table", schema).execute().await
let batches = RecordBatchIterator::new(vec![], schema.clone());
db.create_table("empty_table", Box::new(batches), None)
.await
// --8<-- [end:create_empty_table]
}

View File

@@ -13,14 +13,14 @@
// limitations under the License.
//! LanceDB Database
//!
use std::fs::create_dir_all;
use std::path::Path;
use std::sync::Arc;
use arrow_array::{RecordBatchIterator, RecordBatchReader};
use arrow_schema::SchemaRef;
use lance::dataset::{ReadParams, WriteMode};
use arrow_array::RecordBatchReader;
use lance::dataset::WriteParams;
use lance::io::{ObjectStore, ObjectStoreParams, WrappingObjectStore};
use object_store::{
aws::AwsCredential, local::LocalFileSystem, CredentialProvider, StaticCredentialProvider,
@@ -29,283 +29,73 @@ use snafu::prelude::*;
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
use crate::io::object_store::MirroringObjectStoreWrapper;
use crate::table::{NativeTable, TableRef, WriteOptions};
use crate::table::{NativeTable, ReadParams, TableRef};
pub const LANCE_FILE_EXTENSION: &str = "lance";
pub type TableBuilderCallback = Box<dyn FnOnce(OpenTableBuilder) -> OpenTableBuilder + Send>;
/// Describes what happens when creating a table and a table with
/// the same name already exists
pub enum CreateTableMode {
/// If the table already exists, an error is returned
Create,
/// If the table already exists, it is opened. Any provided data is
/// ignored. The function will be passed an OpenTableBuilder to customize
/// how the table is opened
ExistOk(TableBuilderCallback),
/// If the table already exists, it is overwritten
Overwrite,
}
impl CreateTableMode {
pub fn exist_ok(
callback: impl FnOnce(OpenTableBuilder) -> OpenTableBuilder + Send + 'static,
) -> Self {
Self::ExistOk(Box::new(callback))
}
}
impl Default for CreateTableMode {
fn default() -> Self {
Self::Create
}
}
/// Describes what happens when a vector either contains NaN or
/// does not have enough values
#[derive(Clone, Debug, Default)]
enum BadVectorHandling {
/// An error is returned
#[default]
Error,
#[allow(dead_code)] // https://github.com/lancedb/lancedb/issues/992
/// The offending row is droppped
Drop,
#[allow(dead_code)] // https://github.com/lancedb/lancedb/issues/992
/// The invalid/missing items are replaced by fill_value
Fill(f32),
}
/// A builder for configuring a [`Connection::create_table`] operation
pub struct CreateTableBuilder<const HAS_DATA: bool> {
parent: Arc<dyn ConnectionInternal>,
name: String,
data: Option<Box<dyn RecordBatchReader + Send>>,
schema: Option<SchemaRef>,
mode: CreateTableMode,
write_options: WriteOptions,
}
// Builder methods that only apply when we have initial data
impl CreateTableBuilder<true> {
fn new(
parent: Arc<dyn ConnectionInternal>,
name: String,
data: Box<dyn RecordBatchReader + Send>,
) -> Self {
Self {
parent,
name,
data: Some(data),
schema: None,
mode: CreateTableMode::default(),
write_options: WriteOptions::default(),
}
}
/// Apply the given write options when writing the initial data
pub fn write_options(mut self, write_options: WriteOptions) -> Self {
self.write_options = write_options;
self
}
/// Execute the create table operation
pub async fn execute(self) -> Result<TableRef> {
self.parent.clone().do_create_table(self).await
}
}
// Builder methods that only apply when we do not have initial data
impl CreateTableBuilder<false> {
fn new(parent: Arc<dyn ConnectionInternal>, name: String, schema: SchemaRef) -> Self {
Self {
parent,
name,
data: None,
schema: Some(schema),
mode: CreateTableMode::default(),
write_options: WriteOptions::default(),
}
}
/// Execute the create table operation
pub async fn execute(self) -> Result<TableRef> {
self.parent.clone().do_create_empty_table(self).await
}
}
impl<const HAS_DATA: bool> CreateTableBuilder<HAS_DATA> {
/// Set the mode for creating the table
///
/// This controls what happens if a table with the given name already exists
pub fn mode(mut self, mode: CreateTableMode) -> Self {
self.mode = mode;
self
}
}
#[derive(Clone, Debug)]
pub struct OpenTableBuilder {
parent: Arc<dyn ConnectionInternal>,
name: String,
index_cache_size: u32,
lance_read_params: Option<ReadParams>,
}
impl OpenTableBuilder {
fn new(parent: Arc<dyn ConnectionInternal>, name: String) -> Self {
Self {
parent,
name,
index_cache_size: 256,
lance_read_params: None,
}
}
/// Set the size of the index cache, specified as a number of entries
///
/// The default value is 256
///
/// The exact meaning of an "entry" will depend on the type of index:
/// * IVF - there is one entry for each IVF partition
/// * BTREE - there is one entry for the entire index
///
/// This cache applies to the entire opened table, across all indices.
/// Setting this value higher will increase performance on larger datasets
/// at the expense of more RAM
pub fn index_cache_size(mut self, index_cache_size: u32) -> Self {
self.index_cache_size = index_cache_size;
self
}
/// Advanced parameters that can be used to customize table reads
///
/// If set, these will take precedence over any overlapping `OpenTableOptions` options
pub fn lance_read_params(mut self, params: ReadParams) -> Self {
self.lance_read_params = Some(params);
self
}
/// Open the table
pub async fn execute(self) -> Result<TableRef> {
self.parent.clone().do_open_table(self).await
}
}
#[async_trait::async_trait]
trait ConnectionInternal: Send + Sync + std::fmt::Debug + 'static {
async fn table_names(&self) -> Result<Vec<String>>;
async fn do_create_table(&self, options: CreateTableBuilder<true>) -> Result<TableRef>;
async fn do_open_table(&self, options: OpenTableBuilder) -> Result<TableRef>;
async fn drop_table(&self, name: &str) -> Result<()>;
async fn drop_db(&self) -> Result<()>;
async fn do_create_empty_table(&self, options: CreateTableBuilder<false>) -> Result<TableRef> {
let batches = RecordBatchIterator::new(vec![], options.schema.unwrap());
let opts = CreateTableBuilder::<true>::new(options.parent, options.name, Box::new(batches))
.mode(options.mode)
.write_options(options.write_options);
self.do_create_table(opts).await
}
}
/// A connection to LanceDB
#[derive(Clone)]
pub struct Connection {
uri: String,
internal: Arc<dyn ConnectionInternal>,
}
impl Connection {
/// Get the URI of the connection
pub fn uri(&self) -> &str {
self.uri.as_str()
}
#[async_trait::async_trait]
pub trait Connection: Send + Sync {
/// Get the names of all tables in the database.
pub async fn table_names(&self) -> Result<Vec<String>> {
self.internal.table_names().await
}
async fn table_names(&self) -> Result<Vec<String>>;
/// Create a new table from data
/// Create a new table in the database.
///
/// # Parameters
///
/// * `name` - The name of the table
/// * `initial_data` - The initial data to write to the table
pub fn create_table(
&self,
name: impl Into<String>,
initial_data: Box<dyn RecordBatchReader + Send>,
) -> CreateTableBuilder<true> {
CreateTableBuilder::<true>::new(self.internal.clone(), name.into(), initial_data)
}
/// Create an empty table with a given schema
///
/// # Parameters
///
/// * `name` - The name of the table
/// * `schema` - The schema of the table
pub fn create_empty_table(
&self,
name: impl Into<String>,
schema: SchemaRef,
) -> CreateTableBuilder<false> {
CreateTableBuilder::<false>::new(self.internal.clone(), name.into(), schema)
}
/// Open an existing table in the database
///
/// # Arguments
/// * `name` - The name of the table
/// * `name` - The name of the table.
/// * `batches` - The initial data to write to the table.
/// * `params` - Optional [`WriteParams`] to create the table.
///
/// # Returns
/// Created [`TableRef`], or [`Error::TableNotFound`] if the table does not exist.
pub fn open_table(&self, name: impl Into<String>) -> OpenTableBuilder {
OpenTableBuilder::new(self.internal.clone(), name.into())
/// Created [`TableRef`], or [`Err(Error::TableAlreadyExists)`] if the table already exists.
async fn create_table(
&self,
name: &str,
batches: Box<dyn RecordBatchReader + Send>,
params: Option<WriteParams>,
) -> Result<TableRef>;
async fn open_table(&self, name: &str) -> Result<TableRef> {
self.open_table_with_params(name, ReadParams::default())
.await
}
async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result<TableRef>;
/// Drop a table in the database.
///
/// # Arguments
/// * `name` - The name of the table to drop
pub async fn drop_table(&self, name: impl AsRef<str>) -> Result<()> {
self.internal.drop_table(name.as_ref()).await
}
/// Drop the database
///
/// This is the same as dropping all of the tables
pub async fn drop_db(&self) -> Result<()> {
self.internal.drop_db().await
}
/// * `name` - The name of the table.
async fn drop_table(&self, name: &str) -> Result<()>;
}
#[derive(Debug)]
pub struct ConnectBuilder {
pub struct ConnectOptions {
/// Database URI
///
/// ### Accpeted URI formats
/// # Accpeted URI formats
///
/// - `/path/to/database` - local database on file system.
/// - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud object store
/// - `db://dbname` - LanceDB Cloud
uri: String,
/// - `db://dbname` - Lance Cloud
pub uri: String,
/// LanceDB Cloud API key, required if using Lance Cloud
api_key: Option<String>,
/// LanceDB Cloud region, required if using Lance Cloud
region: Option<String>,
/// LanceDB Cloud host override, only required if using an on-premises Lance Cloud instance
host_override: Option<String>,
/// Lance Cloud API key
pub api_key: Option<String>,
/// Lance Cloud region
pub region: Option<String>,
/// Lance Cloud host override
pub host_override: Option<String>,
/// User provided AWS credentials
aws_creds: Option<AwsCredential>,
pub aws_creds: Option<AwsCredential>,
/// The maximum number of indices to cache in memory. Defaults to 256.
pub index_cache_size: u32,
}
impl ConnectBuilder {
impl ConnectOptions {
/// Create a new [`ConnectOptions`] with the given database URI.
pub fn new(uri: &str) -> Self {
Self {
@@ -314,6 +104,7 @@ impl ConnectBuilder {
region: None,
host_override: None,
aws_creds: None,
index_cache_size: 256,
}
}
@@ -333,18 +124,15 @@ impl ConnectBuilder {
}
/// [`AwsCredential`] to use when connecting to S3.
///
pub fn aws_creds(mut self, aws_creds: AwsCredential) -> Self {
self.aws_creds = Some(aws_creds);
self
}
/// Establishes a connection to the database
pub async fn execute(self) -> Result<Connection> {
let internal = Arc::new(Database::connect_with_options(&self).await?);
Ok(Connection {
internal,
uri: self.uri,
})
pub fn index_cache_size(mut self, index_cache_size: u32) -> Self {
self.index_cache_size = index_cache_size;
self
}
}
@@ -352,14 +140,29 @@ impl ConnectBuilder {
///
/// # Arguments
///
/// * `uri` - URI where the database is located, can be a local directory, supported remote cloud storage,
/// or a LanceDB Cloud database. See [ConnectOptions::uri] for a list of accepted formats
pub fn connect(uri: &str) -> ConnectBuilder {
ConnectBuilder::new(uri)
/// - `uri` - URI where the database is located, can be a local file or a supported remote cloud storage
///
/// ## Accepted URI formats
///
/// - `/path/to/database` - local database on file system.
/// - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud object store
/// - `db://dbname` - Lance Cloud
///
pub async fn connect(uri: &str) -> Result<Arc<dyn Connection>> {
let options = ConnectOptions::new(uri);
connect_with_options(&options).await
}
#[derive(Debug)]
struct Database {
/// Connect with [`ConnectOptions`].
///
/// # Arguments
/// - `options` - [`ConnectOptions`] to connect to the database.
pub async fn connect_with_options(options: &ConnectOptions) -> Result<Arc<dyn Connection>> {
let db = Database::connect(&options.uri).await?;
Ok(Arc::new(db))
}
pub struct Database {
object_store: ObjectStore,
query_string: Option<String>,
@@ -376,7 +179,21 @@ const MIRRORED_STORE: &str = "mirroredStore";
/// A connection to LanceDB
impl Database {
async fn connect_with_options(options: &ConnectBuilder) -> Result<Self> {
/// Connects to LanceDB
///
/// # Arguments
///
/// * `uri` - URI where the database is located, can be a local file or a supported remote cloud storage
///
/// # Returns
///
/// * A [Database] object.
pub async fn connect(uri: &str) -> Result<Self> {
let options = ConnectOptions::new(uri);
Self::connect_with_options(&options).await
}
pub async fn connect_with_options(options: &ConnectOptions) -> Result<Self> {
let uri = &options.uri;
let parse_res = url::Url::parse(uri);
@@ -516,7 +333,7 @@ impl Database {
}
#[async_trait::async_trait]
impl ConnectionInternal for Database {
impl Connection for Database {
async fn table_names(&self) -> Result<Vec<String>> {
let mut f = self
.object_store
@@ -537,67 +354,55 @@ impl ConnectionInternal for Database {
Ok(f)
}
async fn do_create_table(&self, options: CreateTableBuilder<true>) -> Result<TableRef> {
let table_uri = self.table_uri(&options.name)?;
async fn create_table(
&self,
name: &str,
batches: Box<dyn RecordBatchReader + Send>,
params: Option<WriteParams>,
) -> Result<TableRef> {
let table_uri = self.table_uri(name)?;
let mut write_params = options.write_options.lance_write_params.unwrap_or_default();
if matches!(&options.mode, CreateTableMode::Overwrite) {
write_params.mode = WriteMode::Overwrite;
}
match NativeTable::create(
&table_uri,
&options.name,
options.data.unwrap(),
self.store_wrapper.clone(),
Some(write_params),
)
.await
{
Ok(table) => Ok(Arc::new(table)),
Err(Error::TableAlreadyExists { name }) => match options.mode {
CreateTableMode::Create => Err(Error::TableAlreadyExists { name }),
CreateTableMode::ExistOk(callback) => {
let builder = OpenTableBuilder::new(options.parent, options.name);
let builder = (callback)(builder);
builder.execute().await
}
CreateTableMode::Overwrite => unreachable!(),
},
Err(err) => Err(err),
}
}
async fn do_open_table(&self, options: OpenTableBuilder) -> Result<TableRef> {
let table_uri = self.table_uri(&options.name)?;
Ok(Arc::new(
NativeTable::open_with_params(
NativeTable::create(
&table_uri,
&options.name,
name,
batches,
self.store_wrapper.clone(),
options.lance_read_params,
params,
)
.await?,
))
}
/// Open a table in the database.
///
/// # Arguments
/// * `name` - The name of the table.
/// * `params` - The parameters to open the table.
///
/// # Returns
///
/// * A [TableRef] object.
async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result<TableRef> {
let table_uri = self.table_uri(name)?;
Ok(Arc::new(
NativeTable::open_with_params(&table_uri, name, self.store_wrapper.clone(), params)
.await?,
))
}
async fn drop_table(&self, name: &str) -> Result<()> {
let dir_name = format!("{}.{}", name, LANCE_EXTENSION);
let full_path = self.base_path.child(dir_name.clone());
self.object_store.remove_dir_all(full_path).await?;
Ok(())
}
async fn drop_db(&self) -> Result<()> {
todo!()
}
}
#[cfg(test)]
mod tests {
use std::fs::create_dir_all;
use arrow_schema::{DataType, Field, Schema};
use tempfile::tempdir;
use super::*;
@@ -606,7 +411,7 @@ mod tests {
async fn test_connect() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
let db = Database::connect(uri).await.unwrap();
assert_eq!(db.uri, uri);
}
@@ -624,8 +429,7 @@ mod tests {
let relative_root = std::path::PathBuf::from(relative_ancestors.join("/"));
let relative_uri = relative_root.join(&uri);
let db = connect(relative_uri.to_str().unwrap())
.execute()
let db = Database::connect(relative_uri.to_str().unwrap())
.await
.unwrap();
@@ -640,7 +444,7 @@ mod tests {
create_dir_all(tmp_dir.path().join("invalidlance")).unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
let db = Database::connect(uri).await.unwrap();
let tables = db.table_names().await.unwrap();
assert_eq!(tables.len(), 2);
assert!(tables[0].eq(&String::from("table1")));
@@ -658,44 +462,10 @@ mod tests {
create_dir_all(tmp_dir.path().join("table1.lance")).unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
let db = Database::connect(uri).await.unwrap();
db.drop_table("table1").await.unwrap();
let tables = db.table_names().await.unwrap();
assert_eq!(tables.len(), 0);
}
#[tokio::test]
async fn test_create_table_already_exists() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)]));
db.create_empty_table("test", schema.clone())
.execute()
.await
.unwrap();
// TODO: None of the open table options are "inspectable" right now but once one is we
// should assert we are passing these options in correctly
db.create_empty_table("test", schema)
.mode(CreateTableMode::exist_ok(|builder| {
builder.index_cache_size(16)
}))
.execute()
.await
.unwrap();
let other_schema = Arc::new(Schema::new(vec![Field::new("y", DataType::Int32, false)]));
assert!(db
.create_empty_table("test", other_schema.clone())
.execute()
.await
.is_err());
let overwritten = db
.create_empty_table("test", other_schema.clone())
.mode(CreateTableMode::Overwrite)
.execute()
.await
.unwrap();
assert_eq!(other_schema, overwritten.schema());
}
}

View File

@@ -174,6 +174,7 @@ fn coerce_schema_batch(
}
/// Coerce the reader (input data) to match the given [Schema].
///
pub fn coerce_schema(
reader: impl RecordBatchReader + Send + 'static,
schema: Arc<Schema>,

View File

@@ -342,7 +342,7 @@ mod test {
use object_store::local::LocalFileSystem;
use tempfile;
use crate::{connect, table::WriteOptions};
use crate::connection::{Connection, Database};
#[tokio::test]
async fn test_e2e() {
@@ -354,7 +354,7 @@ mod test {
secondary: Arc::new(secondary_store),
});
let db = connect(dir1.to_str().unwrap()).execute().await.unwrap();
let db = Database::connect(dir1.to_str().unwrap()).await.unwrap();
let mut param = WriteParams::default();
let store_params = ObjectStoreParams {
@@ -368,11 +368,7 @@ mod test {
datagen = datagen.col(Box::new(RandomVector::default().named("vector".into())));
let res = db
.create_table("test", Box::new(datagen.batch(100)))
.write_options(WriteOptions {
lance_write_params: Some(param),
})
.execute()
.create_table("test", Box::new(datagen.batch(100)), Some(param.clone()))
.await;
// leave this here for easy debugging

View File

@@ -43,9 +43,10 @@
//! #### Connect to a database.
//!
//! ```rust
//! use vectordb::connect;
//! # use arrow_schema::{Field, Schema};
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! let db = vectordb::connect("data/sample-lancedb").execute().await.unwrap();
//! let db = connect("data/sample-lancedb").await.unwrap();
//! # });
//! ```
//!
@@ -55,20 +56,14 @@
//! - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud object store
//! - `db://dbname` - Lance Cloud
//!
//! You can also use [`ConnectOptions`] to configure the connection to the database.
//! You can also use [`ConnectOptions`] to configure the connectoin to the database.
//!
//! ```rust
//! use object_store::aws::AwsCredential;
//! use vectordb::{connect_with_options, ConnectOptions};
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! let db = vectordb::connect("data/sample-lancedb")
//! .aws_creds(AwsCredential {
//! key_id: "some_key".to_string(),
//! secret_key: "some_secret".to_string(),
//! token: None,
//! })
//! .execute()
//! .await
//! .unwrap();
//! let options = ConnectOptions::new("data/sample-lancedb")
//! .index_cache_size(1024);
//! let db = connect_with_options(&options).await.unwrap();
//! # });
//! ```
//!
@@ -84,44 +79,31 @@
//!
//! ```rust
//! # use std::sync::Arc;
//! use arrow_schema::{DataType, Schema, Field};
//! use arrow_array::{RecordBatch, RecordBatchIterator};
//! use arrow_schema::{DataType, Field, Schema};
//! # use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, types::Float32Type};
//! # use vectordb::connection::{Database, Connection};
//! # use vectordb::connect;
//!
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! # let tmpdir = tempfile::tempdir().unwrap();
//! # let db = vectordb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap();
//! # let db = connect(tmpdir.path().to_str().unwrap()).await.unwrap();
//! let schema = Arc::new(Schema::new(vec![
//! Field::new("id", DataType::Int32, false),
//! Field::new(
//! "vector",
//! DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 128),
//! true,
//! ),
//! Field::new("id", DataType::Int32, false),
//! Field::new("vector", DataType::FixedSizeList(
//! Arc::new(Field::new("item", DataType::Float32, true)), 128), true),
//! ]));
//! // Create a RecordBatch stream.
//! let batches = RecordBatchIterator::new(
//! vec![RecordBatch::try_new(
//! schema.clone(),
//! let batches = RecordBatchIterator::new(vec![
//! RecordBatch::try_new(schema.clone(),
//! vec![
//! Arc::new(Int32Array::from_iter_values(0..256)),
//! Arc::new(
//! FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
//! (0..256).map(|_| Some(vec![Some(1.0); 128])),
//! 128,
//! ),
//! ),
//! ],
//! )
//! .unwrap()]
//! .into_iter()
//! .map(Ok),
//! schema.clone(),
//! );
//! db.create_table("my_table", Box::new(batches))
//! .execute()
//! .await
//! .unwrap();
//! Arc::new(Int32Array::from_iter_values(0..1000)),
//! Arc::new(FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
//! (0..1000).map(|_| Some(vec![Some(1.0); 128])), 128)),
//! ]).unwrap()
//! ].into_iter().map(Ok),
//! schema.clone());
//! db.create_table("my_table", Box::new(batches), None).await.unwrap();
//! # });
//! ```
//!
@@ -129,13 +111,14 @@
//!
//! ```no_run
//! # use std::sync::Arc;
//! # use vectordb::connect;
//! # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
//! # RecordBatchIterator, Int32Array};
//! # use arrow_schema::{Schema, Field, DataType};
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! # let tmpdir = tempfile::tempdir().unwrap();
//! # let db = vectordb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap();
//! # let tbl = db.open_table("idx_test").execute().await.unwrap();
//! # let db = connect(tmpdir.path().to_str().unwrap()).await.unwrap();
//! # let tbl = db.open_table("idx_test").await.unwrap();
//! tbl.create_index(&["vector"])
//! .ivf_pq()
//! .num_partitions(256)
@@ -153,9 +136,10 @@
//! # use arrow_schema::{DataType, Schema, Field};
//! # use arrow_array::{RecordBatch, RecordBatchIterator};
//! # use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, types::Float32Type};
//! # use vectordb::connection::{Database, Connection};
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! # let tmpdir = tempfile::tempdir().unwrap();
//! # let db = vectordb::connect(tmpdir.path().to_str().unwrap()).execute().await.unwrap();
//! # let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
//! # let schema = Arc::new(Schema::new(vec![
//! # Field::new("id", DataType::Int32, false),
//! # Field::new("vector", DataType::FixedSizeList(
@@ -170,8 +154,8 @@
//! # ]).unwrap()
//! # ].into_iter().map(Ok),
//! # schema.clone());
//! # db.create_table("my_table", Box::new(batches)).execute().await.unwrap();
//! # let table = db.open_table("my_table").execute().await.unwrap();
//! # db.create_table("my_table", Box::new(batches), None).await.unwrap();
//! # let table = db.open_table("my_table").await.unwrap();
//! let results = table
//! .search(&[1.0; 128])
//! .execute_stream()
@@ -181,6 +165,8 @@
//! .await
//! .unwrap();
//! # });
//!
//!
//! ```
pub mod connection;
@@ -193,8 +179,10 @@ pub mod query;
pub mod table;
pub mod utils;
pub use connection::{Connection, Database};
pub use error::{Error, Result};
pub use table::{Table, TableRef};
/// Connect to a database
pub use connection::connect;
pub use connection::{connect, connect_with_options, ConnectOptions};
pub use lance::dataset::WriteMode;

View File

@@ -60,6 +60,7 @@ impl Query {
/// # Arguments
///
/// * `dataset` - Lance dataset.
///
pub(crate) fn new(dataset: Arc<Dataset>) -> Self {
Self {
dataset,

View File

@@ -17,7 +17,7 @@
use std::path::Path;
use std::sync::{Arc, Mutex};
use arrow_array::{RecordBatchIterator, RecordBatchReader};
use arrow_array::RecordBatchReader;
use arrow_schema::{Schema, SchemaRef};
use async_trait::async_trait;
use chrono::Duration;
@@ -27,7 +27,7 @@ use lance::dataset::optimize::{
compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions,
};
pub use lance::dataset::ReadParams;
use lance::dataset::{Dataset, UpdateBuilder, WhenMatched, WriteMode, WriteParams};
use lance::dataset::{Dataset, UpdateBuilder, WhenMatched, WriteParams};
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
use lance::io::WrappingObjectStore;
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
@@ -38,6 +38,7 @@ use crate::index::vector::{VectorIndex, VectorIndexStatistics};
use crate::index::IndexBuilder;
use crate::query::Query;
use crate::utils::{PatchReadParam, PatchWriteParam};
use crate::WriteMode;
use self::merge::{MergeInsert, MergeInsertBuilder};
@@ -84,35 +85,6 @@ pub struct OptimizeStats {
pub prune: Option<RemovalStats>,
}
/// Options to use when writing data
#[derive(Clone, Debug, Default)]
pub struct WriteOptions {
// Coming soon: https://github.com/lancedb/lancedb/issues/992
// /// What behavior to take if the data contains invalid vectors
// pub on_bad_vectors: BadVectorHandling,
/// Advanced parameters that can be used to customize table creation
///
/// If set, these will take precedence over any overlapping `OpenTableOptions` options
pub lance_write_params: Option<WriteParams>,
}
#[derive(Debug, Clone, Default)]
pub enum AddDataMode {
/// Rows will be appended to the table (the default)
#[default]
Append,
/// The existing table will be overwritten with the new data
Overwrite,
}
#[derive(Debug, Default, Clone)]
pub struct AddDataOptions {
/// Whether to add new rows (the default) or replace the existing data
pub mode: AddDataMode,
/// Options to use when writing the data
pub write_options: WriteOptions,
}
/// A Table is a collection of strong typed Rows.
///
/// The type of the each row is defined in Apache Arrow [Schema].
@@ -140,12 +112,12 @@ pub trait Table: std::fmt::Display + Send + Sync {
///
/// # Arguments
///
/// * `batches` data to be added to the Table
/// * `options` options to control how data is added
/// * `batches` RecordBatch to be saved in the Table
/// * `params` Append / Overwrite existing records. Default: Append
async fn add(
&self,
batches: Box<dyn RecordBatchReader + Send>,
options: AddDataOptions,
params: Option<WriteParams>,
) -> Result<()>;
/// Delete the rows from table that match the predicate.
@@ -157,43 +129,28 @@ pub trait Table: std::fmt::Display + Send + Sync {
///
/// ```no_run
/// # use std::sync::Arc;
/// # use vectordb::connection::{Database, Connection};
/// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
/// # RecordBatchIterator, Int32Array};
/// # use arrow_schema::{Schema, Field, DataType};
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let tmpdir = tempfile::tempdir().unwrap();
/// let db = vectordb::connect(tmpdir.path().to_str().unwrap())
/// .execute()
/// .await
/// .unwrap();
/// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
/// # let schema = Arc::new(Schema::new(vec![
/// # Field::new("id", DataType::Int32, false),
/// # Field::new("vector", DataType::FixedSizeList(
/// # Arc::new(Field::new("item", DataType::Float32, true)), 128), true),
/// # ]));
/// let batches = RecordBatchIterator::new(
/// vec![RecordBatch::try_new(
/// schema.clone(),
/// vec![
/// Arc::new(Int32Array::from_iter_values(0..10)),
/// Arc::new(
/// FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
/// (0..10).map(|_| Some(vec![Some(1.0); 128])),
/// 128,
/// ),
/// ),
/// ],
/// )
/// .unwrap()]
/// .into_iter()
/// .map(Ok),
/// schema.clone(),
/// );
/// let tbl = db
/// .create_table("delete_test", Box::new(batches))
/// .execute()
/// .await
/// .unwrap();
/// let batches = RecordBatchIterator::new(vec![
/// RecordBatch::try_new(schema.clone(),
/// vec![
/// Arc::new(Int32Array::from_iter_values(0..10)),
/// Arc::new(FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
/// (0..10).map(|_| Some(vec![Some(1.0); 128])), 128)),
/// ]).unwrap()
/// ].into_iter().map(Ok),
/// schema.clone());
/// let tbl = db.create_table("delete_test", Box::new(batches), None).await.unwrap();
/// tbl.delete("id > 5").await.unwrap();
/// # });
/// ```
@@ -205,16 +162,14 @@ pub trait Table: std::fmt::Display + Send + Sync {
///
/// ```no_run
/// # use std::sync::Arc;
/// # use vectordb::connection::{Database, Connection};
/// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
/// # RecordBatchIterator, Int32Array};
/// # use arrow_schema::{Schema, Field, DataType};
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let tmpdir = tempfile::tempdir().unwrap();
/// let db = vectordb::connect(tmpdir.path().to_str().unwrap())
/// .execute()
/// .await
/// .unwrap();
/// # let tbl = db.open_table("idx_test").execute().await.unwrap();
/// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
/// # let tbl = db.open_table("idx_test").await.unwrap();
/// tbl.create_index(&["vector"])
/// .ivf_pq()
/// .num_partitions(256)
@@ -259,44 +214,32 @@ pub trait Table: std::fmt::Display + Send + Sync {
///
/// ```no_run
/// # use std::sync::Arc;
/// # use vectordb::connection::{Database, Connection};
/// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
/// # RecordBatchIterator, Int32Array};
/// # use arrow_schema::{Schema, Field, DataType};
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let tmpdir = tempfile::tempdir().unwrap();
/// let db = vectordb::connect(tmpdir.path().to_str().unwrap())
/// .execute()
/// .await
/// .unwrap();
/// # let tbl = db.open_table("idx_test").execute().await.unwrap();
/// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
/// # let tbl = db.open_table("idx_test").await.unwrap();
/// # let schema = Arc::new(Schema::new(vec![
/// # Field::new("id", DataType::Int32, false),
/// # Field::new("vector", DataType::FixedSizeList(
/// # Arc::new(Field::new("item", DataType::Float32, true)), 128), true),
/// # ]));
/// let new_data = RecordBatchIterator::new(
/// vec![RecordBatch::try_new(
/// schema.clone(),
/// vec![
/// Arc::new(Int32Array::from_iter_values(0..10)),
/// Arc::new(
/// FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
/// (0..10).map(|_| Some(vec![Some(1.0); 128])),
/// 128,
/// ),
/// ),
/// ],
/// )
/// .unwrap()]
/// .into_iter()
/// .map(Ok),
/// schema.clone(),
/// );
/// let new_data = RecordBatchIterator::new(vec![
/// RecordBatch::try_new(schema.clone(),
/// vec![
/// Arc::new(Int32Array::from_iter_values(0..10)),
/// Arc::new(FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
/// (0..10).map(|_| Some(vec![Some(1.0); 128])), 128)),
/// ]).unwrap()
/// ].into_iter().map(Ok),
/// schema.clone());
/// // Perform an upsert operation
/// let mut merge_insert = tbl.merge_insert(&["id"]);
/// merge_insert
/// .when_matched_update_all(None)
/// .when_not_matched_insert_all();
/// merge_insert.when_matched_update_all(None)
/// .when_not_matched_insert_all();
/// merge_insert.execute(Box::new(new_data)).await.unwrap();
/// # });
/// ```
@@ -323,9 +266,7 @@ pub trait Table: std::fmt::Display + Send + Sync {
/// # use futures::TryStreamExt;
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// # let tbl = vectordb::table::NativeTable::open("/tmp/tbl").await.unwrap();
/// let stream = tbl
/// .query()
/// .nearest_to(&[1.0, 2.0, 3.0])
/// let stream = tbl.query().nearest_to(&[1.0, 2.0, 3.0])
/// .refine_factor(5)
/// .nprobes(10)
/// .execute_stream()
@@ -358,7 +299,11 @@ pub trait Table: std::fmt::Display + Send + Sync {
/// # use futures::TryStreamExt;
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// # let tbl = vectordb::table::NativeTable::open("/tmp/tbl").await.unwrap();
/// let stream = tbl.query().execute_stream().await.unwrap();
/// let stream = tbl
/// .query()
/// .execute_stream()
/// .await
/// .unwrap();
/// let batches: Vec<RecordBatch> = stream.try_collect().await.unwrap();
/// # });
/// ```
@@ -406,7 +351,7 @@ impl NativeTable {
/// * A [NativeTable] object.
pub async fn open(uri: &str) -> Result<Self> {
let name = Self::get_table_name(uri)?;
Self::open_with_params(uri, &name, None, None).await
Self::open_with_params(uri, &name, None, ReadParams::default()).await
}
/// Opens an existing Table
@@ -424,9 +369,8 @@ impl NativeTable {
uri: &str,
name: &str,
write_store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
params: Option<ReadParams>,
params: ReadParams,
) -> Result<Self> {
let params = params.unwrap_or_default();
// patch the params if we have a write store wrapper
let params = match write_store_wrapper.clone() {
Some(wrapper) => params.patch_with_store_wrapper(wrapper)?,
@@ -459,6 +403,7 @@ impl NativeTable {
}
/// Checkout a specific version of this [NativeTable]
///
pub async fn checkout(uri: &str, version: u64) -> Result<Self> {
let name = Self::get_table_name(uri)?;
Self::checkout_with_params(uri, &name, version, None, ReadParams::default()).await
@@ -544,14 +489,13 @@ impl NativeTable {
write_store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
params: Option<WriteParams>,
) -> Result<Self> {
let params = params.unwrap_or_default();
// patch the params if we have a write store wrapper
let params = match write_store_wrapper.clone() {
Some(wrapper) => params.patch_with_store_wrapper(wrapper)?,
None => params,
};
let dataset = Dataset::write(batches, uri, Some(params))
let dataset = Dataset::write(batches, uri, params)
.await
.map_err(|e| match e {
lance::Error::DatasetAlreadyExists { .. } => Error::TableAlreadyExists {
@@ -569,17 +513,6 @@ impl NativeTable {
})
}
pub async fn create_empty(
uri: &str,
name: &str,
schema: SchemaRef,
write_store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
params: Option<WriteParams>,
) -> Result<Self> {
let batches = RecordBatchIterator::new(vec![], schema);
Self::create(uri, name, batches, write_store_wrapper, params).await
}
/// Version of this Table
pub fn version(&self) -> u64 {
self.dataset.lock().expect("lock poison").version().version
@@ -807,26 +740,20 @@ impl Table for NativeTable {
async fn add(
&self,
batches: Box<dyn RecordBatchReader + Send>,
params: AddDataOptions,
params: Option<WriteParams>,
) -> Result<()> {
let lance_params = params
.write_options
.lance_write_params
.unwrap_or(WriteParams {
mode: match params.mode {
AddDataMode::Append => WriteMode::Append,
AddDataMode::Overwrite => WriteMode::Overwrite,
},
..Default::default()
});
let params = Some(params.unwrap_or(WriteParams {
mode: WriteMode::Append,
..WriteParams::default()
}));
// patch the params if we have a write store wrapper
let lance_params = match self.store_wrapper.clone() {
Some(wrapper) => lance_params.patch_with_store_wrapper(wrapper)?,
None => lance_params,
let params = match self.store_wrapper.clone() {
Some(wrapper) => params.patch_with_store_wrapper(wrapper)?,
None => params,
};
self.reset_dataset(Dataset::write(batches, &self.uri, Some(lance_params)).await?);
self.reset_dataset(Dataset::write(batches, &self.uri, params).await?);
Ok(())
}
@@ -954,6 +881,25 @@ mod tests {
assert_eq!(c.to_str().unwrap(), "s3://bucket/path/to/file/subfile");
}
#[tokio::test]
async fn test_create_already_exists() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let batches = make_test_batches();
let _ = batches.schema().clone();
NativeTable::create(uri, "test", batches, None, None)
.await
.unwrap();
let batches = make_test_batches();
let result = NativeTable::create(uri, "test", batches, None, None).await;
assert!(matches!(
result.unwrap_err(),
Error::TableAlreadyExists { .. }
));
}
#[tokio::test]
async fn test_count_rows() {
let tmp_dir = tempdir().unwrap();
@@ -994,10 +940,7 @@ mod tests {
schema.clone(),
);
table
.add(Box::new(new_batches), AddDataOptions::default())
.await
.unwrap();
table.add(Box::new(new_batches), None).await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 20);
assert_eq!(table.name, "test");
}
@@ -1060,47 +1003,23 @@ mod tests {
.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 10);
let batches = vec![RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(100..110))],
)
.unwrap()]
.into_iter()
.map(Ok);
let new_batches = RecordBatchIterator::new(batches.clone(), schema.clone());
// Can overwrite using AddDataOptions::mode
table
.add(
Box::new(new_batches),
AddDataOptions {
mode: AddDataMode::Overwrite,
..Default::default()
},
let new_batches = RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(100..110))],
)
.await
.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 10);
assert_eq!(table.name, "test");
// Can overwrite using underlying WriteParams (which
// take precedence over AddDataOptions::mode)
.unwrap()]
.into_iter()
.map(Ok),
schema.clone(),
);
let param: WriteParams = WriteParams {
mode: WriteMode::Overwrite,
..Default::default()
};
let opts = AddDataOptions {
write_options: WriteOptions {
lance_write_params: Some(param),
},
mode: AddDataMode::Append,
};
let new_batches = RecordBatchIterator::new(batches.clone(), schema.clone());
table.add(Box::new(new_batches), opts).await.unwrap();
table.add(Box::new(new_batches), Some(param)).await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 10);
assert_eq!(table.name, "test");
}
@@ -1410,7 +1329,7 @@ mod tests {
..Default::default()
};
assert!(!wrapper.called());
let _ = NativeTable::open_with_params(uri, "test", None, Some(param))
let _ = NativeTable::open_with_params(uri, "test", None, param)
.await
.unwrap();
assert!(wrapper.called());

View File

@@ -32,17 +32,20 @@ impl PatchStoreParam for Option<ObjectStoreParams> {
}
pub trait PatchWriteParam {
fn patch_with_store_wrapper(self, wrapper: Arc<dyn WrappingObjectStore>)
-> Result<WriteParams>;
fn patch_with_store_wrapper(
self,
wrapper: Arc<dyn WrappingObjectStore>,
) -> Result<Option<WriteParams>>;
}
impl PatchWriteParam for WriteParams {
impl PatchWriteParam for Option<WriteParams> {
fn patch_with_store_wrapper(
mut self,
self,
wrapper: Arc<dyn WrappingObjectStore>,
) -> Result<WriteParams> {
self.store_params = self.store_params.patch_with_store_wrapper(wrapper)?;
Ok(self)
) -> Result<Option<WriteParams>> {
let mut params = self.unwrap_or_default();
params.store_params = params.store_params.patch_with_store_wrapper(wrapper)?;
Ok(Some(params))
}
}