Compare commits

...

17 Commits

Author SHA1 Message Date
Lance Release
bd2d40a927 Bump version: 0.1.12 → 0.1.13 2023-07-13 21:17:35 +00:00
Lei Xu
08944bf4fd [Python] Convert Pydantic Model to Arrow Schema (#291)
Provide utility to automatically convert Pydantic model to Arrow Schema

Closes #256
2023-07-13 11:16:37 -07:00
gsilvestrin
826dc90151 feat(node): add option object to connect method (#286) 2023-07-13 11:03:48 -07:00
Lei Xu
08cc483ec9 [Doc] Describe the difference between ANN and KNN, and how to create indices. (#293) 2023-07-13 08:52:58 -07:00
Lei Xu
ff1d206182 [Doc] Split the python integration into different topics (#292) 2023-07-12 21:26:59 -07:00
gsilvestrin
c385c55629 feat(node): pull node binaries into separate packages (3) (#285) 2023-07-12 16:52:04 -07:00
Lance Release
0a03f7ca5a Bump version: 0.1.11 → 0.1.12 2023-07-12 04:20:34 +00:00
Rob Meng
88be978e87 allow logging in JS (#283)
tested with `RUST_LOG=info npm test`
2023-07-11 22:50:36 -04:00
Rob Meng
98b12caa06 export create table with aws credentials (#282) 2023-07-11 17:21:10 -04:00
Lance Release
091dffb171 Bump version: 0.1.10 → 0.1.11 2023-07-11 20:42:15 +00:00
Rob Meng
ace6aa883a Upgrade lance to 0.5.5, and plumb thru new features from the upgrade (#279)
* upgrade
* fixes for the upgrade
* allow JS users to pass custom AWS credentials
2023-07-11 16:33:39 -04:00
Tevin Wang
80c25f9896 [Docs] uncomment cosine metric (#271)
- Change k value to `10` for js search to keep it consistent with python
docs
- Uncomment now that cosine metrix is fixed in lance:
https://github.com/lancedb/lance/pull/1035
2023-07-11 12:30:11 -07:00
gsilvestrin
caf22fdb71 Run rust tests when Cargo.toml changes (#276) 2023-07-11 11:19:06 -07:00
Lei Xu
0e7ae5dfbf [Python] Fix list type conversion to JSON and temporal types (#274) 2023-07-11 11:05:51 -07:00
gsilvestrin
b261e27222 Pin lance version (#275)
we shouldn't auto-upgrade lance
2023-07-11 10:58:15 -07:00
Lei Xu
9f603f73a9 [Python] Schema to JSON (#272) 2023-07-10 18:11:24 -07:00
Lei Xu
9ef846929b [Python] List tables from remote service (#262) 2023-07-09 23:58:03 -07:00
42 changed files with 1709 additions and 300 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.1.10
current_version = 0.1.13
commit = True
message = Bump version: {current_version} → {new_version}
tag = True

View File

@@ -67,8 +67,12 @@ jobs:
- name: Build
run: |
npm ci
npm run build
npm run tsc
npm run build
npm run pack-build
npm install --no-save ./dist/vectordb-*.tgz
# Remove index.node to test with dependency installed
rm index.node
- name: Test
run: npm run test
macos:
@@ -94,8 +98,12 @@ jobs:
- name: Build
run: |
npm ci
npm run build
npm run tsc
npm run build
npm run pack-build
npm install --no-save ./dist/vectordb-*.tgz
# Remove index.node to test with dependency installed
rm index.node
- name: Test
run: |
npm run test

137
.github/workflows/npm-publish.yml vendored Normal file
View File

@@ -0,0 +1,137 @@
name: NPM Publish
on:
release:
types: [ published ]
jobs:
node:
runs-on: ubuntu-latest
# Only runs on tags that matches the make-release action
if: startsWith(github.ref, 'refs/tags/v')
defaults:
run:
shell: bash
working-directory: node
steps:
- name: Checkout
uses: actions/checkout@v3
- uses: actions/setup-node@v3
with:
node-version: 20
cache: 'npm'
cache-dependency-path: node/package-lock.json
- name: Install dependencies
run: |
sudo apt update
sudo apt install -y protobuf-compiler libssl-dev
- name: Build
run: |
npm ci
npm run tsc
npm pack
- name: Upload Linux Artifacts
uses: actions/upload-artifact@v3
with:
name: node-package
path: |
node/vectordb-*.tgz
node-macos:
runs-on: macos-12
# Only runs on tags that matches the make-release action
if: startsWith(github.ref, 'refs/tags/v')
strategy:
fail-fast: false
matrix:
target: [x86_64-apple-darwin, aarch64-apple-darwin]
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Install system dependencies
run: brew install protobuf
- name: Install npm dependencies
run: |
cd node
npm ci
- name: Install rustup target
if: ${{ matrix.target == 'aarch64-apple-darwin' }}
run: rustup target add aarch64-apple-darwin
- name: Build MacOS native node modules
run: bash ci/build_macos_artifacts.sh ${{ matrix.target }}
- name: Upload Darwin Artifacts
uses: actions/upload-artifact@v3
with:
name: darwin-native
path: |
node/dist/vectordb-darwin*.tgz
node-linux:
name: node-linux (${{ matrix.arch}}-unknown-linux-${{ matrix.libc }})
runs-on: ubuntu-latest
# Only runs on tags that matches the make-release action
if: startsWith(github.ref, 'refs/tags/v')
strategy:
fail-fast: false
matrix:
libc:
- gnu
# TODO: re-enable musl once we have refactored to pre-built containers
# Right now we have to build node from source which is too expensive.
# - musl
arch:
- x86_64
# Building on aarch64 is too slow for now
# - aarch64
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Change owner to root (for npm)
# The docker container is run as root, so we need the files to be owned by root
# Otherwise npm is a nightmare: https://github.com/npm/cli/issues/3773
run: sudo chown -R root:root .
- name: Set up QEMU
if: ${{ matrix.arch == 'aarch64' }}
uses: docker/setup-qemu-action@v2
with:
platforms: arm64
- name: Build Linux GNU native node modules
if: ${{ matrix.libc == 'gnu' }}
run: |
docker run \
-v $(pwd):/io -w /io \
rust:1.70-bookworm \
bash ci/build_linux_artifacts.sh ${{ matrix.arch }}-unknown-linux-gnu
- name: Build musl Linux native node modules
if: ${{ matrix.libc == 'musl' }}
run: |
docker run --platform linux/arm64/v8 \
-v $(pwd):/io -w /io \
quay.io/pypa/musllinux_1_1_${{ matrix.arch }} \
bash ci/build_linux_artifacts.sh ${{ matrix.arch }}-unknown-linux-musl
- name: Upload Linux Artifacts
uses: actions/upload-artifact@v3
with:
name: linux-native
path: |
node/dist/vectordb-linux*.tgz
release:
needs: [node, node-macos, node-linux]
runs-on: ubuntu-latest
# Only runs on tags that matches the make-release action
if: startsWith(github.ref, 'refs/tags/v')
steps:
- uses: actions/download-artifact@v3
- name: Display structure of downloaded files
run: ls -R
- uses: actions/setup-node@v3
with:
node-version: 20
- name: Publish to NPM
env:
NODE_AUTH_TOKEN: ${{ secrets.LANCEDB_NPM_REGISTRY_TOKEN }}
run: |
for filename in */*.tgz; do
npm publish $filename
done

View File

@@ -6,6 +6,7 @@ on:
- main
pull_request:
paths:
- Cargo.toml
- rust/**
- .github/workflows/rust.yml

2
.gitignore vendored
View File

@@ -5,6 +5,8 @@
.DS_Store
venv
.vscode
rust/target
rust/Cargo.lock

View File

@@ -6,9 +6,9 @@ members = [
resolver = "2"
[workspace.dependencies]
lance = "0.5.3"
arrow-array = "40.0"
arrow-data = "40.0"
arrow-schema = "40.0"
arrow-ipc = "40.0"
lance = "=0.5.5"
arrow-array = "42.0"
arrow-data = "42.0"
arrow-schema = "42.0"
arrow-ipc = "42.0"
object_store = "0.6.1"

View File

@@ -0,0 +1,72 @@
#!/bin/bash
# Builds the Linux artifacts (node binaries).
# Usage: ./build_linux_artifacts.sh [target]
# Targets supported:
# - x86_64-unknown-linux-gnu:centos
# - aarch64-unknown-linux-gnu:centos
# - aarch64-unknown-linux-musl
# - x86_64-unknown-linux-musl
# TODO: refactor this into a Docker container we can pull
set -e
setup_dependencies() {
echo "Installing system dependencies..."
if [[ $1 == *musl ]]; then
# musllinux
apk add openssl-dev
else
# rust / debian
apt update
apt install -y libssl-dev protobuf-compiler
fi
}
install_node() {
echo "Installing node..."
curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.34.0/install.sh | bash
source "$HOME"/.bashrc
if [[ $1 == *musl ]]; then
# This node version is 15, we need 16 or higher:
# apk add nodejs-current npm
# So instead we install from source (nvm doesn't provide binaries for musl):
nvm install -s --no-progress 17
else
nvm install --no-progress 17 # latest that supports glibc 2.17
fi
}
build_node_binary() {
echo "Building node library for $1..."
pushd node
npm ci
if [[ $1 == *musl ]]; then
# This is needed for cargo to allow build cdylibs with musl
export RUSTFLAGS="-C target-feature=-crt-static"
fi
# Cargo can run out of memory while pulling dependencies, especially when running
# in QEMU. This is a workaround for that.
export CARGO_NET_GIT_FETCH_WITH_CLI=true
# We don't pass in target, since the native target here already matches
# We need to pass OPENSSL_LIB_DIR and OPENSSL_INCLUDE_DIR for static build to work https://github.com/sfackler/rust-openssl/issues/877
OPENSSL_STATIC=1 OPENSSL_LIB_DIR=/usr/lib/x86_64-linux-gnu OPENSSL_INCLUDE_DIR=/usr/include/openssl/ npm run build-release
npm run pack-build
popd
}
TARGET=${1:-x86_64-unknown-linux-gnu}
# Others:
# aarch64-unknown-linux-gnu
# x86_64-unknown-linux-musl
# aarch64-unknown-linux-musl
setup_dependencies $TARGET
install_node $TARGET
build_node_binary $TARGET

View File

@@ -0,0 +1,33 @@
# Builds the macOS artifacts (node binaries).
# Usage: ./ci/build_macos_artifacts.sh [target]
# Targets supported: x86_64-apple-darwin aarch64-apple-darwin
prebuild_rust() {
# Building here for the sake of easier debugging.
pushd rust/ffi/node
echo "Building rust library for $1"
export RUST_BACKTRACE=1
cargo build --release --target $1
popd
}
build_node_binaries() {
pushd node
echo "Building node library for $1"
npm run build-release -- --target $1
npm run pack-build -- --target $1
popd
}
if [ -n "$1" ]; then
targets=$1
else
targets="x86_64-apple-darwin aarch64-apple-darwin"
fi
echo "Building artifacts for targets: $targets"
for target in $targets
do
prebuild_rust $target
build_node_binaries $target
done

View File

@@ -50,13 +50,16 @@ markdown_extensions:
- pymdownx.superfences
- pymdownx.tabbed:
alternate_style: true
- md_in_html
nav:
- Home: index.md
- Basics: basic.md
- Embeddings: embedding.md
- Python full-text search: fts.md
- Python integrations: integrations.md
- Python integrations:
- Pandas and PyArrow: python/arrow.md
- DuckDB: python/duckdb.md
- Python examples:
- YouTube Transcript Search: notebooks/youtube_transcript_search.ipynb
- Documentation QA Bot using LangChain: notebooks/code_qa_bot.ipynb

View File

@@ -1,7 +1,7 @@
# ANN (Approximate Nearest Neighbor) Indexes
You can create an index over your vector data to make search faster.
Vector indexes are faster but less accurate than exhaustive search.
Vector indexes are faster but less accurate than exhaustive search (KNN or Flat Search).
LanceDB provides many parameters to fine-tune the index's size, the speed of queries, and the accuracy of results.
Currently, LanceDB does *not* automatically create the ANN index.
@@ -10,7 +10,18 @@ If you can live with <100ms latency, skipping index creation is a simpler workfl
In the future we will look to automatically create and configure the ANN index.
## Creating an ANN Index
## Types of Index
Lance can support multiple index types, the most widely used one is `IVF_PQ`.
* `IVF_PQ`: use **Inverted File Index (IVF)** to first divide the dataset into `N` partitions,
and then use **Product Quantization** to compress vectors in each partition.
* `DISKANN` (**Experimental**): organize the vector as a on-disk graph, where the vertices approximately
represent the nearest neighbors of each vector.
## Creating an IVF_PQ Index
Lance supports `IVF_PQ` index type by default.
=== "Python"
Creating indexes is done via the [create_index](https://lancedb.github.io/lancedb/python/#lancedb.table.LanceTable.create_index) method.
@@ -45,15 +56,18 @@ In the future we will look to automatically create and configure the ANN index.
await table.createIndex({ type: 'ivf_pq', column: 'vector', num_partitions: 256, num_sub_vectors: 96 })
```
Since `create_index` has a training step, it can take a few minutes to finish for large tables. You can control the index
creation by providing the following parameters:
- **metric** (default: "L2"): The distance metric to use. By default it uses euclidean distance "`L2`".
We also support "cosine" and "dot" distance as well.
- **num_partitions** (default: 256): The number of partitions of the index.
- **num_sub_vectors** (default: 96): The number of sub-vectors (M) that will be created during Product Quantization (PQ).
For D dimensional vector, it will be divided into `M` of `D/M` sub-vectors, each of which is presented by
a single PQ code.
<figure markdown>
![IVF PQ](./assets/ivf_pq.png)
<figcaption>IVF_PQ index with <code>num_partitions=2, num_sub_vectors=4</code></figcaption>
</figure>
- **metric** (default: "L2"): The distance metric to use. By default we use euclidean distance. We also support "cosine" distance.
- **num_partitions** (default: 256): The number of partitions of the index. The number of partitions should be configured so each partition has 3-5K vectors. For example, a table
with ~1M vectors should use 256 partitions. You can specify arbitrary number of partitions but powers of 2 is most conventional.
A higher number leads to faster queries, but it makes index generation slower.
- **num_sub_vectors** (default: 96): The number of subvectors (M) that will be created during Product Quantization (PQ). A larger number makes
search more accurate, but also makes the index larger and slower to build.
## Querying an ANN Index
@@ -138,3 +152,31 @@ You can select the columns returned by the query using a select clause.
.select(["id"])
.execute()
```
## FAQ
### When is it necessary to create an ANN vector index.
`LanceDB` has manually tuned SIMD code for computing vector distances.
In our benchmarks, computing 100K pairs of 1K dimension vectors only take less than 20ms.
For small dataset (<100K rows) or the applications which can accept 100ms latency, vector indices are usually not necessary.
For large-scale or higher dimension vectors, it is beneficial to create vector index.
### How big is my index, and how many memory will it take.
In LanceDB, all vector indices are disk-based, meaning that when responding to a vector query, only the relevant pages from the index file are loaded from disk and cached in memory. Additionally, each sub-vector is usually encoded into 1 byte PQ code.
For example, with a 1024-dimension dataset, if we choose `num_sub_vectors=64`, each sub-vector has `1024 / 64 = 16` float32 numbers.
Product quantization can lead to approximately `16 * sizeof(float32) / 1 = 64` times of space reduction.
### How to choose `num_partitions` and `num_sub_vectors` for `IVF_PQ` index.
`num_partitions` is used to decide how many partitions the first level `IVF` index uses.
Higher number of partitions could lead to more efficient I/O during queries and better accuracy, but it takes much more time to train.
On `SIFT-1M` dataset, our benchmark shows that keeping each partition 1K-4K rows lead to a good latency / recall.
`num_sub_vectors` decides how many Product Quantization code to generate on each vector. Because
Product Quantization is a lossy compression of the original vector, the more `num_sub_vectors` usually results to
less space distortion, and thus yield better accuracy. However, similarly, more `num_sub_vectors` causes heavier I/O and
more PQ computation, thus, higher latency. `dimension / num_sub_vectors` should be aligned with 8 for better SIMD efficiency.

BIN
docs/src/assets/ivf_pq.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 266 KiB

View File

@@ -46,7 +46,7 @@ LanceDB's core is written in Rust 🦀 and is built using <a href="https://githu
const uri = "data/sample-lancedb";
const db = await lancedb.connect(uri);
const table = await db.createTable("my_table",
const table = await db.createTable("my_table",
[{ id: 1, vector: [3.1, 4.1], item: "foo", price: 10.0 },
{ id: 2, vector: [5.9, 26.5], item: "bar", price: 20.0 }])
const results = await table.search([100, 100]).limit(2).execute();
@@ -67,6 +67,6 @@ LanceDB's core is written in Rust 🦀 and is built using <a href="https://githu
* [`Embedding Functions`](embedding.md) - functions for working with embeddings.
* [`Indexing`](ann_indexes.md) - create vector indexes to speed up queries.
* [`Full text search`](fts.md) - [EXPERIMENTAL] full-text search API
* [`Ecosystem Integrations`](integrations.md) - integrating LanceDB with python data tooling ecosystem.
* [`Ecosystem Integrations`](python/integration.md) - integrating LanceDB with python data tooling ecosystem.
* [`Python API Reference`](python/python.md) - detailed documentation for the LanceDB Python SDK.
* [`Node API Reference`](javascript/modules.md) - detailed documentation for the LanceDB Python SDK.

View File

@@ -1,116 +0,0 @@
# Integrations
Built on top of Apache Arrow, `LanceDB` is easy to integrate with the Python ecosystem, including Pandas, PyArrow and DuckDB.
## Pandas and PyArrow
First, we need to connect to a `LanceDB` database.
```py
import lancedb
db = lancedb.connect("data/sample-lancedb")
```
And write a `Pandas DataFrame` to LanceDB directly.
```py
import pandas as pd
data = pd.DataFrame({
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["foo", "bar"],
"price": [10.0, 20.0]
})
table = db.create_table("pd_table", data=data)
```
You will find detailed instructions of creating dataset and index in [Basic Operations](basic.md) and [Indexing](ann_indexes.md)
sections.
We can now perform similarity searches via `LanceDB`.
```py
# Open the table previously created.
table = db.open_table("pd_table")
query_vector = [100, 100]
# Pandas DataFrame
df = table.search(query_vector).limit(1).to_df()
print(df)
```
```
vector item price score
0 [5.9, 26.5] bar 20.0 14257.05957
```
If you have a simple filter, it's faster to provide a where clause to `LanceDB`'s search query.
If you have more complex criteria, you can always apply the filter to the resulting pandas `DataFrame` from the search query.
```python
# Apply the filter via LanceDB
results = table.search([100, 100]).where("price < 15").to_df()
assert len(results) == 1
assert results["item"].iloc[0] == "foo"
# Apply the filter via Pandas
df = results = table.search([100, 100]).to_df()
results = df[df.price < 15]
assert len(results) == 1
assert results["item"].iloc[0] == "foo"
```
## DuckDB
`LanceDB` works with `DuckDB` via [PyArrow integration](https://duckdb.org/docs/guides/python/sql_on_arrow).
Let us start with installing `duckdb` and `lancedb`.
```shell
pip install duckdb lancedb
```
We will re-use the dataset created previously
```python
import lancedb
db = lancedb.connect("data/sample-lancedb")
table = db.open_table("pd_table")
arrow_table = table.to_arrow()
```
`DuckDB` can directly query the `arrow_table`:
```python
import duckdb
duckdb.query("SELECT * FROM arrow_table")
```
```
┌─────────────┬─────────┬────────┐
│ vector │ item │ price │
│ float[] │ varchar │ double │
├─────────────┼─────────┼────────┤
│ [3.1, 4.1] │ foo │ 10.0 │
│ [5.9, 26.5] │ bar │ 20.0 │
└─────────────┴─────────┴────────┘
```
```python
duckdb.query("SELECT mean(price) FROM arrow_table")
```
```
Out[16]:
┌─────────────┐
│ mean(price) │
│ double │
├─────────────┤
│ 15.0 │
└─────────────┘
```

67
docs/src/python/arrow.md Normal file
View File

@@ -0,0 +1,67 @@
# Pandas and PyArrow
Built on top of [Apache Arrow](https://arrow.apache.org/),
`LanceDB` is easy to integrate with the Python ecosystem, including [Pandas](https://pandas.pydata.org/)
and PyArrow.
First, we need to connect to a `LanceDB` database.
```py
import lancedb
db = lancedb.connect("data/sample-lancedb")
```
Afterwards, we write a `Pandas DataFrame` to LanceDB directly.
```py
import pandas as pd
data = pd.DataFrame({
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["foo", "bar"],
"price": [10.0, 20.0]
})
table = db.create_table("pd_table", data=data)
```
You will find detailed instructions of creating dataset and index in
[Basic Operations](basic.md) and [Indexing](ann_indexes.md)
sections.
We can now perform similarity search via `LanceDB` Python API.
```py
# Open the table previously created.
table = db.open_table("pd_table")
query_vector = [100, 100]
# Pandas DataFrame
df = table.search(query_vector).limit(1).to_df()
print(df)
```
```
vector item price score
0 [5.9, 26.5] bar 20.0 14257.05957
```
If you have a simple filter, it's faster to provide a `where clause` to `LanceDB`'s search query.
If you have more complex criteria, you can always apply the filter to the resulting Pandas `DataFrame`.
```python
# Apply the filter via LanceDB
results = table.search([100, 100]).where("price < 15").to_df()
assert len(results) == 1
assert results["item"].iloc[0] == "foo"
# Apply the filter via Pandas
df = results = table.search([100, 100]).to_df()
results = df[df.price < 15]
assert len(results) == 1
assert results["item"].iloc[0] == "foo"
```

56
docs/src/python/duckdb.md Normal file
View File

@@ -0,0 +1,56 @@
# DuckDB
`LanceDB` works with `DuckDB` via [PyArrow integration](https://duckdb.org/docs/guides/python/sql_on_arrow).
Let us start with installing `duckdb` and `lancedb`.
```shell
pip install duckdb lancedb
```
We will re-use [the dataset created previously](./arrow.md):
```python
import pandas as pd
import lancedb
db = lancedb.connect("data/sample-lancedb")
data = pd.DataFrame({
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["foo", "bar"],
"price": [10.0, 20.0]
})
table = db.create_table("pd_table", data=data)
arrow_table = table.to_arrow()
```
`DuckDB` can directly query the `arrow_table`:
```python
import duckdb
duckdb.query("SELECT * FROM arrow_table")
```
```
┌─────────────┬─────────┬────────┐
│ vector │ item │ price │
│ float[] │ varchar │ double │
├─────────────┼─────────┼────────┤
│ [3.1, 4.1] │ foo │ 10.0 │
│ [5.9, 26.5] │ bar │ 20.0 │
└─────────────┴─────────┴────────┘
```
```py
duckdb.query("SELECT mean(price) FROM arrow_table")
```
```
┌─────────────┐
│ mean(price) │
│ double │
├─────────────┤
│ 15.0 │
└─────────────┘
```

View File

@@ -0,0 +1,7 @@
# Integration
Built on top of [Apache Arrow](https://arrow.apache.org/),
`LanceDB` is very easy to be integrate with Python ecosystems.
* [Pandas and Arrow Integration](./arrow.md)
* [DuckDB Integration](./duckdb.md)

View File

@@ -43,3 +43,21 @@ pip install lancedb
::: lancedb.fts.populate_index
::: lancedb.fts.search_index
## Utilities
::: lancedb.schema.schema_to_dict
::: lancedb.schema.dict_to_schema
::: lancedb.vector
## Integrations
### Pydantic
::: lancedb.pydantic.pydantic_to_schema
::: lancedb.pydantic.vector

View File

@@ -25,9 +25,9 @@ Currently, we support the following metrics:
### Flat Search
If LanceDB does not create a vector index, LanceDB would need to scan (`Flat Search`) the entire vector column
and compute the distance for each vector in order to find the closest matches.
If there is no [vector index is created](ann_indexes.md), LanceDB will just brute-force scan
the vector column and compute the distance.
<!-- Setup Code
```python
@@ -79,39 +79,43 @@ await db_setup.createTable('my_vectors', data)
const tbl = await db.openTable("my_vectors")
const results_1 = await tbl.search(Array(1536).fill(1.2))
.limit(20)
.limit(10)
.execute()
```
<!-- Commenting out for now since metricType fails for JS on Ubuntu 22.04.
By default, `l2` will be used as `Metric` type. You can customize the metric type
as well.
-->
<!--
=== "Python"
-->
<!-- ```python
```python
df = tbl.search(np.random.random((1536))) \
.metric("cosine") \
.limit(10) \
.to_df()
```
-->
<!--
=== "JavaScript"
-->
<!-- ```javascript
=== "JavaScript"
```javascript
const results_2 = await tbl.search(Array(1536).fill(1.2))
.metricType("cosine")
.limit(20)
.limit(10)
.execute()
```
-->
### Search with Vector Index.
### Approximate Nearest Neighbor (ANN) Search with Vector Index.
To accelerate vector retrievals, it is common to build vector indices.
A vector index is a data structure specifically designed to efficiently organize and
search vector data based on their similarity or distance metrics.
By constructing a vector index, you can reduce the search space and avoid the need
for brute-force scanning of the entire vector column.
However, fast vector search using indices often entails making a trade-off with accuracy to some extent.
This is why it is often called **Approximate Nearest Neighbors (ANN)** search, while the Flat Search (KNN)
always returns 100% recall.
See [ANN Index](ann_indexes.md) for more details.

4
node/.npmignore Normal file
View File

@@ -0,0 +1,4 @@
gen_test_data.py
index.node
dist/lancedb*.tgz
vectordb*.tgz

View File

@@ -8,6 +8,10 @@ A JavaScript / Node.js library for [LanceDB](https://github.com/lancedb/lancedb)
npm install vectordb
```
This will download the appropriate native library for your platform. We currently
support x86_64 Linux, aarch64 Linux, Intel MacOS, and ARM (M1/M2) MacOS. We do not
yet support Windows or musl-based Linux (such as Alpine Linux).
## Usage
### Basic Example
@@ -26,12 +30,34 @@ The [examples](./examples) folder contains complete examples.
## Development
Run the tests with
To build everything fresh:
```bash
npm install
npm run tsc
npm run build
```
Then you should be able to run the tests with:
```bash
npm test
```
### Rebuilding Rust library
```bash
npm run build
```
### Rebuilding Typescript
```bash
npm run tsc
```
### Fix lints
To run the linter and have it automatically fix all errors
```bash

View File

@@ -12,29 +12,26 @@
// See the License for the specific language governing permissions and
// limitations under the License.
const { currentTarget } = require('@neon-rs/load');
let nativeLib;
function getPlatformLibrary() {
if (process.platform === "darwin" && process.arch == "arm64") {
return require('./aarch64-apple-darwin.node');
} else if (process.platform === "darwin" && process.arch == "x64") {
return require('./x86_64-apple-darwin.node');
} else if (process.platform === "linux" && process.arch == "x64") {
return require('./x86_64-unknown-linux-gnu.node');
} else {
throw new Error(`vectordb: unsupported platform ${process.platform}_${process.arch}. Please file a bug report at https://github.com/lancedb/lancedb/issues`)
}
}
try {
nativeLib = require('./index.node')
nativeLib = require(`vectordb-${currentTarget()}`);
} catch (e) {
if (e.code === "MODULE_NOT_FOUND") {
nativeLib = getPlatformLibrary();
} else {
throw new Error('vectordb: failed to load native library. Please file a bug report at https://github.com/lancedb/lancedb/issues');
}
try {
// Might be developing locally, so try that. But don't expose that error
// to the user.
nativeLib = require("./index.node");
} catch {
throw new Error(`vectordb: failed to load native library.
You may need to run \`npm install vectordb-${currentTarget()}\`.
If that does not work, please file a bug report at https://github.com/lancedb/lancedb/issues
Source error: ${e}`);
}
}
module.exports = nativeLib
// Dynamic require for runtime.
module.exports = nativeLib;

45
node/package-lock.json generated
View File

@@ -1,18 +1,28 @@
{
"name": "vectordb",
"version": "0.1.9",
"version": "0.1.12",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "vectordb",
"version": "0.1.9",
"version": "0.1.12",
"cpu": [
"x64",
"arm64"
],
"license": "Apache-2.0",
"os": [
"darwin",
"linux"
],
"dependencies": {
"@apache-arrow/ts": "^12.0.0",
"@neon-rs/load": "^0.0.74",
"apache-arrow": "^12.0.0"
},
"devDependencies": {
"@neon-rs/cli": "^0.0.74",
"@types/chai": "^4.3.4",
"@types/chai-as-promised": "^7.1.5",
"@types/mocha": "^10.0.1",
@@ -37,6 +47,12 @@
"typedoc": "^0.24.7",
"typedoc-plugin-markdown": "^3.15.3",
"typescript": "*"
},
"optionalDependencies": {
"vectordb-darwin-arm64": "0.1.12",
"vectordb-darwin-x64": "0.1.12",
"vectordb-linux-arm64-gnu": "0.1.12",
"vectordb-linux-x64-gnu": "0.1.12"
}
},
"node_modules/@apache-arrow/ts": {
@@ -204,6 +220,20 @@
"@jridgewell/sourcemap-codec": "^1.4.10"
}
},
"node_modules/@neon-rs/cli": {
"version": "0.0.74",
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.74.tgz",
"integrity": "sha512-9lPmNmjej5iKKOTMPryOMubwkgMRyTWRuaq1yokASvI5mPhr2kzPN7UVjdCOjQvpunNPngR9yAHoirpjiWhUHw==",
"dev": true,
"bin": {
"neon": "index.js"
}
},
"node_modules/@neon-rs/load": {
"version": "0.0.74",
"resolved": "https://registry.npmjs.org/@neon-rs/load/-/load-0.0.74.tgz",
"integrity": "sha512-/cPZD907UNz55yrc/ud4wDgQKtU1TvkD9jeqZWG6J4IMmZkp6zgjkQcKA8UvpkZlcpPHvc8J17sGzLFbP/LUYg=="
},
"node_modules/@nodelib/fs.scandir": {
"version": "2.1.5",
"resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz",
@@ -4601,6 +4631,17 @@
"@jridgewell/sourcemap-codec": "^1.4.10"
}
},
"@neon-rs/cli": {
"version": "0.0.74",
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.74.tgz",
"integrity": "sha512-9lPmNmjej5iKKOTMPryOMubwkgMRyTWRuaq1yokASvI5mPhr2kzPN7UVjdCOjQvpunNPngR9yAHoirpjiWhUHw==",
"dev": true
},
"@neon-rs/load": {
"version": "0.0.74",
"resolved": "https://registry.npmjs.org/@neon-rs/load/-/load-0.0.74.tgz",
"integrity": "sha512-/cPZD907UNz55yrc/ud4wDgQKtU1TvkD9jeqZWG6J4IMmZkp6zgjkQcKA8UvpkZlcpPHvc8J17sGzLFbP/LUYg=="
},
"@nodelib/fs.scandir": {
"version": "2.1.5",
"resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz",

View File

@@ -1,16 +1,18 @@
{
"name": "vectordb",
"version": "0.1.10",
"version": "0.1.13",
"description": " Serverless, low-latency vector database for AI applications",
"main": "dist/index.js",
"types": "dist/index.d.ts",
"scripts": {
"tsc": "tsc -b",
"build": "cargo-cp-artifact --artifact cdylib vectordb-node index.node -- cargo build --message-format=json-render-diagnostics",
"build": "cargo-cp-artifact --artifact cdylib vectordb-node index.node -- cargo build --message-format=json",
"build-release": "npm run build -- --release",
"test": "npm run tsc; mocha -recursive dist/test",
"lint": "eslint src --ext .js,.ts",
"clean": "rm -rf node_modules *.node dist/"
"clean": "rm -rf node_modules *.node dist/",
"pack-build": "neon pack-build",
"check-npm": "printenv && which node && which npm && npm --version"
},
"repository": {
"type": "git",
@@ -25,6 +27,7 @@
"author": "Lance Devs",
"license": "Apache-2.0",
"devDependencies": {
"@neon-rs/cli": "^0.0.74",
"@types/chai": "^4.3.4",
"@types/chai-as-promised": "^7.1.5",
"@types/mocha": "^10.0.1",
@@ -52,6 +55,29 @@
},
"dependencies": {
"@apache-arrow/ts": "^12.0.0",
"@neon-rs/load": "^0.0.74",
"apache-arrow": "^12.0.0"
},
"os": [
"darwin",
"linux"
],
"cpu": [
"x64",
"arm64"
],
"neon": {
"targets": {
"x86_64-apple-darwin": "vectordb-darwin-x64",
"aarch64-apple-darwin": "vectordb-darwin-arm64",
"x86_64-unknown-linux-gnu": "vectordb-linux-x64-gnu",
"aarch64-unknown-linux-gnu": "vectordb-linux-arm64-gnu"
}
},
"optionalDependencies": {
"vectordb-darwin-arm64": "0.1.13",
"vectordb-darwin-x64": "0.1.13",
"vectordb-linux-x64-gnu": "0.1.13",
"vectordb-linux-arm64-gnu": "0.1.13"
}
}

View File

@@ -27,13 +27,38 @@ const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, t
export type { EmbeddingFunction }
export { OpenAIEmbeddingFunction } from './embedding/openai'
export interface AwsCredentials {
accessKeyId: string
secretKey: string
sessionToken?: string
}
export interface ConnectionOptions {
uri: string
awsCredentials?: AwsCredentials
}
/**
* Connect to a LanceDB instance at the given URI
* @param uri The uri of the database.
*/
export async function connect (uri: string): Promise<Connection> {
const db = await databaseNew(uri)
return new LocalConnection(db, uri)
export async function connect (uri: string): Promise<Connection>
export async function connect (opts: Partial<ConnectionOptions>): Promise<Connection>
export async function connect (arg: string | Partial<ConnectionOptions>): Promise<Connection> {
let opts: ConnectionOptions
if (typeof arg === 'string') {
opts = { uri: arg }
} else {
// opts = { uri: arg.uri, awsCredentials = arg.awsCredentials }
opts = Object.assign({
uri: '',
awsCredentials: undefined
}, arg)
}
const db = await databaseNew(opts.uri)
return new LocalConnection(db, opts)
}
/**
@@ -126,16 +151,16 @@ export interface Table<T = number[]> {
* A connection to a LanceDB database.
*/
export class LocalConnection implements Connection {
private readonly _uri: string
private readonly _options: ConnectionOptions
private readonly _db: any
constructor (db: any, uri: string) {
this._uri = uri
constructor (db: any, options: ConnectionOptions) {
this._options = options
this._db = db
}
get uri (): string {
return this._uri
return this._options.uri
}
/**
@@ -158,12 +183,13 @@ export class LocalConnection implements Connection {
* @param embeddings An embedding function to use on this Table
*/
async openTable<T> (name: string, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>>
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
const tbl = await databaseOpenTable.call(this._db, name)
if (embeddings !== undefined) {
return new LocalTable(tbl, name, embeddings)
return new LocalTable(tbl, name, this._options, embeddings)
} else {
return new LocalTable(tbl, name)
return new LocalTable(tbl, name, this._options)
}
}
@@ -186,15 +212,27 @@ export class LocalConnection implements Connection {
* @param embeddings An embedding function to use on this Table
*/
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings?: EmbeddingFunction<T>): Promise<Table<T>>
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
if (mode === undefined) {
mode = WriteMode.Create
}
const tbl = await tableCreate.call(this._db, name, await fromRecordsToBuffer(data, embeddings), mode.toLowerCase())
const createArgs = [this._db, name, await fromRecordsToBuffer(data, embeddings), mode.toLowerCase()]
if (this._options.awsCredentials !== undefined) {
createArgs.push(this._options.awsCredentials.accessKeyId)
createArgs.push(this._options.awsCredentials.secretKey)
if (this._options.awsCredentials.sessionToken !== undefined) {
createArgs.push(this._options.awsCredentials.sessionToken)
}
}
const tbl = await tableCreate.call(...createArgs)
if (embeddings !== undefined) {
return new LocalTable(tbl, name, embeddings)
return new LocalTable(tbl, name, this._options, embeddings)
} else {
return new LocalTable(tbl, name)
return new LocalTable(tbl, name, this._options)
}
}
@@ -217,18 +255,21 @@ export class LocalTable<T = number[]> implements Table<T> {
private readonly _tbl: any
private readonly _name: string
private readonly _embeddings?: EmbeddingFunction<T>
private readonly _options: ConnectionOptions
constructor (tbl: any, name: string)
constructor (tbl: any, name: string, options: ConnectionOptions)
/**
* @param tbl
* @param name
* @param options
* @param embeddings An embedding function to use when interacting with this table
*/
constructor (tbl: any, name: string, embeddings: EmbeddingFunction<T>)
constructor (tbl: any, name: string, embeddings?: EmbeddingFunction<T>) {
constructor (tbl: any, name: string, options: ConnectionOptions, embeddings: EmbeddingFunction<T>)
constructor (tbl: any, name: string, options: ConnectionOptions, embeddings?: EmbeddingFunction<T>) {
this._tbl = tbl
this._name = name
this._embeddings = embeddings
this._options = options
}
get name (): string {
@@ -250,7 +291,15 @@ export class LocalTable<T = number[]> implements Table<T> {
* @return The number of rows added to the table
*/
async add (data: Array<Record<string, unknown>>): Promise<number> {
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Append.toString())
const callArgs = [this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Append.toString()]
if (this._options.awsCredentials !== undefined) {
callArgs.push(this._options.awsCredentials.accessKeyId)
callArgs.push(this._options.awsCredentials.secretKey)
if (this._options.awsCredentials.sessionToken !== undefined) {
callArgs.push(this._options.awsCredentials.sessionToken)
}
}
return tableAdd.call(...callArgs)
}
/**
@@ -260,6 +309,14 @@ export class LocalTable<T = number[]> implements Table<T> {
* @return The number of rows added to the table
*/
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
const callArgs = [this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString()]
if (this._options.awsCredentials !== undefined) {
callArgs.push(this._options.awsCredentials.accessKeyId)
callArgs.push(this._options.awsCredentials.secretKey)
if (this._options.awsCredentials.sessionToken !== undefined) {
callArgs.push(this._options.awsCredentials.sessionToken)
}
}
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString())
}

View File

@@ -18,26 +18,48 @@ import { describe } from 'mocha'
import { assert } from 'chai'
import * as lancedb from '../index'
import { type ConnectionOptions } from '../index'
describe('LanceDB S3 client', function () {
if (process.env.TEST_S3_BASE_URL != null) {
const baseUri = process.env.TEST_S3_BASE_URL
it('should have a valid url', async function () {
const uri = `${baseUri}/valid_url`
const table = await createTestDB(uri, 2, 20)
const con = await lancedb.connect(uri)
assert.equal(con.uri, uri)
const opts = { uri: `${baseUri}/valid_url` }
const table = await createTestDB(opts, 2, 20)
const con = await lancedb.connect(opts)
assert.equal(con.uri, opts.uri)
const results = await table.search([0.1, 0.3]).limit(5).execute()
assert.equal(results.length, 5)
})
}).timeout(10_000)
} else {
describe.skip('Skip S3 test', function () {})
}
if (process.env.TEST_S3_BASE_URL != null && process.env.TEST_AWS_ACCESS_KEY_ID != null && process.env.TEST_AWS_SECRET_ACCESS_KEY != null) {
const baseUri = process.env.TEST_S3_BASE_URL
it('use custom credentials', async function () {
const opts: ConnectionOptions = {
uri: `${baseUri}/custom_credentials`,
awsCredentials: {
accessKeyId: process.env.TEST_AWS_ACCESS_KEY_ID as string,
secretKey: process.env.TEST_AWS_SECRET_ACCESS_KEY as string
}
}
const table = await createTestDB(opts, 2, 20)
const con = await lancedb.connect(opts)
assert.equal(con.uri, opts.uri)
const results = await table.search([0.1, 0.3]).limit(5).execute()
assert.equal(results.length, 5)
}).timeout(10_000)
} else {
describe.skip('Skip S3 test', function () {})
}
})
async function createTestDB (uri: string, numDimensions: number = 2, numRows: number = 2): Promise<lancedb.Table> {
const con = await lancedb.connect(uri)
async function createTestDB (opts: ConnectionOptions, numDimensions: number = 2, numRows: number = 2): Promise<lancedb.Table> {
const con = await lancedb.connect(opts)
const data = []
for (let i = 0; i < numRows; i++) {

View File

@@ -18,7 +18,7 @@ import * as chai from 'chai'
import * as chaiAsPromised from 'chai-as-promised'
import * as lancedb from '../index'
import { type EmbeddingFunction, MetricType, Query, WriteMode } from '../index'
import { type AwsCredentials, type EmbeddingFunction, MetricType, Query, WriteMode } from '../index'
const expect = chai.expect
const assert = chai.assert
@@ -32,6 +32,22 @@ describe('LanceDB client', function () {
assert.equal(con.uri, uri)
})
it('should accept an options object', async function () {
const uri = await createTestDB()
const con = await lancedb.connect({ uri })
assert.equal(con.uri, uri)
})
it('should accept custom aws credentials', async function () {
const uri = await createTestDB()
const awsCredentials: AwsCredentials = {
accessKeyId: '',
secretKey: ''
}
const con = await lancedb.connect({ uri, awsCredentials })
assert.equal(con.uri, uri)
})
it('should return the existing table names', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)

View File

@@ -15,6 +15,7 @@ from typing import Optional
from .db import URI, DBConnection, LanceDBConnection
from .remote.db import RemoteDBConnection
from .schema import vector
def connect(

169
python/lancedb/pydantic.py Normal file
View File

@@ -0,0 +1,169 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pydantic adapter for LanceDB"""
from __future__ import annotations
import inspect
import sys
import types
from abc import ABC, abstractstaticmethod
from typing import Any, List, Type, Union, _GenericAlias
import pyarrow as pa
import pydantic
from pydantic_core import CoreSchema, core_schema
class FixedSizeListMixin(ABC):
@abstractstaticmethod
def dim() -> int:
raise NotImplementedError
@abstractstaticmethod
def value_arrow_type() -> pa.DataType:
raise NotImplementedError
def vector(
dim: int, value_type: pa.DataType = pa.float32()
) -> Type[FixedSizeListMixin]:
"""Pydantic Vector Type.
Note
----
Experimental feature.
Examples
--------
>>> import pydantic
>>> from lancedb.pydantic import vector
...
>>> class MyModel(pydantic.BaseModel):
... vector: vector(756)
... id: int
... description: str
"""
# TODO: make a public parameterized type.
class FixedSizeList(list, FixedSizeListMixin):
@staticmethod
def dim() -> int:
return dim
@staticmethod
def value_arrow_type() -> pa.DataType:
return value_type
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
) -> CoreSchema:
return core_schema.no_info_after_validator_function(
cls,
core_schema.list_schema(
min_length=dim,
max_length=dim,
items_schema=core_schema.float_schema(),
),
)
return FixedSizeList
def _py_type_to_arrow_type(py_type: Type[Any]) -> pa.DataType:
"""Convert Python Type to Arrow DataType.
Raises
------
TypeError
If the type is not supported.
"""
if py_type == int:
return pa.int64()
elif py_type == float:
return pa.float64()
elif py_type == str:
return pa.utf8()
elif py_type == bool:
return pa.bool_()
elif py_type == bytes:
return pa.binary()
raise TypeError(
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}"
)
def _pydantic_model_to_fields(model: pydantic.BaseModel) -> List[pa.Field]:
fields = []
for name, field in model.model_fields.items():
fields.append(_pydantic_to_field(name, field))
return fields
def _pydantic_to_arrow_type(field: pydantic.fields.FieldInfo) -> pa.DataType:
"""Convert a Pydantic FieldInfo to Arrow DataType"""
if isinstance(field.annotation, _GenericAlias) or (
sys.version_info > (3, 9) and isinstance(field.annotation, types.GenericAlias)
):
origin = field.annotation.__origin__
args = field.annotation.__args__
if origin == list:
child = args[0]
return pa.list_(_py_type_to_arrow_type(child))
elif origin == Union:
if len(args) == 2 and args[1] == type(None):
return _py_type_to_arrow_type(args[0])
elif inspect.isclass(field.annotation):
if issubclass(field.annotation, pydantic.BaseModel):
# Struct
fields = _pydantic_model_to_fields(field.annotation)
return pa.struct(fields)
elif issubclass(field.annotation, FixedSizeListMixin):
return pa.list_(field.annotation.value_arrow_type(), field.annotation.dim())
return _py_type_to_arrow_type(field.annotation)
def is_nullable(field: pydantic.fields.FieldInfo) -> bool:
"""Check if a Pydantic FieldInfo is nullable."""
if isinstance(field.annotation, _GenericAlias):
origin = field.annotation.__origin__
args = field.annotation.__args__
if origin == Union:
if len(args) == 2 and args[1] == type(None):
return True
return False
def _pydantic_to_field(name: str, field: pydantic.fields.FieldInfo) -> pa.Field:
"""Convert a Pydantic field to a PyArrow Field."""
dt = _pydantic_to_arrow_type(field)
return pa.field(name, dt, is_nullable(field))
def pydantic_to_schema(model: Type[pydantic.BaseModel]) -> pa.Schema:
"""Convert a Pydantic model to a PyArrow Schema.
Parameters
----------
model : Type[pydantic.BaseModel]
The Pydantic BaseModel to convert to Arrow Schema.
Returns
-------
A PyArrow Schema.
"""
fields = _pydantic_model_to_fields(model)
return pa.schema(fields)

View File

@@ -13,11 +13,12 @@
import functools
from typing import Dict
from typing import Any, Callable, Dict, Union
import aiohttp
import attr
import pyarrow as pa
from pydantic import BaseModel
from lancedb.common import Credential
from lancedb.remote import VectorQuery, VectorQueryResult
@@ -34,6 +35,12 @@ def _check_not_closed(f):
return wrapped
async def _read_ipc(resp: aiohttp.ClientResponse) -> pa.Table:
resp_body = await resp.read()
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
return reader.read_all()
@attr.define(slots=False)
class RestfulLanceDBClient:
db_name: str
@@ -56,28 +63,67 @@ class RestfulLanceDBClient:
"x-api-key": self.api_key,
}
@staticmethod
async def _check_status(resp: aiohttp.ClientResponse):
if resp.status == 404:
raise LanceDBClientError(f"Not found: {await resp.text()}")
elif 400 <= resp.status < 500:
raise LanceDBClientError(
f"Bad Request: {resp.status}, error: {await resp.text()}"
)
elif 500 <= resp.status < 600:
raise LanceDBClientError(
f"Internal Server Error: {resp.status}, error: {await resp.text()}"
)
elif resp.status != 200:
raise LanceDBClientError(
f"Unknown Error: {resp.status}, error: {await resp.text()}"
)
@_check_not_closed
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
async def get(self, uri: str, params: Union[Dict[str, Any], BaseModel] = None):
"""Send a GET request and returns the deserialized response payload."""
if isinstance(params, BaseModel):
params: Dict[str, Any] = params.dict(exclude_none=True)
async with self.session.get(uri, params=params, headers=self.headers) as resp:
await self._check_status(resp)
return await resp.json()
@_check_not_closed
async def post(
self,
uri: str,
data: Union[Dict[str, Any], BaseModel],
deserialize: Callable = lambda resp: resp.json(),
) -> Dict[str, Any]:
"""Send a POST request and returns the deserialized response payload.
Parameters
----------
uri : str
The uri to send the POST request to.
data: Union[Dict[str, Any], BaseModel]
"""
if isinstance(data, BaseModel):
data: Dict[str, Any] = data.dict(exclude_none=True)
async with self.session.post(
f"/1/table/{table_name}/",
json=query.dict(exclude_none=True),
uri,
json=data,
headers=self.headers,
) as resp:
resp: aiohttp.ClientResponse = resp
if 400 <= resp.status < 500:
raise LanceDBClientError(
f"Bad Request: {resp.status}, error: {await resp.text()}"
)
if 500 <= resp.status < 600:
raise LanceDBClientError(
f"Internal Server Error: {resp.status}, error: {await resp.text()}"
)
if resp.status != 200:
raise LanceDBClientError(
f"Unknown Error: {resp.status}, error: {await resp.text()}"
)
await self._check_status(resp)
return await deserialize(resp)
resp_body = await resp.read()
with pa.ipc.open_file(pa.BufferReader(resp_body)) as reader:
tbl = reader.read_all()
@_check_not_closed
async def list_tables(self):
"""List all tables in the database."""
json = await self.get("/1/table/", {})
return json["tables"]
@_check_not_closed
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
"""Query a table."""
tbl = await self.post(f"/1/table/{table_name}/", query, deserialize=_read_ipc)
return VectorQueryResult(tbl)

View File

@@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from typing import List
from urllib.parse import urlparse
@@ -34,12 +35,18 @@ class RemoteDBConnection(DBConnection):
self.db_name = parsed.netloc
self.api_key = api_key
self._client = RestfulLanceDBClient(self.db_name, region, api_key)
try:
self._loop = asyncio.get_running_loop()
except RuntimeError:
self._loop = asyncio.get_event_loop()
def __repr__(self) -> str:
return f"RemoveConnect(name={self.db_name})"
def table_names(self) -> List[str]:
raise NotImplementedError
"""List the names of all tables in the database."""
result = self._loop.run_until_complete(self._client.list_tables())
return result
def open_table(self, name: str) -> Table:
"""Open a Lance Table in the database.

View File

@@ -11,7 +11,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from typing import Union
import pyarrow as pa
@@ -62,9 +61,5 @@ class RemoteTable(Table):
return LanceQueryBuilder(self, query, vector_column)
def _execute_query(self, query: Query) -> pa.Table:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.get_event_loop()
result = self._conn._client.query(self._name, query)
return loop.run_until_complete(result).to_arrow()
return self._conn._loop.run_until_complete(result).to_arrow()

289
python/lancedb/schema.py Normal file
View File

@@ -0,0 +1,289 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Schema related utilities."""
import json
from typing import Any, Dict, Type
import pyarrow as pa
def vector(dimension: int, value_type: pa.DataType = pa.float32()) -> pa.DataType:
"""A help function to create a vector type.
Parameters
----------
dimension: The dimension of the vector.
value_type: pa.DataType, optional
The type of the value in the vector.
Returns
-------
A PyArrow DataType for vectors.
Examples
--------
>>> import pyarrow as pa
>>> import lancedb
>>> schema = pa.schema([
... pa.field("id", pa.int64()),
... pa.field("vector", lancedb.vector(756)),
... ])
"""
return pa.list_(value_type, dimension)
def _type_to_dict(dt: pa.DataType) -> Dict[str, Any]:
if pa.types.is_boolean(dt):
return {"type": "boolean"}
elif pa.types.is_int8(dt):
return {"type": "int8"}
elif pa.types.is_int16(dt):
return {"type": "int16"}
elif pa.types.is_int32(dt):
return {"type": "int32"}
elif pa.types.is_int64(dt):
return {"type": "int64"}
elif pa.types.is_uint8(dt):
return {"type": "uint8"}
elif pa.types.is_uint16(dt):
return {"type": "uint16"}
elif pa.types.is_uint32(dt):
return {"type": "uint32"}
elif pa.types.is_uint64(dt):
return {"type": "uint64"}
elif pa.types.is_float16(dt):
return {"type": "float16"}
elif pa.types.is_float32(dt):
return {"type": "float32"}
elif pa.types.is_float64(dt):
return {"type": "float64"}
elif pa.types.is_date32(dt):
return {"type": f"date32"}
elif pa.types.is_date64(dt):
return {"type": f"date64"}
elif pa.types.is_time32(dt):
return {"type": f"time32:{dt.unit}"}
elif pa.types.is_time64(dt):
return {"type": f"time64:{dt.unit}"}
elif pa.types.is_timestamp(dt):
return {"type": f"timestamp:{dt.unit}:{dt.tz if dt.tz is not None else ''}"}
elif pa.types.is_string(dt):
return {"type": "string"}
elif pa.types.is_binary(dt):
return {"type": "binary"}
elif pa.types.is_large_string(dt):
return {"type": "large_string"}
elif pa.types.is_large_binary(dt):
return {"type": "large_binary"}
elif pa.types.is_fixed_size_binary(dt):
return {"type": "fixed_size_binary", "width": dt.byte_width}
elif pa.types.is_fixed_size_list(dt):
return {
"type": "fixed_size_list",
"width": dt.list_size,
"value_type": _type_to_dict(dt.value_type),
}
elif pa.types.is_list(dt):
return {
"type": "list",
"value_type": _type_to_dict(dt.value_type),
}
elif pa.types.is_struct(dt):
return {
"type": "struct",
"fields": [_field_to_dict(dt.field(i)) for i in range(dt.num_fields)],
}
elif pa.types.is_dictionary(dt):
return {
"type": "dictionary",
"index_type": _type_to_dict(dt.index_type),
"value_type": _type_to_dict(dt.value_type),
}
# TODO: support extension types
raise TypeError(f"Unsupported type: {dt}")
def _field_to_dict(field: pa.field) -> Dict[str, Any]:
ret = {
"name": field.name,
"type": _type_to_dict(field.type),
"nullable": field.nullable,
}
if field.metadata is not None:
ret["metadata"] = field.metadata
return ret
def schema_to_dict(schema: pa.Schema) -> Dict[str, Any]:
"""Convert a PyArrow [Schema](pyarrow.Schema) to a dictionary.
Parameters
----------
schema : pa.Schema
The PyArrow Schema to convert
Returns
-------
A dict of the data type.
Examples
--------
>>> import pyarrow as pa
>>> import lancedb
>>> schema = pa.schema(
... [
... pa.field("id", pa.int64()),
... pa.field("vector", lancedb.vector(512), nullable=False),
... pa.field(
... "struct",
... pa.struct(
... [
... pa.field("a", pa.utf8()),
... pa.field("b", pa.float32()),
... ]
... ),
... True,
... ),
... ],
... metadata={"key": "value"},
... )
>>> json_schema = schema_to_dict(schema)
>>> assert json_schema == {
... "fields": [
... {"name": "id", "type": {"type": "int64"}, "nullable": True},
... {
... "name": "vector",
... "type": {
... "type": "fixed_size_list",
... "value_type": {"type": "float32"},
... "width": 512,
... },
... "nullable": False,
... },
... {
... "name": "struct",
... "type": {
... "type": "struct",
... "fields": [
... {"name": "a", "type": {"type": "string"}, "nullable": True},
... {"name": "b", "type": {"type": "float32"}, "nullable": True},
... ],
... },
... "nullable": True,
... },
... ],
... "metadata": {"key": "value"},
... }
"""
fields = []
for name in schema.names:
field = schema.field(name)
fields.append(_field_to_dict(field))
json_schema = {
"fields": fields,
"metadata": {
k.decode("utf-8"): v.decode("utf-8") for (k, v) in schema.metadata.items()
}
if schema.metadata is not None
else {},
}
return json_schema
def _dict_to_type(dt: Dict[str, Any]) -> pa.DataType:
type_name = dt["type"]
try:
return {
"boolean": pa.bool_(),
"int8": pa.int8(),
"int16": pa.int16(),
"int32": pa.int32(),
"int64": pa.int64(),
"uint8": pa.uint8(),
"uint16": pa.uint16(),
"uint32": pa.uint32(),
"uint64": pa.uint64(),
"float16": pa.float16(),
"float32": pa.float32(),
"float64": pa.float64(),
"string": pa.string(),
"binary": pa.binary(),
"large_string": pa.large_string(),
"large_binary": pa.large_binary(),
"date32": pa.date32(),
"date64": pa.date64(),
}[type_name]
except KeyError:
pass
if type_name == "fixed_size_binary":
return pa.binary(dt["width"])
elif type_name == "fixed_size_list":
return pa.list_(_dict_to_type(dt["value_type"]), dt["width"])
elif type_name == "list":
return pa.list_(_dict_to_type(dt["value_type"]))
elif type_name == "struct":
fields = []
for field in dt["fields"]:
fields.append(_dict_to_field(field))
return pa.struct(fields)
elif type_name == "dictionary":
return pa.dictionary(
_dict_to_type(dt["index_type"]), _dict_to_type(dt["value_type"])
)
elif type_name.startswith("time32:"):
return pa.time32(type_name.split(":")[1])
elif type_name.startswith("time64:"):
return pa.time64(type_name.split(":")[1])
elif type_name.startswith("timestamp:"):
fields = type_name.split(":")
unit = fields[1]
tz = fields[2] if len(fields) > 2 else None
return pa.timestamp(unit, tz)
raise TypeError(f"Unsupported type: {dt}")
def _dict_to_field(field: Dict[str, Any]) -> pa.Field:
name = field["name"]
nullable = field["nullable"] if "nullable" in field else True
dt = _dict_to_type(field["type"])
metadata = field.get("metadata", None)
return pa.field(name, dt, nullable, metadata)
def dict_to_schema(json: Dict[str, Any]) -> pa.Schema:
"""Reconstruct a PyArrow Schema from a JSON dict.
Parameters
----------
json : Dict[str, Any]
The JSON dict to reconstruct Schema from.
Returns
-------
A PyArrow Schema.
"""
fields = []
for field in json["fields"]:
fields.append(_dict_to_field(field))
metadata = {
k.encode("utf-8"): v.encode("utf-8")
for (k, v) in json.get("metadata", {}).items()
}
return pa.schema(fields, metadata)

View File

@@ -1,7 +1,7 @@
[project]
name = "lancedb"
version = "0.1.10"
dependencies = ["pylance~=0.5.0", "ratelimiter", "retry", "tqdm", "aiohttp", "pydantic", "attr"]
dependencies = ["pylance~=0.5.0", "ratelimiter", "retry", "tqdm", "aiohttp", "pydantic>=2", "attr"]
description = "lancedb"
authors = [
{ name = "LanceDB Devs", email = "dev@lancedb.com" },

View File

@@ -0,0 +1,155 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import sys
from typing import List, Optional
import pyarrow as pa
import pydantic
import pytest
from lancedb.pydantic import pydantic_to_schema, vector
@pytest.mark.skipif(
sys.version_info < (3, 9),
reason="using native type alias requires python3.9 or higher",
)
def test_pydantic_to_arrow():
class StructModel(pydantic.BaseModel):
a: str
b: Optional[float]
class TestModel(pydantic.BaseModel):
id: int
s: str
vec: list[float]
li: List[int]
opt: Optional[str] = None
st: StructModel
# d: dict
m = TestModel(
id=1, s="hello", vec=[1.0, 2.0, 3.0], li=[2, 3, 4], st=StructModel(a="a", b=1.0)
)
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.int64(), False),
pa.field("s", pa.utf8(), False),
pa.field("vec", pa.list_(pa.float64()), False),
pa.field("li", pa.list_(pa.int64()), False),
pa.field("opt", pa.utf8(), True),
pa.field(
"st",
pa.struct(
[pa.field("a", pa.utf8(), False), pa.field("b", pa.float64(), True)]
),
False,
),
]
)
assert schema == expect_schema
def test_pydantic_to_arrow_py38():
class StructModel(pydantic.BaseModel):
a: str
b: Optional[float]
class TestModel(pydantic.BaseModel):
id: int
s: str
vec: List[float]
li: List[int]
opt: Optional[str] = None
st: StructModel
# d: dict
m = TestModel(
id=1, s="hello", vec=[1.0, 2.0, 3.0], li=[2, 3, 4], st=StructModel(a="a", b=1.0)
)
schema = pydantic_to_schema(TestModel)
expect_schema = pa.schema(
[
pa.field("id", pa.int64(), False),
pa.field("s", pa.utf8(), False),
pa.field("vec", pa.list_(pa.float64()), False),
pa.field("li", pa.list_(pa.int64()), False),
pa.field("opt", pa.utf8(), True),
pa.field(
"st",
pa.struct(
[pa.field("a", pa.utf8(), False), pa.field("b", pa.float64(), True)]
),
False,
),
]
)
assert schema == expect_schema
def test_fixed_size_list_field():
class TestModel(pydantic.BaseModel):
vec: vector(16)
li: List[int]
data = TestModel(vec=list(range(16)), li=[1, 2, 3])
assert json.loads(data.model_dump_json()) == {
"vec": list(range(16)),
"li": [1, 2, 3],
}
schema = pydantic_to_schema(TestModel)
assert schema == pa.schema(
[
pa.field("vec", pa.list_(pa.float32(), 16), False),
pa.field("li", pa.list_(pa.int64()), False),
]
)
json_schema = TestModel.model_json_schema()
assert json_schema == {
"properties": {
"vec": {
"items": {"type": "number"},
"maxItems": 16,
"minItems": 16,
"title": "Vec",
"type": "array",
},
"li": {"items": {"type": "integer"}, "title": "Li", "type": "array"},
},
"required": ["vec", "li"],
"title": "TestModel",
"type": "object",
}
def test_fixed_size_list_validation():
class TestModel(pydantic.BaseModel):
vec: vector(8)
with pytest.raises(pydantic.ValidationError):
TestModel(vec=range(9))
with pytest.raises(pydantic.ValidationError):
TestModel(vec=range(7))
TestModel(vec=range(8))

109
python/tests/test_schema.py Normal file
View File

@@ -0,0 +1,109 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pyarrow as pa
import lancedb
from lancedb.schema import dict_to_schema, schema_to_dict
def test_schema_to_dict():
schema = pa.schema(
[
pa.field("id", pa.int64()),
pa.field("vector", lancedb.vector(512), nullable=False),
pa.field(
"struct",
pa.struct(
[
pa.field("a", pa.utf8()),
pa.field("b", pa.float32()),
]
),
True,
),
pa.field("d", pa.dictionary(pa.int64(), pa.utf8()), False),
],
metadata={"key": "value"},
)
json_schema = schema_to_dict(schema)
assert json_schema == {
"fields": [
{"name": "id", "type": {"type": "int64"}, "nullable": True},
{
"name": "vector",
"type": {
"type": "fixed_size_list",
"value_type": {"type": "float32"},
"width": 512,
},
"nullable": False,
},
{
"name": "struct",
"type": {
"type": "struct",
"fields": [
{"name": "a", "type": {"type": "string"}, "nullable": True},
{"name": "b", "type": {"type": "float32"}, "nullable": True},
],
},
"nullable": True,
},
{
"name": "d",
"type": {
"type": "dictionary",
"index_type": {"type": "int64"},
"value_type": {"type": "string"},
},
"nullable": False,
},
],
"metadata": {"key": "value"},
}
actual_schema = dict_to_schema(json_schema)
assert actual_schema == schema
def test_temporal_types():
schema = pa.schema(
[
pa.field("t32", pa.time32("s")),
pa.field("t32ms", pa.time32("ms")),
pa.field("t64", pa.time64("ns")),
pa.field("ts", pa.timestamp("s")),
pa.field("ts_us_tz", pa.timestamp("us", tz="America/New_York")),
],
)
json_schema = schema_to_dict(schema)
assert json_schema == {
"fields": [
{"name": "t32", "type": {"type": "time32:s"}, "nullable": True},
{"name": "t32ms", "type": {"type": "time32:ms"}, "nullable": True},
{"name": "t64", "type": {"type": "time64:ns"}, "nullable": True},
{"name": "ts", "type": {"type": "timestamp:s:"}, "nullable": True},
{
"name": "ts_us_tz",
"type": {"type": "timestamp:us:America/New_York"},
"nullable": True,
},
],
"metadata": {},
}
actual_schema = dict_to_schema(json_schema)
assert actual_schema == schema

View File

@@ -1,6 +1,6 @@
[package]
name = "vectordb-node"
version = "0.1.10"
version = "0.1.13"
description = "Serverless, low-latency vector database for AI applications"
license = "Apache-2.0"
edition = "2018"
@@ -19,3 +19,6 @@ lance = { workspace = true }
vectordb = { path = "../../vectordb" }
tokio = { version = "1.23", features = ["rt-multi-thread"] }
neon = {version = "0.10.1", default-features = false, features = ["channel-api", "napi-6", "promise-api", "task-api"] }
object_store = { workspace = true, features = ["aws"] }
async-trait = "0"
env_logger = "0"

View File

@@ -13,7 +13,6 @@
// limitations under the License.
use std::io::Cursor;
use std::ops::Deref;
use std::sync::Arc;
use arrow_array::cast::as_list_array;
@@ -25,10 +24,13 @@ use lance::arrow::{FixedSizeListArrayExt, RecordBatchExt};
pub(crate) fn convert_record_batch(record_batch: RecordBatch) -> RecordBatch {
let column = record_batch
.column_by_name("vector")
.cloned()
.expect("vector column is missing");
let arr = as_list_array(column.deref());
// TODO: we should just consume the underlaying js buffer in the future instead of this arrow around a bunch of times
let arr = as_list_array(column.as_ref());
let list_size = arr.values().len() / record_batch.num_rows();
let r = FixedSizeListArray::try_new(arr.values(), list_size as i32).unwrap();
let r =
FixedSizeListArray::try_new_from_values(arr.values().to_owned(), list_size as i32).unwrap();
let schema = Arc::new(Schema::new(vec![Field::new(
"vector",

View File

@@ -17,19 +17,23 @@ use std::convert::TryFrom;
use std::ops::Deref;
use std::sync::{Arc, Mutex};
use arrow_array::{Float32Array, RecordBatchIterator, RecordBatchReader};
use arrow_array::{Float32Array, RecordBatchIterator};
use arrow_ipc::writer::FileWriter;
use async_trait::async_trait;
use futures::{TryFutureExt, TryStreamExt};
use lance::dataset::{WriteMode, WriteParams};
use lance::dataset::{ReadParams, WriteMode, WriteParams};
use lance::index::vector::MetricType;
use lance::io::object_store::ObjectStoreParams;
use neon::prelude::*;
use neon::types::buffer::TypedArray;
use object_store::aws::{AwsCredential, AwsCredentialProvider};
use object_store::CredentialProvider;
use once_cell::sync::OnceCell;
use tokio::runtime::Runtime;
use vectordb::database::Database;
use vectordb::error::Error;
use vectordb::table::Table;
use vectordb::table::{OpenTableParams, Table};
use crate::arrow::arrow_buffer_to_record_batch;
@@ -49,8 +53,38 @@ struct JsTable {
impl Finalize for JsTable {}
// TODO: object_store didn't export this type so I copied it.
// Make a requiest to object_store to export this type
#[derive(Debug)]
pub struct StaticCredentialProvider<T> {
credential: Arc<T>,
}
impl<T> StaticCredentialProvider<T> {
pub fn new(credential: T) -> Self {
Self {
credential: Arc::new(credential),
}
}
}
#[async_trait]
impl<T> CredentialProvider for StaticCredentialProvider<T>
where
T: std::fmt::Debug + Send + Sync,
{
type Credential = T;
async fn get_credential(&self) -> object_store::Result<Arc<T>> {
Ok(Arc::clone(&self.credential))
}
}
fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> {
static RUNTIME: OnceCell<Runtime> = OnceCell::new();
static LOG: OnceCell<()> = OnceCell::new();
LOG.get_or_init(|| env_logger::init());
RUNTIME.get_or_try_init(|| Runtime::new().or_else(|err| cx.throw_error(err.to_string())))
}
@@ -97,19 +131,74 @@ fn database_table_names(mut cx: FunctionContext) -> JsResult<JsPromise> {
Ok(promise)
}
fn get_aws_creds<T>(
cx: &mut FunctionContext,
arg_starting_location: i32,
) -> Result<Option<AwsCredentialProvider>, NeonResult<T>> {
let secret_key_id = cx
.argument_opt(arg_starting_location)
.map(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
.flatten()
.map(|v| v.value(cx));
let secret_key = cx
.argument_opt(arg_starting_location + 1)
.map(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
.flatten()
.map(|v| v.value(cx));
let temp_token = cx
.argument_opt(arg_starting_location + 2)
.map(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
.flatten()
.map(|v| v.value(cx));
match (secret_key_id, secret_key, temp_token) {
(Some(key_id), Some(key), optional_token) => Ok(Some(Arc::new(
StaticCredentialProvider::new(AwsCredential {
key_id: key_id,
secret_key: key,
token: optional_token,
}),
))),
(None, None, None) => Ok(None),
_ => Err(cx.throw_error("Invalid credentials configuration")),
}
}
fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
let db = cx
.this()
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
let aws_creds = match get_aws_creds(&mut cx, 1) {
Ok(creds) => creds,
Err(err) => return err,
};
let param = ReadParams {
store_options: Some(ObjectStoreParams {
aws_credentials: aws_creds,
..ObjectStoreParams::default()
}),
..ReadParams::default()
};
let rt = runtime(&mut cx)?;
let channel = cx.channel();
let database = db.database.clone();
let (deferred, promise) = cx.promise();
rt.spawn(async move {
let table_rst = database.open_table(&table_name).await;
let table_rst = database
.open_table_with_params(
&table_name,
OpenTableParams {
open_table_params: param,
},
)
.await;
deferred.settle_with(&channel, move |mut cx| {
let table = Arc::new(Mutex::new(
@@ -241,8 +330,6 @@ fn table_create(mut cx: FunctionContext) -> JsResult<JsPromise> {
"create" => WriteMode::Create,
_ => return cx.throw_error("Table::create only supports 'overwrite' and 'create' modes"),
};
let mut params = WriteParams::default();
params.mode = mode;
let rt = runtime(&mut cx)?;
let channel = cx.channel();
@@ -250,11 +337,22 @@ fn table_create(mut cx: FunctionContext) -> JsResult<JsPromise> {
let (deferred, promise) = cx.promise();
let database = db.database.clone();
let aws_creds = match get_aws_creds(&mut cx, 3) {
Ok(creds) => creds,
Err(err) => return err,
};
let params = WriteParams {
store_params: Some(ObjectStoreParams {
aws_credentials: aws_creds,
..ObjectStoreParams::default()
}),
mode: mode,
..WriteParams::default()
};
rt.block_on(async move {
let batch_reader: Box<dyn RecordBatchReader> = Box::new(RecordBatchIterator::new(
batches.into_iter().map(Ok),
schema,
));
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
let table_rst = database
.create_table(&table_name, batch_reader, Some(params))
.await;
@@ -289,16 +387,27 @@ fn table_add(mut cx: FunctionContext) -> JsResult<JsPromise> {
let table = js_table.table.clone();
let write_mode = write_mode_map.get(write_mode.as_str()).cloned();
let aws_creds = match get_aws_creds(&mut cx, 2) {
Ok(creds) => creds,
Err(err) => return err,
};
let params = WriteParams {
store_params: Some(ObjectStoreParams {
aws_credentials: aws_creds,
..ObjectStoreParams::default()
}),
mode: write_mode.unwrap_or(WriteMode::Append),
..WriteParams::default()
};
rt.block_on(async move {
let batch_reader: Box<dyn RecordBatchReader> = Box::new(RecordBatchIterator::new(
batches.into_iter().map(Ok),
schema,
));
let add_result = table.lock().unwrap().add(batch_reader, write_mode).await;
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
let add_result = table.lock().unwrap().add(batch_reader, Some(params)).await;
deferred.settle_with(&channel, move |mut cx| {
let added = add_result.or_else(|err| cx.throw_error(err.to_string()))?;
Ok(cx.number(added as f64))
let _added = add_result.or_else(|err| cx.throw_error(err.to_string()))?;
Ok(cx.boolean(true))
});
});
Ok(promise)

View File

@@ -1,6 +1,6 @@
[package]
name = "vectordb"
version = "0.1.10"
version = "0.1.13"
edition = "2021"
description = "Serverless, low-latency vector database for AI applications"
license = "Apache-2.0"

View File

@@ -100,7 +100,7 @@ impl Database {
pub async fn create_table(
&self,
name: &str,
batches: Box<dyn RecordBatchReader>,
batches: impl RecordBatchReader + Send + 'static,
params: Option<WriteParams>,
) -> Result<Table> {
Table::create(&self.uri, name, batches, params).await

View File

@@ -173,10 +173,8 @@ mod tests {
#[tokio::test]
async fn test_setters_getters() {
let mut batches: Box<dyn RecordBatchReader> = make_test_batches();
let ds = Dataset::write(&mut batches, "memory://foo", None)
.await
.unwrap();
let batches = make_test_batches();
let ds = Dataset::write(batches, "memory://foo", None).await.unwrap();
let vector = Float32Array::from_iter_values([0.1, 0.2]);
let query = Query::new(Arc::new(ds), vector.clone());
@@ -202,10 +200,8 @@ mod tests {
#[tokio::test]
async fn test_execute() {
let mut batches: Box<dyn RecordBatchReader> = make_test_batches();
let ds = Dataset::write(&mut batches, "memory://foo", None)
.await
.unwrap();
let batches = make_test_batches();
let ds = Dataset::write(batches, "memory://foo", None).await.unwrap();
let vector = Float32Array::from_iter_values([0.1; 128]);
let query = Query::new(Arc::new(ds), vector.clone());
@@ -213,7 +209,7 @@ mod tests {
assert_eq!(result.is_ok(), true);
}
fn make_test_batches() -> Box<dyn RecordBatchReader> {
fn make_test_batches() -> impl RecordBatchReader + Send + 'static {
let dim: usize = 128;
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("key", DataType::Int32, false),
@@ -227,11 +223,11 @@ mod tests {
),
ArrowField::new("uri", DataType::Utf8, true),
]));
Box::new(RecordBatchIterator::new(
RecordBatchIterator::new(
vec![RecordBatch::new_empty(schema.clone())]
.into_iter()
.map(Ok),
schema,
))
)
}
}

View File

@@ -22,8 +22,8 @@ use snafu::prelude::*;
use crate::error::{Error, InvalidTableNameSnafu, Result};
use crate::index::vector::VectorIndexBuilder;
use crate::WriteMode;
use crate::query::Query;
use crate::WriteMode;
pub const VECTOR_COLUMN_NAME: &str = "vector";
pub const LANCE_FILE_EXTENSION: &str = "lance";
@@ -117,7 +117,7 @@ impl Table {
pub async fn create(
base_uri: &str,
name: &str,
mut batches: Box<dyn RecordBatchReader>,
batches: impl RecordBatchReader + Send + 'static,
params: Option<WriteParams>,
) -> Result<Self> {
let base_path = Path::new(base_uri);
@@ -127,7 +127,7 @@ impl Table {
.to_str()
.context(InvalidTableNameSnafu { name })?
.to_string();
let dataset = Dataset::write(&mut batches, &uri, params)
let dataset = Dataset::write(batches, &uri, params)
.await
.map_err(|e| match e {
lance::Error::DatasetAlreadyExists { .. } => Error::TableAlreadyExists {
@@ -176,14 +176,16 @@ impl Table {
/// * The number of rows added
pub async fn add(
&mut self,
mut batches: Box<dyn RecordBatchReader>,
write_mode: Option<WriteMode>,
) -> Result<usize> {
let mut params = WriteParams::default();
params.mode = write_mode.unwrap_or(WriteMode::Append);
batches: impl RecordBatchReader + Send + 'static,
params: Option<WriteParams>,
) -> Result<()> {
let params = params.unwrap_or(WriteParams {
mode: WriteMode::Append,
..WriteParams::default()
});
self.dataset = Arc::new(Dataset::write(&mut batches, &self.uri, Some(params)).await?);
Ok(batches.count())
self.dataset = Arc::new(Dataset::write(batches, &self.uri, Some(params)).await?);
Ok(())
}
/// Creates a new Query object that can be executed.
@@ -207,12 +209,12 @@ impl Table {
/// Merge new data into this table.
pub async fn merge(
&mut self,
mut batches: Box<dyn RecordBatchReader>,
batches: impl RecordBatchReader + Send + 'static,
left_on: &str,
right_on: &str,
) -> Result<()> {
let mut dataset = self.dataset.as_ref().clone();
dataset.merge(&mut batches, left_on, right_on).await?;
dataset.merge(batches, left_on, right_on).await?;
self.dataset = Arc::new(dataset);
Ok(())
}
@@ -253,8 +255,8 @@ mod tests {
let dataset_path = tmp_dir.path().join("test.lance");
let uri = tmp_dir.path().to_str().unwrap();
let mut batches: Box<dyn RecordBatchReader> = make_test_batches();
Dataset::write(&mut batches, dataset_path.to_str().unwrap(), None)
let batches = make_test_batches();
Dataset::write(batches, dataset_path.to_str().unwrap(), None)
.await
.unwrap();
@@ -284,11 +286,11 @@ mod tests {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let batches: Box<dyn RecordBatchReader> = make_test_batches();
let batches = make_test_batches();
let _ = batches.schema().clone();
Table::create(&uri, "test", batches, None).await.unwrap();
let batches: Box<dyn RecordBatchReader> = make_test_batches();
let batches = make_test_batches();
let result = Table::create(&uri, "test", batches, None).await;
assert!(matches!(
result.unwrap_err(),
@@ -301,12 +303,12 @@ mod tests {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let batches: Box<dyn RecordBatchReader> = make_test_batches();
let batches = make_test_batches();
let schema = batches.schema().clone();
let mut table = Table::create(&uri, "test", batches, None).await.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
let new_batches: Box<dyn RecordBatchReader> = Box::new(RecordBatchIterator::new(
let new_batches = RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(100..110))],
@@ -315,7 +317,7 @@ mod tests {
.into_iter()
.map(Ok),
schema.clone(),
));
);
table.add(new_batches, None).await.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 20);
@@ -327,12 +329,12 @@ mod tests {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let batches: Box<dyn RecordBatchReader> = make_test_batches();
let batches = make_test_batches();
let schema = batches.schema().clone();
let mut table = Table::create(uri, "test", batches, None).await.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
let new_batches: Box<dyn RecordBatchReader> = Box::new(RecordBatchIterator::new(
let new_batches = RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(100..110))],
@@ -341,10 +343,15 @@ mod tests {
.into_iter()
.map(Ok),
schema.clone(),
));
);
let param: WriteParams = WriteParams {
mode: WriteMode::Overwrite,
..Default::default()
};
table
.add(new_batches, Some(WriteMode::Overwrite))
.add(new_batches, Some(param))
.await
.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
@@ -357,8 +364,8 @@ mod tests {
let dataset_path = tmp_dir.path().join("test.lance");
let uri = tmp_dir.path().to_str().unwrap();
let mut batches: Box<dyn RecordBatchReader> = make_test_batches();
Dataset::write(&mut batches, dataset_path.to_str().unwrap(), None)
let batches = make_test_batches();
Dataset::write(batches, dataset_path.to_str().unwrap(), None)
.await
.unwrap();
@@ -369,7 +376,7 @@ mod tests {
assert_eq!(vector, query.query_vector);
}
#[derive(Default)]
#[derive(Default, Debug)]
struct NoOpCacheWrapper {
called: AtomicBool,
}
@@ -396,8 +403,8 @@ mod tests {
let dataset_path = tmp_dir.path().join("test.lance");
let uri = tmp_dir.path().to_str().unwrap();
let mut batches: Box<dyn RecordBatchReader> = make_test_batches();
Dataset::write(&mut batches, dataset_path.to_str().unwrap(), None)
let batches = make_test_batches();
Dataset::write(batches, dataset_path.to_str().unwrap(), None)
.await
.unwrap();
@@ -417,15 +424,15 @@ mod tests {
assert!(wrapper.called());
}
fn make_test_batches() -> Box<dyn RecordBatchReader> {
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
Box::new(RecordBatchIterator::new(
RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..10))],
)],
schema,
))
)
}
#[tokio::test]
@@ -465,9 +472,7 @@ mod tests {
schema,
);
let reader: Box<dyn RecordBatchReader + Send> = Box::new(batches);
let mut table = Table::create(uri, "test", reader, None).await.unwrap();
let mut table = Table::create(uri, "test", batches, None).await.unwrap();
let mut i = IvfPQIndexBuilder::new();
let index_builder = i