Compare commits

..

8 Commits

Author SHA1 Message Date
ayush chaurasia
887ac0d79d Merge branch 'main' of https://github.com/lancedb/lancedb into hybrid_query 2024-03-01 11:14:06 +05:30
ayush chaurasia
3ad4992282 update 2024-02-23 14:11:58 +05:30
ayush chaurasia
51cc422799 update 2024-02-23 14:03:21 +05:30
ayush chaurasia
a696dbc8f4 update 2024-02-23 13:54:44 +05:30
ayush chaurasia
9ca0260d54 update 2024-02-23 03:03:39 +05:30
ayush chaurasia
6486ec870b update 2024-02-23 03:02:05 +05:30
ayush chaurasia
64db2393f7 update 2024-02-22 16:28:17 +05:30
ayush chaurasia
bd4e8341fe update 2024-02-21 21:43:23 +05:30
141 changed files with 3503 additions and 12013 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 0.4.13 current_version = 0.4.11
commit = True commit = True
message = Bump version: {current_version} → {new_version} message = Bump version: {current_version} → {new_version}
tag = True tag = True

View File

@@ -24,14 +24,10 @@ jobs:
environment: environment:
name: github-pages name: github-pages
url: ${{ steps.deployment.outputs.page_url }} url: ${{ steps.deployment.outputs.page_url }}
runs-on: buildjet-8vcpu-ubuntu-2204 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Install dependecies needed for ubuntu
run: |
sudo apt install -y protobuf-compiler libssl-dev
rustup update && rustup default
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:

View File

@@ -24,22 +24,16 @@ env:
jobs: jobs:
test-python: test-python:
name: Test doc python code name: Test doc python code
runs-on: "buildjet-8vcpu-ubuntu-2204" runs-on: "ubuntu-latest"
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Install dependecies needed for ubuntu
run: |
sudo apt install -y protobuf-compiler libssl-dev
rustup update && rustup default
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: 3.11 python-version: 3.11
cache: "pip" cache: "pip"
cache-dependency-path: "docs/test/requirements.txt" cache-dependency-path: "docs/test/requirements.txt"
- name: Rust cache
uses: swatinem/rust-cache@v2
- name: Build Python - name: Build Python
working-directory: docs/test working-directory: docs/test
run: run:
@@ -54,8 +48,8 @@ jobs:
for d in *; do cd "$d"; echo "$d".py; python "$d".py; cd ..; done for d in *; do cd "$d"; echo "$d".py; python "$d".py; cd ..; done
test-node: test-node:
name: Test doc nodejs code name: Test doc nodejs code
runs-on: "buildjet-8vcpu-ubuntu-2204" runs-on: "ubuntu-latest"
timeout-minutes: 60 timeout-minutes: 45
strategy: strategy:
fail-fast: false fail-fast: false
steps: steps:
@@ -71,7 +65,6 @@ jobs:
- name: Install dependecies needed for ubuntu - name: Install dependecies needed for ubuntu
run: | run: |
sudo apt install -y protobuf-compiler libssl-dev sudo apt install -y protobuf-compiler libssl-dev
rustup update && rustup default
- name: Rust cache - name: Rust cache
uses: swatinem/rust-cache@v2 uses: swatinem/rust-cache@v2
- name: Install node dependencies - name: Install node dependencies

View File

@@ -24,6 +24,27 @@ env:
RUST_BACKTRACE: "1" RUST_BACKTRACE: "1"
jobs: jobs:
lint:
name: Lint
runs-on: ubuntu-22.04
defaults:
run:
shell: bash
working-directory: node
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
lfs: true
- uses: actions/setup-node@v3
with:
node-version: 20
cache: 'npm'
cache-dependency-path: node/package-lock.json
- name: Lint
run: |
npm ci
npm run lint
linux: linux:
name: Linux (Node ${{ matrix.node-version }}) name: Linux (Node ${{ matrix.node-version }})
timeout-minutes: 30 timeout-minutes: 30

View File

@@ -49,7 +49,6 @@ jobs:
cargo clippy --all --all-features -- -D warnings cargo clippy --all --all-features -- -D warnings
npm ci npm ci
npm run lint npm run lint
npm run chkformat
linux: linux:
name: Linux (NodeJS ${{ matrix.node-version }}) name: Linux (NodeJS ${{ matrix.node-version }})
timeout-minutes: 30 timeout-minutes: 30
@@ -112,3 +111,4 @@ jobs:
- name: Test - name: Test
run: | run: |
npm run test npm run test

View File

@@ -66,7 +66,7 @@ jobs:
- name: Install - name: Install
run: | run: |
pip install -e .[tests,dev,embeddings] pip install -e .[tests,dev,embeddings]
pip install tantivy pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
pip install mlx pip install mlx
- name: Doctest - name: Doctest
run: pytest --doctest-modules python/lancedb run: pytest --doctest-modules python/lancedb
@@ -188,6 +188,6 @@ jobs:
run: | run: |
pip install "pydantic<2" pip install "pydantic<2"
pip install -e .[tests] pip install -e .[tests]
pip install tantivy pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
- name: Run tests - name: Run tests
run: pytest -m "not slow" -x -v --durations=30 python/tests run: pytest -m "not slow" -x -v --durations=30 python/tests

View File

@@ -0,0 +1,37 @@
name: LanceDb Cloud Integration Test
on:
workflow_run:
workflows: [Rust]
types:
- completed
env:
LANCEDB_PROJECT: ${{ secrets.LANCEDB_PROJECT }}
LANCEDB_API_KEY: ${{ secrets.LANCEDB_API_KEY }}
LANCEDB_REGION: ${{ secrets.LANCEDB_REGION }}
jobs:
test:
timeout-minutes: 30
runs-on: ubuntu-22.04
defaults:
run:
shell: bash
working-directory: rust
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
lfs: true
- uses: Swatinem/rust-cache@v2
with:
workspaces: rust
- name: Install dependencies
run: |
sudo apt update
sudo apt install -y protobuf-compiler libssl-dev
- name: Build
run: cargo build --all-features
- name: Run Integration test
run: cargo test --tests -- --ignored

View File

@@ -10,9 +10,3 @@ repos:
rev: v0.2.2 rev: v0.2.2
hooks: hooks:
- id: ruff - id: ruff
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.1.0
hooks:
- id: prettier
files: "nodejs/.*"
exclude: nodejs/lancedb/native.d.ts|nodejs/dist/.*

View File

@@ -14,10 +14,10 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
categories = ["database-implementations"] categories = ["database-implementations"]
[workspace.dependencies] [workspace.dependencies]
lance = { "version" = "=0.10.5", "features" = ["dynamodb"] } lance = { "version" = "=0.10.1", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.10.5" } lance-index = { "version" = "=0.10.1" }
lance-linalg = { "version" = "=0.10.5" } lance-linalg = { "version" = "=0.10.1" }
lance-testing = { "version" = "=0.10.5" } lance-testing = { "version" = "=0.10.1" }
# Note that this one does not include pyarrow # Note that this one does not include pyarrow
arrow = { version = "50.0", optional = false } arrow = { version = "50.0", optional = false }
arrow-array = "50.0" arrow-array = "50.0"
@@ -28,14 +28,13 @@ arrow-schema = "50.0"
arrow-arith = "50.0" arrow-arith = "50.0"
arrow-cast = "50.0" arrow-cast = "50.0"
async-trait = "0" async-trait = "0"
chrono = "0.4.35" chrono = "0.4.23"
half = { "version" = "=2.3.1", default-features = false, features = [ half = { "version" = "=2.3.1", default-features = false, features = [
"num-traits", "num-traits",
] } ] }
futures = "0" futures = "0"
log = "0.4" log = "0.4"
object_store = "0.9.0" object_store = "0.9.0"
pin-project = "1.0.7"
snafu = "0.7.4" snafu = "0.7.4"
url = "2" url = "2"
num-traits = "0.2" num-traits = "0.2"

View File

@@ -27,6 +27,7 @@ theme:
- content.tabs.link - content.tabs.link
- content.action.edit - content.action.edit
- toc.follow - toc.follow
# - toc.integrate
- navigation.top - navigation.top
- navigation.tabs - navigation.tabs
- navigation.tabs.sticky - navigation.tabs.sticky
@@ -139,14 +140,12 @@ nav:
- Serverless Website Chatbot: examples/serverless_website_chatbot.md - Serverless Website Chatbot: examples/serverless_website_chatbot.md
- YouTube Transcript Search: examples/youtube_transcript_bot_with_nodejs.md - YouTube Transcript Search: examples/youtube_transcript_bot_with_nodejs.md
- TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md - TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md
- 🦀 Rust:
- Overview: examples/examples_rust.md
- 🔧 CLI & Config: cli_config.md - 🔧 CLI & Config: cli_config.md
- 💭 FAQs: faq.md - 💭 FAQs: faq.md
- ⚙️ API reference: - ⚙️ API reference:
- 🐍 Python: python/python.md - 🐍 Python: python/python.md
- 👾 JavaScript: javascript/modules.md - 👾 JavaScript: javascript/modules.md
- 🦀 Rust: https://docs.rs/lancedb/latest/lancedb/ - 🦀 Rust: https://docs.rs/vectordb/latest/vectordb/
- ☁️ LanceDB Cloud: - ☁️ LanceDB Cloud:
- Overview: cloud/index.md - Overview: cloud/index.md
- API reference: - API reference:
@@ -190,21 +189,21 @@ nav:
- Pydantic: python/pydantic.md - Pydantic: python/pydantic.md
- Voxel51: integrations/voxel51.md - Voxel51: integrations/voxel51.md
- PromptTools: integrations/prompttools.md - PromptTools: integrations/prompttools.md
- Examples: - Python examples:
- examples/index.md - examples/index.md
- YouTube Transcript Search: notebooks/youtube_transcript_search.ipynb - YouTube Transcript Search: notebooks/youtube_transcript_search.ipynb
- Documentation QA Bot using LangChain: notebooks/code_qa_bot.ipynb - Documentation QA Bot using LangChain: notebooks/code_qa_bot.ipynb
- Multimodal search using CLIP: notebooks/multimodal_search.ipynb - Multimodal search using CLIP: notebooks/multimodal_search.ipynb
- Serverless QA Bot with S3 and Lambda: examples/serverless_lancedb_with_s3_and_lambda.md - Serverless QA Bot with S3 and Lambda: examples/serverless_lancedb_with_s3_and_lambda.md
- Serverless QA Bot with Modal: examples/serverless_qa_bot_with_modal_and_langchain.md - Serverless QA Bot with Modal: examples/serverless_qa_bot_with_modal_and_langchain.md
- YouTube Transcript Search (JS): examples/youtube_transcript_bot_with_nodejs.md - Javascript examples:
- Overview: examples/examples_js.md
- YouTube Transcript Search: examples/youtube_transcript_bot_with_nodejs.md
- Serverless Chatbot from any website: examples/serverless_website_chatbot.md - Serverless Chatbot from any website: examples/serverless_website_chatbot.md
- TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md - TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md
- API reference: - API reference:
- Overview: api_reference.md
- Python: python/python.md - Python: python/python.md
- Javascript: javascript/modules.md - Javascript: javascript/modules.md
- Rust: https://docs.rs/lancedb/latest/lancedb/index.html
- LanceDB Cloud: - LanceDB Cloud:
- Overview: cloud/index.md - Overview: cloud/index.md
- API reference: - API reference:

View File

@@ -7,11 +7,20 @@ for brute-force scanning of the entire vector space.
A vector index is faster but less accurate than exhaustive search (kNN or flat search). A vector index is faster but less accurate than exhaustive search (kNN or flat search).
LanceDB provides many parameters to fine-tune the index's size, the speed of queries, and the accuracy of results. LanceDB provides many parameters to fine-tune the index's size, the speed of queries, and the accuracy of results.
## Disk-based Index Currently, LanceDB does _not_ automatically create the ANN index.
LanceDB has optimized code for kNN as well. For many use-cases, datasets under 100K vectors won't require index creation at all.
If you can live with <100ms latency, skipping index creation is a simpler workflow while guaranteeing 100% recall.
Lance provides an `IVF_PQ` disk-based index. It uses **Inverted File Index (IVF)** to first divide In the future we will look to automatically create and configure the ANN index as data comes in.
the dataset into `N` partitions, and then applies **Product Quantization** to compress vectors in each partition.
See the [indexing](concepts/index_ivfpq.md) concepts guide for more information on how this works. ## Types of Index
Lance can support multiple index types, the most widely used one is `IVF_PQ`.
- `IVF_PQ`: use **Inverted File Index (IVF)** to first divide the dataset into `N` partitions,
and then use **Product Quantization** to compress vectors in each partition.
- `DiskANN` (**Experimental**): organize the vector as a on-disk graph, where the vertices approximately
represent the nearest neighbors of each vector.
## Creating an IVF_PQ Index ## Creating an IVF_PQ Index
@@ -46,34 +55,12 @@ Lance supports `IVF_PQ` index type by default.
--8<-- "docs/src/ann_indexes.ts:ingest" --8<-- "docs/src/ann_indexes.ts:ingest"
``` ```
=== "Rust" - **metric** (default: "L2"): The distance metric to use. By default it uses euclidean distance "`L2`".
```rust
--8<-- "rust/lancedb/examples/ivf_pq.rs:create_index"
```
IVF_PQ index parameters are more fully defined in the [crate docs](https://docs.rs/lancedb/latest/lancedb/index/vector/struct.IvfPqIndexBuilder.html).
The following IVF_PQ paramters can be specified:
- **distance_type**: The distance metric to use. By default it uses euclidean distance "`L2`".
We also support "cosine" and "dot" distance as well. We also support "cosine" and "dot" distance as well.
- **num_partitions**: The number of partitions in the index. The default is the square root - **num_partitions** (default: 256): The number of partitions of the index.
of the number of rows. - **num_sub_vectors** (default: 96): The number of sub-vectors (M) that will be created during Product Quantization (PQ).
For D dimensional vector, it will be divided into `M` of `D/M` sub-vectors, each of which is presented by
!!! note a single PQ code.
In the synchronous python SDK and node's `vectordb` the default is 256. This default has
changed in the asynchronous python SDK and node's `lancedb`.
- **num_sub_vectors**: The number of sub-vectors (M) that will be created during Product Quantization (PQ).
For D dimensional vector, it will be divided into `M` subvectors with dimension `D/M`, each of which is replaced by
a single PQ code. The default is the dimension of the vector divided by 16.
!!! note
In the synchronous python SDK and node's `vectordb` the default is currently 96. This default has
changed in the asynchronous python SDK and node's `lancedb`.
<figure markdown> <figure markdown>
![IVF PQ](./assets/ivf_pq.png) ![IVF PQ](./assets/ivf_pq.png)
@@ -101,7 +88,7 @@ You can specify the GPU device to train IVF partitions via
) )
``` ```
=== "MacOS" === "Macos"
<!-- skip-test --> <!-- skip-test -->
```python ```python
@@ -113,7 +100,7 @@ You can specify the GPU device to train IVF partitions via
) )
``` ```
Troubleshooting: Trouble shootings:
If you see `AssertionError: Torch not compiled with CUDA enabled`, you need to [install If you see `AssertionError: Torch not compiled with CUDA enabled`, you need to [install
PyTorch with CUDA support](https://pytorch.org/get-started/locally/). PyTorch with CUDA support](https://pytorch.org/get-started/locally/).
@@ -156,14 +143,6 @@ There are a couple of parameters that can be used to fine-tune the search:
--8<-- "docs/src/ann_indexes.ts:search1" --8<-- "docs/src/ann_indexes.ts:search1"
``` ```
=== "Rust"
```rust
--8<-- "rust/lancedb/examples/ivf_pq.rs:search1"
```
Vector search options are more fully defined in the [crate docs](https://docs.rs/lancedb/latest/lancedb/query/struct.Query.html#method.nearest_to).
The search will return the data requested in addition to the distance of each item. The search will return the data requested in addition to the distance of each item.
### Filtering (where clause) ### Filtering (where clause)
@@ -208,21 +187,13 @@ You can select the columns returned by the query using a select clause.
## FAQ ## FAQ
### Why do I need to manually create an index?
Currently, LanceDB does _not_ automatically create the ANN index.
LanceDB is well-optimized for kNN (exhaustive search) via a disk-based index. For many use-cases,
datasets of the order of ~100K vectors don't require index creation. If you can live with up to
100ms latency, skipping index creation is a simpler workflow while guaranteeing 100% recall.
### When is it necessary to create an ANN vector index? ### When is it necessary to create an ANN vector index?
`LanceDB` comes out-of-the-box with highly optimized SIMD code for computing vector similarity. `LanceDB` has manually-tuned SIMD code for computing vector distances.
In our benchmarks, computing distances for 100K pairs of 1K dimension vectors takes **less than 20ms**. In our benchmarks, computing 100K pairs of 1K dimension vectors takes **less than 20ms**.
We observe that for small datasets (~100K rows) or for applications that can accept 100ms latency, For small datasets (< 100K rows) or applications that can accept 100ms latency, vector indices are usually not necessary.
vector indices are usually not necessary.
For large-scale or higher dimension vectors, it can beneficial to create vector index for performance. For large-scale or higher dimension vectors, it is beneficial to create vector index.
### How big is my index, and how many memory will it take? ### How big is my index, and how many memory will it take?

View File

@@ -1,7 +0,0 @@
# API Reference
The API reference for the LanceDB client SDKs are available at the following locations:
- [Python](python/python.md)
- [JavaScript](javascript/modules.md)
- [Rust](https://docs.rs/lancedb/latest/lancedb/index.html)

View File

@@ -3,7 +3,7 @@
!!! info "LanceDB can be run in a number of ways:" !!! info "LanceDB can be run in a number of ways:"
* Embedded within an existing backend (like your Django, Flask, Node.js or FastAPI application) * Embedded within an existing backend (like your Django, Flask, Node.js or FastAPI application)
* Directly from a client application like a Jupyter notebook for analytical workloads * Connected to directly from a client application like a Jupyter notebook for analytical workloads
* Deployed as a remote serverless database * Deployed as a remote serverless database
![](assets/lancedb_embedded_explanation.png) ![](assets/lancedb_embedded_explanation.png)
@@ -24,11 +24,13 @@
=== "Rust" === "Rust"
!!! warning "Rust SDK is experimental, might introduce breaking changes in the near future"
```shell ```shell
cargo add lancedb cargo add vectordb
``` ```
!!! info "To use the lancedb create, you first need to install protobuf." !!! info "To use the vectordb create, you first need to install protobuf."
=== "macOS" === "macOS"
@@ -42,9 +44,9 @@
sudo apt install -y protobuf-compiler libssl-dev sudo apt install -y protobuf-compiler libssl-dev
``` ```
!!! info "Please also make sure you're using the same version of Arrow as in the [lancedb crate](https://github.com/lancedb/lancedb/blob/main/Cargo.toml)" !!! info "Please also make sure you're using the same version of Arrow as in the [vectordb crate](https://github.com/lancedb/lancedb/blob/main/Cargo.toml)"
## Connect to a database ## How to connect to a database
=== "Python" === "Python"
@@ -67,23 +69,17 @@
```rust ```rust
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
--8<-- "rust/lancedb/examples/simple.rs:connect" --8<-- "rust/vectordb/examples/simple.rs:connect"
} }
``` ```
!!! info "See [examples/simple.rs](https://github.com/lancedb/lancedb/tree/main/rust/lancedb/examples/simple.rs) for a full working example." !!! info "See [examples/simple.rs](https://github.com/lancedb/lancedb/tree/main/rust/vectordb/examples/simple.rs) for a full working example."
LanceDB will create the directory if it doesn't exist (including parent directories). LanceDB will create the directory if it doesn't exist (including parent directories).
If you need a reminder of the uri, you can call `db.uri()`. If you need a reminder of the uri, you can call `db.uri()`.
## Create a table ## How to create a table
### Create a table from initial data
If you have data to insert into the table at creation time, you can simultaneously create a
table and insert the data into it. The schema of the data will be used as the schema of the
table.
=== "Python" === "Python"
@@ -119,27 +115,20 @@ table.
=== "Rust" === "Rust"
```rust ```rust
--8<-- "rust/lancedb/examples/simple.rs:create_table" use arrow_schema::{DataType, Schema, Field};
use arrow_array::{RecordBatch, RecordBatchIterator};
--8<-- "rust/vectordb/examples/simple.rs:create_table"
``` ```
If the table already exists, LanceDB will raise an error by default. See If the table already exists, LanceDB will raise an error by default.
[the mode option](https://docs.rs/lancedb/latest/lancedb/connection/struct.CreateTableBuilder.html#method.mode)
for details on how to overwrite (or open) existing tables instead.
!!! Providing table records in Rust !!! info "Under the hood, LanceDB is converting the input data into an Apache Arrow table and persisting it to disk in [Lance format](https://www.github.com/lancedb/lance)."
The Rust SDK currently expects data to be provided as an Arrow ### Creating an empty table
[RecordBatchReader](https://docs.rs/arrow-array/latest/arrow_array/trait.RecordBatchReader.html)
Support for additional formats (such as serde or polars) is on the roadmap.
!!! info "Under the hood, LanceDB reads in the Apache Arrow data and persists it to disk using the [Lance format](https://www.github.com/lancedb/lance)."
### Create an empty table
Sometimes you may not have the data to insert into the table at creation time. Sometimes you may not have the data to insert into the table at creation time.
In this case, you can create an empty table and specify the schema, so that you can add In this case, you can create an empty table and specify the schema.
data to the table at a later time (as long as it conforms to the schema). This is
similar to a `CREATE TABLE` statement in SQL.
=== "Python" === "Python"
@@ -158,12 +147,12 @@ similar to a `CREATE TABLE` statement in SQL.
=== "Rust" === "Rust"
```rust ```rust
--8<-- "rust/lancedb/examples/simple.rs:create_empty_table" --8<-- "rust/vectordb/examples/simple.rs:create_empty_table"
``` ```
## Open an existing table ## How to open an existing table
Once created, you can open a table as follows: Once created, you can open a table using the following code:
=== "Python" === "Python"
@@ -180,7 +169,7 @@ Once created, you can open a table as follows:
=== "Rust" === "Rust"
```rust ```rust
--8<-- "rust/lancedb/examples/simple.rs:open_existing_tbl" --8<-- "rust/vectordb/examples/simple.rs:open_with_existing_file"
``` ```
If you forget the name of your table, you can always get a listing of all table names: If you forget the name of your table, you can always get a listing of all table names:
@@ -200,12 +189,12 @@ If you forget the name of your table, you can always get a listing of all table
=== "Rust" === "Rust"
```rust ```rust
--8<-- "rust/lancedb/examples/simple.rs:list_names" --8<-- "rust/vectordb/examples/simple.rs:list_names"
``` ```
## Add data to a table ## How to add data to a table
After a table has been created, you can always add more data to it as follows: After a table has been created, you can always add more data to it using
=== "Python" === "Python"
@@ -230,12 +219,12 @@ After a table has been created, you can always add more data to it as follows:
=== "Rust" === "Rust"
```rust ```rust
--8<-- "rust/lancedb/examples/simple.rs:add" --8<-- "rust/vectordb/examples/simple.rs:add"
``` ```
## Search for nearest neighbors ## How to search for (approximate) nearest neighbors
Once you've embedded the query, you can find its nearest neighbors as follows: Once you've embedded the query, you can find its nearest neighbors using the following code:
=== "Python" === "Python"
@@ -256,20 +245,11 @@ Once you've embedded the query, you can find its nearest neighbors as follows:
```rust ```rust
use futures::TryStreamExt; use futures::TryStreamExt;
--8<-- "rust/lancedb/examples/simple.rs:search" --8<-- "rust/vectordb/examples/simple.rs:search"
``` ```
!!! Query vectors in Rust
Rust does not yet support automatic execution of embedding functions. You will need to
calculate embeddings yourself. Support for this is on the roadmap and can be tracked at
https://github.com/lancedb/lancedb/issues/994
Query vectors can be provided as Arrow arrays or a Vec/slice of Rust floats.
Support for additional formats (e.g. `polars::series::Series`) is on the roadmap.
By default, LanceDB runs a brute-force scan over dataset to find the K nearest neighbours (KNN). By default, LanceDB runs a brute-force scan over dataset to find the K nearest neighbours (KNN).
For tables with more than 50K vectors, creating an ANN index is recommended to speed up search performance. For tables with more than 50K vectors, creating an ANN index is recommended to speed up search performance.
LanceDB allows you to create an ANN index on a table as follows:
=== "Python" === "Python"
@@ -286,17 +266,12 @@ LanceDB allows you to create an ANN index on a table as follows:
=== "Rust" === "Rust"
```rust ```rust
--8<-- "rust/lancedb/examples/simple.rs:create_index" --8<-- "rust/vectordb/examples/simple.rs:create_index"
``` ```
!!! note "Why do I need to create an index manually?" Check [Approximate Nearest Neighbor (ANN) Indexes](/ann_indices.md) section for more details.
LanceDB does not automatically create the ANN index for two reasons. The first is that it's optimized
for really fast retrievals via a disk-based index, and the second is that data and query workloads can
be very diverse, so there's no one-size-fits-all index configuration. LanceDB provides many parameters
to fine-tune index size, query latency and accuracy. See the section on
[ANN indexes](ann_indexes.md) for more details.
## Delete rows from a table ## How to delete rows from a table
Use the `delete()` method on tables to delete rows from a table. To choose Use the `delete()` method on tables to delete rows from a table. To choose
which rows to delete, provide a filter that matches on the metadata columns. which rows to delete, provide a filter that matches on the metadata columns.
@@ -317,13 +292,12 @@ This can delete any number of rows that match the filter.
=== "Rust" === "Rust"
```rust ```rust
--8<-- "rust/lancedb/examples/simple.rs:delete" --8<-- "rust/vectordb/examples/simple.rs:delete"
``` ```
The deletion predicate is a SQL expression that supports the same expressions The deletion predicate is a SQL expression that supports the same expressions
as the `where()` clause (`only_if()` in Rust) on a search. They can be as as the `where()` clause on a search. They can be as simple or complex as needed.
simple or complex as needed. To see what expressions are supported, see the To see what expressions are supported, see the [SQL filters](sql.md) section.
[SQL filters](sql.md) section.
=== "Python" === "Python"
@@ -333,11 +307,7 @@ simple or complex as needed. To see what expressions are supported, see the
Read more: [vectordb.Table.delete](javascript/interfaces/Table.md#delete) Read more: [vectordb.Table.delete](javascript/interfaces/Table.md#delete)
=== "Rust" ## How to remove a table
Read more: [lancedb::Table::delete](https://docs.rs/lancedb/latest/lancedb/table/struct.Table.html#method.delete)
## Drop a table
Use the `drop_table()` method on the database to remove a table. Use the `drop_table()` method on the database to remove a table.
@@ -363,7 +333,7 @@ Use the `drop_table()` method on the database to remove a table.
=== "Rust" === "Rust"
```rust ```rust
--8<-- "rust/lancedb/examples/simple.rs:drop_table" --8<-- "rust/vectordb/examples/simple.rs:drop_table"
``` ```
!!! note "Bundling `vectordb` apps with Webpack" !!! note "Bundling `vectordb` apps with Webpack"

View File

@@ -31,7 +31,7 @@ As an example, consider starting with 128-dimensional vector consisting of 32-bi
While PQ helps with reducing the size of the index, IVF primarily addresses search performance. The primary purpose of an inverted file index is to facilitate rapid and effective nearest neighbor search by narrowing down the search space. While PQ helps with reducing the size of the index, IVF primarily addresses search performance. The primary purpose of an inverted file index is to facilitate rapid and effective nearest neighbor search by narrowing down the search space.
In IVF, the PQ vector space is divided into *Voronoi cells*, which are essentially partitions that consist of all the points in the space that are within a threshold distance of the given region's seed point. These seed points are initialized by running K-means over the stored vectors. The centroids of K-means turn into the seed points which then each define a region. These regions are then are used to create an inverted index that correlates each centroid with a list of vectors in the space, allowing a search to be restricted to just a subset of vectors in the index. In IVF, the PQ vector space is divided into *Voronoi cells*, which are essentially partitions that consist of all the points in the space that are within a threshold distance of the given region's seed point. These seed points are used to create an inverted index that correlates each centroid with a list of vectors in the space, allowing a search to be restricted to just a subset of vectors in the index.
![](../assets/ivfpq_ivf_desc.webp) ![](../assets/ivfpq_ivf_desc.webp)
@@ -81,4 +81,24 @@ The above query will perform a search on the table `tbl` using the given query v
* `to_pandas()`: Convert the results to a pandas DataFrame * `to_pandas()`: Convert the results to a pandas DataFrame
And there you have it! You now understand what an IVF-PQ index is, and how to create and query it in LanceDB. And there you have it! You now understand what an IVF-PQ index is, and how to create and query it in LanceDB.
To see how to create an IVF-PQ index in LanceDB, take a look at the [ANN indexes](../ann_indexes.md) section.
## FAQ
### When is it necessary to create a vector index?
LanceDB has manually-tuned SIMD code for computing vector distances. In our benchmarks, computing 100K pairs of 1K dimension vectors takes **<20ms**. For small datasets (<100K rows) or applications that can accept up to 100ms latency, vector indices are usually not necessary.
For large-scale or higher dimension vectors, it is beneficial to create vector index.
### How big is my index, and how much memory will it take?
In LanceDB, all vector indices are disk-based, meaning that when responding to a vector query, only the relevant pages from the index file are loaded from disk and cached in memory. Additionally, each sub-vector is usually encoded into 1 byte PQ code.
For example, with 1024-dimension vectors, if we choose `num_sub_vectors = 64`, each sub-vector has `1024 / 64 = 16` float32 numbers. Product quantization can lead to approximately `16 * sizeof(float32) / 1 = 64` times of space reduction.
### How to choose `num_partitions` and `num_sub_vectors` for IVF_PQ index?
`num_partitions` is used to decide how many partitions the first level IVF index uses. Higher number of partitions could lead to more efficient I/O during queries and better accuracy, but it takes much more time to train. On SIFT-1M dataset, our benchmark shows that keeping each partition 1K-4K rows lead to a good latency/recall.
`num_sub_vectors` specifies how many PQ short codes to generate on each vector. Because PQ is a lossy compression of the original vector, a higher `num_sub_vectors` usually results in less space distortion, and thus yields better accuracy. However, a higher `num_sub_vectors` also causes heavier I/O and more PQ computation, and thus, higher latency. `dimension / num_sub_vectors` should be a multiple of 8 for optimum SIMD efficiency.

View File

@@ -47,7 +47,6 @@ LanceDB registers the OpenAI embeddings function in the registry by default, as
| Parameter | Type | Default Value | Description | | Parameter | Type | Default Value | Description |
|---|---|---|---| |---|---|---|---|
| `name` | `str` | `"text-embedding-ada-002"` | The name of the model. | | `name` | `str` | `"text-embedding-ada-002"` | The name of the model. |
| `dim` | `int` | Model default | For OpenAI's newer text-embedding-3 model, we can specify a dimensionality that is smaller than the 1536 size. This feature supports it |
```python ```python
@@ -176,8 +175,7 @@ Supported Embedding modelIDs are:
* `cohere.embed-english-v3` * `cohere.embed-english-v3`
* `cohere.embed-multilingual-v3` * `cohere.embed-multilingual-v3`
Supported parameters (to be passed in `create` method) are: Supported paramters (to be passed in `create` method) are:
| Parameter | Type | Default Value | Description | | Parameter | Type | Default Value | Description |
|---|---|---|---| |---|---|---|---|
| **name** | str | "amazon.titan-embed-text-v1" | The model ID of the bedrock model to use. Supported base models for Text Embeddings: amazon.titan-embed-text-v1, cohere.embed-english-v3, cohere.embed-multilingual-v3 | | **name** | str | "amazon.titan-embed-text-v1" | The model ID of the bedrock model to use. Supported base models for Text Embeddings: amazon.titan-embed-text-v1, cohere.embed-english-v3, cohere.embed-multilingual-v3 |
@@ -224,6 +222,7 @@ This embedding function supports ingesting images as both bytes and urls. You ca
!!! info !!! info
LanceDB supports ingesting images directly from accessible links. LanceDB supports ingesting images directly from accessible links.
```python ```python
db = lancedb.connect(tmp_path) db = lancedb.connect(tmp_path)
@@ -289,67 +288,4 @@ print(actual.label)
``` ```
### Imagebind embeddings
We have support for [imagebind](https://github.com/facebookresearch/ImageBind) model embeddings. You can download our version of the packaged model via - `pip install imagebind-packaged==0.1.2`.
This function is registered as `imagebind` and supports Audio, Video and Text modalities(extending to Thermal,Depth,IMU data):
| Parameter | Type | Default Value | Description |
|---|---|---|---|
| `name` | `str` | `"imagebind_huge"` | Name of the model. |
| `device` | `str` | `"cpu"` | The device to run the model on. Can be `"cpu"` or `"gpu"`. |
| `normalize` | `bool` | `False` | set to `True` to normalize your inputs before model ingestion. |
Below is an example demonstrating how the API works:
```python
db = lancedb.connect(tmp_path)
registry = EmbeddingFunctionRegistry.get_instance()
func = registry.get("imagebind").create()
class ImageBindModel(LanceModel):
text: str
image_uri: str = func.SourceField()
audio_path: str
vector: Vector(func.ndims()) = func.VectorField()
# add locally accessible image paths
text_list=["A dog.", "A car", "A bird"]
image_paths=[".assets/dog_image.jpg", ".assets/car_image.jpg", ".assets/bird_image.jpg"]
audio_paths=[".assets/dog_audio.wav", ".assets/car_audio.wav", ".assets/bird_audio.wav"]
# Load data
inputs = [
{"text": a, "audio_path": b, "image_uri": c}
for a, b, c in zip(text_list, audio_paths, image_paths)
]
#create table and add data
table = db.create_table("img_bind", schema=ImageBindModel)
table.add(inputs)
```
Now, we can search using any modality:
#### image search
```python
query_image = "./assets/dog_image2.jpg" #download an image and enter that path here
actual = table.search(query_image).limit(1).to_pydantic(ImageBindModel)[0]
print(actual.text == "dog")
```
#### audio search
```python
query_audio = "./assets/car_audio2.wav" #download an audio clip and enter path here
actual = table.search(query_audio).limit(1).to_pydantic(ImageBindModel)[0]
print(actual.text == "car")
```
#### Text search
You can add any input query and fetch the result as follows:
```python
query = "an animal which flies and tweets"
actual = table.search(query).limit(1).to_pydantic(ImageBindModel)[0]
print(actual.text == "bird")
```
If you have any questions about the embeddings API, supported models, or see a relevant model missing, please raise an issue [on GitHub](https://github.com/lancedb/lancedb/issues). If you have any questions about the embeddings API, supported models, or see a relevant model missing, please raise an issue [on GitHub](https://github.com/lancedb/lancedb/issues).

View File

@@ -1,3 +0,0 @@
# Examples: Rust
Our Rust SDK is now stable. Examples are coming soon.

View File

@@ -43,7 +43,7 @@ pip install lancedb
We also need to install a specific commit of `tantivy`, a dependency of the LanceDB full text search engine we will use later in this guide: We also need to install a specific commit of `tantivy`, a dependency of the LanceDB full text search engine we will use later in this guide:
``` ```
pip install tantivy pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
``` ```
Create a new Python file and add the following code: Create a new Python file and add the following code:

View File

@@ -2,11 +2,10 @@
## Recipes and example code ## Recipes and example code
LanceDB provides language APIs, allowing you to embed a database in your language of choice. LanceDB provides language APIs, allowing you to embed a database in your language of choice. We currently provide Python and Javascript APIs, with the Rust API and examples actively being worked on and will be available soon.
* 🐍 [Python](examples_python.md) examples * 🐍 [Python](examples_python.md) examples
* 👾 [JavaScript](examples_js.md) examples * 👾 [JavaScript](exampled_js.md) examples
* 🦀 Rust examples (coming soon)
## Applications powered by LanceDB ## Applications powered by LanceDB

View File

@@ -16,7 +16,7 @@ As we mention in our talk titled “[Lance, a modern columnar data format](https
### Why build in Rust? 🦀 ### Why build in Rust? 🦀
We believe that the Rust ecosystem has attained mainstream maturity and that Rust will form the underpinnings of large parts of the data and ML landscape in a few years. Performance, latency and reliability are paramount to a vector DB, and building in Rust allows us to iterate and release updates more rapidly due to Rusts safety guarantees. Both Lance (the data format) and LanceDB (the database) are written entirely in Rust. We also provide Python, JavaScript, and Rust client libraries to interact with the database. We believe that the Rust ecosystem has attained mainstream maturity and that Rust will form the underpinnings of large parts of the data and ML landscape in a few years. Performance, latency and reliability are paramount to a vector DB, and building in Rust allows us to iterate and release updates more rapidly due to Rusts safety guarantees. Both Lance (the data format) and LanceDB (the database) are written entirely in Rust. We also provide Python and JavaScript client libraries to interact with the database. Our Rust API is a little rough around the edges right now, but is fast becoming on par with the Python and JS APIs.
### What is the difference between LanceDB OSS and LanceDB Cloud? ### What is the difference between LanceDB OSS and LanceDB Cloud?
@@ -40,11 +40,11 @@ LanceDB and its underlying data format, Lance, are built to scale to really larg
No. LanceDB is blazing fast (due to its disk-based index) for even brute force kNN search, within reason. In our benchmarks, computing 100K pairs of 1000-dimension vectors takes less than 20ms. For small datasets of ~100K records or applications that can accept ~100ms latency, an ANN index is usually not necessary. No. LanceDB is blazing fast (due to its disk-based index) for even brute force kNN search, within reason. In our benchmarks, computing 100K pairs of 1000-dimension vectors takes less than 20ms. For small datasets of ~100K records or applications that can accept ~100ms latency, an ANN index is usually not necessary.
For large-scale (>1M) or higher dimension vectors, it is beneficial to create an ANN index. See the [ANN indexes](ann_indexes.md) section for more details. For large-scale (>1M) or higher dimension vectors, it is beneficial to create an ANN index.
### Does LanceDB support full-text search? ### Does LanceDB support full-text search?
Yes, LanceDB supports full-text search (FTS) via [Tantivy](https://github.com/quickwit-oss/tantivy). Our current FTS integration is Python-only, and our goal is to push it down to the Rust level in future versions to enable much more powerful search capabilities available to our Python, JavaScript and Rust clients. Follow along in the [Github issue](https://github.com/lancedb/lance/issues/1195) Yes, LanceDB supports full-text search (FTS) via [Tantivy](https://github.com/quickwit-oss/tantivy). Our current FTS integration is Python-only, and our goal is to push it down to the Rust level in future versions to enable much more powerful search capabilities available to our Python, JavaScript and Rust clients.
### How can I speed up data inserts? ### How can I speed up data inserts?

View File

@@ -1,6 +1,6 @@
# Full-text search # Full-text search
LanceDB provides support for full-text search via [Tantivy](https://github.com/quickwit-oss/tantivy) (currently Python only), allowing you to incorporate keyword-based search (based on BM25) in your retrieval solutions. Our goal is to push the FTS integration down to the Rust level in the future, so that it's available for Rust and JavaScript users as well. Follow along at [this Github issue](https://github.com/lancedb/lance/issues/1195) LanceDB provides support for full-text search via [Tantivy](https://github.com/quickwit-oss/tantivy) (currently Python only), allowing you to incorporate keyword-based search (based on BM25) in your retrieval solutions. Our goal is to push the FTS integration down to the Rust level in the future, so that it's available for JavaScript users as well.
A hybrid search solution combining vector and full-text search is also on the way. A hybrid search solution combining vector and full-text search is also on the way.
@@ -75,70 +75,21 @@ applied on top of the full text search results. This can be invoked via the fami
table.search("puppy").limit(10).where("meta='foo'").to_list() table.search("puppy").limit(10).where("meta='foo'").to_list()
``` ```
## Sorting ## Syntax
You can pre-sort the documents by specifying `ordering_field_names` when For full-text search you can perform either a phrase query like "the old man and the sea",
creating the full-text search index. Once pre-sorted, you can then specify or a structured search query like "(Old AND Man) AND Sea".
`ordering_field_name` while searching to return results sorted by the given Double quotes are used to disambiguate.
field. For example,
``` For example:
table.create_fts_index(["text_field"], ordering_field_names=["sort_by_field"])
(table.search("terms", ordering_field_name="sort_by_field") If you intended "they could have been dogs OR cats" as a phrase query, this actually
.limit(20) raises a syntax error since `OR` is a recognized operator. If you make `or` lower case,
.to_list()) this avoids the syntax error. However, it is cumbersome to have to remember what will
``` conflict with the query syntax. Instead, if you search using
`table.search('"they could have been dogs OR cats"')`, then the syntax checker avoids
checking inside the quotes.
!!! note
If you wish to specify an ordering field at query time, you must also
have specified it during indexing time. Otherwise at query time, an
error will be raised that looks like `ValueError: The field does not exist: xxx`
!!! note
The fields to sort on must be of typed unsigned integer, or else you will see
an error during indexing that looks like
`TypeError: argument 'value': 'float' object cannot be interpreted as an integer`.
!!! note
You can specify multiple fields for ordering at indexing time.
But at query time only one ordering field is supported.
## Phrase queries vs. terms queries
For full-text search you can specify either a **phrase** query like `"the old man and the sea"`,
or a **terms** search query like `"(Old AND Man) AND Sea"`. For more details on the terms
query syntax, see Tantivy's [query parser rules](https://docs.rs/tantivy/latest/tantivy/query/struct.QueryParser.html).
!!! tip "Note"
The query parser will raise an exception on queries that are ambiguous. For example, in the query `they could have been dogs OR cats`, `OR` is capitalized so it's considered a keyword query operator. But it's ambiguous how the left part should be treated. So if you submit this search query as is, you'll get `Syntax Error: they could have been dogs OR cats`.
```py
# This raises a syntax error
table.search("they could have been dogs OR cats")
```
On the other hand, lowercasing `OR` to `or` will work, because there are no capitalized logical operators and
the query is treated as a phrase query.
```py
# This works!
table.search("they could have been dogs or cats")
```
It can be cumbersome to have to remember what will cause a syntax error depending on the type of
query you want to perform. To make this simpler, when you want to perform a phrase query, you can
enforce it in one of two ways:
1. Place the double-quoted query inside single quotes. For example, `table.search('"they could have been dogs OR cats"')` is treated as
a phrase query.
2. Explicitly declare the `phrase_query()` method. This is useful when you have a phrase query that
itself contains double quotes. For example, `table.search('the cats OR dogs were not really "pets" at all').phrase_query()`
is treated as a phrase query.
In general, a query that's declared as a phrase query will be wrapped in double quotes during parsing, with nested
double quotes replaced by single quotes.
## Configurations ## Configurations
@@ -161,3 +112,4 @@ table.create_fts_index(["text1", "text2"], writer_heap_size=heap, replace=True)
2. We currently only support local filesystem paths for the FTS index. 2. We currently only support local filesystem paths for the FTS index.
This is a tantivy limitation. We've implemented an object store plugin This is a tantivy limitation. We've implemented an object store plugin
but there's no way in tantivy-py to specify to use it. but there's no way in tantivy-py to specify to use it.

View File

@@ -168,24 +168,24 @@ This guide will show how to create tables, insert data into them, and update the
--8<-- "docs/src/basic_legacy.ts:create_f16_table" --8<-- "docs/src/basic_legacy.ts:create_f16_table"
``` ```
### From Pydantic Models ### From Pydantic Models
When you create an empty table without data, you must specify the table schema. When you create an empty table without data, you must specify the table schema.
LanceDB supports creating tables by specifying a PyArrow schema or a specialized LanceDB supports creating tables by specifying a PyArrow schema or a specialized
Pydantic model called `LanceModel`. Pydantic model called `LanceModel`.
For example, the following Content model specifies a table with 5 columns: For example, the following Content model specifies a table with 5 columns:
`movie_id`, `vector`, `genres`, `title`, and `imdb_id`. When you create a table, you can `movie_id`, `vector`, `genres`, `title`, and `imdb_id`. When you create a table, you can
pass the class as the value of the `schema` parameter to `create_table`. pass the class as the value of the `schema` parameter to `create_table`.
The `vector` column is a `Vector` type, which is a specialized Pydantic type that The `vector` column is a `Vector` type, which is a specialized Pydantic type that
can be configured with the vector dimensions. It is also important to note that can be configured with the vector dimensions. It is also important to note that
LanceDB only understands subclasses of `lancedb.pydantic.LanceModel` LanceDB only understands subclasses of `lancedb.pydantic.LanceModel`
(which itself derives from `pydantic.BaseModel`). (which itself derives from `pydantic.BaseModel`).
```python ```python
from lancedb.pydantic import Vector, LanceModel from lancedb.pydantic import Vector, LanceModel
class Content(LanceModel): class Content(LanceModel):
movie_id: int movie_id: int
vector: Vector(128) vector: Vector(128)
genres: str genres: str
@@ -196,65 +196,65 @@ class Content(LanceModel):
def imdb_url(self) -> str: def imdb_url(self) -> str:
return f"https://www.imdb.com/title/tt{self.imdb_id}" return f"https://www.imdb.com/title/tt{self.imdb_id}"
import pyarrow as pa import pyarrow as pa
db = lancedb.connect("~/.lancedb") db = lancedb.connect("~/.lancedb")
table_name = "movielens_small" table_name = "movielens_small"
table = db.create_table(table_name, schema=Content) table = db.create_table(table_name, schema=Content)
``` ```
#### Nested schemas #### Nested schemas
Sometimes your data model may contain nested objects. Sometimes your data model may contain nested objects.
For example, you may want to store the document string For example, you may want to store the document string
and the document soure name as a nested Document object: and the document soure name as a nested Document object:
```python ```python
class Document(BaseModel): class Document(BaseModel):
content: str content: str
source: str source: str
``` ```
This can be used as the type of a LanceDB table column: This can be used as the type of a LanceDB table column:
```python ```python
class NestedSchema(LanceModel): class NestedSchema(LanceModel):
id: str id: str
vector: Vector(1536) vector: Vector(1536)
document: Document document: Document
tbl = db.create_table("nested_table", schema=NestedSchema, mode="overwrite") tbl = db.create_table("nested_table", schema=NestedSchema, mode="overwrite")
``` ```
This creates a struct column called "document" that has two subfields This creates a struct column called "document" that has two subfields
called "content" and "source": called "content" and "source":
``` ```
In [28]: tbl.schema In [28]: tbl.schema
Out[28]: Out[28]:
id: string not null id: string not null
vector: fixed_size_list<item: float>[1536] not null vector: fixed_size_list<item: float>[1536] not null
child 0, item: float child 0, item: float
document: struct<content: string not null, source: string not null> not null document: struct<content: string not null, source: string not null> not null
child 0, content: string not null child 0, content: string not null
child 1, source: string not null child 1, source: string not null
``` ```
#### Validators #### Validators
Note that neither Pydantic nor PyArrow automatically validates that input data Note that neither Pydantic nor PyArrow automatically validates that input data
is of the correct timezone, but this is easy to add as a custom field validator: is of the correct timezone, but this is easy to add as a custom field validator:
```python ```python
from datetime import datetime from datetime import datetime
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
from lancedb.pydantic import LanceModel from lancedb.pydantic import LanceModel
from pydantic import Field, field_validator, ValidationError, ValidationInfo from pydantic import Field, field_validator, ValidationError, ValidationInfo
tzname = "America/New_York" tzname = "America/New_York"
tz = ZoneInfo(tzname) tz = ZoneInfo(tzname)
class TestModel(LanceModel): class TestModel(LanceModel):
dt_with_tz: datetime = Field(json_schema_extra={"tz": tzname}) dt_with_tz: datetime = Field(json_schema_extra={"tz": tzname})
@field_validator('dt_with_tz') @field_validator('dt_with_tz')
@@ -263,35 +263,35 @@ class TestModel(LanceModel):
assert dt.tzinfo == tz assert dt.tzinfo == tz
return dt return dt
ok = TestModel(dt_with_tz=datetime.now(tz)) ok = TestModel(dt_with_tz=datetime.now(tz))
try: try:
TestModel(dt_with_tz=datetime.now(ZoneInfo("Asia/Shanghai"))) TestModel(dt_with_tz=datetime.now(ZoneInfo("Asia/Shanghai")))
assert 0 == 1, "this should raise ValidationError" assert 0 == 1, "this should raise ValidationError"
except ValidationError: except ValidationError:
print("A ValidationError was raised.") print("A ValidationError was raised.")
pass pass
``` ```
When you run this code it should print "A ValidationError was raised." When you run this code it should print "A ValidationError was raised."
#### Pydantic custom types #### Pydantic custom types
LanceDB does NOT yet support converting pydantic custom types. If this is something you need, LanceDB does NOT yet support converting pydantic custom types. If this is something you need,
please file a feature request on the [LanceDB Github repo](https://github.com/lancedb/lancedb/issues/new). please file a feature request on the [LanceDB Github repo](https://github.com/lancedb/lancedb/issues/new).
### Using Iterators / Writing Large Datasets ### Using Iterators / Writing Large Datasets
It is recommended to use iterators to add large datasets in batches when creating your table in one go. This does not create multiple versions of your dataset unlike manually adding batches using `table.add()` It is recommended to use iterators to add large datasets in batches when creating your table in one go. This does not create multiple versions of your dataset unlike manually adding batches using `table.add()`
LanceDB additionally supports PyArrow's `RecordBatch` Iterators or other generators producing supported data types. LanceDB additionally supports PyArrow's `RecordBatch` Iterators or other generators producing supported data types.
Here's an example using using `RecordBatch` iterator for creating tables. Here's an example using using `RecordBatch` iterator for creating tables.
```python ```python
import pyarrow as pa import pyarrow as pa
def make_batches(): def make_batches():
for i in range(5): for i in range(5):
yield pa.RecordBatch.from_arrays( yield pa.RecordBatch.from_arrays(
[ [
@@ -303,16 +303,16 @@ def make_batches():
["vector", "item", "price"], ["vector", "item", "price"],
) )
schema = pa.schema([ schema = pa.schema([
pa.field("vector", pa.list_(pa.float32(), 4)), pa.field("vector", pa.list_(pa.float32(), 4)),
pa.field("item", pa.utf8()), pa.field("item", pa.utf8()),
pa.field("price", pa.float32()), pa.field("price", pa.float32()),
]) ])
db.create_table("batched_tale", make_batches(), schema=schema) db.create_table("batched_tale", make_batches(), schema=schema)
``` ```
You can also use iterators of other types like Pandas DataFrame or Pylists directly in the above example. You can also use iterators of other types like Pandas DataFrame or Pylists directly in the above example.
## Open existing tables ## Open existing tables

View File

@@ -28,7 +28,7 @@ LanceDB **Cloud** is a SaaS (software-as-a-service) solution that runs serverles
* Fast production-scale vector similarity, full-text & hybrid search and a SQL query interface (via [DataFusion](https://github.com/apache/arrow-datafusion)) * Fast production-scale vector similarity, full-text & hybrid search and a SQL query interface (via [DataFusion](https://github.com/apache/arrow-datafusion))
* Python, Javascript/Typescript, and Rust support * Native Python and Javascript/Typescript support
* Store, query & manage multi-modal data (text, images, videos, point clouds, etc.), not just the embeddings and metadata * Store, query & manage multi-modal data (text, images, videos, point clouds, etc.), not just the embeddings and metadata
@@ -54,4 +54,3 @@ The following pages go deeper into the internal of LanceDB and how to use it.
* [Ecosystem Integrations](integrations/index.md): Integrate LanceDB with other tools in the data ecosystem * [Ecosystem Integrations](integrations/index.md): Integrate LanceDB with other tools in the data ecosystem
* [Python API Reference](python/python.md): Python OSS and Cloud API references * [Python API Reference](python/python.md): Python OSS and Cloud API references
* [JavaScript API Reference](javascript/modules.md): JavaScript OSS and Cloud API references * [JavaScript API Reference](javascript/modules.md): JavaScript OSS and Cloud API references
* [Rust API Reference](https://docs.rs/lancedb/latest/lancedb/index.html): Rust API reference

View File

@@ -13,7 +13,7 @@ Get started using these examples and quick links.
| Integrations | | | Integrations | |
|---|---:| |---|---:|
| <h3> LlamaIndex </h3>LlamaIndex is a simple, flexible data framework for connecting custom data sources to large language models. Llama index integrates with LanceDB as the serverless VectorDB. <h3>[Lean More](https://gpt-index.readthedocs.io/en/latest/examples/vector_stores/LanceDBIndexDemo.html) </h3> |<img src="../assets/llama-index.jpg" alt="image" width="150" height="auto">| | <h3> LlamaIndex </h3>LlamaIndex is a simple, flexible data framework for connecting custom data sources to large language models. Llama index integrates with LanceDB as the serverless VectorDB. <h3>[Lean More](https://gpt-index.readthedocs.io/en/latest/examples/vector_stores/LanceDBIndexDemo.html) </h3> |<img src="../assets/llama-index.jpg" alt="image" width="150" height="auto">|
| <h3>Langchain</h3>Langchain allows building applications with LLMs through composability <h3>[Lean More](https://python.langchain.com/docs/integrations/vectorstores/lancedb) | <img src="../assets/langchain.png" alt="image" width="150" height="auto">| | <h3>Langchain</h3>Langchain allows building applications with LLMs through composability <h3>[Lean More](https://python.langchain.com/en/latest/modules/indexes/vectorstores/examples/lancedb.html) | <img src="../assets/langchain.png" alt="image" width="150" height="auto">|
| <h3>Langchain TS</h3> Javascript bindings for Langchain. It integrates with LanceDB's serverless vectordb allowing you to build powerful AI applications through composibility using only serverless functions. <h3>[Learn More]( https://js.langchain.com/docs/modules/data_connection/vectorstores/integrations/lancedb) | <img src="../assets/langchain.png" alt="image" width="150" height="auto">| | <h3>Langchain TS</h3> Javascript bindings for Langchain. It integrates with LanceDB's serverless vectordb allowing you to build powerful AI applications through composibility using only serverless functions. <h3>[Learn More]( https://js.langchain.com/docs/modules/data_connection/vectorstores/integrations/lancedb) | <img src="../assets/langchain.png" alt="image" width="150" height="auto">|
| <h3>Voxel51</h3> It is an open source toolkit that enables you to build better computer vision workflows by improving the quality of your datasets and delivering insights about your models.<h3>[Learn More](./voxel51.md) | <img src="../assets/voxel.gif" alt="image" width="150" height="auto">| | <h3>Voxel51</h3> It is an open source toolkit that enables you to build better computer vision workflows by improving the quality of your datasets and delivering insights about your models.<h3>[Learn More](./voxel51.md) | <img src="../assets/voxel.gif" alt="image" width="150" height="auto">|
| <h3>PromptTools</h3> Offers a set of free, open-source tools for testing and experimenting with models, prompts, and configurations. The core idea is to enable developers to evaluate prompts using familiar interfaces like code and notebooks. You can use it to experiment with different configurations of LanceDB, and test how LanceDB integrates with the LLM of your choice.<h3>[Learn More](./prompttools.md) | <img src="../assets/prompttools.jpeg" alt="image" width="150" height="auto">| | <h3>PromptTools</h3> Offers a set of free, open-source tools for testing and experimenting with models, prompts, and configurations. The core idea is to enable developers to evaluate prompts using familiar interfaces like code and notebooks. You can use it to experiment with different configurations of LanceDB, and test how LanceDB integrates with the LLM of your choice.<h3>[Learn More](./prompttools.md) | <img src="../assets/prompttools.jpeg" alt="image" width="150" height="auto">|

File diff suppressed because one or more lines are too long

View File

@@ -24,12 +24,6 @@ pip install lancedb
::: lancedb.query.LanceQueryBuilder ::: lancedb.query.LanceQueryBuilder
::: lancedb.query.LanceVectorQueryBuilder
::: lancedb.query.LanceFtsQueryBuilder
::: lancedb.query.LanceHybridQueryBuilder
## Embeddings ## Embeddings
::: lancedb.embeddings.registry.EmbeddingFunctionRegistry ::: lancedb.embeddings.registry.EmbeddingFunctionRegistry
@@ -68,22 +62,10 @@ pip install lancedb
## Integrations ## Integrations
## Pydantic ### Pydantic
::: lancedb.pydantic.pydantic_to_schema ::: lancedb.pydantic.pydantic_to_schema
::: lancedb.pydantic.vector ::: lancedb.pydantic.vector
::: lancedb.pydantic.LanceModel ::: lancedb.pydantic.LanceModel
## Reranking
::: lancedb.rerankers.linear_combination.LinearCombinationReranker
::: lancedb.rerankers.cohere.CohereReranker
::: lancedb.rerankers.colbert.ColbertReranker
::: lancedb.rerankers.cross_encoder.CrossEncoderReranker
::: lancedb.rerankers.openai.OpenaiReranker

View File

@@ -22,7 +22,7 @@ Currently, LanceDB supports the following metrics:
## Exhaustive search (kNN) ## Exhaustive search (kNN)
If you do not create a vector index, LanceDB exhaustively scans the _entire_ vector space If you do not create a vector index, LanceDB exhaustively scans the _entire_ vector space
and computes the distance to every vector in order to find the exact nearest neighbors. This is effectively a kNN search. and compute the distance to every vector in order to find the exact nearest neighbors. This is effectively a kNN search.
<!-- Setup Code <!-- Setup Code
```python ```python
@@ -85,7 +85,7 @@ To perform scalable vector retrieval with acceptable latencies, it's common to b
While the exhaustive search is guaranteed to always return 100% recall, the approximate nature of While the exhaustive search is guaranteed to always return 100% recall, the approximate nature of
an ANN search means that using an index often involves a trade-off between recall and latency. an ANN search means that using an index often involves a trade-off between recall and latency.
See the [IVF_PQ index](./concepts/index_ivfpq.md) for a deeper description of how `IVF_PQ` See the [IVF_PQ index](./concepts/index_ivfpq.md.md) for a deeper description of how `IVF_PQ`
indexes work in LanceDB. indexes work in LanceDB.
## Output search results ## Output search results
@@ -184,3 +184,4 @@ Let's create a LanceDB table with a nested schema:
Note that in this case the extra `_distance` field is discarded since Note that in this case the extra `_distance` field is discarded since
it's not part of the LanceSchema. it's not part of the LanceSchema.

View File

@@ -13,10 +13,5 @@ module.exports = {
}, },
rules: { rules: {
"@typescript-eslint/method-signature-style": "off", "@typescript-eslint/method-signature-style": "off",
"@typescript-eslint/quotes": "off",
"@typescript-eslint/semi": "off",
"@typescript-eslint/explicit-function-return-type": "off",
"@typescript-eslint/space-before-function-paren": "off",
"@typescript-eslint/indent": "off",
} }
} }

87
node/package-lock.json generated
View File

@@ -1,12 +1,12 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.4.13", "version": "0.4.11",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "vectordb", "name": "vectordb",
"version": "0.4.13", "version": "0.4.11",
"cpu": [ "cpu": [
"x64", "x64",
"arm64" "arm64"
@@ -18,7 +18,9 @@
"win32" "win32"
], ],
"dependencies": { "dependencies": {
"@apache-arrow/ts": "^14.0.2",
"@neon-rs/load": "^0.0.74", "@neon-rs/load": "^0.0.74",
"apache-arrow": "^14.0.2",
"axios": "^1.4.0" "axios": "^1.4.0"
}, },
"devDependencies": { "devDependencies": {
@@ -31,7 +33,6 @@
"@types/temp": "^0.9.1", "@types/temp": "^0.9.1",
"@types/uuid": "^9.0.3", "@types/uuid": "^9.0.3",
"@typescript-eslint/eslint-plugin": "^5.59.1", "@typescript-eslint/eslint-plugin": "^5.59.1",
"apache-arrow-old": "npm:apache-arrow@13.0.0",
"cargo-cp-artifact": "^0.1", "cargo-cp-artifact": "^0.1",
"chai": "^4.3.7", "chai": "^4.3.7",
"chai-as-promised": "^7.1.1", "chai-as-promised": "^7.1.1",
@@ -52,15 +53,11 @@
"uuid": "^9.0.0" "uuid": "^9.0.0"
}, },
"optionalDependencies": { "optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.4.13", "@lancedb/vectordb-darwin-arm64": "0.4.11",
"@lancedb/vectordb-darwin-x64": "0.4.13", "@lancedb/vectordb-darwin-x64": "0.4.11",
"@lancedb/vectordb-linux-arm64-gnu": "0.4.13", "@lancedb/vectordb-linux-arm64-gnu": "0.4.11",
"@lancedb/vectordb-linux-x64-gnu": "0.4.13", "@lancedb/vectordb-linux-x64-gnu": "0.4.11",
"@lancedb/vectordb-win32-x64-msvc": "0.4.13" "@lancedb/vectordb-win32-x64-msvc": "0.4.11"
},
"peerDependencies": {
"@apache-arrow/ts": "^14.0.2",
"apache-arrow": "^14.0.2"
} }
}, },
"node_modules/@75lb/deep-merge": { "node_modules/@75lb/deep-merge": {
@@ -96,7 +93,6 @@
"version": "14.0.2", "version": "14.0.2",
"resolved": "https://registry.npmjs.org/@apache-arrow/ts/-/ts-14.0.2.tgz", "resolved": "https://registry.npmjs.org/@apache-arrow/ts/-/ts-14.0.2.tgz",
"integrity": "sha512-CtwAvLkK0CZv7xsYeCo91ml6PvlfzAmAJZkRYuz2GNBwfYufj5SVi0iuSMwIMkcU/szVwvLdzORSLa5PlF/2ug==", "integrity": "sha512-CtwAvLkK0CZv7xsYeCo91ml6PvlfzAmAJZkRYuz2GNBwfYufj5SVi0iuSMwIMkcU/szVwvLdzORSLa5PlF/2ug==",
"peer": true,
"dependencies": { "dependencies": {
"@types/command-line-args": "5.2.0", "@types/command-line-args": "5.2.0",
"@types/command-line-usage": "5.0.2", "@types/command-line-usage": "5.0.2",
@@ -113,8 +109,7 @@
"node_modules/@apache-arrow/ts/node_modules/@types/node": { "node_modules/@apache-arrow/ts/node_modules/@types/node": {
"version": "20.3.0", "version": "20.3.0",
"resolved": "https://registry.npmjs.org/@types/node/-/node-20.3.0.tgz", "resolved": "https://registry.npmjs.org/@types/node/-/node-20.3.0.tgz",
"integrity": "sha512-cumHmIAf6On83X7yP+LrsEyUOf/YlociZelmpRYaGFydoaPdxdt80MAbu6vWerQT2COCp2nPvHdsbD7tHn/YlQ==", "integrity": "sha512-cumHmIAf6On83X7yP+LrsEyUOf/YlociZelmpRYaGFydoaPdxdt80MAbu6vWerQT2COCp2nPvHdsbD7tHn/YlQ=="
"peer": true
}, },
"node_modules/@cargo-messages/android-arm-eabi": { "node_modules/@cargo-messages/android-arm-eabi": {
"version": "0.0.160", "version": "0.0.160",
@@ -334,9 +329,9 @@
} }
}, },
"node_modules/@lancedb/vectordb-darwin-arm64": { "node_modules/@lancedb/vectordb-darwin-arm64": {
"version": "0.4.13", "version": "0.4.11",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.4.13.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.4.11.tgz",
"integrity": "sha512-JfroNCG8yKIU931Y+x8d0Fp8C9DHUSC5j+CjI+e5err7rTWtie4j3JbsXlWAnPFaFEOg0Xk3BWkSikCvhPGJGg==", "integrity": "sha512-JDOKmFnuJPFkA7ZmrzBJolROwSjWr7yMvAbi40uLBc25YbbVezodd30u2EFtIwWwtk1GqNYRZ49FZOElKYeC/Q==",
"cpu": [ "cpu": [
"arm64" "arm64"
], ],
@@ -346,9 +341,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-darwin-x64": { "node_modules/@lancedb/vectordb-darwin-x64": {
"version": "0.4.13", "version": "0.4.11",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.4.13.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.4.11.tgz",
"integrity": "sha512-dG6IMvfpHpnHdbJ0UffzJ7cZfMiC02MjIi6YJzgx+hKz2UNXWNBIfTvvhqli85mZsGRXL1OYDdYv0K1YzNjXlA==", "integrity": "sha512-iy6r+8tp2v1EFgJV52jusXtxgO6NY6SkpOdX41xPqN2mQWMkfUAR9Xtks1mgknjPOIKH4MRc8ZS0jcW/UWmilQ==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
@@ -358,9 +353,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-linux-arm64-gnu": { "node_modules/@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.4.13", "version": "0.4.11",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.4.13.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.4.11.tgz",
"integrity": "sha512-BRR1VzaMviXby7qmLm0axNZM8eUZF3ZqfvnDKdVRpC3LaRueD6pMXHuC2IUKaFkn7xktf+8BlDZb6foFNEj8bQ==", "integrity": "sha512-5K6IVcTMuH0SZBjlqB5Gg39WC889FpTwIWKufxzQMMXrzxo5J3lKUHVoR28RRlNhDF2d9kZXBEyCpIfDFsV9iQ==",
"cpu": [ "cpu": [
"arm64" "arm64"
], ],
@@ -370,9 +365,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-linux-x64-gnu": { "node_modules/@lancedb/vectordb-linux-x64-gnu": {
"version": "0.4.13", "version": "0.4.11",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.4.13.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.4.11.tgz",
"integrity": "sha512-WnekZ7ZMlria+NODZ6aBCljCFQSe2bBNUS9ZpyFl/Y1vHduSQPuBxM6V7vp2QubC0daq/rifgjDob89DF+x3xw==", "integrity": "sha512-hF9ZChsdqKqqnivOzd9mE7lC3PmhZadXtwThi2RrsPiOLoEaGDfmr6Ni3amVQnB3bR8YEJtTxdQxe0NC4uW/8g==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
@@ -382,9 +377,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-win32-x64-msvc": { "node_modules/@lancedb/vectordb-win32-x64-msvc": {
"version": "0.4.13", "version": "0.4.11",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.4.13.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.4.11.tgz",
"integrity": "sha512-3NDpMWBL2ksDHXAraXhowiLqQcNWM5bdbeHwze4+InYMD54hyQ2ODNc+4usxp63Nya9biVnFS27yXULqkzIEqQ==", "integrity": "sha512-0+9ut1ccKoqIyGxsVixwx3771Z+DXpl5WfSmOeA8kf3v3jlOg2H+0YUahiXLDid2ju+yeLPrAUYm7A1gKHVhew==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
@@ -953,7 +948,6 @@
"version": "14.0.2", "version": "14.0.2",
"resolved": "https://registry.npmjs.org/apache-arrow/-/apache-arrow-14.0.2.tgz", "resolved": "https://registry.npmjs.org/apache-arrow/-/apache-arrow-14.0.2.tgz",
"integrity": "sha512-EBO2xJN36/XoY81nhLcwCJgFwkboDZeyNQ+OPsG7bCoQjc2BT0aTyH/MR6SrL+LirSNz+cYqjGRlupMMlP1aEg==", "integrity": "sha512-EBO2xJN36/XoY81nhLcwCJgFwkboDZeyNQ+OPsG7bCoQjc2BT0aTyH/MR6SrL+LirSNz+cYqjGRlupMMlP1aEg==",
"peer": true,
"dependencies": { "dependencies": {
"@types/command-line-args": "5.2.0", "@types/command-line-args": "5.2.0",
"@types/command-line-usage": "5.0.2", "@types/command-line-usage": "5.0.2",
@@ -970,39 +964,10 @@
"arrow2csv": "bin/arrow2csv.js" "arrow2csv": "bin/arrow2csv.js"
} }
}, },
"node_modules/apache-arrow-old": {
"name": "apache-arrow",
"version": "13.0.0",
"resolved": "https://registry.npmjs.org/apache-arrow/-/apache-arrow-13.0.0.tgz",
"integrity": "sha512-3gvCX0GDawWz6KFNC28p65U+zGh/LZ6ZNKWNu74N6CQlKzxeoWHpi4CgEQsgRSEMuyrIIXi1Ea2syja7dwcHvw==",
"dev": true,
"dependencies": {
"@types/command-line-args": "5.2.0",
"@types/command-line-usage": "5.0.2",
"@types/node": "20.3.0",
"@types/pad-left": "2.1.1",
"command-line-args": "5.2.1",
"command-line-usage": "7.0.1",
"flatbuffers": "23.5.26",
"json-bignum": "^0.0.3",
"pad-left": "^2.1.0",
"tslib": "^2.5.3"
},
"bin": {
"arrow2csv": "bin/arrow2csv.js"
}
},
"node_modules/apache-arrow-old/node_modules/@types/node": {
"version": "20.3.0",
"resolved": "https://registry.npmjs.org/@types/node/-/node-20.3.0.tgz",
"integrity": "sha512-cumHmIAf6On83X7yP+LrsEyUOf/YlociZelmpRYaGFydoaPdxdt80MAbu6vWerQT2COCp2nPvHdsbD7tHn/YlQ==",
"dev": true
},
"node_modules/apache-arrow/node_modules/@types/node": { "node_modules/apache-arrow/node_modules/@types/node": {
"version": "20.3.0", "version": "20.3.0",
"resolved": "https://registry.npmjs.org/@types/node/-/node-20.3.0.tgz", "resolved": "https://registry.npmjs.org/@types/node/-/node-20.3.0.tgz",
"integrity": "sha512-cumHmIAf6On83X7yP+LrsEyUOf/YlociZelmpRYaGFydoaPdxdt80MAbu6vWerQT2COCp2nPvHdsbD7tHn/YlQ==", "integrity": "sha512-cumHmIAf6On83X7yP+LrsEyUOf/YlociZelmpRYaGFydoaPdxdt80MAbu6vWerQT2COCp2nPvHdsbD7tHn/YlQ=="
"peer": true
}, },
"node_modules/arg": { "node_modules/arg": {
"version": "4.1.3", "version": "4.1.3",

View File

@@ -1,6 +1,6 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.4.13", "version": "0.4.11",
"description": " Serverless, low-latency vector database for AI applications", "description": " Serverless, low-latency vector database for AI applications",
"main": "dist/index.js", "main": "dist/index.js",
"types": "dist/index.d.ts", "types": "dist/index.d.ts",
@@ -41,7 +41,6 @@
"@types/temp": "^0.9.1", "@types/temp": "^0.9.1",
"@types/uuid": "^9.0.3", "@types/uuid": "^9.0.3",
"@typescript-eslint/eslint-plugin": "^5.59.1", "@typescript-eslint/eslint-plugin": "^5.59.1",
"apache-arrow-old": "npm:apache-arrow@13.0.0",
"cargo-cp-artifact": "^0.1", "cargo-cp-artifact": "^0.1",
"chai": "^4.3.7", "chai": "^4.3.7",
"chai-as-promised": "^7.1.1", "chai-as-promised": "^7.1.1",
@@ -88,10 +87,10 @@
} }
}, },
"optionalDependencies": { "optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.4.13", "@lancedb/vectordb-darwin-arm64": "0.4.11",
"@lancedb/vectordb-darwin-x64": "0.4.13", "@lancedb/vectordb-darwin-x64": "0.4.11",
"@lancedb/vectordb-linux-arm64-gnu": "0.4.13", "@lancedb/vectordb-linux-arm64-gnu": "0.4.11",
"@lancedb/vectordb-linux-x64-gnu": "0.4.13", "@lancedb/vectordb-linux-x64-gnu": "0.4.11",
"@lancedb/vectordb-win32-x64-msvc": "0.4.13" "@lancedb/vectordb-win32-x64-msvc": "0.4.11"
} }
} }

View File

@@ -20,20 +20,19 @@ import {
type Vector, type Vector,
FixedSizeList, FixedSizeList,
vectorFromArray, vectorFromArray,
Schema, type Schema,
Table as ArrowTable, Table as ArrowTable,
RecordBatchStreamWriter, RecordBatchStreamWriter,
List, List,
RecordBatch, RecordBatch,
makeData, makeData,
Struct, Struct,
Float, type Float,
DataType, DataType,
Binary, Binary,
Float32 Float32
} from 'apache-arrow' } from 'apache-arrow'
import { type EmbeddingFunction } from './index' import { type EmbeddingFunction } from './index'
import { sanitizeSchema } from './sanitize'
/* /*
* Options to control how a column should be converted to a vector array * Options to control how a column should be converted to a vector array
@@ -202,13 +201,10 @@ export function makeArrowTable (
} }
const opt = new MakeArrowTableOptions(options !== undefined ? options : {}) const opt = new MakeArrowTableOptions(options !== undefined ? options : {})
if (opt.schema !== undefined && opt.schema !== null) {
opt.schema = sanitizeSchema(opt.schema)
}
const columns: Record<string, Vector> = {} const columns: Record<string, Vector> = {}
// TODO: sample dataset to find missing columns // TODO: sample dataset to find missing columns
// Prefer the field ordering of the schema, if present // Prefer the field ordering of the schema, if present
const columnNames = ((opt.schema) != null) ? (opt.schema.names as string[]) : Object.keys(data[0]) const columnNames = ((options?.schema) != null) ? (options?.schema?.names as string[]) : Object.keys(data[0])
for (const colName of columnNames) { for (const colName of columnNames) {
if (data.length !== 0 && !Object.prototype.hasOwnProperty.call(data[0], colName)) { if (data.length !== 0 && !Object.prototype.hasOwnProperty.call(data[0], colName)) {
// The field is present in the schema, but not in the data, skip it // The field is present in the schema, but not in the data, skip it
@@ -333,9 +329,6 @@ async function applyEmbeddings<T> (table: ArrowTable, embeddings?: EmbeddingFunc
if (embeddings == null) { if (embeddings == null) {
return table return table
} }
if (schema !== undefined && schema !== null) {
schema = sanitizeSchema(schema)
}
// Convert from ArrowTable to Record<String, Vector> // Convert from ArrowTable to Record<String, Vector>
const colEntries = [...Array(table.numCols).keys()].map((_, idx) => { const colEntries = [...Array(table.numCols).keys()].map((_, idx) => {
@@ -446,9 +439,6 @@ export async function fromRecordsToBuffer<T> (
embeddings?: EmbeddingFunction<T>, embeddings?: EmbeddingFunction<T>,
schema?: Schema schema?: Schema
): Promise<Buffer> { ): Promise<Buffer> {
if (schema !== undefined && schema !== null) {
schema = sanitizeSchema(schema)
}
const table = await convertToTable(data, embeddings, { schema }) const table = await convertToTable(data, embeddings, { schema })
const writer = RecordBatchFileWriter.writeAll(table) const writer = RecordBatchFileWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array()) return Buffer.from(await writer.toUint8Array())
@@ -466,9 +456,6 @@ export async function fromRecordsToStreamBuffer<T> (
embeddings?: EmbeddingFunction<T>, embeddings?: EmbeddingFunction<T>,
schema?: Schema schema?: Schema
): Promise<Buffer> { ): Promise<Buffer> {
if (schema !== null && schema !== undefined) {
schema = sanitizeSchema(schema)
}
const table = await convertToTable(data, embeddings, { schema }) const table = await convertToTable(data, embeddings, { schema })
const writer = RecordBatchStreamWriter.writeAll(table) const writer = RecordBatchStreamWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array()) return Buffer.from(await writer.toUint8Array())
@@ -487,9 +474,6 @@ export async function fromTableToBuffer<T> (
embeddings?: EmbeddingFunction<T>, embeddings?: EmbeddingFunction<T>,
schema?: Schema schema?: Schema
): Promise<Buffer> { ): Promise<Buffer> {
if (schema !== null && schema !== undefined) {
schema = sanitizeSchema(schema)
}
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema) const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema)
const writer = RecordBatchFileWriter.writeAll(tableWithEmbeddings) const writer = RecordBatchFileWriter.writeAll(tableWithEmbeddings)
return Buffer.from(await writer.toUint8Array()) return Buffer.from(await writer.toUint8Array())
@@ -508,9 +492,6 @@ export async function fromTableToStreamBuffer<T> (
embeddings?: EmbeddingFunction<T>, embeddings?: EmbeddingFunction<T>,
schema?: Schema schema?: Schema
): Promise<Buffer> { ): Promise<Buffer> {
if (schema !== null && schema !== undefined) {
schema = sanitizeSchema(schema)
}
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema) const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema)
const writer = RecordBatchStreamWriter.writeAll(tableWithEmbeddings) const writer = RecordBatchStreamWriter.writeAll(tableWithEmbeddings)
return Buffer.from(await writer.toUint8Array()) return Buffer.from(await writer.toUint8Array())
@@ -547,5 +528,5 @@ function alignTable (table: ArrowTable, schema: Schema): ArrowTable {
// Creates an empty Arrow Table // Creates an empty Arrow Table
export function createEmptyTable (schema: Schema): ArrowTable { export function createEmptyTable (schema: Schema): ArrowTable {
return new ArrowTable(sanitizeSchema(schema)) return new ArrowTable(schema)
} }

View File

@@ -176,10 +176,6 @@ export async function connect (
opts = { uri: arg } opts = { uri: arg }
} else { } else {
// opts = { uri: arg.uri, awsCredentials = arg.awsCredentials } // opts = { uri: arg.uri, awsCredentials = arg.awsCredentials }
const keys = Object.keys(arg)
if (keys.length === 1 && keys[0] === 'uri' && typeof arg.uri === 'string') {
opts = { uri: arg.uri }
} else {
opts = Object.assign( opts = Object.assign(
{ {
uri: '', uri: '',
@@ -191,7 +187,6 @@ export async function connect (
arg arg
) )
} }
}
if (opts.uri.startsWith('db://')) { if (opts.uri.startsWith('db://')) {
// Remote connection // Remote connection

View File

@@ -1,508 +0,0 @@
// Copyright 2023 LanceDB Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// The utilities in this file help sanitize data from the user's arrow
// library into the types expected by vectordb's arrow library. Node
// generally allows for mulitple versions of the same library (and sometimes
// even multiple copies of the same version) to be installed at the same
// time. However, arrow-js uses instanceof which expected that the input
// comes from the exact same library instance. This is not always the case
// and so we must sanitize the input to ensure that it is compatible.
import {
Field,
Utf8,
FixedSizeBinary,
FixedSizeList,
Schema,
List,
Struct,
Float,
Bool,
Date_,
Decimal,
DataType,
Dictionary,
Binary,
Float32,
Interval,
Map_,
Duration,
Union,
Time,
Timestamp,
Type,
Null,
Int,
type Precision,
type DateUnit,
Int8,
Int16,
Int32,
Int64,
Uint8,
Uint16,
Uint32,
Uint64,
Float16,
Float64,
DateDay,
DateMillisecond,
DenseUnion,
SparseUnion,
TimeNanosecond,
TimeMicrosecond,
TimeMillisecond,
TimeSecond,
TimestampNanosecond,
TimestampMicrosecond,
TimestampMillisecond,
TimestampSecond,
IntervalDayTime,
IntervalYearMonth,
DurationNanosecond,
DurationMicrosecond,
DurationMillisecond,
DurationSecond,
} from "apache-arrow";
import type { IntBitWidth, TimeBitWidth } from "apache-arrow/type";
function sanitizeMetadata(
metadataLike?: unknown,
): Map<string, string> | undefined {
if (metadataLike === undefined || metadataLike === null) {
return undefined;
}
if (!(metadataLike instanceof Map)) {
throw Error("Expected metadata, if present, to be a Map<string, string>");
}
for (const item of metadataLike) {
if (!(typeof item[0] === "string" || !(typeof item[1] === "string"))) {
throw Error(
"Expected metadata, if present, to be a Map<string, string> but it had non-string keys or values",
);
}
}
return metadataLike as Map<string, string>;
}
function sanitizeInt(typeLike: object) {
if (
!("bitWidth" in typeLike) ||
typeof typeLike.bitWidth !== "number" ||
!("isSigned" in typeLike) ||
typeof typeLike.isSigned !== "boolean"
) {
throw Error(
"Expected an Int Type to have a `bitWidth` and `isSigned` property",
);
}
return new Int(typeLike.isSigned, typeLike.bitWidth as IntBitWidth);
}
function sanitizeFloat(typeLike: object) {
if (!("precision" in typeLike) || typeof typeLike.precision !== "number") {
throw Error("Expected a Float Type to have a `precision` property");
}
return new Float(typeLike.precision as Precision);
}
function sanitizeDecimal(typeLike: object) {
if (
!("scale" in typeLike) ||
typeof typeLike.scale !== "number" ||
!("precision" in typeLike) ||
typeof typeLike.precision !== "number" ||
!("bitWidth" in typeLike) ||
typeof typeLike.bitWidth !== "number"
) {
throw Error(
"Expected a Decimal Type to have `scale`, `precision`, and `bitWidth` properties",
);
}
return new Decimal(typeLike.scale, typeLike.precision, typeLike.bitWidth);
}
function sanitizeDate(typeLike: object) {
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
throw Error("Expected a Date type to have a `unit` property");
}
return new Date_(typeLike.unit as DateUnit);
}
function sanitizeTime(typeLike: object) {
if (
!("unit" in typeLike) ||
typeof typeLike.unit !== "number" ||
!("bitWidth" in typeLike) ||
typeof typeLike.bitWidth !== "number"
) {
throw Error(
"Expected a Time type to have `unit` and `bitWidth` properties",
);
}
return new Time(typeLike.unit, typeLike.bitWidth as TimeBitWidth);
}
function sanitizeTimestamp(typeLike: object) {
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
throw Error("Expected a Timestamp type to have a `unit` property");
}
let timezone = null;
if ("timezone" in typeLike && typeof typeLike.timezone === "string") {
timezone = typeLike.timezone;
}
return new Timestamp(typeLike.unit, timezone);
}
function sanitizeTypedTimestamp(
typeLike: object,
Datatype:
| typeof TimestampNanosecond
| typeof TimestampMicrosecond
| typeof TimestampMillisecond
| typeof TimestampSecond,
) {
let timezone = null;
if ("timezone" in typeLike && typeof typeLike.timezone === "string") {
timezone = typeLike.timezone;
}
return new Datatype(timezone);
}
function sanitizeInterval(typeLike: object) {
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
throw Error("Expected an Interval type to have a `unit` property");
}
return new Interval(typeLike.unit);
}
function sanitizeList(typeLike: object) {
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error(
"Expected a List type to have an array-like `children` property",
);
}
if (typeLike.children.length !== 1) {
throw Error("Expected a List type to have exactly one child");
}
return new List(sanitizeField(typeLike.children[0]));
}
function sanitizeStruct(typeLike: object) {
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error(
"Expected a Struct type to have an array-like `children` property",
);
}
return new Struct(typeLike.children.map((child) => sanitizeField(child)));
}
function sanitizeUnion(typeLike: object) {
if (
!("typeIds" in typeLike) ||
!("mode" in typeLike) ||
typeof typeLike.mode !== "number"
) {
throw Error(
"Expected a Union type to have `typeIds` and `mode` properties",
);
}
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error(
"Expected a Union type to have an array-like `children` property",
);
}
return new Union(
typeLike.mode,
typeLike.typeIds as any,
typeLike.children.map((child) => sanitizeField(child)),
);
}
function sanitizeTypedUnion(
typeLike: object,
UnionType: typeof DenseUnion | typeof SparseUnion,
) {
if (!("typeIds" in typeLike)) {
throw Error(
"Expected a DenseUnion/SparseUnion type to have a `typeIds` property",
);
}
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error(
"Expected a DenseUnion/SparseUnion type to have an array-like `children` property",
);
}
return new UnionType(
typeLike.typeIds as any,
typeLike.children.map((child) => sanitizeField(child)),
);
}
function sanitizeFixedSizeBinary(typeLike: object) {
if (!("byteWidth" in typeLike) || typeof typeLike.byteWidth !== "number") {
throw Error(
"Expected a FixedSizeBinary type to have a `byteWidth` property",
);
}
return new FixedSizeBinary(typeLike.byteWidth);
}
function sanitizeFixedSizeList(typeLike: object) {
if (!("listSize" in typeLike) || typeof typeLike.listSize !== "number") {
throw Error("Expected a FixedSizeList type to have a `listSize` property");
}
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error(
"Expected a FixedSizeList type to have an array-like `children` property",
);
}
if (typeLike.children.length !== 1) {
throw Error("Expected a FixedSizeList type to have exactly one child");
}
return new FixedSizeList(
typeLike.listSize,
sanitizeField(typeLike.children[0]),
);
}
function sanitizeMap(typeLike: object) {
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error(
"Expected a Map type to have an array-like `children` property",
);
}
if (!("keysSorted" in typeLike) || typeof typeLike.keysSorted !== "boolean") {
throw Error("Expected a Map type to have a `keysSorted` property");
}
return new Map_(
typeLike.children.map((field) => sanitizeField(field)) as any,
typeLike.keysSorted,
);
}
function sanitizeDuration(typeLike: object) {
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
throw Error("Expected a Duration type to have a `unit` property");
}
return new Duration(typeLike.unit);
}
function sanitizeDictionary(typeLike: object) {
if (!("id" in typeLike) || typeof typeLike.id !== "number") {
throw Error("Expected a Dictionary type to have an `id` property");
}
if (!("indices" in typeLike) || typeof typeLike.indices !== "object") {
throw Error("Expected a Dictionary type to have an `indices` property");
}
if (!("dictionary" in typeLike) || typeof typeLike.dictionary !== "object") {
throw Error("Expected a Dictionary type to have an `dictionary` property");
}
if (!("isOrdered" in typeLike) || typeof typeLike.isOrdered !== "boolean") {
throw Error("Expected a Dictionary type to have an `isOrdered` property");
}
return new Dictionary(
sanitizeType(typeLike.dictionary),
sanitizeType(typeLike.indices) as any,
typeLike.id,
typeLike.isOrdered,
);
}
function sanitizeType(typeLike: unknown): DataType<any> {
if (typeof typeLike !== "object" || typeLike === null) {
throw Error("Expected a Type but object was null/undefined");
}
if (!("typeId" in typeLike) || !(typeof typeLike.typeId !== "function")) {
throw Error("Expected a Type to have a typeId function");
}
let typeId: Type;
if (typeof typeLike.typeId === "function") {
typeId = (typeLike.typeId as () => unknown)() as Type;
} else if (typeof typeLike.typeId === "number") {
typeId = typeLike.typeId as Type;
} else {
throw Error("Type's typeId property was not a function or number");
}
switch (typeId) {
case Type.NONE:
throw Error("Received a Type with a typeId of NONE");
case Type.Null:
return new Null();
case Type.Int:
return sanitizeInt(typeLike);
case Type.Float:
return sanitizeFloat(typeLike);
case Type.Binary:
return new Binary();
case Type.Utf8:
return new Utf8();
case Type.Bool:
return new Bool();
case Type.Decimal:
return sanitizeDecimal(typeLike);
case Type.Date:
return sanitizeDate(typeLike);
case Type.Time:
return sanitizeTime(typeLike);
case Type.Timestamp:
return sanitizeTimestamp(typeLike);
case Type.Interval:
return sanitizeInterval(typeLike);
case Type.List:
return sanitizeList(typeLike);
case Type.Struct:
return sanitizeStruct(typeLike);
case Type.Union:
return sanitizeUnion(typeLike);
case Type.FixedSizeBinary:
return sanitizeFixedSizeBinary(typeLike);
case Type.FixedSizeList:
return sanitizeFixedSizeList(typeLike);
case Type.Map:
return sanitizeMap(typeLike);
case Type.Duration:
return sanitizeDuration(typeLike);
case Type.Dictionary:
return sanitizeDictionary(typeLike);
case Type.Int8:
return new Int8();
case Type.Int16:
return new Int16();
case Type.Int32:
return new Int32();
case Type.Int64:
return new Int64();
case Type.Uint8:
return new Uint8();
case Type.Uint16:
return new Uint16();
case Type.Uint32:
return new Uint32();
case Type.Uint64:
return new Uint64();
case Type.Float16:
return new Float16();
case Type.Float32:
return new Float32();
case Type.Float64:
return new Float64();
case Type.DateMillisecond:
return new DateMillisecond();
case Type.DateDay:
return new DateDay();
case Type.TimeNanosecond:
return new TimeNanosecond();
case Type.TimeMicrosecond:
return new TimeMicrosecond();
case Type.TimeMillisecond:
return new TimeMillisecond();
case Type.TimeSecond:
return new TimeSecond();
case Type.TimestampNanosecond:
return sanitizeTypedTimestamp(typeLike, TimestampNanosecond);
case Type.TimestampMicrosecond:
return sanitizeTypedTimestamp(typeLike, TimestampMicrosecond);
case Type.TimestampMillisecond:
return sanitizeTypedTimestamp(typeLike, TimestampMillisecond);
case Type.TimestampSecond:
return sanitizeTypedTimestamp(typeLike, TimestampSecond);
case Type.DenseUnion:
return sanitizeTypedUnion(typeLike, DenseUnion);
case Type.SparseUnion:
return sanitizeTypedUnion(typeLike, SparseUnion);
case Type.IntervalDayTime:
return new IntervalDayTime();
case Type.IntervalYearMonth:
return new IntervalYearMonth();
case Type.DurationNanosecond:
return new DurationNanosecond();
case Type.DurationMicrosecond:
return new DurationMicrosecond();
case Type.DurationMillisecond:
return new DurationMillisecond();
case Type.DurationSecond:
return new DurationSecond();
}
}
function sanitizeField(fieldLike: unknown): Field {
if (fieldLike instanceof Field) {
return fieldLike;
}
if (typeof fieldLike !== "object" || fieldLike === null) {
throw Error("Expected a Field but object was null/undefined");
}
if (
!("type" in fieldLike) ||
!("name" in fieldLike) ||
!("nullable" in fieldLike)
) {
throw Error(
"The field passed in is missing a `type`/`name`/`nullable` property",
);
}
const type = sanitizeType(fieldLike.type);
const name = fieldLike.name;
if (!(typeof name === "string")) {
throw Error("The field passed in had a non-string `name` property");
}
const nullable = fieldLike.nullable;
if (!(typeof nullable === "boolean")) {
throw Error("The field passed in had a non-boolean `nullable` property");
}
let metadata;
if ("metadata" in fieldLike) {
metadata = sanitizeMetadata(fieldLike.metadata);
}
return new Field(name, type, nullable, metadata);
}
/**
* Convert something schemaLike into a Schema instance
*
* This method is often needed even when the caller is using a Schema
* instance because they might be using a different instance of apache-arrow
* than lancedb is using.
*/
export function sanitizeSchema(schemaLike: unknown): Schema {
if (schemaLike instanceof Schema) {
return schemaLike;
}
if (typeof schemaLike !== "object" || schemaLike === null) {
throw Error("Expected a Schema but object was null/undefined");
}
if (!("fields" in schemaLike)) {
throw Error(
"The schema passed in does not appear to be a schema (no 'fields' property)",
);
}
let metadata;
if ("metadata" in schemaLike) {
metadata = sanitizeMetadata(schemaLike.metadata);
}
if (!Array.isArray(schemaLike.fields)) {
throw Error(
"The schema passed in had a 'fields' property but it was not an array",
);
}
const sanitizedFields = schemaLike.fields.map((field) =>
sanitizeField(field),
);
return new Schema(sanitizedFields, metadata);
}

View File

@@ -34,20 +34,8 @@ import {
List, List,
DataType, DataType,
Dictionary, Dictionary,
Int64, Int64
MetadataVersion
} from 'apache-arrow' } from 'apache-arrow'
import {
Dictionary as OldDictionary,
Field as OldField,
FixedSizeList as OldFixedSizeList,
Float32 as OldFloat32,
Int32 as OldInt32,
Struct as OldStruct,
Schema as OldSchema,
TimestampNanosecond as OldTimestampNanosecond,
Utf8 as OldUtf8
} from 'apache-arrow-old'
import { type EmbeddingFunction } from '../embedding/embedding_function' import { type EmbeddingFunction } from '../embedding/embedding_function'
chaiUse(chaiAsPromised) chaiUse(chaiAsPromised)
@@ -330,31 +318,3 @@ describe('makeEmptyTable', function () {
await checkTableCreation(async (_, __, schema) => makeEmptyTable(schema)) await checkTableCreation(async (_, __, schema) => makeEmptyTable(schema))
}) })
}) })
describe('when using two versions of arrow', function () {
it('can still import data', async function() {
const schema = new OldSchema([
new OldField('id', new OldInt32()),
new OldField('vector', new OldFixedSizeList(1024, new OldField("item", new OldFloat32(), true))),
new OldField('struct', new OldStruct([
new OldField('nested', new OldDictionary(new OldUtf8(), new OldInt32(), 1, true)),
new OldField('ts_with_tz', new OldTimestampNanosecond("some_tz")),
new OldField('ts_no_tz', new OldTimestampNanosecond(null))
]))
]) as any
// We use arrow version 13 to emulate a "foreign arrow" and this version doesn't have metadataVersion
// In theory, this wouldn't matter. We don't rely on that property. However, it causes deepEqual to
// fail so we patch it back in
schema.metadataVersion = MetadataVersion.V5
const table = makeArrowTable(
[],
{ schema }
)
const buf = await fromTableToBuffer(table)
assert.isAbove(buf.byteLength, 0)
const actual = tableFromIPC(buf)
const actualSchema = actual.schema
assert.deepEqual(actualSchema, schema)
})
})

View File

@@ -128,11 +128,6 @@ describe('LanceDB client', function () {
assertResults(results) assertResults(results)
results = await table.where('id % 2 = 0').execute() results = await table.where('id % 2 = 0').execute()
assertResults(results) assertResults(results)
// Should reject a bad filter
await expect(table.filter('id % 2 = 0 AND').execute()).to.be.rejectedWith(
/.*sql parser error: Expected an expression:, found: EOF.*/
)
}) })
it('uses a filter / where clause', async function () { it('uses a filter / where clause', async function () {
@@ -288,8 +283,7 @@ describe('LanceDB client', function () {
it('create a table from an Arrow Table', async function () { it('create a table from an Arrow Table', async function () {
const dir = await track().mkdir('lancejs') const dir = await track().mkdir('lancejs')
// Also test the connect function with an object const con = await lancedb.connect(dir)
const con = await lancedb.connect({ uri: dir })
const i32s = new Int32Array(new Array<number>(10)) const i32s = new Int32Array(new Array<number>(10))
const i32 = makeVector(i32s) const i32 = makeVector(i32s)
@@ -751,11 +745,11 @@ describe('LanceDB client', function () {
num_sub_vectors: 2 num_sub_vectors: 2
}) })
await expect(createIndex).to.be.rejectedWith( await expect(createIndex).to.be.rejectedWith(
"index cannot be created on the column `name` which has data type Utf8" /VectorIndex requires the column data type to be fixed size list of float32s/
) )
}) })
it('it should fail when num_partitions is invalid', async function () { it('it should fail when the column is not a vector', async function () {
const uri = await createTestDB(32, 300) const uri = await createTestDB(32, 300)
const con = await lancedb.connect(uri) const con = await lancedb.connect(uri)
const table = await con.openTable('vectors') const table = await con.openTable('vectors')

View File

@@ -1,3 +0,0 @@
**/dist/**/*
**/native.js
**/native.d.ts

22
nodejs/.eslintrc.js Normal file
View File

@@ -0,0 +1,22 @@
module.exports = {
env: {
browser: true,
es2021: true,
},
extends: [
"eslint:recommended",
"plugin:@typescript-eslint/recommended-type-checked",
"plugin:@typescript-eslint/stylistic-type-checked",
],
overrides: [],
parserOptions: {
project: "./tsconfig.json",
ecmaVersion: "latest",
sourceType: "module",
},
rules: {
"@typescript-eslint/method-signature-style": "off",
"@typescript-eslint/no-explicit-any": "off",
},
ignorePatterns: ["node_modules/", "dist/", "build/", "lancedb/native.*"],
};

View File

@@ -1 +0,0 @@
.eslintignore

View File

@@ -14,10 +14,12 @@ crate-type = ["cdylib"]
[dependencies] [dependencies]
arrow-ipc.workspace = true arrow-ipc.workspace = true
futures.workspace = true futures.workspace = true
lance-linalg.workspace = true
lance.workspace = true
lancedb = { path = "../rust/lancedb" } lancedb = { path = "../rust/lancedb" }
napi = { version = "2.15", default-features = false, features = [ napi = { version = "2.15", default-features = false, features = [
"napi7", "napi7",
"async", "async"
] } ] }
napi-derive = "2" napi-derive = "2"

View File

@@ -2,6 +2,7 @@
It will replace the NodeJS SDK when it is ready. It will replace the NodeJS SDK when it is ready.
## Development ## Development
```sh ```sh
@@ -9,35 +10,9 @@ npm run build
npm t npm t
``` ```
### Running lint / format Generating docs
LanceDb uses eslint for linting. VSCode does not need any plugins to use eslint. However, it
may need some additional configuration. Make sure that eslint.experimental.useFlatConfig is
set to true. Also, if your vscode root folder is the repo root then you will need to set
the eslint.workingDirectories to ["nodejs"]. To manually lint your code you can run:
```sh
npm run lint
``` ```
LanceDb uses prettier for formatting. If you are using VSCode you will need to install the
"Prettier - Code formatter" extension. You should then configure it to be the default formatter
for typescript and you should enable format on save. To manually check your code's format you
can run:
```sh
npm run chkformat
```
If you need to manually format your code you can run:
```sh
npx prettier --write .
```
### Generating docs
```sh
npm run docs npm run docs
cd ../docs cd ../docs

View File

@@ -12,13 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
import { makeArrowTable, toBuffer } from "../lancedb/arrow";
import { import {
convertToTable, Int64,
fromTableToBuffer,
makeArrowTable,
makeEmptyTable,
} from "../dist/arrow";
import {
Field, Field,
FixedSizeList, FixedSizeList,
Float16, Float16,
@@ -27,137 +23,43 @@ import {
tableFromIPC, tableFromIPC,
Schema, Schema,
Float64, Float64,
type Table,
Binary,
Bool,
Utf8,
Struct,
List,
DataType,
Dictionary,
Int64,
Float,
Precision,
MetadataVersion,
} from "apache-arrow"; } from "apache-arrow";
import {
Dictionary as OldDictionary,
Field as OldField,
FixedSizeList as OldFixedSizeList,
Float32 as OldFloat32,
Int32 as OldInt32,
Struct as OldStruct,
Schema as OldSchema,
TimestampNanosecond as OldTimestampNanosecond,
Utf8 as OldUtf8,
} from "apache-arrow-old";
import { type EmbeddingFunction } from "../dist/embedding/embedding_function";
// eslint-disable-next-line @typescript-eslint/no-explicit-any test("customized schema", function () {
function sampleRecords(): Array<Record<string, any>> {
return [
{
binary: Buffer.alloc(5),
boolean: false,
number: 7,
string: "hello",
struct: { x: 0, y: 0 },
list: ["anime", "action", "comedy"],
},
];
}
// Helper method to verify various ways to create a table
async function checkTableCreation(
tableCreationMethod: (
records: Record<string, unknown>[],
recordsReversed: Record<string, unknown>[],
schema: Schema,
) => Promise<Table>,
infersTypes: boolean,
): Promise<void> {
const records = sampleRecords();
const recordsReversed = [
{
list: ["anime", "action", "comedy"],
struct: { x: 0, y: 0 },
string: "hello",
number: 7,
boolean: false,
binary: Buffer.alloc(5),
},
];
const schema = new Schema([ const schema = new Schema([
new Field("binary", new Binary(), false), new Field("a", new Int32(), true),
new Field("boolean", new Bool(), false), new Field("b", new Float32(), true),
new Field("number", new Float64(), false),
new Field("string", new Utf8(), false),
new Field( new Field(
"struct", "c",
new Struct([ new FixedSizeList(3, new Field("item", new Float16())),
new Field("x", new Float64(), false), true
new Field("y", new Float64(), false),
]),
), ),
new Field("list", new List(new Field("item", new Utf8(), false)), false),
]);
const table = await tableCreationMethod(records, recordsReversed, schema);
schema.fields.forEach((field, idx) => {
const actualField = table.schema.fields[idx];
// Type inference always assumes nullable=true
if (infersTypes) {
expect(actualField.nullable).toBe(true);
} else {
expect(actualField.nullable).toBe(false);
}
expect(table.getChild(field.name)?.type.toString()).toEqual(
field.type.toString(),
);
expect(table.getChildAt(idx)?.type.toString()).toEqual(
field.type.toString(),
);
});
}
describe("The function makeArrowTable", function () {
it("will use data types from a provided schema instead of inference", async function () {
const schema = new Schema([
new Field("a", new Int32()),
new Field("b", new Float32()),
new Field("c", new FixedSizeList(3, new Field("item", new Float16()))),
new Field("d", new Int64()),
]); ]);
const table = makeArrowTable( const table = makeArrowTable(
[ [
{ a: 1, b: 2, c: [1, 2, 3], d: 9 }, { a: 1, b: 2, c: [1, 2, 3] },
{ a: 4, b: 5, c: [4, 5, 6], d: 10 }, { a: 4, b: 5, c: [4, 5, 6] },
{ a: 7, b: 8, c: [7, 8, 9], d: null }, { a: 7, b: 8, c: [7, 8, 9] },
], ],
{ schema }, { schema }
); );
const buf = await fromTableToBuffer(table); expect(table.schema.toString()).toEqual(schema.toString());
const buf = toBuffer(table);
expect(buf.byteLength).toBeGreaterThan(0); expect(buf.byteLength).toBeGreaterThan(0);
const actual = tableFromIPC(buf); const actual = tableFromIPC(buf);
expect(actual.numRows).toBe(3); expect(actual.numRows).toBe(3);
const actualSchema = actual.schema; const actualSchema = actual.schema;
expect(actualSchema).toEqual(schema); expect(actualSchema.toString()).toStrictEqual(schema.toString());
}); });
it("will assume the column `vector` is FixedSizeList<Float32> by default", async function () { test("default vector column", function () {
const schema = new Schema([ const schema = new Schema([
new Field("a", new Float(Precision.DOUBLE), true), new Field("a", new Float64(), true),
new Field("b", new Float(Precision.DOUBLE), true), new Field("b", new Float64(), true),
new Field( new Field("vector", new FixedSizeList(3, new Field("item", new Float32()))),
"vector",
new FixedSizeList(
3,
new Field("item", new Float(Precision.SINGLE), true),
),
true,
),
]); ]);
const table = makeArrowTable([ const table = makeArrowTable([
{ a: 1, b: 2, vector: [1, 2, 3] }, { a: 1, b: 2, vector: [1, 2, 3] },
@@ -165,29 +67,21 @@ describe("The function makeArrowTable", function () {
{ a: 7, b: 8, vector: [7, 8, 9] }, { a: 7, b: 8, vector: [7, 8, 9] },
]); ]);
const buf = await fromTableToBuffer(table); const buf = toBuffer(table);
expect(buf.byteLength).toBeGreaterThan(0); expect(buf.byteLength).toBeGreaterThan(0);
const actual = tableFromIPC(buf); const actual = tableFromIPC(buf);
expect(actual.numRows).toBe(3); expect(actual.numRows).toBe(3);
const actualSchema = actual.schema; const actualSchema = actual.schema;
expect(actualSchema).toEqual(schema); expect(actualSchema.toString()).toEqual(actualSchema.toString());
}); });
it("can support multiple vector columns", async function () { test("2 vector columns", function () {
const schema = new Schema([ const schema = new Schema([
new Field("a", new Float(Precision.DOUBLE), true), new Field("a", new Float64()),
new Field("b", new Float(Precision.DOUBLE), true), new Field("b", new Float64()),
new Field( new Field("vec1", new FixedSizeList(3, new Field("item", new Float16()))),
"vec1", new Field("vec2", new FixedSizeList(3, new Field("item", new Float16()))),
new FixedSizeList(3, new Field("item", new Float16(), true)),
true,
),
new Field(
"vec2",
new FixedSizeList(3, new Field("item", new Float16(), true)),
true,
),
]); ]);
const table = makeArrowTable( const table = makeArrowTable(
[ [
@@ -200,271 +94,27 @@ describe("The function makeArrowTable", function () {
vec1: { type: new Float16() }, vec1: { type: new Float16() },
vec2: { type: new Float16() }, vec2: { type: new Float16() },
}, },
}, }
); );
const buf = await fromTableToBuffer(table); const buf = toBuffer(table);
expect(buf.byteLength).toBeGreaterThan(0); expect(buf.byteLength).toBeGreaterThan(0);
const actual = tableFromIPC(buf); const actual = tableFromIPC(buf);
expect(actual.numRows).toBe(3); expect(actual.numRows).toBe(3);
const actualSchema = actual.schema; const actualSchema = actual.schema;
expect(actualSchema).toEqual(schema); expect(actualSchema.toString()).toEqual(schema.toString());
}); });
it("will allow different vector column types", async function () {
const table = makeArrowTable([{ fp16: [1], fp32: [1], fp64: [1] }], {
vectorColumns: {
fp16: { type: new Float16() },
fp32: { type: new Float32() },
fp64: { type: new Float64() },
},
});
expect(table.getChild("fp16")?.type.children[0].type.toString()).toEqual(
new Float16().toString(),
);
expect(table.getChild("fp32")?.type.children[0].type.toString()).toEqual(
new Float32().toString(),
);
expect(table.getChild("fp64")?.type.children[0].type.toString()).toEqual(
new Float64().toString(),
);
});
it("will use dictionary encoded strings if asked", async function () {
const table = makeArrowTable([{ str: "hello" }]);
expect(DataType.isUtf8(table.getChild("str")?.type)).toBe(true);
const tableWithDict = makeArrowTable([{ str: "hello" }], {
dictionaryEncodeStrings: true,
});
expect(DataType.isDictionary(tableWithDict.getChild("str")?.type)).toBe(
true,
);
test("handles int64", function() {
// https://github.com/lancedb/lancedb/issues/960
const schema = new Schema([ const schema = new Schema([
new Field("str", new Dictionary(new Utf8(), new Int32())), new Field("x", new Int64(), true)
]); ]);
const table = makeArrowTable([
const tableWithDict2 = makeArrowTable([{ str: "hello" }], { schema }); { x: 1 },
expect(DataType.isDictionary(tableWithDict2.getChild("str")?.type)).toBe( { x: 2 },
true, { x: 3 }
); ], { schema });
}); expect(table.schema).toEqual(schema);
})
it("will infer data types correctly", async function () {
await checkTableCreation(async (records) => makeArrowTable(records), true);
});
it("will allow a schema to be provided", async function () {
await checkTableCreation(
async (records, _, schema) => makeArrowTable(records, { schema }),
false,
);
});
it("will use the field order of any provided schema", async function () {
await checkTableCreation(
async (_, recordsReversed, schema) =>
makeArrowTable(recordsReversed, { schema }),
false,
);
});
it("will make an empty table", async function () {
await checkTableCreation(
async (_, __, schema) => makeArrowTable([], { schema }),
false,
);
});
});
class DummyEmbedding implements EmbeddingFunction<string> {
public readonly sourceColumn = "string";
public readonly embeddingDimension = 2;
public readonly embeddingDataType = new Float16();
async embed(data: string[]): Promise<number[][]> {
return data.map(() => [0.0, 0.0]);
}
}
class DummyEmbeddingWithNoDimension implements EmbeddingFunction<string> {
public readonly sourceColumn = "string";
async embed(data: string[]): Promise<number[][]> {
return data.map(() => [0.0, 0.0]);
}
}
describe("convertToTable", function () {
it("will infer data types correctly", async function () {
await checkTableCreation(
async (records) => await convertToTable(records),
true,
);
});
it("will allow a schema to be provided", async function () {
await checkTableCreation(
async (records, _, schema) =>
await convertToTable(records, undefined, { schema }),
false,
);
});
it("will use the field order of any provided schema", async function () {
await checkTableCreation(
async (_, recordsReversed, schema) =>
await convertToTable(recordsReversed, undefined, { schema }),
false,
);
});
it("will make an empty table", async function () {
await checkTableCreation(
async (_, __, schema) => await convertToTable([], undefined, { schema }),
false,
);
});
it("will apply embeddings", async function () {
const records = sampleRecords();
const table = await convertToTable(records, new DummyEmbedding());
expect(DataType.isFixedSizeList(table.getChild("vector")?.type)).toBe(true);
expect(table.getChild("vector")?.type.children[0].type.toString()).toEqual(
new Float16().toString(),
);
});
it("will fail if missing the embedding source column", async function () {
await expect(
convertToTable([{ id: 1 }], new DummyEmbedding()),
).rejects.toThrow("'string' was not present");
});
it("use embeddingDimension if embedding missing from table", async function () {
const schema = new Schema([new Field("string", new Utf8(), false)]);
// Simulate getting an empty Arrow table (minus embedding) from some other source
// In other words, we aren't starting with records
const table = makeEmptyTable(schema);
// If the embedding specifies the dimension we are fine
await fromTableToBuffer(table, new DummyEmbedding());
// We can also supply a schema and should be ok
const schemaWithEmbedding = new Schema([
new Field("string", new Utf8(), false),
new Field(
"vector",
new FixedSizeList(2, new Field("item", new Float16(), false)),
false,
),
]);
await fromTableToBuffer(
table,
new DummyEmbeddingWithNoDimension(),
schemaWithEmbedding,
);
// Otherwise we will get an error
await expect(
fromTableToBuffer(table, new DummyEmbeddingWithNoDimension()),
).rejects.toThrow("does not specify `embeddingDimension`");
});
it("will apply embeddings to an empty table", async function () {
const schema = new Schema([
new Field("string", new Utf8(), false),
new Field(
"vector",
new FixedSizeList(2, new Field("item", new Float16(), false)),
false,
),
]);
const table = await convertToTable([], new DummyEmbedding(), { schema });
expect(DataType.isFixedSizeList(table.getChild("vector")?.type)).toBe(true);
expect(table.getChild("vector")?.type.children[0].type.toString()).toEqual(
new Float16().toString(),
);
});
it("will complain if embeddings present but schema missing embedding column", async function () {
const schema = new Schema([new Field("string", new Utf8(), false)]);
await expect(
convertToTable([], new DummyEmbedding(), { schema }),
).rejects.toThrow("column vector was missing");
});
it("will provide a nice error if run twice", async function () {
const records = sampleRecords();
const table = await convertToTable(records, new DummyEmbedding());
// fromTableToBuffer will try and apply the embeddings again
await expect(
fromTableToBuffer(table, new DummyEmbedding()),
).rejects.toThrow("already existed");
});
});
describe("makeEmptyTable", function () {
it("will make an empty table", async function () {
await checkTableCreation(
async (_, __, schema) => makeEmptyTable(schema),
false,
);
});
});
describe("when using two versions of arrow", function () {
it("can still import data", async function () {
const schema = new OldSchema([
new OldField("id", new OldInt32()),
new OldField(
"vector",
new OldFixedSizeList(
1024,
new OldField("item", new OldFloat32(), true),
),
),
new OldField(
"struct",
new OldStruct([
new OldField(
"nested",
new OldDictionary(new OldUtf8(), new OldInt32(), 1, true),
),
new OldField("ts_with_tz", new OldTimestampNanosecond("some_tz")),
new OldField("ts_no_tz", new OldTimestampNanosecond(null)),
]),
),
// eslint-disable-next-line @typescript-eslint/no-explicit-any
]) as any;
schema.metadataVersion = MetadataVersion.V5;
const table = makeArrowTable([], { schema });
const buf = await fromTableToBuffer(table);
expect(buf.byteLength).toBeGreaterThan(0);
const actual = tableFromIPC(buf);
const actualSchema = actual.schema;
expect(actualSchema.fields.length).toBe(3);
// Deep equality gets hung up on some very minor unimportant differences
// between arrow version 13 and 15 which isn't really what we're testing for
// and so we do our own comparison that just checks name/type/nullability
function compareFields(lhs: Field, rhs: Field) {
expect(lhs.name).toEqual(rhs.name);
expect(lhs.nullable).toEqual(rhs.nullable);
expect(lhs.typeId).toEqual(rhs.typeId);
if ("children" in lhs.type && lhs.type.children !== null) {
const lhsChildren = lhs.type.children as Field[];
lhsChildren.forEach((child: Field, idx) => {
compareFields(child, rhs.type.children[idx]);
});
}
}
actualSchema.fields.forEach((field, idx) => {
compareFields(field, actualSchema.fields[idx]);
});
});
});

View File

@@ -12,77 +12,23 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
import * as tmp from "tmp"; import * as os from "os";
import * as path from "path";
import * as fs from "fs";
import { Connection, connect } from "../dist/index.js"; import { connect } from "../dist/index.js";
describe("when connecting", () => { describe("when working with a connection", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => (tmpDir = tmp.dirSync({ unsafeCleanup: true })));
afterEach(() => tmpDir.removeCallback());
it("should connect", async () => { const tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "test-connection"));
const db = await connect(tmpDir.name);
expect(db.display()).toBe(
`NativeDatabase(uri=${tmpDir.name}, read_consistency_interval=None)`,
);
});
it("should allow read consistency interval to be specified", async () => { it("should fail if creating table twice, unless overwrite is true", async() => {
const db = await connect(tmpDir.name, { readConsistencyInterval: 5 }); const db = await connect(tmpDir);
expect(db.display()).toBe(
`NativeDatabase(uri=${tmpDir.name}, read_consistency_interval=5s)`,
);
});
});
describe("given a connection", () => {
let tmpDir: tmp.DirResult;
let db: Connection;
beforeEach(async () => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
db = await connect(tmpDir.name);
});
afterEach(() => tmpDir.removeCallback());
it("should raise an error if opening a non-existent table", async () => {
await expect(db.openTable("non-existent")).rejects.toThrow("was not found");
});
it("should raise an error if any operation is tried after it is closed", async () => {
expect(db.isOpen()).toBe(true);
await db.close();
expect(db.isOpen()).toBe(false);
await expect(db.tableNames()).rejects.toThrow("Connection is closed");
});
it("should fail if creating table twice, unless overwrite is true", async () => {
let tbl = await db.createTable("test", [{ id: 1 }, { id: 2 }]); let tbl = await db.createTable("test", [{ id: 1 }, { id: 2 }]);
await expect(tbl.countRows()).resolves.toBe(2); await expect(tbl.countRows()).resolves.toBe(2);
await expect( await expect(db.createTable("test", [{ id: 1 }, { id: 2 }])).rejects.toThrow();
db.createTable("test", [{ id: 1 }, { id: 2 }]),
).rejects.toThrow();
tbl = await db.createTable("test", [{ id: 3 }], { mode: "overwrite" }); tbl = await db.createTable("test", [{ id: 3 }], { mode: "overwrite" });
await expect(tbl.countRows()).resolves.toBe(1); await expect(tbl.countRows()).resolves.toBe(1);
}); })
it("should respect limit and page token when listing tables", async () => {
const db = await connect(tmpDir.name);
await db.createTable("b", [{ id: 1 }]);
await db.createTable("a", [{ id: 1 }]);
await db.createTable("c", [{ id: 1 }]);
let tables = await db.tableNames();
expect(tables).toEqual(["a", "b", "c"]);
tables = await db.tableNames({ limit: 1 });
expect(tables).toEqual(["a"]);
tables = await db.tableNames({ limit: 1, startAfter: "a" });
expect(tables).toEqual(["b"]);
tables = await db.tableNames({ startAfter: "a" });
expect(tables).toEqual(["b", "c"]);
});
}); });

View File

@@ -0,0 +1,34 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import * as os from "os";
import * as path from "path";
import * as fs from "fs";
import { Schema, Field, Float64 } from "apache-arrow";
import { connect } from "../dist/index.js";
test("open database", async () => {
const tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "test-open"));
const db = await connect(tmpDir);
let tableNames = await db.tableNames();
expect(tableNames).toStrictEqual([]);
const tbl = await db.createTable("test", [{ id: 1 }, { id: 2 }]);
expect(await db.tableNames()).toStrictEqual(["test"]);
const schema = await tbl.schema();
expect(schema).toEqual(new Schema([new Field("id", new Float64(), true)]));
});

View File

@@ -12,91 +12,27 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
import * as fs from "fs"; import * as os from "os";
import * as path from "path"; import * as path from "path";
import * as tmp from "tmp"; import * as fs from "fs";
import { Table, connect } from "../dist"; import { connect } from "../dist";
import { import { Schema, Field, Float32, Int32, FixedSizeList, Int64, Float64 } from "apache-arrow";
Schema,
Field,
Float32,
Int32,
FixedSizeList,
Int64,
Float64,
} from "apache-arrow";
import { makeArrowTable } from "../dist/arrow"; import { makeArrowTable } from "../dist/arrow";
import { Index } from "../dist/indices";
describe("Given a table", () => { describe("Test creating index", () => {
let tmpDir: tmp.DirResult; let tmpDir: string;
let table: Table;
const schema = new Schema([new Field("id", new Float64(), true)]);
beforeEach(async () => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
const conn = await connect(tmpDir.name);
table = await conn.createEmptyTable("some_table", schema);
});
afterEach(() => tmpDir.removeCallback());
it("be displayable", async () => {
expect(table.display()).toMatch(
/NativeTable\(some_table, uri=.*, read_consistency_interval=None\)/,
);
table.close();
expect(table.display()).toBe("ClosedTable(some_table)");
});
it("should let me add data", async () => {
await table.add([{ id: 1 }, { id: 2 }]);
await table.add([{ id: 1 }]);
await expect(table.countRows()).resolves.toBe(3);
});
it("should overwrite data if asked", async () => {
await table.add([{ id: 1 }, { id: 2 }]);
await table.add([{ id: 1 }], { mode: "overwrite" });
await expect(table.countRows()).resolves.toBe(1);
});
it("should let me close the table", async () => {
expect(table.isOpen()).toBe(true);
table.close();
expect(table.isOpen()).toBe(false);
expect(table.countRows()).rejects.toThrow("Table some_table is closed");
});
it("should let me update values", async () => {
await table.add([{ id: 1 }]);
expect(await table.countRows("id == 1")).toBe(1);
expect(await table.countRows("id == 7")).toBe(0);
await table.update({ id: "7" });
expect(await table.countRows("id == 1")).toBe(0);
expect(await table.countRows("id == 7")).toBe(1);
await table.add([{ id: 2 }]);
// Test Map as input
await table.update(new Map(Object.entries({ id: "10" })), {
where: "id % 2 == 0",
});
expect(await table.countRows("id == 2")).toBe(0);
expect(await table.countRows("id == 7")).toBe(1);
expect(await table.countRows("id == 10")).toBe(1);
});
});
describe("When creating an index", () => {
let tmpDir: tmp.DirResult;
const schema = new Schema([ const schema = new Schema([
new Field("id", new Int32(), true), new Field("id", new Int32(), true),
new Field("vec", new FixedSizeList(32, new Field("item", new Float32()))), new Field("vec", new FixedSizeList(32, new Field("item", new Float32()))),
]); ]);
let tbl: Table;
let queryVec: number[];
beforeEach(async () => { beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true }); tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "index-"));
const db = await connect(tmpDir.name); });
test("create vector index with no column", async () => {
const db = await connect(tmpDir);
const data = makeArrowTable( const data = makeArrowTable(
Array(300) Array(300)
.fill(1) .fill(1)
@@ -108,90 +44,57 @@ describe("When creating an index", () => {
})), })),
{ {
schema, schema,
}, }
); );
queryVec = data.toArray()[5].vec.toJSON(); const tbl = await db.createTable("test", data);
tbl = await db.createTable("test", data); await tbl.createIndex().build();
});
afterEach(() => tmpDir.removeCallback());
it("should create a vector index on vector columns", async () => {
await tbl.createIndex("vec");
// check index directory // check index directory
const indexDir = path.join(tmpDir.name, "test.lance", "_indices"); const indexDir = path.join(tmpDir, "test.lance", "_indices");
expect(fs.readdirSync(indexDir)).toHaveLength(1); expect(fs.readdirSync(indexDir)).toHaveLength(1);
const indices = await tbl.listIndices(); // TODO: check index type.
expect(indices.length).toBe(1);
expect(indices[0]).toEqual({
indexType: "IvfPq",
columns: ["vec"],
});
// Search without specifying the column // Search without specifying the column
let rst = await tbl let query_vector = data.toArray()[5].vec.toJSON();
.query() let rst = await tbl.query().nearestTo(query_vector).limit(2).toArrow();
.limit(2)
.nearestTo(queryVec)
.distanceType("DoT")
.toArrow();
expect(rst.numRows).toBe(2);
// Search using `vectorSearch`
rst = await tbl.vectorSearch(queryVec).limit(2).toArrow();
expect(rst.numRows).toBe(2); expect(rst.numRows).toBe(2);
// Search with specifying the column // Search with specifying the column
const rst2 = await tbl let rst2 = await tbl.search(query_vector, "vec").limit(2).toArrow();
.query()
.limit(2)
.nearestTo(queryVec)
.column("vec")
.toArrow();
expect(rst2.numRows).toBe(2); expect(rst2.numRows).toBe(2);
expect(rst.toString()).toEqual(rst2.toString()); expect(rst.toString()).toEqual(rst2.toString());
}); });
it("should allow parameters to be specified", async () => { test("no vector column available", async () => {
await tbl.createIndex("vec", { const db = await connect(tmpDir);
config: Index.ivfPq({ const tbl = await db.createTable(
numPartitions: 10, "no_vec",
}), makeArrowTable([
}); { id: 1, val: 2 },
{ id: 2, val: 3 },
// TODO: Verify parameters when we can load index config as part of list indices ])
}); );
await expect(tbl.createIndex().build()).rejects.toThrow(
it("should allow me to replace (or not) an existing index", async () => { "No vector column found"
await tbl.createIndex("id");
// Default is replace=true
await tbl.createIndex("id");
await expect(tbl.createIndex("id", { replace: false })).rejects.toThrow(
"already exists",
); );
await tbl.createIndex("id", { replace: true });
});
test("should create a scalar index on scalar columns", async () => { await tbl.createIndex("val").build();
await tbl.createIndex("id"); const indexDir = path.join(tmpDir, "no_vec.lance", "_indices");
const indexDir = path.join(tmpDir.name, "test.lance", "_indices");
expect(fs.readdirSync(indexDir)).toHaveLength(1); expect(fs.readdirSync(indexDir)).toHaveLength(1);
for await (const r of tbl.query().where("id > 1").select(["id"])) { for await (const r of tbl.query().filter("id > 1").select(["id"])) {
expect(r.numRows).toBe(298); expect(r.numRows).toBe(1);
} }
}); });
// TODO: Move this test to the query API test (making sure we can reject queries
// when the dimension is incorrect)
test("two columns with different dimensions", async () => { test("two columns with different dimensions", async () => {
const db = await connect(tmpDir.name); const db = await connect(tmpDir);
const schema = new Schema([ const schema = new Schema([
new Field("id", new Int32(), true), new Field("id", new Int32(), true),
new Field("vec", new FixedSizeList(32, new Field("item", new Float32()))), new Field("vec", new FixedSizeList(32, new Field("item", new Float32()))),
new Field( new Field(
"vec2", "vec2",
new FixedSizeList(64, new Field("item", new Float32())), new FixedSizeList(64, new Field("item", new Float32()))
), ),
]); ]);
const tbl = await db.createTable( const tbl = await db.createTable(
@@ -208,71 +111,90 @@ describe("When creating an index", () => {
.fill(1) .fill(1)
.map(() => Math.random()), .map(() => Math.random()),
})), })),
{ schema }, { schema }
), )
); );
// Only build index over v1 // Only build index over v1
await tbl.createIndex("vec", { await expect(tbl.createIndex().build()).rejects.toThrow(
config: Index.ivfPq({ numPartitions: 2, numSubVectors: 2 }), /.*More than one vector columns found.*/
}); );
tbl
.createIndex("vec")
.ivf_pq({ num_partitions: 2, num_sub_vectors: 2 })
.build();
const rst = await tbl const rst = await tbl
.query() .query()
.limit(2)
.nearestTo( .nearestTo(
Array(32) Array(32)
.fill(1) .fill(1)
.map(() => Math.random()), .map(() => Math.random())
) )
.limit(2)
.toArrow(); .toArrow();
expect(rst.numRows).toBe(2); expect(rst.numRows).toBe(2);
// Search with specifying the column // Search with specifying the column
await expect( await expect(
tbl tbl
.query() .search(
.limit(2)
.nearestTo(
Array(64) Array(64)
.fill(1) .fill(1)
.map(() => Math.random()), .map(() => Math.random()),
"vec"
) )
.column("vec") .limit(2)
.toArrow(), .toArrow()
).rejects.toThrow(/.* query dim=64, expected vector dim=32.*/); ).rejects.toThrow(/.*does not match the dimension.*/);
const query64 = Array(64) const query64 = Array(64)
.fill(1) .fill(1)
.map(() => Math.random()); .map(() => Math.random());
const rst64Query = await tbl.query().limit(2).nearestTo(query64).toArrow(); const rst64_1 = await tbl.query().nearestTo(query64).limit(2).toArrow();
const rst64Search = await tbl const rst64_2 = await tbl.search(query64, "vec2").limit(2).toArrow();
.query() expect(rst64_1.toString()).toEqual(rst64_2.toString());
.limit(2) expect(rst64_1.numRows).toBe(2);
.nearestTo(query64) });
.column("vec2")
.toArrow(); test("create scalar index", async () => {
expect(rst64Query.toString()).toEqual(rst64Search.toString()); const db = await connect(tmpDir);
expect(rst64Query.numRows).toBe(2); const data = makeArrowTable(
Array(300)
.fill(1)
.map((_, i) => ({
id: i,
vec: Array(32)
.fill(1)
.map(() => Math.random()),
})),
{
schema,
}
);
const tbl = await db.createTable("test", data);
await tbl.createIndex("id").build();
// check index directory
const indexDir = path.join(tmpDir, "test.lance", "_indices");
expect(fs.readdirSync(indexDir)).toHaveLength(1);
// TODO: check index type.
}); });
}); });
describe("Read consistency interval", () => { describe("Read consistency interval", () => {
let tmpDir: tmp.DirResult; let tmpDir: string;
beforeEach(() => { beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true }); tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "read-consistency-"));
}); });
afterEach(() => tmpDir.removeCallback());
// const intervals = [undefined, 0, 0.1]; // const intervals = [undefined, 0, 0.1];
const intervals = [0]; const intervals = [0];
test.each(intervals)("read consistency interval %p", async (interval) => { test.each(intervals)("read consistency interval %p", async (interval) => {
const db = await connect(tmpDir.name); const db = await connect({ uri: tmpDir });
const table = await db.createTable("my_table", [{ id: 1 }]); const table = await db.createTable("my_table", [{ id: 1 }]);
const db2 = await connect(tmpDir.name, { const db2 = await connect({ uri: tmpDir, readConsistencyInterval: interval });
readConsistencyInterval: interval,
});
const table2 = await db2.openTable("my_table"); const table2 = await db2.openTable("my_table");
expect(await table2.countRows()).toEqual(await table.countRows()); expect(await table2.countRows()).toEqual(await table.countRows());
@@ -288,134 +210,73 @@ describe("Read consistency interval", () => {
} else { } else {
// interval == 0.1 // interval == 0.1
expect(await table2.countRows()).toEqual(1); expect(await table2.countRows()).toEqual(1);
await new Promise((r) => setTimeout(r, 100)); await new Promise(r => setTimeout(r, 100));
expect(await table2.countRows()).toEqual(2); expect(await table2.countRows()).toEqual(2);
} }
}); });
}); });
describe("schema evolution", function () {
let tmpDir: tmp.DirResult; describe('schema evolution', function () {
let tmpDir: string;
beforeEach(() => { beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true }); tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "schema-evolution-"));
});
afterEach(() => {
tmpDir.removeCallback();
}); });
// Create a new sample table // Create a new sample table
it("can add a new column to the schema", async function () { it('can add a new column to the schema', async function () {
const con = await connect(tmpDir.name); const con = await connect(tmpDir)
const table = await con.createTable("vectors", [ const table = await con.createTable('vectors', [
{ id: 1n, vector: [0.1, 0.2] }, { id: 1n, vector: [0.1, 0.2] }
]); ])
await table.addColumns([ await table.addColumns([{ name: 'price', valueSql: 'cast(10.0 as float)' }])
{ name: "price", valueSql: "cast(10.0 as float)" },
]);
const expectedSchema = new Schema([ const expectedSchema = new Schema([
new Field("id", new Int64(), true), new Field('id', new Int64(), true),
new Field( new Field('vector', new FixedSizeList(2, new Field('item', new Float32(), true)), true),
"vector", new Field('price', new Float32(), false)
new FixedSizeList(2, new Field("item", new Float32(), true)), ])
true, expect(await table.schema()).toEqual(expectedSchema)
),
new Field("price", new Float32(), false),
]);
expect(await table.schema()).toEqual(expectedSchema);
}); });
it("can alter the columns in the schema", async function () { it('can alter the columns in the schema', async function () {
const con = await connect(tmpDir.name); const con = await connect(tmpDir)
const schema = new Schema([ const schema = new Schema([
new Field("id", new Int64(), true), new Field('id', new Int64(), true),
new Field( new Field('vector', new FixedSizeList(2, new Field('item', new Float32(), true)), true),
"vector", new Field('price', new Float64(), false)
new FixedSizeList(2, new Field("item", new Float32(), true)), ])
true, const table = await con.createTable('vectors', [
), { id: 1n, vector: [0.1, 0.2] }
new Field("price", new Float64(), false), ])
]);
const table = await con.createTable("vectors", [
{ id: 1n, vector: [0.1, 0.2] },
]);
// Can create a non-nullable column only through addColumns at the moment. // Can create a non-nullable column only through addColumns at the moment.
await table.addColumns([ await table.addColumns([{ name: 'price', valueSql: 'cast(10.0 as double)' }])
{ name: "price", valueSql: "cast(10.0 as double)" }, expect(await table.schema()).toEqual(schema)
]);
expect(await table.schema()).toEqual(schema);
await table.alterColumns([ await table.alterColumns([
{ path: "id", rename: "new_id" }, { path: 'id', rename: 'new_id' },
{ path: "price", nullable: true }, { path: 'price', nullable: true }
]); ])
const expectedSchema = new Schema([ const expectedSchema = new Schema([
new Field("new_id", new Int64(), true), new Field('new_id', new Int64(), true),
new Field( new Field('vector', new FixedSizeList(2, new Field('item', new Float32(), true)), true),
"vector", new Field('price', new Float64(), true)
new FixedSizeList(2, new Field("item", new Float32(), true)), ])
true, expect(await table.schema()).toEqual(expectedSchema)
),
new Field("price", new Float64(), true),
]);
expect(await table.schema()).toEqual(expectedSchema);
}); });
it("can drop a column from the schema", async function () { it('can drop a column from the schema', async function () {
const con = await connect(tmpDir.name); const con = await connect(tmpDir)
const table = await con.createTable("vectors", [ const table = await con.createTable('vectors', [
{ id: 1n, vector: [0.1, 0.2] }, { id: 1n, vector: [0.1, 0.2] }
]); ])
await table.dropColumns(["vector"]); await table.dropColumns(['vector'])
const expectedSchema = new Schema([new Field("id", new Int64(), true)]); const expectedSchema = new Schema([
expect(await table.schema()).toEqual(expectedSchema); new Field('id', new Int64(), true)
}); ])
}); expect(await table.schema()).toEqual(expectedSchema)
describe("when dealing with versioning", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
});
afterEach(() => {
tmpDir.removeCallback();
});
it("can travel in time", async () => {
// Setup
const con = await connect(tmpDir.name);
const table = await con.createTable("vectors", [
{ id: 1n, vector: [0.1, 0.2] },
]);
const version = await table.version();
await table.add([{ id: 2n, vector: [0.1, 0.2] }]);
expect(await table.countRows()).toBe(2);
// Make sure we can rewind
await table.checkout(version);
expect(await table.countRows()).toBe(1);
// Can't add data in time travel mode
await expect(table.add([{ id: 3n, vector: [0.1, 0.2] }])).rejects.toThrow(
"table cannot be modified when a specific version is checked out",
);
// Can go back to normal mode
await table.checkoutLatest();
expect(await table.countRows()).toBe(2);
// Should be able to add data again
await table.add([{ id: 2n, vector: [0.1, 0.2] }]);
expect(await table.countRows()).toBe(3);
// Now checkout and restore
await table.checkout(version);
await table.restore();
expect(await table.countRows()).toBe(1);
// Should be able to add data
await table.add([{ id: 2n, vector: [0.1, 0.2] }]);
expect(await table.countRows()).toBe(2);
// Can't use restore if not checked out
await expect(table.restore()).rejects.toThrow(
"checkout before running restore",
);
}); });
}); });

View File

@@ -4,7 +4,12 @@
"outDir": "./dist/spec", "outDir": "./dist/spec",
"module": "commonjs", "module": "commonjs",
"target": "es2022", "target": "es2022",
"types": ["jest", "node"] "types": [
"jest",
"node"
]
}, },
"include": ["**/*"] "include": [
"**/*",
]
} }

View File

@@ -1,28 +0,0 @@
/* eslint-disable @typescript-eslint/naming-convention */
// @ts-check
const eslint = require("@eslint/js");
const tseslint = require("typescript-eslint");
const eslintConfigPrettier = require("eslint-config-prettier");
const jsdoc = require("eslint-plugin-jsdoc");
module.exports = tseslint.config(
eslint.configs.recommended,
jsdoc.configs["flat/recommended"],
eslintConfigPrettier,
...tseslint.configs.recommended,
{
rules: {
"@typescript-eslint/naming-convention": "error",
"jsdoc/require-returns": "off",
"jsdoc/require-param": "off",
"jsdoc/require-jsdoc": [
"error",
{
publicOnly: true,
},
],
},
plugins: jsdoc,
},
);

View File

@@ -1,7 +1,7 @@
/** @type {import('ts-jest').JestConfigWithTsJest} */ /** @type {import('ts-jest').JestConfigWithTsJest} */
module.exports = { module.exports = {
preset: "ts-jest", preset: 'ts-jest',
testEnvironment: "node", testEnvironment: 'node',
moduleDirectories: ["node_modules", "./dist"], moduleDirectories: ["node_modules", "./dist"],
moduleFileExtensions: ["js", "ts"], moduleFileExtensions: ["js", "ts"],
}; };

View File

@@ -1,4 +1,4 @@
// Copyright 2023 Lance Developers. // Copyright 2024 Lance Developers.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@@ -13,35 +13,23 @@
// limitations under the License. // limitations under the License.
import { import {
Int64,
Field, Field,
makeBuilder,
RecordBatchFileWriter,
Utf8,
type Vector,
FixedSizeList, FixedSizeList,
vectorFromArray, Float,
type Schema,
Table as ArrowTable,
RecordBatchStreamWriter,
List,
RecordBatch,
makeData,
Struct,
type Float,
DataType,
Binary,
Float32, Float32,
type makeTable, Schema,
Table as ArrowTable,
Table,
Vector,
vectorFromArray,
tableToIPC,
DataType,
} from "apache-arrow"; } from "apache-arrow";
import { type EmbeddingFunction } from "./embedding/embedding_function";
import { sanitizeSchema } from "./sanitize";
/** Data type accepted by NodeJS SDK */ /** Data type accepted by NodeJS SDK */
export type Data = Record<string, unknown>[] | ArrowTable; export type Data = Record<string, unknown>[] | ArrowTable;
/*
* Options to control how a column should be converted to a vector array
*/
export class VectorColumnOptions { export class VectorColumnOptions {
/** Vector column type. */ /** Vector column type. */
type: Float = new Float32(); type: Float = new Float32();
@@ -53,50 +41,14 @@ export class VectorColumnOptions {
/** Options to control the makeArrowTable call. */ /** Options to control the makeArrowTable call. */
export class MakeArrowTableOptions { export class MakeArrowTableOptions {
/* /** Provided schema. */
* Schema of the data.
*
* If this is not provided then the data type will be inferred from the
* JS type. Integer numbers will become int64, floating point numbers
* will become float64 and arrays will become variable sized lists with
* the data type inferred from the first element in the array.
*
* The schema must be specified if there are no records (e.g. to make
* an empty table)
*/
schema?: Schema; schema?: Schema;
/* /** Vector columns */
* Mapping from vector column name to expected type
*
* Lance expects vector columns to be fixed size list arrays (i.e. tensors)
* However, `makeArrowTable` will not infer this by default (it creates
* variable size list arrays). This field can be used to indicate that a column
* should be treated as a vector column and converted to a fixed size list.
*
* The keys should be the names of the vector columns. The value specifies the
* expected data type of the vector columns.
*
* If `schema` is provided then this field is ignored.
*
* By default, the column named "vector" will be assumed to be a float32
* vector column.
*/
vectorColumns: Record<string, VectorColumnOptions> = { vectorColumns: Record<string, VectorColumnOptions> = {
vector: new VectorColumnOptions(), vector: new VectorColumnOptions(),
}; };
/**
* If true then string columns will be encoded with dictionary encoding
*
* Set this to true if your string columns tend to repeat the same values
* often. For more precise control use the `schema` property to specify the
* data type for individual columns.
*
* If `schema` is provided then this property is ignored.
*/
dictionaryEncodeStrings: boolean = false;
constructor(values?: Partial<MakeArrowTableOptions>) { constructor(values?: Partial<MakeArrowTableOptions>) {
Object.assign(this, values); Object.assign(this, values);
} }
@@ -106,30 +58,15 @@ export class MakeArrowTableOptions {
* An enhanced version of the {@link makeTable} function from Apache Arrow * An enhanced version of the {@link makeTable} function from Apache Arrow
* that supports nested fields and embeddings columns. * that supports nested fields and embeddings columns.
* *
* This function converts an array of Record<String, any> (row-major JS objects)
* to an Arrow Table (a columnar structure)
*
* Note that it currently does not support nulls. * Note that it currently does not support nulls.
* *
* If a schema is provided then it will be used to determine the resulting array * @param data input data
* types. Fields will also be reordered to fit the order defined by the schema. * @param options options to control the makeArrowTable call.
* *
* If a schema is not provided then the types will be inferred and the field order
* will be controlled by the order of properties in the first record. If a type
* is inferred it will always be nullable.
*
* If the input is empty then a schema must be provided to create an empty table.
*
* When a schema is not specified then data types will be inferred. The inference
* rules are as follows:
*
* - boolean => Bool
* - number => Float64
* - String => Utf8
* - Buffer => Binary
* - Record<String, any> => Struct
* - Array<any> => List
* @example * @example
*
* ```ts
*
* import { fromTableToBuffer, makeArrowTable } from "../arrow"; * import { fromTableToBuffer, makeArrowTable } from "../arrow";
* import { Field, FixedSizeList, Float16, Float32, Int32, Schema } from "apache-arrow"; * import { Field, FixedSizeList, Float16, Float32, Int32, Schema } from "apache-arrow";
* *
@@ -145,10 +82,8 @@ export class MakeArrowTableOptions {
* ], { schema }); * ], { schema });
* ``` * ```
* *
* By default it assumes that the column named `vector` is a vector column * It guesses the vector columns if the schema is not provided. For example,
* and it will be converted into a fixed size list array of type float32. * by default it assumes that the column named `vector` is a vector column.
* The `vectorColumns` option can be used to support other vector column
* names and data types.
* *
* ```ts * ```ts
* *
@@ -192,456 +127,62 @@ export class MakeArrowTableOptions {
* ``` * ```
*/ */
export function makeArrowTable( export function makeArrowTable(
data: Array<Record<string, unknown>>, data: Record<string, any>[],
options?: Partial<MakeArrowTableOptions>, options?: Partial<MakeArrowTableOptions>
): ArrowTable { ): Table {
if ( if (data.length === 0) {
data.length === 0 && throw new Error("At least one record needs to be provided");
(options?.schema === undefined || options?.schema === null)
) {
throw new Error("At least one record or a schema needs to be provided");
}
const opt = new MakeArrowTableOptions(options !== undefined ? options : {});
if (opt.schema !== undefined && opt.schema !== null) {
opt.schema = sanitizeSchema(opt.schema);
} }
const opt = new MakeArrowTableOptions(options ?? {});
const columns: Record<string, Vector> = {}; const columns: Record<string, Vector> = {};
// TODO: sample dataset to find missing columns // TODO: sample dataset to find missing columns
// Prefer the field ordering of the schema, if present const columnNames = Object.keys(data[0]);
const columnNames =
opt.schema != null ? (opt.schema.names as string[]) : Object.keys(data[0]);
for (const colName of columnNames) { for (const colName of columnNames) {
if ( // eslint-disable-next-line @typescript-eslint/no-unsafe-return
data.length !== 0 &&
!Object.prototype.hasOwnProperty.call(data[0], colName)
) {
// The field is present in the schema, but not in the data, skip it
continue;
}
// Extract a single column from the records (transpose from row-major to col-major)
let values = data.map((datum) => datum[colName]); let values = data.map((datum) => datum[colName]);
let vector: Vector;
// By default (type === undefined) arrow will infer the type from the JS type
let type;
if (opt.schema !== undefined) { if (opt.schema !== undefined) {
// If there is a schema provided, then use that for the type instead // Explicit schema is provided, highest priority
type = opt.schema?.fields.filter((f) => f.name === colName)[0]?.type; const fieldType: DataType | undefined = opt.schema.fields.filter((f) => f.name === colName)[0]?.type as DataType;
if (DataType.isInt(type) && type.bitWidth === 64) { if (fieldType instanceof Int64) {
// wrap in BigInt to avoid bug: https://github.com/apache/arrow/issues/40051 // wrap in BigInt to avoid bug: https://github.com/apache/arrow/issues/40051
values = values.map((v) => { // eslint-disable-next-line @typescript-eslint/no-unsafe-argument
if (v === null) { values = values.map((v) => BigInt(v));
return v;
}
if (typeof v === "bigint") {
return v;
}
if (typeof v === "number") {
return BigInt(v);
}
throw new Error(
`Expected BigInt or number for column ${colName}, got ${typeof v}`,
);
});
} }
vector = vectorFromArray(values, fieldType);
} else { } else {
// Otherwise, check to see if this column is one of the vector columns
// defined by opt.vectorColumns and, if so, use the fixed size list type
const vectorColumnOptions = opt.vectorColumns[colName]; const vectorColumnOptions = opt.vectorColumns[colName];
if (vectorColumnOptions !== undefined) { if (vectorColumnOptions !== undefined) {
const firstNonNullValue = values.find((v) => v !== null); const fslType = new FixedSizeList(
if (Array.isArray(firstNonNullValue)) { (values[0] as any[]).length,
type = newVectorType( new Field("item", vectorColumnOptions.type, false)
firstNonNullValue.length,
vectorColumnOptions.type,
); );
vector = vectorFromArray(values, fslType);
} else { } else {
throw new Error( // Normal case
`Column ${colName} is expected to be a vector column but first non-null value is not an array. Could not determine size of vector column`, vector = vectorFromArray(values);
);
} }
} }
columns[colName] = vector;
} }
try { return new Table(columns);
// Convert an Array of JS values to an arrow vector }
columns[colName] = makeVector(values, type, opt.dictionaryEncodeStrings);
} catch (error: unknown) {
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
throw Error(`Could not convert column "${colName}" to Arrow: ${error}`);
}
}
if (opt.schema != null) { /**
// `new ArrowTable(columns)` infers a schema which may sometimes have * Convert an Arrow Table to a Buffer.
// incorrect nullability (it assumes nullable=true always) *
// * @param data Arrow Table
// `new ArrowTable(schema, columns)` will also fail because it will create a * @param schema Arrow Schema, optional
// batch with an inferred schema and then complain that the batch schema * @returns Buffer node
// does not match the provided schema. */
// export function toBuffer(data: Data, schema?: Schema): Buffer {
// To work around this we first create a table with the wrong schema and let tbl: Table;
// then patch the schema of the batches so we can use if (data instanceof Table) {
// `new ArrowTable(schema, batches)` which does not do any schema inference tbl = data;
const firstTable = new ArrowTable(columns);
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const batchesFixed = firstTable.batches.map(
(batch) => new RecordBatch(opt.schema!, batch.data),
);
return new ArrowTable(opt.schema, batchesFixed);
} else { } else {
return new ArrowTable(columns); tbl = makeArrowTable(data, { schema });
} }
} return Buffer.from(tableToIPC(tbl));
/**
* Create an empty Arrow table with the provided schema
*/
export function makeEmptyTable(schema: Schema): ArrowTable {
return makeArrowTable([], { schema });
}
/**
* Helper function to convert Array<Array<any>> to a variable sized list array
*/
// @ts-expect-error (Vector<unknown> is not assignable to Vector<any>)
function makeListVector(lists: unknown[][]): Vector<unknown> {
if (lists.length === 0 || lists[0].length === 0) {
throw Error("Cannot infer list vector from empty array or empty list");
}
const sampleList = lists[0];
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let inferredType: any;
try {
const sampleVector = makeVector(sampleList);
inferredType = sampleVector.type;
} catch (error: unknown) {
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
throw Error(`Cannot infer list vector. Cannot infer inner type: ${error}`);
}
const listBuilder = makeBuilder({
type: new List(new Field("item", inferredType, true)),
});
for (const list of lists) {
listBuilder.append(list);
}
return listBuilder.finish().toVector();
}
/** Helper function to convert an Array of JS values to an Arrow Vector */
function makeVector(
values: unknown[],
type?: DataType,
stringAsDictionary?: boolean,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
): Vector<any> {
if (type !== undefined) {
// No need for inference, let Arrow create it
return vectorFromArray(values, type);
}
if (values.length === 0) {
throw Error(
"makeVector requires at least one value or the type must be specfied",
);
}
const sampleValue = values.find((val) => val !== null && val !== undefined);
if (sampleValue === undefined) {
throw Error(
"makeVector cannot infer the type if all values are null or undefined",
);
}
if (Array.isArray(sampleValue)) {
// Default Arrow inference doesn't handle list types
return makeListVector(values as unknown[][]);
} else if (Buffer.isBuffer(sampleValue)) {
// Default Arrow inference doesn't handle Buffer
return vectorFromArray(values, new Binary());
} else if (
!(stringAsDictionary ?? false) &&
(typeof sampleValue === "string" || sampleValue instanceof String)
) {
// If the type is string then don't use Arrow's default inference unless dictionaries are requested
// because it will always use dictionary encoding for strings
return vectorFromArray(values, new Utf8());
} else {
// Convert a JS array of values to an arrow vector
return vectorFromArray(values);
}
}
/** Helper function to apply embeddings to an input table */
async function applyEmbeddings<T>(
table: ArrowTable,
embeddings?: EmbeddingFunction<T>,
schema?: Schema,
): Promise<ArrowTable> {
if (embeddings == null) {
return table;
}
if (schema !== undefined && schema !== null) {
schema = sanitizeSchema(schema);
}
// Convert from ArrowTable to Record<String, Vector>
const colEntries = [...Array(table.numCols).keys()].map((_, idx) => {
const name = table.schema.fields[idx].name;
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const vec = table.getChildAt(idx)!;
return [name, vec];
});
const newColumns = Object.fromEntries(colEntries);
const sourceColumn = newColumns[embeddings.sourceColumn];
const destColumn = embeddings.destColumn ?? "vector";
const innerDestType = embeddings.embeddingDataType ?? new Float32();
if (sourceColumn === undefined) {
throw new Error(
`Cannot apply embedding function because the source column '${embeddings.sourceColumn}' was not present in the data`,
);
}
if (table.numRows === 0) {
if (Object.prototype.hasOwnProperty.call(newColumns, destColumn)) {
// We have an empty table and it already has the embedding column so no work needs to be done
// Note: we don't return an error like we did below because this is a common occurrence. For example,
// if we call convertToTable with 0 records and a schema that includes the embedding
return table;
}
if (embeddings.embeddingDimension !== undefined) {
const destType = newVectorType(
embeddings.embeddingDimension,
innerDestType,
);
newColumns[destColumn] = makeVector([], destType);
} else if (schema != null) {
const destField = schema.fields.find((f) => f.name === destColumn);
if (destField != null) {
newColumns[destColumn] = makeVector([], destField.type);
} else {
throw new Error(
`Attempt to apply embeddings to an empty table failed because schema was missing embedding column '${destColumn}'`,
);
}
} else {
throw new Error(
"Attempt to apply embeddings to an empty table when the embeddings function does not specify `embeddingDimension`",
);
}
} else {
if (Object.prototype.hasOwnProperty.call(newColumns, destColumn)) {
throw new Error(
`Attempt to apply embeddings to table failed because column ${destColumn} already existed`,
);
}
if (table.batches.length > 1) {
throw new Error(
"Internal error: `makeArrowTable` unexpectedly created a table with more than one batch",
);
}
const values = sourceColumn.toArray();
const vectors = await embeddings.embed(values as T[]);
if (vectors.length !== values.length) {
throw new Error(
"Embedding function did not return an embedding for each input element",
);
}
const destType = newVectorType(vectors[0].length, innerDestType);
newColumns[destColumn] = makeVector(vectors, destType);
}
const newTable = new ArrowTable(newColumns);
if (schema != null) {
if (schema.fields.find((f) => f.name === destColumn) === undefined) {
throw new Error(
`When using embedding functions and specifying a schema the schema should include the embedding column but the column ${destColumn} was missing`,
);
}
return alignTable(newTable, schema);
}
return newTable;
}
/**
* Convert an Array of records into an Arrow Table, optionally applying an
* embeddings function to it.
*
* This function calls `makeArrowTable` first to create the Arrow Table.
* Any provided `makeTableOptions` (e.g. a schema) will be passed on to
* that call.
*
* The embedding function will be passed a column of values (based on the
* `sourceColumn` of the embedding function) and expects to receive back
* number[][] which will be converted into a fixed size list column. By
* default this will be a fixed size list of Float32 but that can be
* customized by the `embeddingDataType` property of the embedding function.
*
* If a schema is provided in `makeTableOptions` then it should include the
* embedding columns. If no schema is provded then embedding columns will
* be placed at the end of the table, after all of the input columns.
*/
export async function convertToTable<T>(
data: Array<Record<string, unknown>>,
embeddings?: EmbeddingFunction<T>,
makeTableOptions?: Partial<MakeArrowTableOptions>,
): Promise<ArrowTable> {
const table = makeArrowTable(data, makeTableOptions);
return await applyEmbeddings(table, embeddings, makeTableOptions?.schema);
}
/** Creates the Arrow Type for a Vector column with dimension `dim` */
function newVectorType<T extends Float>(
dim: number,
innerType: T,
): FixedSizeList<T> {
// in Lance we always default to have the elements nullable, so we need to set it to true
// otherwise we often get schema mismatches because the stored data always has schema with nullable elements
const children = new Field<T>("item", innerType, true);
return new FixedSizeList(dim, children);
}
/**
* Serialize an Array of records into a buffer using the Arrow IPC File serialization
*
* This function will call `convertToTable` and pass on `embeddings` and `schema`
*
* `schema` is required if data is empty
*/
export async function fromRecordsToBuffer<T>(
data: Array<Record<string, unknown>>,
embeddings?: EmbeddingFunction<T>,
schema?: Schema,
): Promise<Buffer> {
if (schema !== undefined && schema !== null) {
schema = sanitizeSchema(schema);
}
const table = await convertToTable(data, embeddings, { schema });
const writer = RecordBatchFileWriter.writeAll(table);
return Buffer.from(await writer.toUint8Array());
}
/**
* Serialize an Array of records into a buffer using the Arrow IPC Stream serialization
*
* This function will call `convertToTable` and pass on `embeddings` and `schema`
*
* `schema` is required if data is empty
*/
export async function fromRecordsToStreamBuffer<T>(
data: Array<Record<string, unknown>>,
embeddings?: EmbeddingFunction<T>,
schema?: Schema,
): Promise<Buffer> {
if (schema !== undefined && schema !== null) {
schema = sanitizeSchema(schema);
}
const table = await convertToTable(data, embeddings, { schema });
const writer = RecordBatchStreamWriter.writeAll(table);
return Buffer.from(await writer.toUint8Array());
}
/**
* Serialize an Arrow Table into a buffer using the Arrow IPC File serialization
*
* This function will apply `embeddings` to the table in a manner similar to
* `convertToTable`.
*
* `schema` is required if the table is empty
*/
export async function fromTableToBuffer<T>(
table: ArrowTable,
embeddings?: EmbeddingFunction<T>,
schema?: Schema,
): Promise<Buffer> {
if (schema !== undefined && schema !== null) {
schema = sanitizeSchema(schema);
}
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema);
const writer = RecordBatchFileWriter.writeAll(tableWithEmbeddings);
return Buffer.from(await writer.toUint8Array());
}
/**
* Serialize an Arrow Table into a buffer using the Arrow IPC File serialization
*
* This function will apply `embeddings` to the table in a manner similar to
* `convertToTable`.
*
* `schema` is required if the table is empty
*/
export async function fromDataToBuffer<T>(
data: Data,
embeddings?: EmbeddingFunction<T>,
schema?: Schema,
): Promise<Buffer> {
if (schema !== undefined && schema !== null) {
schema = sanitizeSchema(schema);
}
if (data instanceof ArrowTable) {
return fromTableToBuffer(data, embeddings, schema);
} else {
const table = await convertToTable(data);
return fromTableToBuffer(table, embeddings, schema);
}
}
/**
* Serialize an Arrow Table into a buffer using the Arrow IPC Stream serialization
*
* This function will apply `embeddings` to the table in a manner similar to
* `convertToTable`.
*
* `schema` is required if the table is empty
*/
export async function fromTableToStreamBuffer<T>(
table: ArrowTable,
embeddings?: EmbeddingFunction<T>,
schema?: Schema,
): Promise<Buffer> {
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema);
const writer = RecordBatchStreamWriter.writeAll(tableWithEmbeddings);
return Buffer.from(await writer.toUint8Array());
}
/**
* Reorder the columns in `batch` so that they agree with the field order in `schema`
*/
function alignBatch(batch: RecordBatch, schema: Schema): RecordBatch {
const alignedChildren = [];
for (const field of schema.fields) {
const indexInBatch = batch.schema.fields?.findIndex(
(f) => f.name === field.name,
);
if (indexInBatch < 0) {
throw new Error(
`The column ${field.name} was not found in the Arrow Table`,
);
}
alignedChildren.push(batch.data.children[indexInBatch]);
}
const newData = makeData({
type: new Struct(schema.fields),
length: batch.numRows,
nullCount: batch.nullCount,
children: alignedChildren,
});
return new RecordBatch(schema, newData);
}
/**
* Reorder the columns in `table` so that they agree with the field order in `schema`
*/
function alignTable(table: ArrowTable, schema: Schema): ArrowTable {
const alignedBatches = table.batches.map((batch) =>
alignBatch(batch, schema),
);
return new ArrowTable(schema, alignedBatches);
}
/**
* Create an empty table with the given schema
*/
export function createEmptyTable(schema: Schema): ArrowTable {
return new ArrowTable(sanitizeSchema(schema));
} }

View File

@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
import { fromTableToBuffer, makeArrowTable, makeEmptyTable } from "./arrow"; import { toBuffer } from "./arrow";
import { Connection as LanceDbConnection } from "./native"; import { Connection as _NativeConnection } from "./native";
import { Table } from "./table"; import { Table } from "./table";
import { Table as ArrowTable, Schema } from "apache-arrow"; import { Table as ArrowTable } from "apache-arrow";
export interface CreateTableOptions { export interface CreateTableOptions {
/** /**
@@ -35,79 +35,28 @@ export interface CreateTableOptions {
existOk: boolean; existOk: boolean;
} }
export interface TableNamesOptions {
/**
* If present, only return names that come lexicographically after the
* supplied value.
*
* This can be combined with limit to implement pagination by setting this to
* the last table name from the previous page.
*/
startAfter?: string;
/** An optional limit to the number of results to return. */
limit?: number;
}
/** /**
* A LanceDB Connection that allows you to open tables and create new ones. * A LanceDB Connection that allows you to open tables and create new ones.
* *
* Connection could be local against filesystem or remote against a server. * Connection could be local against filesystem or remote against a server.
*
* A Connection is intended to be a long lived object and may hold open
* resources such as HTTP connection pools. This is generally fine and
* a single connection should be shared if it is going to be used many
* times. However, if you are finished with a connection, you may call
* close to eagerly free these resources. Any call to a Connection
* method after it has been closed will result in an error.
*
* Closing a connection is optional. Connections will automatically
* be closed when they are garbage collected.
*
* Any created tables are independent and will continue to work even if
* the underlying connection has been closed.
*/ */
export class Connection { export class Connection {
readonly inner: LanceDbConnection; readonly inner: _NativeConnection;
constructor(inner: LanceDbConnection) { constructor(inner: _NativeConnection) {
this.inner = inner; this.inner = inner;
} }
/** Return true if the connection has not been closed */ /** List all the table names in this database. */
isOpen(): boolean { async tableNames(): Promise<string[]> {
return this.inner.isOpen(); return this.inner.tableNames();
}
/**
* Close the connection, releasing any underlying resources.
*
* It is safe to call this method multiple times.
*
* Any attempt to use the connection after it is closed will result in an error.
*/
close(): void {
this.inner.close();
}
/** Return a brief description of the connection */
display(): string {
return this.inner.display();
}
/**
* List all the table names in this database.
*
* Tables will be returned in lexicographical order.
* @param {Partial<TableNamesOptions>} options - options to control the
* paging / start point
*/
async tableNames(options?: Partial<TableNamesOptions>): Promise<string[]> {
return this.inner.tableNames(options?.startAfter, options?.limit);
} }
/** /**
* Open a table in the database. * Open a table in the database.
* @param {string} name - The name of the table *
* @param name The name of the table.
* @param embeddings An embedding function to use on this table
*/ */
async openTable(name: string): Promise<Table> { async openTable(name: string): Promise<Table> {
const innerTable = await this.inner.openTable(name); const innerTable = await this.inner.openTable(name);
@@ -116,14 +65,14 @@ export class Connection {
/** /**
* Creates a new Table and initialize it with new data. * Creates a new Table and initialize it with new data.
*
* @param {string} name - The name of the table. * @param {string} name - The name of the table.
* @param {Record<string, unknown>[] | ArrowTable} data - Non-empty Array of Records * @param data - Non-empty Array of Records to be inserted into the table
* to be inserted into the table
*/ */
async createTable( async createTable(
name: string, name: string,
data: Record<string, unknown>[] | ArrowTable, data: Record<string, unknown>[] | ArrowTable,
options?: Partial<CreateTableOptions>, options?: Partial<CreateTableOptions>
): Promise<Table> { ): Promise<Table> {
let mode: string = options?.mode ?? "create"; let mode: string = options?.mode ?? "create";
const existOk = options?.existOk ?? false; const existOk = options?.existOk ?? false;
@@ -132,43 +81,14 @@ export class Connection {
mode = "exist_ok"; mode = "exist_ok";
} }
let table: ArrowTable; const buf = toBuffer(data);
if (data instanceof ArrowTable) {
table = data;
} else {
table = makeArrowTable(data);
}
const buf = await fromTableToBuffer(table);
const innerTable = await this.inner.createTable(name, buf, mode); const innerTable = await this.inner.createTable(name, buf, mode);
return new Table(innerTable); return new Table(innerTable);
} }
/**
* Creates a new empty Table
* @param {string} name - The name of the table.
* @param {Schema} schema - The schema of the table
*/
async createEmptyTable(
name: string,
schema: Schema,
options?: Partial<CreateTableOptions>,
): Promise<Table> {
let mode: string = options?.mode ?? "create";
const existOk = options?.existOk ?? false;
if (mode === "create" && existOk) {
mode = "exist_ok";
}
const table = makeEmptyTable(schema);
const buf = await fromTableToBuffer(table);
const innerTable = await this.inner.createEmptyTable(name, buf, mode);
return new Table(innerTable);
}
/** /**
* Drop an existing table. * Drop an existing table.
* @param {string} name The name of the table to drop. * @param name The name of the table to drop.
*/ */
async dropTable(name: string): Promise<void> { async dropTable(name: string): Promise<void> {
return this.inner.dropTable(name); return this.inner.dropTable(name);

View File

@@ -1,78 +0,0 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { type Float } from "apache-arrow";
/**
* An embedding function that automatically creates vector representation for a given column.
*/
export interface EmbeddingFunction<T> {
/**
* The name of the column that will be used as input for the Embedding Function.
*/
sourceColumn: string;
/**
* The data type of the embedding
*
* The embedding function should return `number`. This will be converted into
* an Arrow float array. By default this will be Float32 but this property can
* be used to control the conversion.
*/
embeddingDataType?: Float;
/**
* The dimension of the embedding
*
* This is optional, normally this can be determined by looking at the results of
* `embed`. If this is not specified, and there is an attempt to apply the embedding
* to an empty table, then that process will fail.
*/
embeddingDimension?: number;
/**
* The name of the column that will contain the embedding
*
* By default this is "vector"
*/
destColumn?: string;
/**
* Should the source column be excluded from the resulting table
*
* By default the source column is included. Set this to true and
* only the embedding will be stored.
*/
excludeSource?: boolean;
/**
* Creates a vector representation for the given values.
*/
embed: (data: T[]) => Promise<number[][]>;
}
/** Test if the input seems to be an embedding function */
export function isEmbeddingFunction<T>(
value: unknown,
): value is EmbeddingFunction<T> {
if (typeof value !== "object" || value === null) {
return false;
}
if (!("sourceColumn" in value) || !("embed" in value)) {
return false;
}
return (
typeof value.sourceColumn === "string" && typeof value.embed === "function"
);
}

View File

@@ -1,62 +0,0 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { type EmbeddingFunction } from "./embedding_function";
import type OpenAI from "openai";
export class OpenAIEmbeddingFunction implements EmbeddingFunction<string> {
private readonly _openai: OpenAI;
private readonly _modelName: string;
constructor(
sourceColumn: string,
openAIKey: string,
modelName: string = "text-embedding-ada-002",
) {
/**
* @type {import("openai").default}
*/
// eslint-disable-next-line @typescript-eslint/naming-convention
let Openai;
try {
// eslint-disable-next-line @typescript-eslint/no-var-requires
Openai = require("openai");
} catch {
throw new Error("please install openai@^4.24.1 using npm install openai");
}
this.sourceColumn = sourceColumn;
const configuration = {
apiKey: openAIKey,
};
this._openai = new Openai(configuration);
this._modelName = modelName;
}
async embed(data: string[]): Promise<number[][]> {
const response = await this._openai.embeddings.create({
model: this._modelName,
input: data,
});
const embeddings: number[][] = [];
for (let i = 0; i < response.data.length; i++) {
embeddings.push(response.data[i].embedding);
}
return embeddings;
}
sourceColumn: string;
}

View File

@@ -13,14 +13,18 @@
// limitations under the License. // limitations under the License.
import { Connection } from "./connection"; import { Connection } from "./connection";
import { import { Connection as NativeConnection, ConnectionOptions } from "./native.js";
Connection as LanceDbConnection,
ConnectionOptions,
} from "./native.js";
export { ConnectionOptions, WriteOptions, Query } from "./native.js"; export {
export { Connection, CreateTableOptions } from "./connection"; ConnectionOptions,
export { Table, AddDataOptions } from "./table"; WriteOptions,
Query,
MetricType,
} from "./native.js";
export { Connection } from "./connection";
export { Table } from "./table";
export { Data } from "./arrow";
export { IvfPQOptions, IndexBuilder } from "./indexer";
/** /**
* Connect to a LanceDB instance at the given URI. * Connect to a LanceDB instance at the given URI.
@@ -30,15 +34,31 @@ export { Table, AddDataOptions } from "./table";
* - `/path/to/database` - local database * - `/path/to/database` - local database
* - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud storage * - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud storage
* - `db://host:port` - remote database (LanceDB cloud) * - `db://host:port` - remote database (LanceDB cloud)
* @param {string} uri - The uri of the database. If the database uri starts *
* with `db://` then it connects to a remote database. * @param uri The uri of the database. If the database uri starts with `db://` then it connects to a remote database.
*
* @see {@link ConnectionOptions} for more details on the URI format. * @see {@link ConnectionOptions} for more details on the URI format.
*/ */
export async function connect(uri: string): Promise<Connection>;
export async function connect( export async function connect(
uri: string, opts: Partial<ConnectionOptions>
opts?: Partial<ConnectionOptions>, ): Promise<Connection>;
export async function connect(
args: string | Partial<ConnectionOptions>
): Promise<Connection> { ): Promise<Connection> {
opts = opts ?? {}; let opts: ConnectionOptions;
const nativeConn = await LanceDbConnection.new(uri, opts); if (typeof args === "string") {
opts = { uri: args };
} else {
opts = Object.assign(
{
uri: "",
apiKey: undefined,
hostOverride: undefined,
},
args
);
}
const nativeConn = await NativeConnection.new(opts);
return new Connection(nativeConn); return new Connection(nativeConn);
} }

102
nodejs/lancedb/indexer.ts Normal file
View File

@@ -0,0 +1,102 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import {
MetricType,
IndexBuilder as NativeBuilder,
Table as NativeTable,
} from "./native";
/** Options to create `IVF_PQ` index */
export interface IvfPQOptions {
/** Number of IVF partitions. */
num_partitions?: number;
/** Number of sub-vectors in PQ coding. */
num_sub_vectors?: number;
/** Number of bits used for each PQ code.
*/
num_bits?: number;
/** Metric type to calculate the distance between vectors.
*
* Supported metrics: `L2`, `Cosine` and `Dot`.
*/
metric_type?: MetricType;
/** Number of iterations to train K-means.
*
* Default is 50. The more iterations it usually yield better results,
* but it takes longer to train.
*/
max_iterations?: number;
sample_rate?: number;
}
/**
* Building an index on LanceDB {@link Table}
*
* @see {@link Table.createIndex} for detailed usage.
*/
export class IndexBuilder {
private inner: NativeBuilder;
constructor(tbl: NativeTable) {
this.inner = tbl.createIndex();
}
/** Instruct the builder to build an `IVF_PQ` index */
ivf_pq(options?: IvfPQOptions): IndexBuilder {
this.inner.ivfPq(
options?.metric_type,
options?.num_partitions,
options?.num_sub_vectors,
options?.num_bits,
options?.max_iterations,
options?.sample_rate
);
return this;
}
/** Instruct the builder to build a Scalar index. */
scalar(): IndexBuilder {
this.scalar();
return this;
}
/** Set the column(s) to create index on top of. */
column(col: string): IndexBuilder {
this.inner.column(col);
return this;
}
/** Set to true to replace existing index. */
replace(val: boolean): IndexBuilder {
this.inner.replace(val);
return this;
}
/** Specify the name of the index. Optional */
name(n: string): IndexBuilder {
this.inner.name(n);
return this;
}
/** Building the index. */
async build() {
await this.inner.build();
}
}

View File

@@ -1,203 +0,0 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { Index as LanceDbIndex } from "./native";
/**
* Options to create an `IVF_PQ` index
*/
export interface IvfPqOptions {
/**
* The number of IVF partitions to create.
*
* This value should generally scale with the number of rows in the dataset.
* By default the number of partitions is the square root of the number of
* rows.
*
* If this value is too large then the first part of the search (picking the
* right partition) will be slow. If this value is too small then the second
* part of the search (searching within a partition) will be slow.
*/
numPartitions?: number;
/**
* Number of sub-vectors of PQ.
*
* This value controls how much the vector is compressed during the quantization step.
* The more sub vectors there are the less the vector is compressed. The default is
* the dimension of the vector divided by 16. If the dimension is not evenly divisible
* by 16 we use the dimension divded by 8.
*
* The above two cases are highly preferred. Having 8 or 16 values per subvector allows
* us to use efficient SIMD instructions.
*
* If the dimension is not visible by 8 then we use 1 subvector. This is not ideal and
* will likely result in poor performance.
*/
numSubVectors?: number;
/**
* Distance type to use to build the index.
*
* Default value is "l2".
*
* This is used when training the index to calculate the IVF partitions
* (vectors are grouped in partitions with similar vectors according to this
* distance type) and to calculate a subvector's code during quantization.
*
* The distance type used to train an index MUST match the distance type used
* to search the index. Failure to do so will yield inaccurate results.
*
* The following distance types are available:
*
* "l2" - Euclidean distance. This is a very common distance metric that
* accounts for both magnitude and direction when determining the distance
* between vectors. L2 distance has a range of [0, ∞).
*
* "cosine" - Cosine distance. Cosine distance is a distance metric
* calculated from the cosine similarity between two vectors. Cosine
* similarity is a measure of similarity between two non-zero vectors of an
* inner product space. It is defined to equal the cosine of the angle
* between them. Unlike L2, the cosine distance is not affected by the
* magnitude of the vectors. Cosine distance has a range of [0, 2].
*
* Note: the cosine distance is undefined when one (or both) of the vectors
* are all zeros (there is no direction). These vectors are invalid and may
* never be returned from a vector search.
*
* "dot" - Dot product. Dot distance is the dot product of two vectors. Dot
* distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
* L2 norm is 1), then dot distance is equivalent to the cosine distance.
*/
distanceType?: "l2" | "cosine" | "dot";
/**
* Max iteration to train IVF kmeans.
*
* When training an IVF PQ index we use kmeans to calculate the partitions. This parameter
* controls how many iterations of kmeans to run.
*
* Increasing this might improve the quality of the index but in most cases these extra
* iterations have diminishing returns.
*
* The default value is 50.
*/
maxIterations?: number;
/**
* The number of vectors, per partition, to sample when training IVF kmeans.
*
* When an IVF PQ index is trained, we need to calculate partitions. These are groups
* of vectors that are similar to each other. To do this we use an algorithm called kmeans.
*
* Running kmeans on a large dataset can be slow. To speed this up we run kmeans on a
* random sample of the data. This parameter controls the size of the sample. The total
* number of vectors used to train the index is `sample_rate * num_partitions`.
*
* Increasing this value might improve the quality of the index but in most cases the
* default should be sufficient.
*
* The default value is 256.
*/
sampleRate?: number;
}
export class Index {
private readonly inner: LanceDbIndex;
private constructor(inner: LanceDbIndex) {
this.inner = inner;
}
/**
* Create an IvfPq index
*
* This index stores a compressed (quantized) copy of every vector. These vectors
* are grouped into partitions of similar vectors. Each partition keeps track of
* a centroid which is the average value of all vectors in the group.
*
* During a query the centroids are compared with the query vector to find the closest
* partitions. The compressed vectors in these partitions are then searched to find
* the closest vectors.
*
* The compression scheme is called product quantization. Each vector is divided into
* subvectors and then each subvector is quantized into a small number of bits. the
* parameters `num_bits` and `num_subvectors` control this process, providing a tradeoff
* between index size (and thus search speed) and index accuracy.
*
* The partitioning process is called IVF and the `num_partitions` parameter controls how
* many groups to create.
*
* Note that training an IVF PQ index on a large dataset is a slow operation and
* currently is also a memory intensive operation.
*/
static ivfPq(options?: Partial<IvfPqOptions>) {
return new Index(
LanceDbIndex.ivfPq(
options?.distanceType,
options?.numPartitions,
options?.numSubVectors,
options?.maxIterations,
options?.sampleRate,
),
);
}
/**
* Create a btree index
*
* A btree index is an index on a scalar columns. The index stores a copy of the column
* in sorted order. A header entry is created for each block of rows (currently the
* block size is fixed at 4096). These header entries are stored in a separate
* cacheable structure (a btree). To search for data the header is used to determine
* which blocks need to be read from disk.
*
* For example, a btree index in a table with 1Bi rows requires sizeof(Scalar) * 256Ki
* bytes of memory and will generally need to read sizeof(Scalar) * 4096 bytes to find
* the correct row ids.
*
* This index is good for scalar columns with mostly distinct values and does best when
* the query is highly selective.
*
* The btree index does not currently have any parameters though parameters such as the
* block size may be added in the future.
*/
static btree() {
return new Index(LanceDbIndex.btree());
}
}
export interface IndexOptions {
/**
* Advanced index configuration
*
* This option allows you to specify a specfic index to create and also
* allows you to pass in configuration for training the index.
*
* See the static methods on Index for details on the various index types.
*
* If this is not supplied then column data type(s) and column statistics
* will be used to determine the most useful kind of index to create.
*/
config?: Index;
/**
* Whether to replace the existing index
*
* If this is false, and another index already exists on the same columns
* and the same name, then an error will be returned. This is true even if
* that index is out of date.
*
* The default is true
*/
replace?: boolean;
}

View File

@@ -3,17 +3,14 @@
/* auto-generated by NAPI-RS */ /* auto-generated by NAPI-RS */
/** A description of an index currently configured on a column */ export const enum IndexType {
export interface IndexConfig { Scalar = 0,
/** The type of the index */ IvfPq = 1
indexType: string }
/** export const enum MetricType {
* The columns in the index L2 = 0,
* Cosine = 1,
* Currently this is always an array of size 1. In the future there may Dot = 2
* be more columns to represent composite indices.
*/
columns: Array<string>
} }
/** /**
* A definition of a column alteration. The alteration changes the column at * A definition of a column alteration. The alteration changes the column at
@@ -48,6 +45,7 @@ export interface AddColumnsSql {
valueSql: string valueSql: string
} }
export interface ConnectionOptions { export interface ConnectionOptions {
uri: string
apiKey?: string apiKey?: string
hostOverride?: string hostOverride?: string
/** /**
@@ -73,15 +71,12 @@ export const enum WriteMode {
export interface WriteOptions { export interface WriteOptions {
mode?: WriteMode mode?: WriteMode
} }
export function connect(uri: string, options: ConnectionOptions): Promise<Connection> export function connect(options: ConnectionOptions): Promise<Connection>
export class Connection { export class Connection {
/** Create a new Connection instance from the given URI. */ /** Create a new Connection instance from the given URI. */
static new(uri: string, options: ConnectionOptions): Promise<Connection> static new(options: ConnectionOptions): Promise<Connection>
display(): string
isOpen(): boolean
close(): void
/** List all tables in the dataset. */ /** List all tables in the dataset. */
tableNames(startAfter?: string | undefined | null, limit?: number | undefined | null): Promise<Array<string>> tableNames(): Promise<Array<string>>
/** /**
* Create table from a Apache Arrow IPC (file) buffer. * Create table from a Apache Arrow IPC (file) buffer.
* *
@@ -91,57 +86,42 @@ export class Connection {
* *
*/ */
createTable(name: string, buf: Buffer, mode: string): Promise<Table> createTable(name: string, buf: Buffer, mode: string): Promise<Table>
createEmptyTable(name: string, schemaBuf: Buffer, mode: string): Promise<Table>
openTable(name: string): Promise<Table> openTable(name: string): Promise<Table>
/** Drop table with the name. Or raise an error if the table does not exist. */ /** Drop table with the name. Or raise an error if the table does not exist. */
dropTable(name: string): Promise<void> dropTable(name: string): Promise<void>
} }
export class Index { export class IndexBuilder {
static ivfPq(distanceType?: string | undefined | null, numPartitions?: number | undefined | null, numSubVectors?: number | undefined | null, maxIterations?: number | undefined | null, sampleRate?: number | undefined | null): Index replace(v: boolean): void
static btree(): Index column(c: string): void
name(name: string): void
ivfPq(metricType?: MetricType | undefined | null, numPartitions?: number | undefined | null, numSubVectors?: number | undefined | null, numBits?: number | undefined | null, maxIterations?: number | undefined | null, sampleRate?: number | undefined | null): void
scalar(): void
build(): Promise<void>
} }
/** Typescript-style Async Iterator over RecordBatches */ /** Typescript-style Async Iterator over RecordBatches */
export class RecordBatchIterator { export class RecordBatchIterator {
next(): Promise<Buffer | null> next(): Promise<Buffer | null>
} }
export class Query { export class Query {
onlyIf(predicate: string): void
select(columns: Array<[string, string]>): void
limit(limit: number): void
nearestTo(vector: Float32Array): VectorQuery
execute(): Promise<RecordBatchIterator>
}
export class VectorQuery {
column(column: string): void column(column: string): void
distanceType(distanceType: string): void filter(filter: string): void
postfilter(): void select(columns: Array<string>): void
limit(limit: number): void
prefilter(prefilter: boolean): void
nearestTo(vector: Float32Array): void
refineFactor(refineFactor: number): void refineFactor(refineFactor: number): void
nprobes(nprobe: number): void nprobes(nprobe: number): void
bypassVectorIndex(): void executeStream(): Promise<RecordBatchIterator>
onlyIf(predicate: string): void
select(columns: Array<[string, string]>): void
limit(limit: number): void
execute(): Promise<RecordBatchIterator>
} }
export class Table { export class Table {
display(): string
isOpen(): boolean
close(): void
/** Return Schema as empty Arrow IPC file. */ /** Return Schema as empty Arrow IPC file. */
schema(): Promise<Buffer> schema(): Promise<Buffer>
add(buf: Buffer, mode: string): Promise<void> add(buf: Buffer): Promise<void>
countRows(filter?: string | undefined | null): Promise<number> countRows(filter?: string | undefined | null): Promise<number>
delete(predicate: string): Promise<void> delete(predicate: string): Promise<void>
createIndex(index: Index | undefined | null, column: string, replace?: boolean | undefined | null): Promise<void> createIndex(): IndexBuilder
update(onlyIf: string | undefined | null, columns: Array<[string, string]>): Promise<void>
query(): Query query(): Query
vectorSearch(vector: Float32Array): VectorQuery
addColumns(transforms: Array<AddColumnsSql>): Promise<void> addColumns(transforms: Array<AddColumnsSql>): Promise<void>
alterColumns(alterations: Array<ColumnAlteration>): Promise<void> alterColumns(alterations: Array<ColumnAlteration>): Promise<void>
dropColumns(columns: Array<string>): Promise<void> dropColumns(columns: Array<string>): Promise<void>
version(): Promise<number>
checkout(version: number): Promise<void>
checkoutLatest(): Promise<void>
restore(): Promise<void>
listIndices(): Promise<Array<IndexConfig>>
} }

View File

@@ -5,325 +5,304 @@
/* auto-generated by NAPI-RS */ /* auto-generated by NAPI-RS */
const { existsSync, readFileSync } = require('fs') const { existsSync, readFileSync } = require('fs')
const { join } = require("path"); const { join } = require('path')
const { platform, arch } = process; const { platform, arch } = process
let nativeBinding = null; let nativeBinding = null
let localFileExisted = false; let localFileExisted = false
let loadError = null; let loadError = null
function isMusl() { function isMusl() {
// For Node 10 // For Node 10
if (!process.report || typeof process.report.getReport !== "function") { if (!process.report || typeof process.report.getReport !== 'function') {
try { try {
const lddPath = require("child_process") const lddPath = require('child_process').execSync('which ldd').toString().trim()
.execSync("which ldd") return readFileSync(lddPath, 'utf8').includes('musl')
.toString()
.trim();
return readFileSync(lddPath, "utf8").includes("musl");
} catch (e) { } catch (e) {
return true; return true
} }
} else { } else {
const { glibcVersionRuntime } = process.report.getReport().header; const { glibcVersionRuntime } = process.report.getReport().header
return !glibcVersionRuntime; return !glibcVersionRuntime
} }
} }
switch (platform) { switch (platform) {
case "android": case 'android':
switch (arch) { switch (arch) {
case "arm64": case 'arm64':
localFileExisted = existsSync( localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.android-arm64.node'))
join(__dirname, "lancedb-nodejs.android-arm64.node"),
);
try { try {
if (localFileExisted) { if (localFileExisted) {
nativeBinding = require("./lancedb-nodejs.android-arm64.node"); nativeBinding = require('./lancedb-nodejs.android-arm64.node')
} else { } else {
nativeBinding = require("lancedb-android-arm64"); nativeBinding = require('lancedb-android-arm64')
} }
} catch (e) { } catch (e) {
loadError = e; loadError = e
} }
break; break
case "arm": case 'arm':
localFileExisted = existsSync( localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.android-arm-eabi.node'))
join(__dirname, "lancedb-nodejs.android-arm-eabi.node"),
);
try { try {
if (localFileExisted) { if (localFileExisted) {
nativeBinding = require("./lancedb-nodejs.android-arm-eabi.node"); nativeBinding = require('./lancedb-nodejs.android-arm-eabi.node')
} else { } else {
nativeBinding = require("lancedb-android-arm-eabi"); nativeBinding = require('lancedb-android-arm-eabi')
} }
} catch (e) { } catch (e) {
loadError = e; loadError = e
} }
break; break
default: default:
throw new Error(`Unsupported architecture on Android ${arch}`); throw new Error(`Unsupported architecture on Android ${arch}`)
} }
break; break
case "win32": case 'win32':
switch (arch) { switch (arch) {
case "x64": case 'x64':
localFileExisted = existsSync( localFileExisted = existsSync(
join(__dirname, "lancedb-nodejs.win32-x64-msvc.node"), join(__dirname, 'lancedb-nodejs.win32-x64-msvc.node')
); )
try { try {
if (localFileExisted) { if (localFileExisted) {
nativeBinding = require("./lancedb-nodejs.win32-x64-msvc.node"); nativeBinding = require('./lancedb-nodejs.win32-x64-msvc.node')
} else { } else {
nativeBinding = require("lancedb-win32-x64-msvc"); nativeBinding = require('lancedb-win32-x64-msvc')
} }
} catch (e) { } catch (e) {
loadError = e; loadError = e
} }
break; break
case "ia32": case 'ia32':
localFileExisted = existsSync( localFileExisted = existsSync(
join(__dirname, "lancedb-nodejs.win32-ia32-msvc.node"), join(__dirname, 'lancedb-nodejs.win32-ia32-msvc.node')
); )
try { try {
if (localFileExisted) { if (localFileExisted) {
nativeBinding = require("./lancedb-nodejs.win32-ia32-msvc.node"); nativeBinding = require('./lancedb-nodejs.win32-ia32-msvc.node')
} else { } else {
nativeBinding = require("lancedb-win32-ia32-msvc"); nativeBinding = require('lancedb-win32-ia32-msvc')
} }
} catch (e) { } catch (e) {
loadError = e; loadError = e
} }
break; break
case "arm64": case 'arm64':
localFileExisted = existsSync( localFileExisted = existsSync(
join(__dirname, "lancedb-nodejs.win32-arm64-msvc.node"), join(__dirname, 'lancedb-nodejs.win32-arm64-msvc.node')
); )
try { try {
if (localFileExisted) { if (localFileExisted) {
nativeBinding = require("./lancedb-nodejs.win32-arm64-msvc.node"); nativeBinding = require('./lancedb-nodejs.win32-arm64-msvc.node')
} else { } else {
nativeBinding = require("lancedb-win32-arm64-msvc"); nativeBinding = require('lancedb-win32-arm64-msvc')
} }
} catch (e) { } catch (e) {
loadError = e; loadError = e
} }
break; break
default: default:
throw new Error(`Unsupported architecture on Windows: ${arch}`); throw new Error(`Unsupported architecture on Windows: ${arch}`)
} }
break; break
case "darwin": case 'darwin':
localFileExisted = existsSync( localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.darwin-universal.node'))
join(__dirname, "lancedb-nodejs.darwin-universal.node"),
);
try { try {
if (localFileExisted) { if (localFileExisted) {
nativeBinding = require("./lancedb-nodejs.darwin-universal.node"); nativeBinding = require('./lancedb-nodejs.darwin-universal.node')
} else { } else {
nativeBinding = require("lancedb-darwin-universal"); nativeBinding = require('lancedb-darwin-universal')
} }
break; break
} catch {} } catch {}
switch (arch) { switch (arch) {
case "x64": case 'x64':
localFileExisted = existsSync( localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.darwin-x64.node'))
join(__dirname, "lancedb-nodejs.darwin-x64.node"),
);
try { try {
if (localFileExisted) { if (localFileExisted) {
nativeBinding = require("./lancedb-nodejs.darwin-x64.node"); nativeBinding = require('./lancedb-nodejs.darwin-x64.node')
} else { } else {
nativeBinding = require("lancedb-darwin-x64"); nativeBinding = require('lancedb-darwin-x64')
} }
} catch (e) { } catch (e) {
loadError = e; loadError = e
} }
break; break
case "arm64": case 'arm64':
localFileExisted = existsSync( localFileExisted = existsSync(
join(__dirname, "lancedb-nodejs.darwin-arm64.node"), join(__dirname, 'lancedb-nodejs.darwin-arm64.node')
); )
try { try {
if (localFileExisted) { if (localFileExisted) {
nativeBinding = require("./lancedb-nodejs.darwin-arm64.node"); nativeBinding = require('./lancedb-nodejs.darwin-arm64.node')
} else { } else {
nativeBinding = require("lancedb-darwin-arm64"); nativeBinding = require('lancedb-darwin-arm64')
} }
} catch (e) { } catch (e) {
loadError = e; loadError = e
} }
break; break
default: default:
throw new Error(`Unsupported architecture on macOS: ${arch}`); throw new Error(`Unsupported architecture on macOS: ${arch}`)
} }
break; break
case "freebsd": case 'freebsd':
if (arch !== "x64") { if (arch !== 'x64') {
throw new Error(`Unsupported architecture on FreeBSD: ${arch}`); throw new Error(`Unsupported architecture on FreeBSD: ${arch}`)
} }
localFileExisted = existsSync( localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.freebsd-x64.node'))
join(__dirname, "lancedb-nodejs.freebsd-x64.node"),
);
try { try {
if (localFileExisted) { if (localFileExisted) {
nativeBinding = require("./lancedb-nodejs.freebsd-x64.node"); nativeBinding = require('./lancedb-nodejs.freebsd-x64.node')
} else { } else {
nativeBinding = require("lancedb-freebsd-x64"); nativeBinding = require('lancedb-freebsd-x64')
} }
} catch (e) { } catch (e) {
loadError = e; loadError = e
} }
break; break
case "linux": case 'linux':
switch (arch) { switch (arch) {
case "x64": case 'x64':
if (isMusl()) { if (isMusl()) {
localFileExisted = existsSync( localFileExisted = existsSync(
join(__dirname, "lancedb-nodejs.linux-x64-musl.node"), join(__dirname, 'lancedb-nodejs.linux-x64-musl.node')
); )
try { try {
if (localFileExisted) { if (localFileExisted) {
nativeBinding = require("./lancedb-nodejs.linux-x64-musl.node"); nativeBinding = require('./lancedb-nodejs.linux-x64-musl.node')
} else { } else {
nativeBinding = require("lancedb-linux-x64-musl"); nativeBinding = require('lancedb-linux-x64-musl')
} }
} catch (e) { } catch (e) {
loadError = e; loadError = e
} }
} else { } else {
localFileExisted = existsSync( localFileExisted = existsSync(
join(__dirname, "lancedb-nodejs.linux-x64-gnu.node"), join(__dirname, 'lancedb-nodejs.linux-x64-gnu.node')
); )
try { try {
if (localFileExisted) { if (localFileExisted) {
nativeBinding = require("./lancedb-nodejs.linux-x64-gnu.node"); nativeBinding = require('./lancedb-nodejs.linux-x64-gnu.node')
} else { } else {
nativeBinding = require("lancedb-linux-x64-gnu"); nativeBinding = require('lancedb-linux-x64-gnu')
} }
} catch (e) { } catch (e) {
loadError = e; loadError = e
} }
} }
break; break
case "arm64": case 'arm64':
if (isMusl()) { if (isMusl()) {
localFileExisted = existsSync( localFileExisted = existsSync(
join(__dirname, "lancedb-nodejs.linux-arm64-musl.node"), join(__dirname, 'lancedb-nodejs.linux-arm64-musl.node')
); )
try { try {
if (localFileExisted) { if (localFileExisted) {
nativeBinding = require("./lancedb-nodejs.linux-arm64-musl.node"); nativeBinding = require('./lancedb-nodejs.linux-arm64-musl.node')
} else { } else {
nativeBinding = require("lancedb-linux-arm64-musl"); nativeBinding = require('lancedb-linux-arm64-musl')
} }
} catch (e) { } catch (e) {
loadError = e; loadError = e
} }
} else { } else {
localFileExisted = existsSync( localFileExisted = existsSync(
join(__dirname, "lancedb-nodejs.linux-arm64-gnu.node"), join(__dirname, 'lancedb-nodejs.linux-arm64-gnu.node')
); )
try { try {
if (localFileExisted) { if (localFileExisted) {
nativeBinding = require("./lancedb-nodejs.linux-arm64-gnu.node"); nativeBinding = require('./lancedb-nodejs.linux-arm64-gnu.node')
} else { } else {
nativeBinding = require("lancedb-linux-arm64-gnu"); nativeBinding = require('lancedb-linux-arm64-gnu')
} }
} catch (e) { } catch (e) {
loadError = e; loadError = e
} }
} }
break; break
case "arm": case 'arm':
localFileExisted = existsSync( localFileExisted = existsSync(
join(__dirname, "lancedb-nodejs.linux-arm-gnueabihf.node"), join(__dirname, 'lancedb-nodejs.linux-arm-gnueabihf.node')
); )
try { try {
if (localFileExisted) { if (localFileExisted) {
nativeBinding = require("./lancedb-nodejs.linux-arm-gnueabihf.node"); nativeBinding = require('./lancedb-nodejs.linux-arm-gnueabihf.node')
} else { } else {
nativeBinding = require("lancedb-linux-arm-gnueabihf"); nativeBinding = require('lancedb-linux-arm-gnueabihf')
} }
} catch (e) { } catch (e) {
loadError = e; loadError = e
} }
break; break
case "riscv64": case 'riscv64':
if (isMusl()) { if (isMusl()) {
localFileExisted = existsSync( localFileExisted = existsSync(
join(__dirname, "lancedb-nodejs.linux-riscv64-musl.node"), join(__dirname, 'lancedb-nodejs.linux-riscv64-musl.node')
); )
try { try {
if (localFileExisted) { if (localFileExisted) {
nativeBinding = require("./lancedb-nodejs.linux-riscv64-musl.node"); nativeBinding = require('./lancedb-nodejs.linux-riscv64-musl.node')
} else { } else {
nativeBinding = require("lancedb-linux-riscv64-musl"); nativeBinding = require('lancedb-linux-riscv64-musl')
} }
} catch (e) { } catch (e) {
loadError = e; loadError = e
} }
} else { } else {
localFileExisted = existsSync( localFileExisted = existsSync(
join(__dirname, "lancedb-nodejs.linux-riscv64-gnu.node"), join(__dirname, 'lancedb-nodejs.linux-riscv64-gnu.node')
); )
try { try {
if (localFileExisted) { if (localFileExisted) {
nativeBinding = require("./lancedb-nodejs.linux-riscv64-gnu.node"); nativeBinding = require('./lancedb-nodejs.linux-riscv64-gnu.node')
} else { } else {
nativeBinding = require("lancedb-linux-riscv64-gnu"); nativeBinding = require('lancedb-linux-riscv64-gnu')
} }
} catch (e) { } catch (e) {
loadError = e; loadError = e
} }
} }
break; break
case "s390x": case 's390x':
localFileExisted = existsSync( localFileExisted = existsSync(
join(__dirname, "lancedb-nodejs.linux-s390x-gnu.node"), join(__dirname, 'lancedb-nodejs.linux-s390x-gnu.node')
); )
try { try {
if (localFileExisted) { if (localFileExisted) {
nativeBinding = require("./lancedb-nodejs.linux-s390x-gnu.node"); nativeBinding = require('./lancedb-nodejs.linux-s390x-gnu.node')
} else { } else {
nativeBinding = require("lancedb-linux-s390x-gnu"); nativeBinding = require('lancedb-linux-s390x-gnu')
} }
} catch (e) { } catch (e) {
loadError = e; loadError = e
} }
break; break
default: default:
throw new Error(`Unsupported architecture on Linux: ${arch}`); throw new Error(`Unsupported architecture on Linux: ${arch}`)
} }
break; break
default: default:
throw new Error(`Unsupported OS: ${platform}, architecture: ${arch}`); throw new Error(`Unsupported OS: ${platform}, architecture: ${arch}`)
} }
if (!nativeBinding) { if (!nativeBinding) {
if (loadError) { if (loadError) {
throw loadError; throw loadError
} }
throw new Error(`Failed to load native binding`); throw new Error(`Failed to load native binding`)
} }
const { const { Connection, IndexType, MetricType, IndexBuilder, RecordBatchIterator, Query, Table, WriteMode, connect } = nativeBinding
Connection,
Index,
RecordBatchIterator,
Query,
VectorQuery,
Table,
WriteMode,
connect,
} = nativeBinding;
module.exports.Connection = Connection; module.exports.Connection = Connection
module.exports.Index = Index; module.exports.IndexType = IndexType
module.exports.RecordBatchIterator = RecordBatchIterator; module.exports.MetricType = MetricType
module.exports.Query = Query; module.exports.IndexBuilder = IndexBuilder
module.exports.VectorQuery = VectorQuery; module.exports.RecordBatchIterator = RecordBatchIterator
module.exports.Table = Table; module.exports.Query = Query
module.exports.WriteMode = WriteMode; module.exports.Table = Table
module.exports.connect = connect; module.exports.WriteMode = WriteMode
module.exports.connect = connect

View File

@@ -17,22 +17,24 @@ import {
RecordBatchIterator as NativeBatchIterator, RecordBatchIterator as NativeBatchIterator,
Query as NativeQuery, Query as NativeQuery,
Table as NativeTable, Table as NativeTable,
VectorQuery as NativeVectorQuery,
} from "./native"; } from "./native";
import { type IvfPqOptions } from "./indices";
class RecordBatchIterator implements AsyncIterator<RecordBatch> { class RecordBatchIterator implements AsyncIterator<RecordBatch> {
private promisedInner?: Promise<NativeBatchIterator>; private promised_inner?: Promise<NativeBatchIterator>;
private inner?: NativeBatchIterator; private inner?: NativeBatchIterator;
constructor(promise?: Promise<NativeBatchIterator>) { constructor(
inner?: NativeBatchIterator,
promise?: Promise<NativeBatchIterator>
) {
// TODO: check promise reliably so we dont need to pass two arguments. // TODO: check promise reliably so we dont need to pass two arguments.
this.promisedInner = promise; this.inner = inner;
this.promised_inner = promise;
} }
// eslint-disable-next-line @typescript-eslint/no-explicit-any async next(): Promise<IteratorResult<RecordBatch<any>, any>> {
async next(): Promise<IteratorResult<RecordBatch<any>>> {
if (this.inner === undefined) { if (this.inner === undefined) {
this.inner = await this.promisedInner; this.inner = await this.promised_inner;
} }
if (this.inner === undefined) { if (this.inner === undefined) {
throw new Error("Invalid iterator state state"); throw new Error("Invalid iterator state state");
@@ -50,113 +52,82 @@ class RecordBatchIterator implements AsyncIterator<RecordBatch> {
} }
/* eslint-enable */ /* eslint-enable */
/** Common methods supported by all query types */ /** Query executor */
export class QueryBase< export class Query implements AsyncIterable<RecordBatch> {
NativeQueryType extends NativeQuery | NativeVectorQuery, private readonly inner: NativeQuery;
QueryType,
> implements AsyncIterable<RecordBatch>
{
protected constructor(protected inner: NativeQueryType) {}
/** constructor(tbl: NativeTable) {
* A filter statement to be applied to this query. this.inner = tbl.query();
}
/** Set the column to run query. */
column(column: string): Query {
this.inner.column(column);
return this;
}
/** Set the filter predicate, only returns the results that satisfy the filter.
* *
* The filter should be supplied as an SQL query string. For example:
* @example
* x > 10
* y > 0 AND y < 100
* x > 5 OR y = 'test'
*
* Filtering performance can often be improved by creating a scalar index
* on the filter column(s).
*/ */
where(predicate: string): QueryType { filter(predicate: string): Query {
this.inner.onlyIf(predicate); this.inner.filter(predicate);
return this as unknown as QueryType; return this;
} }
/** /**
* Return only the specified columns. * Select the columns to return. If not set, all columns are returned.
*
* By default a query will return all columns from the table. However, this can have
* a very significant impact on latency. LanceDb stores data in a columnar fashion. This
* means we can finely tune our I/O to select exactly the columns we need.
*
* As a best practice you should always limit queries to the columns that you need. If you
* pass in an array of column names then only those columns will be returned.
*
* You can also use this method to create new "dynamic" columns based on your existing columns.
* For example, you may not care about "a" or "b" but instead simply want "a + b". This is often
* seen in the SELECT clause of an SQL query (e.g. `SELECT a+b FROM my_table`).
*
* To create dynamic columns you can pass in a Map<string, string>. A column will be returned
* for each entry in the map. The key provides the name of the column. The value is
* an SQL string used to specify how the column is calculated.
*
* For example, an SQL query might state `SELECT a + b AS combined, c`. The equivalent
* input to this method would be:
* @example
* new Map([["combined", "a + b"], ["c", "c"]])
*
* Columns will always be returned in the order given, even if that order is different than
* the order used when adding the data.
*
* Note that you can pass in a `Record<string, string>` (e.g. an object literal). This method
* uses `Object.entries` which should preserve the insertion order of the object. However,
* object insertion order is easy to get wrong and `Map` is more foolproof.
*/ */
select( select(columns: string[]): Query {
columns: string[] | Map<string, string> | Record<string, string>, this.inner.select(columns);
): QueryType { return this;
let columnTuples: [string, string][];
if (Array.isArray(columns)) {
columnTuples = columns.map((c) => [c, c]);
} else if (columns instanceof Map) {
columnTuples = Array.from(columns.entries());
} else {
columnTuples = Object.entries(columns);
}
this.inner.select(columnTuples);
return this as unknown as QueryType;
} }
/** /**
* Set the maximum number of results to return. * Set the limit of rows to return.
*
* By default, a plain search has no limit. If this method is not
* called then every valid row from the table will be returned.
*/ */
limit(limit: number): QueryType { limit(limit: number): Query {
this.inner.limit(limit); this.inner.limit(limit);
return this as unknown as QueryType; return this;
} }
protected nativeExecute(): Promise<NativeBatchIterator> { prefilter(prefilter: boolean): Query {
return this.inner.execute(); this.inner.prefilter(prefilter);
return this;
} }
/** /**
* Execute the query and return the results as an @see {@link AsyncIterator} * Set the query vector.
* of @see {@link RecordBatch}.
*
* By default, LanceDb will use many threads to calculate results and, when
* the result set is large, multiple batches will be processed at one time.
* This readahead is limited however and backpressure will be applied if this
* stream is consumed slowly (this constrains the maximum memory used by a
* single query)
*
*/ */
protected execute(): RecordBatchIterator { nearestTo(vector: number[]): Query {
return new RecordBatchIterator(this.nativeExecute()); this.inner.nearestTo(Float32Array.from(vector));
return this;
} }
// eslint-disable-next-line @typescript-eslint/no-explicit-any /**
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>> { * Set the number of IVF partitions to use for the query.
const promise = this.nativeExecute(); */
return new RecordBatchIterator(promise); nprobes(nprobes: number): Query {
this.inner.nprobes(nprobes);
return this;
} }
/** Collect the results as an Arrow @see {@link ArrowTable}. */ /**
* Set the refine factor for the query.
*/
refineFactor(refine_factor: number): Query {
this.inner.refineFactor(refine_factor);
return this;
}
/**
* Execute the query and return the results as an AsyncIterator.
*/
async executeStream(): Promise<RecordBatchIterator> {
const inner = await this.inner.executeStream();
return new RecordBatchIterator(inner);
}
/** Collect the results as an Arrow Table. */
async toArrow(): Promise<ArrowTable> { async toArrow(): Promise<ArrowTable> {
const batches = []; const batches = [];
for await (const batch of this) { for await (const batch of this) {
@@ -165,211 +136,17 @@ export class QueryBase<
return new ArrowTable(batches); return new ArrowTable(batches);
} }
/** Collect the results as an array of objects. */ /** Returns a JSON Array of All results.
async toArray(): Promise<unknown[]> { *
*/
async toArray(): Promise<any[]> {
const tbl = await this.toArrow(); const tbl = await this.toArrow();
// eslint-disable-next-line @typescript-eslint/no-unsafe-return // eslint-disable-next-line @typescript-eslint/no-unsafe-return
return tbl.toArray(); return tbl.toArray();
} }
}
/** [Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>> {
* An interface for a query that can be executed const promise = this.inner.executeStream();
* return new RecordBatchIterator(undefined, promise);
* Supported by all query types
*/
export interface ExecutableQuery {}
/**
* A builder used to construct a vector search
*
* This builder can be reused to execute the query many times.
*/
export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
constructor(inner: NativeVectorQuery) {
super(inner);
}
/**
* Set the number of partitions to search (probe)
*
* This argument is only used when the vector column has an IVF PQ index.
* If there is no index then this value is ignored.
*
* The IVF stage of IVF PQ divides the input into partitions (clusters) of
* related values.
*
* The partition whose centroids are closest to the query vector will be
* exhaustiely searched to find matches. This parameter controls how many
* partitions should be searched.
*
* Increasing this value will increase the recall of your query but will
* also increase the latency of your query. The default value is 20. This
* default is good for many cases but the best value to use will depend on
* your data and the recall that you need to achieve.
*
* For best results we recommend tuning this parameter with a benchmark against
* your actual data to find the smallest possible value that will still give
* you the desired recall.
*/
nprobes(nprobes: number): VectorQuery {
this.inner.nprobes(nprobes);
return this;
}
/**
* Set the vector column to query
*
* This controls which column is compared to the query vector supplied in
* the call to @see {@link Query#nearestTo}
*
* This parameter must be specified if the table has more than one column
* whose data type is a fixed-size-list of floats.
*/
column(column: string): VectorQuery {
this.inner.column(column);
return this;
}
/**
* Set the distance metric to use
*
* When performing a vector search we try and find the "nearest" vectors according
* to some kind of distance metric. This parameter controls which distance metric to
* use. See @see {@link IvfPqOptions.distanceType} for more details on the different
* distance metrics available.
*
* Note: if there is a vector index then the distance type used MUST match the distance
* type used to train the vector index. If this is not done then the results will be
* invalid.
*
* By default "l2" is used.
*/
distanceType(distanceType: string): VectorQuery {
this.inner.distanceType(distanceType);
return this;
}
/**
* A multiplier to control how many additional rows are taken during the refine step
*
* This argument is only used when the vector column has an IVF PQ index.
* If there is no index then this value is ignored.
*
* An IVF PQ index stores compressed (quantized) values. They query vector is compared
* against these values and, since they are compressed, the comparison is inaccurate.
*
* This parameter can be used to refine the results. It can improve both improve recall
* and correct the ordering of the nearest results.
*
* To refine results LanceDb will first perform an ANN search to find the nearest
* `limit` * `refine_factor` results. In other words, if `refine_factor` is 3 and
* `limit` is the default (10) then the first 30 results will be selected. LanceDb
* then fetches the full, uncompressed, values for these 30 results. The results are
* then reordered by the true distance and only the nearest 10 are kept.
*
* Note: there is a difference between calling this method with a value of 1 and never
* calling this method at all. Calling this method with any value will have an impact
* on your search latency. When you call this method with a `refine_factor` of 1 then
* LanceDb still needs to fetch the full, uncompressed, values so that it can potentially
* reorder the results.
*
* Note: if this method is NOT called then the distances returned in the _distance column
* will be approximate distances based on the comparison of the quantized query vector
* and the quantized result vectors. This can be considerably different than the true
* distance between the query vector and the actual uncompressed vector.
*/
refineFactor(refineFactor: number): VectorQuery {
this.inner.refineFactor(refineFactor);
return this;
}
/**
* If this is called then filtering will happen after the vector search instead of
* before.
*
* By default filtering will be performed before the vector search. This is how
* filtering is typically understood to work. This prefilter step does add some
* additional latency. Creating a scalar index on the filter column(s) can
* often improve this latency. However, sometimes a filter is too complex or scalar
* indices cannot be applied to the column. In these cases postfiltering can be
* used instead of prefiltering to improve latency.
*
* Post filtering applies the filter to the results of the vector search. This means
* we only run the filter on a much smaller set of data. However, it can cause the
* query to return fewer than `limit` results (or even no results) if none of the nearest
* results match the filter.
*
* Post filtering happens during the "refine stage" (described in more detail in
* @see {@link VectorQuery#refineFactor}). This means that setting a higher refine
* factor can often help restore some of the results lost by post filtering.
*/
postfilter(): VectorQuery {
this.inner.postfilter();
return this;
}
/**
* If this is called then any vector index is skipped
*
* An exhaustive (flat) search will be performed. The query vector will
* be compared to every vector in the table. At high scales this can be
* expensive. However, this is often still useful. For example, skipping
* the vector index can give you ground truth results which you can use to
* calculate your recall to select an appropriate value for nprobes.
*/
bypassVectorIndex(): VectorQuery {
this.inner.bypassVectorIndex();
return this;
}
}
/** A builder for LanceDB queries. */
export class Query extends QueryBase<NativeQuery, Query> {
constructor(tbl: NativeTable) {
super(tbl.query());
}
/**
* Find the nearest vectors to the given query vector.
*
* This converts the query from a plain query to a vector query.
*
* This method will attempt to convert the input to the query vector
* expected by the embedding model. If the input cannot be converted
* then an error will be thrown.
*
* By default, there is no embedding model, and the input should be
* an array-like object of numbers (something that can be used as input
* to Float32Array.from)
*
* If there is only one vector column (a column whose data type is a
* fixed size list of floats) then the column does not need to be specified.
* If there is more than one vector column you must use
* @see {@link VectorQuery#column} to specify which column you would like
* to compare with.
*
* If no index has been created on the vector column then a vector query
* will perform a distance comparison between the query vector and every
* vector in the database and then sort the results. This is sometimes
* called a "flat search"
*
* For small databases, with a few hundred thousand vectors or less, this can
* be reasonably fast. In larger databases you should create a vector index
* on the column. If there is a vector index then an "approximate" nearest
* neighbor search (frequently called an ANN search) will be performed. This
* search is much faster, but the results will be approximate.
*
* The query can be further parameterized using the returned builder. There
* are various ANN search parameters that will let you fine tune your recall
* accuracy vs search latency.
*
* Vector searches always have a `limit`. If `limit` has not been called then
* a default `limit` of 10 will be used. @see {@link Query#limit}
*/
nearestTo(vector: unknown): VectorQuery {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector as any));
return new VectorQuery(vectorQuery);
} }
} }

View File

@@ -1,516 +0,0 @@
// Copyright 2023 LanceDB Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// The utilities in this file help sanitize data from the user's arrow
// library into the types expected by vectordb's arrow library. Node
// generally allows for mulitple versions of the same library (and sometimes
// even multiple copies of the same version) to be installed at the same
// time. However, arrow-js uses instanceof which expected that the input
// comes from the exact same library instance. This is not always the case
// and so we must sanitize the input to ensure that it is compatible.
import {
Field,
Utf8,
FixedSizeBinary,
FixedSizeList,
Schema,
List,
Struct,
Float,
Bool,
Date_,
Decimal,
DataType,
Dictionary,
Binary,
Float32,
Interval,
Map_,
Duration,
Union,
Time,
Timestamp,
Type,
Null,
Int,
type Precision,
type DateUnit,
Int8,
Int16,
Int32,
Int64,
Uint8,
Uint16,
Uint32,
Uint64,
Float16,
Float64,
DateDay,
DateMillisecond,
DenseUnion,
SparseUnion,
TimeNanosecond,
TimeMicrosecond,
TimeMillisecond,
TimeSecond,
TimestampNanosecond,
TimestampMicrosecond,
TimestampMillisecond,
TimestampSecond,
IntervalDayTime,
IntervalYearMonth,
DurationNanosecond,
DurationMicrosecond,
DurationMillisecond,
DurationSecond,
} from "apache-arrow";
import type { IntBitWidth, TKeys, TimeBitWidth } from "apache-arrow/type";
function sanitizeMetadata(
metadataLike?: unknown,
): Map<string, string> | undefined {
if (metadataLike === undefined || metadataLike === null) {
return undefined;
}
if (!(metadataLike instanceof Map)) {
throw Error("Expected metadata, if present, to be a Map<string, string>");
}
for (const item of metadataLike) {
if (!(typeof item[0] === "string" || !(typeof item[1] === "string"))) {
throw Error(
"Expected metadata, if present, to be a Map<string, string> but it had non-string keys or values",
);
}
}
return metadataLike as Map<string, string>;
}
function sanitizeInt(typeLike: object) {
if (
!("bitWidth" in typeLike) ||
typeof typeLike.bitWidth !== "number" ||
!("isSigned" in typeLike) ||
typeof typeLike.isSigned !== "boolean"
) {
throw Error(
"Expected an Int Type to have a `bitWidth` and `isSigned` property",
);
}
return new Int(typeLike.isSigned, typeLike.bitWidth as IntBitWidth);
}
function sanitizeFloat(typeLike: object) {
if (!("precision" in typeLike) || typeof typeLike.precision !== "number") {
throw Error("Expected a Float Type to have a `precision` property");
}
return new Float(typeLike.precision as Precision);
}
function sanitizeDecimal(typeLike: object) {
if (
!("scale" in typeLike) ||
typeof typeLike.scale !== "number" ||
!("precision" in typeLike) ||
typeof typeLike.precision !== "number" ||
!("bitWidth" in typeLike) ||
typeof typeLike.bitWidth !== "number"
) {
throw Error(
"Expected a Decimal Type to have `scale`, `precision`, and `bitWidth` properties",
);
}
return new Decimal(typeLike.scale, typeLike.precision, typeLike.bitWidth);
}
function sanitizeDate(typeLike: object) {
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
throw Error("Expected a Date type to have a `unit` property");
}
return new Date_(typeLike.unit as DateUnit);
}
function sanitizeTime(typeLike: object) {
if (
!("unit" in typeLike) ||
typeof typeLike.unit !== "number" ||
!("bitWidth" in typeLike) ||
typeof typeLike.bitWidth !== "number"
) {
throw Error(
"Expected a Time type to have `unit` and `bitWidth` properties",
);
}
return new Time(typeLike.unit, typeLike.bitWidth as TimeBitWidth);
}
function sanitizeTimestamp(typeLike: object) {
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
throw Error("Expected a Timestamp type to have a `unit` property");
}
let timezone = null;
if ("timezone" in typeLike && typeof typeLike.timezone === "string") {
timezone = typeLike.timezone;
}
return new Timestamp(typeLike.unit, timezone);
}
function sanitizeTypedTimestamp(
typeLike: object,
// eslint-disable-next-line @typescript-eslint/naming-convention
Datatype:
| typeof TimestampNanosecond
| typeof TimestampMicrosecond
| typeof TimestampMillisecond
| typeof TimestampSecond,
) {
let timezone = null;
if ("timezone" in typeLike && typeof typeLike.timezone === "string") {
timezone = typeLike.timezone;
}
return new Datatype(timezone);
}
function sanitizeInterval(typeLike: object) {
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
throw Error("Expected an Interval type to have a `unit` property");
}
return new Interval(typeLike.unit);
}
function sanitizeList(typeLike: object) {
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error(
"Expected a List type to have an array-like `children` property",
);
}
if (typeLike.children.length !== 1) {
throw Error("Expected a List type to have exactly one child");
}
return new List(sanitizeField(typeLike.children[0]));
}
function sanitizeStruct(typeLike: object) {
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error(
"Expected a Struct type to have an array-like `children` property",
);
}
return new Struct(typeLike.children.map((child) => sanitizeField(child)));
}
function sanitizeUnion(typeLike: object) {
if (
!("typeIds" in typeLike) ||
!("mode" in typeLike) ||
typeof typeLike.mode !== "number"
) {
throw Error(
"Expected a Union type to have `typeIds` and `mode` properties",
);
}
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error(
"Expected a Union type to have an array-like `children` property",
);
}
return new Union(
typeLike.mode,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
typeLike.typeIds as any,
typeLike.children.map((child) => sanitizeField(child)),
);
}
function sanitizeTypedUnion(
typeLike: object,
// eslint-disable-next-line @typescript-eslint/naming-convention
UnionType: typeof DenseUnion | typeof SparseUnion,
) {
if (!("typeIds" in typeLike)) {
throw Error(
"Expected a DenseUnion/SparseUnion type to have a `typeIds` property",
);
}
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error(
"Expected a DenseUnion/SparseUnion type to have an array-like `children` property",
);
}
return new UnionType(
typeLike.typeIds as Int32Array | number[],
typeLike.children.map((child) => sanitizeField(child)),
);
}
function sanitizeFixedSizeBinary(typeLike: object) {
if (!("byteWidth" in typeLike) || typeof typeLike.byteWidth !== "number") {
throw Error(
"Expected a FixedSizeBinary type to have a `byteWidth` property",
);
}
return new FixedSizeBinary(typeLike.byteWidth);
}
function sanitizeFixedSizeList(typeLike: object) {
if (!("listSize" in typeLike) || typeof typeLike.listSize !== "number") {
throw Error("Expected a FixedSizeList type to have a `listSize` property");
}
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error(
"Expected a FixedSizeList type to have an array-like `children` property",
);
}
if (typeLike.children.length !== 1) {
throw Error("Expected a FixedSizeList type to have exactly one child");
}
return new FixedSizeList(
typeLike.listSize,
sanitizeField(typeLike.children[0]),
);
}
function sanitizeMap(typeLike: object) {
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error(
"Expected a Map type to have an array-like `children` property",
);
}
if (!("keysSorted" in typeLike) || typeof typeLike.keysSorted !== "boolean") {
throw Error("Expected a Map type to have a `keysSorted` property");
}
return new Map_(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
typeLike.children.map((field) => sanitizeField(field)) as any,
typeLike.keysSorted,
);
}
function sanitizeDuration(typeLike: object) {
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
throw Error("Expected a Duration type to have a `unit` property");
}
return new Duration(typeLike.unit);
}
function sanitizeDictionary(typeLike: object) {
if (!("id" in typeLike) || typeof typeLike.id !== "number") {
throw Error("Expected a Dictionary type to have an `id` property");
}
if (!("indices" in typeLike) || typeof typeLike.indices !== "object") {
throw Error("Expected a Dictionary type to have an `indices` property");
}
if (!("dictionary" in typeLike) || typeof typeLike.dictionary !== "object") {
throw Error("Expected a Dictionary type to have an `dictionary` property");
}
if (!("isOrdered" in typeLike) || typeof typeLike.isOrdered !== "boolean") {
throw Error("Expected a Dictionary type to have an `isOrdered` property");
}
return new Dictionary(
sanitizeType(typeLike.dictionary),
sanitizeType(typeLike.indices) as TKeys,
typeLike.id,
typeLike.isOrdered,
);
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
function sanitizeType(typeLike: unknown): DataType<any> {
if (typeof typeLike !== "object" || typeLike === null) {
throw Error("Expected a Type but object was null/undefined");
}
if (!("typeId" in typeLike) || !(typeof typeLike.typeId !== "function")) {
throw Error("Expected a Type to have a typeId function");
}
let typeId: Type;
if (typeof typeLike.typeId === "function") {
typeId = (typeLike.typeId as () => unknown)() as Type;
} else if (typeof typeLike.typeId === "number") {
typeId = typeLike.typeId as Type;
} else {
throw Error("Type's typeId property was not a function or number");
}
switch (typeId) {
case Type.NONE:
throw Error("Received a Type with a typeId of NONE");
case Type.Null:
return new Null();
case Type.Int:
return sanitizeInt(typeLike);
case Type.Float:
return sanitizeFloat(typeLike);
case Type.Binary:
return new Binary();
case Type.Utf8:
return new Utf8();
case Type.Bool:
return new Bool();
case Type.Decimal:
return sanitizeDecimal(typeLike);
case Type.Date:
return sanitizeDate(typeLike);
case Type.Time:
return sanitizeTime(typeLike);
case Type.Timestamp:
return sanitizeTimestamp(typeLike);
case Type.Interval:
return sanitizeInterval(typeLike);
case Type.List:
return sanitizeList(typeLike);
case Type.Struct:
return sanitizeStruct(typeLike);
case Type.Union:
return sanitizeUnion(typeLike);
case Type.FixedSizeBinary:
return sanitizeFixedSizeBinary(typeLike);
case Type.FixedSizeList:
return sanitizeFixedSizeList(typeLike);
case Type.Map:
return sanitizeMap(typeLike);
case Type.Duration:
return sanitizeDuration(typeLike);
case Type.Dictionary:
return sanitizeDictionary(typeLike);
case Type.Int8:
return new Int8();
case Type.Int16:
return new Int16();
case Type.Int32:
return new Int32();
case Type.Int64:
return new Int64();
case Type.Uint8:
return new Uint8();
case Type.Uint16:
return new Uint16();
case Type.Uint32:
return new Uint32();
case Type.Uint64:
return new Uint64();
case Type.Float16:
return new Float16();
case Type.Float32:
return new Float32();
case Type.Float64:
return new Float64();
case Type.DateMillisecond:
return new DateMillisecond();
case Type.DateDay:
return new DateDay();
case Type.TimeNanosecond:
return new TimeNanosecond();
case Type.TimeMicrosecond:
return new TimeMicrosecond();
case Type.TimeMillisecond:
return new TimeMillisecond();
case Type.TimeSecond:
return new TimeSecond();
case Type.TimestampNanosecond:
return sanitizeTypedTimestamp(typeLike, TimestampNanosecond);
case Type.TimestampMicrosecond:
return sanitizeTypedTimestamp(typeLike, TimestampMicrosecond);
case Type.TimestampMillisecond:
return sanitizeTypedTimestamp(typeLike, TimestampMillisecond);
case Type.TimestampSecond:
return sanitizeTypedTimestamp(typeLike, TimestampSecond);
case Type.DenseUnion:
return sanitizeTypedUnion(typeLike, DenseUnion);
case Type.SparseUnion:
return sanitizeTypedUnion(typeLike, SparseUnion);
case Type.IntervalDayTime:
return new IntervalDayTime();
case Type.IntervalYearMonth:
return new IntervalYearMonth();
case Type.DurationNanosecond:
return new DurationNanosecond();
case Type.DurationMicrosecond:
return new DurationMicrosecond();
case Type.DurationMillisecond:
return new DurationMillisecond();
case Type.DurationSecond:
return new DurationSecond();
default:
throw new Error("Unrecoginized type id in schema: " + typeId);
}
}
function sanitizeField(fieldLike: unknown): Field {
if (fieldLike instanceof Field) {
return fieldLike;
}
if (typeof fieldLike !== "object" || fieldLike === null) {
throw Error("Expected a Field but object was null/undefined");
}
if (
!("type" in fieldLike) ||
!("name" in fieldLike) ||
!("nullable" in fieldLike)
) {
throw Error(
"The field passed in is missing a `type`/`name`/`nullable` property",
);
}
const type = sanitizeType(fieldLike.type);
const name = fieldLike.name;
if (!(typeof name === "string")) {
throw Error("The field passed in had a non-string `name` property");
}
const nullable = fieldLike.nullable;
if (!(typeof nullable === "boolean")) {
throw Error("The field passed in had a non-boolean `nullable` property");
}
let metadata;
if ("metadata" in fieldLike) {
metadata = sanitizeMetadata(fieldLike.metadata);
}
return new Field(name, type, nullable, metadata);
}
/**
* Convert something schemaLike into a Schema instance
*
* This method is often needed even when the caller is using a Schema
* instance because they might be using a different instance of apache-arrow
* than lancedb is using.
*/
export function sanitizeSchema(schemaLike: unknown): Schema {
if (schemaLike instanceof Schema) {
return schemaLike;
}
if (typeof schemaLike !== "object" || schemaLike === null) {
throw Error("Expected a Schema but object was null/undefined");
}
if (!("fields" in schemaLike)) {
throw Error(
"The schema passed in does not appear to be a schema (no 'fields' property)",
);
}
let metadata;
if ("metadata" in schemaLike) {
metadata = sanitizeMetadata(schemaLike.metadata);
}
if (!Array.isArray(schemaLike.fields)) {
throw Error(
"The schema passed in had a 'fields' property but it was not an array",
);
}
const sanitizedFields = schemaLike.fields.map((field) =>
sanitizeField(field),
);
return new Schema(sanitizedFields, metadata);
}

View File

@@ -13,54 +13,15 @@
// limitations under the License. // limitations under the License.
import { Schema, tableFromIPC } from "apache-arrow"; import { Schema, tableFromIPC } from "apache-arrow";
import { import { AddColumnsSql, ColumnAlteration, Table as _NativeTable } from "./native";
AddColumnsSql, import { toBuffer, Data } from "./arrow";
ColumnAlteration, import { Query } from "./query";
IndexConfig, import { IndexBuilder } from "./indexer";
Table as _NativeTable,
} from "./native";
import { Query, VectorQuery } from "./query";
import { IndexOptions } from "./indices";
import { Data, fromDataToBuffer } from "./arrow";
export { IndexConfig } from "./native";
/**
* Options for adding data to a table.
*/
export interface AddDataOptions {
/**
* If "append" (the default) then the new data will be added to the table
*
* If "overwrite" then the new data will replace the existing data in the table.
*/
mode: "append" | "overwrite";
}
export interface UpdateOptions {
/**
* A filter that limits the scope of the update.
*
* This should be an SQL filter expression.
*
* Only rows that satisfy the expression will be updated.
*
* For example, this could be 'my_col == 0' to replace all instances
* of 0 in a column with some other default value.
*/
where: string;
}
/** /**
* A Table is a collection of Records in a LanceDB Database. * A LanceDB Table is the collection of Records.
* *
* A Table object is expected to be long lived and reused for multiple operations. * Each Record has one or more vector fields.
* Table objects will cache a certain amount of index data in memory. This cache
* will be freed when the Table is garbage collected. To eagerly free the cache you
* can call the `close` method. Once the Table is closed, it cannot be used for any
* further operations.
*
* Closing a table is optional. It not closed, it will be closed when it is garbage
* collected.
*/ */
export class Table { export class Table {
private readonly inner: _NativeTable; private readonly inner: _NativeTable;
@@ -70,27 +31,6 @@ export class Table {
this.inner = inner; this.inner = inner;
} }
/** Return true if the table has not been closed */
isOpen(): boolean {
return this.inner.isOpen();
}
/**
* Close the table, releasing any underlying resources.
*
* It is safe to call this method multiple times.
*
* Any attempt to use the table after it is closed will result in an error.
*/
close(): void {
this.inner.close();
}
/** Return a brief description of the table */
display(): string {
return this.inner.display();
}
/** Get the schema of the table. */ /** Get the schema of the table. */
async schema(): Promise<Schema> { async schema(): Promise<Schema> {
const schemaBuf = await this.inner.schema(); const schemaBuf = await this.inner.schema();
@@ -100,52 +40,13 @@ export class Table {
/** /**
* Insert records into this Table. * Insert records into this Table.
*
* @param {Data} data Records to be inserted into the Table * @param {Data} data Records to be inserted into the Table
* @return The number of rows added to the table
*/ */
async add(data: Data, options?: Partial<AddDataOptions>): Promise<void> { async add(data: Data): Promise<void> {
const mode = options?.mode ?? "append"; const buffer = toBuffer(data);
await this.inner.add(buffer);
const buffer = await fromDataToBuffer(data);
await this.inner.add(buffer, mode);
}
/**
* Update existing records in the Table
*
* An update operation can be used to adjust existing values. Use the
* returned builder to specify which columns to update. The new value
* can be a literal value (e.g. replacing nulls with some default value)
* or an expression applied to the old value (e.g. incrementing a value)
*
* An optional condition can be specified (e.g. "only update if the old
* value is 0")
*
* Note: if your condition is something like "some_id_column == 7" and
* you are updating many rows (with different ids) then you will get
* better performance with a single [`merge_insert`] call instead of
* repeatedly calilng this method.
* @param {Map<string, string> | Record<string, string>} updates - the
* columns to update
*
* Keys in the map should specify the name of the column to update.
* Values in the map provide the new value of the column. These can
* be SQL literal strings (e.g. "7" or "'foo'") or they can be expressions
* based on the row being updated (e.g. "my_col + 1")
* @param {Partial<UpdateOptions>} options - additional options to control
* the update behavior
*/
async update(
updates: Map<string, string> | Record<string, string>,
options?: Partial<UpdateOptions>,
) {
const onlyIf = options?.where;
let columns: [string, string][];
if (updates instanceof Map) {
columns = Array.from(updates.entries());
} else {
columns = Object.entries(updates);
}
await this.inner.update(onlyIf, columns);
} }
/** Count the total number of rows in the dataset. */ /** Count the total number of rows in the dataset. */
@@ -158,105 +59,106 @@ export class Table {
await this.inner.delete(predicate); await this.inner.delete(predicate);
} }
/** /** Create an index over the columns.
* Create an index to speed up queries. *
* @param {string} column The column to create the index on. If not specified,
* it will create an index on vector field.
* *
* Indices can be created on vector columns or scalar columns.
* Indices on vector columns will speed up vector searches.
* Indices on scalar columns will speed up filtering (in both
* vector and non-vector searches)
* @example * @example
* // If the column has a vector (fixed size list) data type then *
* // an IvfPq vector index will be created. * By default, it creates vector idnex on one vector column.
*
* ```typescript
* const table = await conn.openTable("my_table"); * const table = await conn.openTable("my_table");
* await table.createIndex(["vector"]); * await table.createIndex().build();
* @example * ```
* // For advanced control over vector index creation you can specify *
* // the index type and options. * You can specify `IVF_PQ` parameters via `ivf_pq({})` call.
* ```typescript
* const table = await conn.openTable("my_table"); * const table = await conn.openTable("my_table");
* await table.createIndex(["vector"], I) * await table.createIndex("my_vec_col")
* .ivf_pq({ num_partitions: 128, num_sub_vectors: 16 }) * .ivf_pq({ num_partitions: 128, num_sub_vectors: 16 })
* .build(); * .build();
* @example * ```
* // Or create a Scalar index *
* Or create a Scalar index
*
* ```typescript
* await table.createIndex("my_float_col").build(); * await table.createIndex("my_float_col").build();
* ```
*/ */
async createIndex(column: string, options?: Partial<IndexOptions>) { createIndex(column?: string): IndexBuilder {
// Bit of a hack to get around the fact that TS has no package-scope. let builder = new IndexBuilder(this.inner);
// eslint-disable-next-line @typescript-eslint/no-explicit-any if (column !== undefined) {
const nativeIndex = (options?.config as any)?.inner; builder = builder.column(column);
await this.inner.createIndex(nativeIndex, column, options?.replace); }
return builder;
} }
/** /**
* Create a {@link Query} Builder. * Create a generic {@link Query} Builder.
*
* Queries allow you to search your existing data. By default the query will
* return all the data in the table in no particular order. The builder
* returned by this method can be used to control the query using filtering,
* vector similarity, sorting, and more.
*
* Note: By default, all columns are returned. For best performance, you should
* only fetch the columns you need. See [`Query::select_with_projection`] for
* more details.
* *
* When appropriate, various indices and statistics based pruning will be used to * When appropriate, various indices and statistics based pruning will be used to
* accelerate the query. * accelerate the query.
*
* @example * @example
* // SQL-style filtering *
* // * ### Run a SQL-style query
* // This query will return up to 1000 rows whose value in the `id` column * ```typescript
* // is greater than 5. LanceDb supports a broad set of filtering functions.
* for await (const batch of table.query() * for await (const batch of table.query()
* .filter("id > 1").select(["id"]).limit(20)) { * .filter("id > 1").select(["id"]).limit(20)) {
* console.log(batch); * console.log(batch);
* } * }
* @example * ```
* // Vector Similarity Search *
* // * ### Run Top-10 vector similarity search
* // This example will find the 10 rows whose value in the "vector" column are * ```typescript
* // closest to the query vector [1.0, 2.0, 3.0]. If an index has been created
* // on the "vector" column then this will perform an ANN search.
* //
* // The `refine_factor` and `nprobes` methods are used to control the recall /
* // latency tradeoff of the search.
* for await (const batch of table.query() * for await (const batch of table.query()
* .nearestTo([1, 2, 3]) * .nearestTo([1, 2, 3])
* .refineFactor(5).nprobe(10) * .refineFactor(5).nprobe(10)
* .limit(10)) { * .limit(10)) {
* console.log(batch); * console.log(batch);
* } * }
* @example *```
* // Scan the full dataset *
* // * ### Scan the full dataset
* // This query will return everything in the table in no particular order. * ```typescript
* for await (const batch of table.query()) { * for await (const batch of table.query()) {
* console.log(batch); * console.log(batch);
* } * }
* @returns {Query} A builder that can be used to parameterize the query *
* ### Return the full dataset as Arrow Table
* ```typescript
* let arrowTbl = await table.query().nearestTo([1.0, 2.0, 0.5, 6.7]).toArrow();
* ```
*
* @returns {@link Query}
*/ */
query(): Query { query(): Query {
return new Query(this.inner); return new Query(this.inner);
} }
/** /** Search the table with a given query vector.
* Search the table with a given query vector.
* *
* This is a convenience method for preparing a vector query and * This is a convenience method for preparing an ANN {@link Query}.
* is the same thing as calling `nearestTo` on the builder returned
* by `query`. @see {@link Query#nearestTo} for more details.
*/ */
vectorSearch(vector: unknown): VectorQuery { search(vector: number[], column?: string): Query {
return this.query().nearestTo(vector); const q = this.query();
q.nearestTo(vector);
if (column !== undefined) {
q.column(column);
}
return q;
} }
// TODO: Support BatchUDF // TODO: Support BatchUDF
/** /**
* Add new columns with defined values. * Add new columns with defined values.
* @param {AddColumnsSql[]} newColumnTransforms pairs of column names and *
* the SQL expression to use to calculate the value of the new column. These * @param newColumnTransforms pairs of column names and the SQL expression to use
* expressions will be evaluated for each row in the table, and can * to calculate the value of the new column. These
* reference existing columns in the table. * expressions will be evaluated for each row in the
* table, and can reference existing columns in the table.
*/ */
async addColumns(newColumnTransforms: AddColumnsSql[]): Promise<void> { async addColumns(newColumnTransforms: AddColumnsSql[]): Promise<void> {
await this.inner.addColumns(newColumnTransforms); await this.inner.addColumns(newColumnTransforms);
@@ -264,8 +166,8 @@ export class Table {
/** /**
* Alter the name or nullability of columns. * Alter the name or nullability of columns.
* @param {ColumnAlteration[]} columnAlterations One or more alterations to *
* apply to columns. * @param columnAlterations One or more alterations to apply to columns.
*/ */
async alterColumns(columnAlterations: ColumnAlteration[]): Promise<void> { async alterColumns(columnAlterations: ColumnAlteration[]): Promise<void> {
await this.inner.alterColumns(columnAlterations); await this.inner.alterColumns(columnAlterations);
@@ -278,76 +180,12 @@ export class Table {
* underlying storage. In order to remove the data, you must subsequently * underlying storage. In order to remove the data, you must subsequently
* call ``compact_files`` to rewrite the data without the removed columns and * call ``compact_files`` to rewrite the data without the removed columns and
* then call ``cleanup_files`` to remove the old files. * then call ``cleanup_files`` to remove the old files.
* @param {string[]} columnNames The names of the columns to drop. These can *
* be nested column references (e.g. "a.b.c") or top-level column names * @param columnNames The names of the columns to drop. These can be nested
* (e.g. "a"). * column references (e.g. "a.b.c") or top-level column
* names (e.g. "a").
*/ */
async dropColumns(columnNames: string[]): Promise<void> { async dropColumns(columnNames: string[]): Promise<void> {
await this.inner.dropColumns(columnNames); await this.inner.dropColumns(columnNames);
} }
/**
* Retrieve the version of the table
*
* LanceDb supports versioning. Every operation that modifies the table increases
* version. As long as a version hasn't been deleted you can `[Self::checkout]` that
* version to view the data at that point. In addition, you can `[Self::restore]` the
* version to replace the current table with a previous version.
*/
async version(): Promise<number> {
return await this.inner.version();
}
/**
* Checks out a specific version of the Table
*
* Any read operation on the table will now access the data at the checked out version.
* As a consequence, calling this method will disable any read consistency interval
* that was previously set.
*
* This is a read-only operation that turns the table into a sort of "view"
* or "detached head". Other table instances will not be affected. To make the change
* permanent you can use the `[Self::restore]` method.
*
* Any operation that modifies the table will fail while the table is in a checked
* out state.
*
* To return the table to a normal state use `[Self::checkout_latest]`
*/
async checkout(version: number): Promise<void> {
await this.inner.checkout(version);
}
/**
* Ensures the table is pointing at the latest version
*
* This can be used to manually update a table when the read_consistency_interval is None
* It can also be used to undo a `[Self::checkout]` operation
*/
async checkoutLatest(): Promise<void> {
await this.inner.checkoutLatest();
}
/**
* Restore the table to the currently checked out version
*
* This operation will fail if checkout has not been called previously
*
* This operation will overwrite the latest version of the table with a
* previous version. Any changes made since the checked out version will
* no longer be visible.
*
* Once the operation concludes the table will no longer be in a checked
* out state and the read_consistency_interval, if any, will apply.
*/
async restore(): Promise<void> {
await this.inner.restore();
}
/**
* List all indices that have been created with Self::create_index
*/
async listIndices(): Promise<IndexConfig[]> {
return await this.inner.listIndices();
}
} }

979
nodejs/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -19,21 +19,14 @@
"devDependencies": { "devDependencies": {
"@napi-rs/cli": "^2.18.0", "@napi-rs/cli": "^2.18.0",
"@types/jest": "^29.1.2", "@types/jest": "^29.1.2",
"@types/tmp": "^0.2.6",
"@typescript-eslint/eslint-plugin": "^6.19.0", "@typescript-eslint/eslint-plugin": "^6.19.0",
"@typescript-eslint/parser": "^6.19.0", "@typescript-eslint/parser": "^6.19.0",
"apache-arrow-old": "npm:apache-arrow@13.0.0", "eslint": "^8.56.0",
"eslint": "^8.57.0",
"eslint-config-prettier": "^9.1.0",
"eslint-plugin-jsdoc": "^48.2.1",
"jest": "^29.7.0", "jest": "^29.7.0",
"prettier": "^3.1.0",
"tmp": "^0.2.3",
"ts-jest": "^29.1.2", "ts-jest": "^29.1.2",
"typedoc": "^0.25.7", "typedoc": "^0.25.7",
"typedoc-plugin-markdown": "^3.17.1", "typedoc-plugin-markdown": "^3.17.1",
"typescript": "^5.3.3", "typescript": "^5.3.3"
"typescript-eslint": "^7.1.0"
}, },
"ava": { "ava": {
"timeout": "3m" "timeout": "3m"
@@ -55,9 +48,8 @@
"build:native": "napi build --platform --release --js lancedb/native.js --dts lancedb/native.d.ts dist/", "build:native": "napi build --platform --release --js lancedb/native.js --dts lancedb/native.d.ts dist/",
"build:debug": "napi build --platform --dts ../lancedb/native.d.ts --js ../lancedb/native.js dist/", "build:debug": "napi build --platform --dts ../lancedb/native.d.ts --js ../lancedb/native.js dist/",
"build": "npm run build:debug && tsc -b", "build": "npm run build:debug && tsc -b",
"chkformat": "prettier . --check",
"docs": "typedoc --plugin typedoc-plugin-markdown lancedb/index.ts", "docs": "typedoc --plugin typedoc-plugin-markdown lancedb/index.ts",
"lint": "eslint lancedb && eslint __test__", "lint": "eslint lancedb --ext .js,.ts",
"prepublishOnly": "napi prepublish -t npm", "prepublishOnly": "napi prepublish -t npm",
"test": "npm run build && jest --verbose", "test": "npm run build && jest --verbose",
"universal": "napi universal", "universal": "napi universal",
@@ -67,8 +59,7 @@
"lancedb-darwin-arm64": "0.4.3", "lancedb-darwin-arm64": "0.4.3",
"lancedb-darwin-x64": "0.4.3", "lancedb-darwin-x64": "0.4.3",
"lancedb-linux-arm64-gnu": "0.4.3", "lancedb-linux-arm64-gnu": "0.4.3",
"lancedb-linux-x64-gnu": "0.4.3", "lancedb-linux-x64-gnu": "0.4.3"
"openai": "^4.28.4"
}, },
"peerDependencies": { "peerDependencies": {
"apache-arrow": "^15.0.0" "apache-arrow": "^15.0.0"

View File

@@ -18,23 +18,11 @@ use napi_derive::*;
use crate::table::Table; use crate::table::Table;
use crate::ConnectionOptions; use crate::ConnectionOptions;
use lancedb::connection::{ConnectBuilder, Connection as LanceDBConnection, CreateTableMode}; use lancedb::connection::{ConnectBuilder, Connection as LanceDBConnection, CreateTableMode};
use lancedb::ipc::{ipc_file_to_batches, ipc_file_to_schema}; use lancedb::ipc::ipc_file_to_batches;
#[napi] #[napi]
pub struct Connection { pub struct Connection {
inner: Option<LanceDBConnection>, conn: LanceDBConnection,
}
impl Connection {
pub(crate) fn inner_new(inner: LanceDBConnection) -> Self {
Self { inner: Some(inner) }
}
fn get_inner(&self) -> napi::Result<&LanceDBConnection> {
self.inner
.as_ref()
.ok_or_else(|| napi::Error::from_reason("Connection is closed"))
}
} }
impl Connection { impl Connection {
@@ -52,8 +40,8 @@ impl Connection {
impl Connection { impl Connection {
/// Create a new Connection instance from the given URI. /// Create a new Connection instance from the given URI.
#[napi(factory)] #[napi(factory)]
pub async fn new(uri: String, options: ConnectionOptions) -> napi::Result<Self> { pub async fn new(options: ConnectionOptions) -> napi::Result<Self> {
let mut builder = ConnectBuilder::new(&uri); let mut builder = ConnectBuilder::new(&options.uri);
if let Some(api_key) = options.api_key { if let Some(api_key) = options.api_key {
builder = builder.api_key(&api_key); builder = builder.api_key(&api_key);
} }
@@ -64,44 +52,19 @@ impl Connection {
builder = builder =
builder.read_consistency_interval(std::time::Duration::from_secs_f64(interval)); builder.read_consistency_interval(std::time::Duration::from_secs_f64(interval));
} }
Ok(Self::inner_new( Ok(Self {
builder conn: builder
.execute() .execute()
.await .await
.map_err(|e| napi::Error::from_reason(format!("{}", e)))?, .map_err(|e| napi::Error::from_reason(format!("{}", e)))?,
)) })
}
#[napi]
pub fn display(&self) -> napi::Result<String> {
Ok(self.get_inner()?.to_string())
}
#[napi]
pub fn is_open(&self) -> bool {
self.inner.is_some()
}
#[napi]
pub fn close(&mut self) {
self.inner.take();
} }
/// List all tables in the dataset. /// List all tables in the dataset.
#[napi] #[napi]
pub async fn table_names( pub async fn table_names(&self) -> napi::Result<Vec<String>> {
&self, self.conn
start_after: Option<String>, .table_names()
limit: Option<u32>,
) -> napi::Result<Vec<String>> {
let mut op = self.get_inner()?.table_names();
if let Some(start_after) = start_after {
op = op.start_after(start_after);
}
if let Some(limit) = limit {
op = op.limit(limit);
}
op.execute()
.await .await
.map_err(|e| napi::Error::from_reason(format!("{}", e))) .map_err(|e| napi::Error::from_reason(format!("{}", e)))
} }
@@ -123,29 +86,8 @@ impl Connection {
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?; .map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
let mode = Self::parse_create_mode_str(&mode)?; let mode = Self::parse_create_mode_str(&mode)?;
let tbl = self let tbl = self
.get_inner()? .conn
.create_table(&name, batches) .create_table(&name, Box::new(batches))
.mode(mode)
.execute()
.await
.map_err(|e| napi::Error::from_reason(format!("{}", e)))?;
Ok(Table::new(tbl))
}
#[napi]
pub async fn create_empty_table(
&self,
name: String,
schema_buf: Buffer,
mode: String,
) -> napi::Result<Table> {
let schema = ipc_file_to_schema(schema_buf.to_vec()).map_err(|e| {
napi::Error::from_reason(format!("Failed to marshal schema from JS to Rust: {}", e))
})?;
let mode = Self::parse_create_mode_str(&mode)?;
let tbl = self
.get_inner()?
.create_empty_table(&name, schema)
.mode(mode) .mode(mode)
.execute() .execute()
.await .await
@@ -156,7 +98,7 @@ impl Connection {
#[napi] #[napi]
pub async fn open_table(&self, name: String) -> napi::Result<Table> { pub async fn open_table(&self, name: String) -> napi::Result<Table> {
let tbl = self let tbl = self
.get_inner()? .conn
.open_table(&name) .open_table(&name)
.execute() .execute()
.await .await
@@ -167,7 +109,7 @@ impl Connection {
/// Drop table with the name. Or raise an error if the table does not exist. /// Drop table with the name. Or raise an error if the table does not exist.
#[napi] #[napi]
pub async fn drop_table(&self, name: String) -> napi::Result<()> { pub async fn drop_table(&self, name: String) -> napi::Result<()> {
self.get_inner()? self.conn
.drop_table(&name) .drop_table(&name)
.await .await
.map_err(|e| napi::Error::from_reason(format!("{}", e))) .map_err(|e| napi::Error::from_reason(format!("{}", e)))

View File

@@ -1,12 +0,0 @@
pub type Result<T> = napi::Result<T>;
pub trait NapiErrorExt<T> {
/// Convert to a napi error using from_reason(err.to_string())
fn default_error(self) -> Result<T>;
}
impl<T> NapiErrorExt<T> for std::result::Result<T, lancedb::Error> {
fn default_error(self) -> Result<T> {
self.map_err(|err| napi::Error::from_reason(err.to_string()))
}
}

View File

@@ -12,68 +12,89 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::sync::Mutex; use lance_linalg::distance::MetricType as LanceMetricType;
use lancedb::index::scalar::BTreeIndexBuilder;
use lancedb::index::vector::IvfPqIndexBuilder;
use lancedb::index::Index as LanceDbIndex;
use napi_derive::napi; use napi_derive::napi;
use crate::util::parse_distance_type;
#[napi] #[napi]
pub struct Index { pub enum IndexType {
inner: Mutex<Option<LanceDbIndex>>, Scalar,
IvfPq,
} }
impl Index { #[napi]
pub fn consume(&self) -> napi::Result<LanceDbIndex> { pub enum MetricType {
self.inner L2,
.lock() Cosine,
.unwrap() Dot,
.take() }
.ok_or(napi::Error::from_reason(
"attempt to use an index more than once", impl From<MetricType> for LanceMetricType {
)) fn from(metric: MetricType) -> Self {
match metric {
MetricType::L2 => Self::L2,
MetricType::Cosine => Self::Cosine,
MetricType::Dot => Self::Dot,
}
} }
} }
#[napi] #[napi]
impl Index { pub struct IndexBuilder {
#[napi(factory)] inner: lancedb::index::IndexBuilder,
pub fn ivf_pq( }
distance_type: Option<String>,
#[napi]
impl IndexBuilder {
pub fn new(tbl: &dyn lancedb::Table) -> Self {
let inner = tbl.create_index(&[]);
Self { inner }
}
#[napi]
pub unsafe fn replace(&mut self, v: bool) {
self.inner.replace(v);
}
#[napi]
pub unsafe fn column(&mut self, c: String) {
self.inner.columns(&[c.as_str()]);
}
#[napi]
pub unsafe fn name(&mut self, name: String) {
self.inner.name(name.as_str());
}
#[napi]
pub unsafe fn ivf_pq(
&mut self,
metric_type: Option<MetricType>,
num_partitions: Option<u32>, num_partitions: Option<u32>,
num_sub_vectors: Option<u32>, num_sub_vectors: Option<u32>,
num_bits: Option<u32>,
max_iterations: Option<u32>, max_iterations: Option<u32>,
sample_rate: Option<u32>, sample_rate: Option<u32>,
) -> napi::Result<Self> { ) {
let mut ivf_pq_builder = IvfPqIndexBuilder::default(); self.inner.ivf_pq();
if let Some(distance_type) = distance_type { metric_type.map(|m| self.inner.metric_type(m.into()));
let distance_type = parse_distance_type(distance_type)?; num_partitions.map(|p| self.inner.num_partitions(p));
ivf_pq_builder = ivf_pq_builder.distance_type(distance_type); num_sub_vectors.map(|s| self.inner.num_sub_vectors(s));
} num_bits.map(|b| self.inner.num_bits(b));
if let Some(num_partitions) = num_partitions { max_iterations.map(|i| self.inner.max_iterations(i));
ivf_pq_builder = ivf_pq_builder.num_partitions(num_partitions); sample_rate.map(|s| self.inner.sample_rate(s));
}
if let Some(num_sub_vectors) = num_sub_vectors {
ivf_pq_builder = ivf_pq_builder.num_sub_vectors(num_sub_vectors);
}
if let Some(max_iterations) = max_iterations {
ivf_pq_builder = ivf_pq_builder.max_iterations(max_iterations);
}
if let Some(sample_rate) = sample_rate {
ivf_pq_builder = ivf_pq_builder.sample_rate(sample_rate);
}
Ok(Self {
inner: Mutex::new(Some(LanceDbIndex::IvfPq(ivf_pq_builder))),
})
} }
#[napi(factory)] #[napi]
pub fn btree() -> Self { pub unsafe fn scalar(&mut self) {
Self { self.inner.scalar();
inner: Mutex::new(Some(LanceDbIndex::BTree(BTreeIndexBuilder::default()))),
} }
#[napi]
pub async fn build(&self) -> napi::Result<()> {
self.inner
.build()
.await
.map_err(|e| napi::Error::from_reason(format!("Failed to build index: {}", e)))?;
Ok(())
} }
} }

View File

@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
use futures::StreamExt; use futures::StreamExt;
use lancedb::arrow::SendableRecordBatchStream; use lance::io::RecordBatchStream;
use lancedb::ipc::batches_to_ipc_file; use lancedb::ipc::batches_to_ipc_file;
use napi::bindgen_prelude::*; use napi::bindgen_prelude::*;
use napi_derive::napi; use napi_derive::napi;
@@ -21,12 +21,12 @@ use napi_derive::napi;
/** Typescript-style Async Iterator over RecordBatches */ /** Typescript-style Async Iterator over RecordBatches */
#[napi] #[napi]
pub struct RecordBatchIterator { pub struct RecordBatchIterator {
inner: SendableRecordBatchStream, inner: Box<dyn RecordBatchStream + Unpin>,
} }
#[napi] #[napi]
impl RecordBatchIterator { impl RecordBatchIterator {
pub(crate) fn new(inner: SendableRecordBatchStream) -> Self { pub(crate) fn new(inner: Box<dyn RecordBatchStream + Unpin>) -> Self {
Self { inner } Self { inner }
} }

View File

@@ -16,16 +16,15 @@ use connection::Connection;
use napi_derive::*; use napi_derive::*;
mod connection; mod connection;
mod error;
mod index; mod index;
mod iterator; mod iterator;
mod query; mod query;
mod table; mod table;
mod util;
#[napi(object)] #[napi(object)]
#[derive(Debug)] #[derive(Debug)]
pub struct ConnectionOptions { pub struct ConnectionOptions {
pub uri: String,
pub api_key: Option<String>, pub api_key: Option<String>,
pub host_override: Option<String>, pub host_override: Option<String>,
/// (For LanceDB OSS only): The interval, in seconds, at which to check for /// (For LanceDB OSS only): The interval, in seconds, at which to check for
@@ -55,6 +54,6 @@ pub struct WriteOptions {
} }
#[napi] #[napi]
pub async fn connect(uri: String, options: ConnectionOptions) -> napi::Result<Connection> { pub async fn connect(options: ConnectionOptions) -> napi::Result<Connection> {
Connection::new(uri, options).await Connection::new(options).await
} }

View File

@@ -12,38 +12,38 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use lancedb::query::ExecutableQuery; use lancedb::query::Query as LanceDBQuery;
use lancedb::query::Query as LanceDbQuery;
use lancedb::query::QueryBase;
use lancedb::query::Select;
use lancedb::query::VectorQuery as LanceDbVectorQuery;
use napi::bindgen_prelude::*; use napi::bindgen_prelude::*;
use napi_derive::napi; use napi_derive::napi;
use crate::error::NapiErrorExt; use crate::{iterator::RecordBatchIterator, table::Table};
use crate::iterator::RecordBatchIterator;
use crate::util::parse_distance_type;
#[napi] #[napi]
pub struct Query { pub struct Query {
inner: LanceDbQuery, inner: LanceDBQuery,
} }
#[napi] #[napi]
impl Query { impl Query {
pub fn new(query: LanceDbQuery) -> Self { pub fn new(table: &Table) -> Self {
Self { inner: query } Self {
inner: table.table.query(),
} }
// We cannot call this r#where because NAPI gets confused by the r#
#[napi]
pub fn only_if(&mut self, predicate: String) {
self.inner = self.inner.clone().only_if(predicate);
} }
#[napi] #[napi]
pub fn select(&mut self, columns: Vec<(String, String)>) { pub fn column(&mut self, column: String) {
self.inner = self.inner.clone().select(Select::dynamic(&columns)); self.inner = self.inner.clone().column(&column);
}
#[napi]
pub fn filter(&mut self, filter: String) {
self.inner = self.inner.clone().filter(filter);
}
#[napi]
pub fn select(&mut self, columns: Vec<String>) {
self.inner = self.inner.clone().select(&columns);
} }
#[napi] #[napi]
@@ -52,46 +52,13 @@ impl Query {
} }
#[napi] #[napi]
pub fn nearest_to(&mut self, vector: Float32Array) -> Result<VectorQuery> { pub fn prefilter(&mut self, prefilter: bool) {
let inner = self self.inner = self.inner.clone().prefilter(prefilter);
.inner
.clone()
.nearest_to(vector.as_ref())
.default_error()?;
Ok(VectorQuery { inner })
} }
#[napi] #[napi]
pub async fn execute(&self) -> napi::Result<RecordBatchIterator> { pub fn nearest_to(&mut self, vector: Float32Array) {
let inner_stream = self.inner.execute().await.map_err(|e| { self.inner = self.inner.clone().nearest_to(&vector);
napi::Error::from_reason(format!("Failed to execute query stream: {}", e))
})?;
Ok(RecordBatchIterator::new(inner_stream))
}
}
#[napi]
pub struct VectorQuery {
inner: LanceDbVectorQuery,
}
#[napi]
impl VectorQuery {
#[napi]
pub fn column(&mut self, column: String) {
self.inner = self.inner.clone().column(&column);
}
#[napi]
pub fn distance_type(&mut self, distance_type: String) -> napi::Result<()> {
let distance_type = parse_distance_type(distance_type)?;
self.inner = self.inner.clone().distance_type(distance_type);
Ok(())
}
#[napi]
pub fn postfilter(&mut self) {
self.inner = self.inner.clone().postfilter();
} }
#[napi] #[napi]
@@ -105,30 +72,10 @@ impl VectorQuery {
} }
#[napi] #[napi]
pub fn bypass_vector_index(&mut self) { pub async fn execute_stream(&self) -> napi::Result<RecordBatchIterator> {
self.inner = self.inner.clone().bypass_vector_index() let inner_stream = self.inner.execute_stream().await.map_err(|e| {
}
#[napi]
pub fn only_if(&mut self, predicate: String) {
self.inner = self.inner.clone().only_if(predicate);
}
#[napi]
pub fn select(&mut self, columns: Vec<(String, String)>) {
self.inner = self.inner.clone().select(Select::dynamic(&columns));
}
#[napi]
pub fn limit(&mut self, limit: u32) {
self.inner = self.inner.clone().limit(limit as usize);
}
#[napi]
pub async fn execute(&self) -> napi::Result<RecordBatchIterator> {
let inner_stream = self.inner.execute().await.map_err(|e| {
napi::Error::from_reason(format!("Failed to execute query stream: {}", e)) napi::Error::from_reason(format!("Failed to execute query stream: {}", e))
})?; })?;
Ok(RecordBatchIterator::new(inner_stream)) Ok(RecordBatchIterator::new(Box::new(inner_stream)))
} }
} }

View File

@@ -13,66 +13,33 @@
// limitations under the License. // limitations under the License.
use arrow_ipc::writer::FileWriter; use arrow_ipc::writer::FileWriter;
use lancedb::ipc::ipc_file_to_batches; use lance::dataset::ColumnAlteration as LanceColumnAlteration;
use lancedb::table::{ use lancedb::{
AddDataMode, ColumnAlteration as LanceColumnAlteration, NewColumnTransform, ipc::ipc_file_to_batches,
Table as LanceDbTable, table::{AddDataOptions, TableRef},
}; };
use napi::bindgen_prelude::*; use napi::bindgen_prelude::*;
use napi_derive::napi; use napi_derive::napi;
use crate::error::NapiErrorExt; use crate::index::IndexBuilder;
use crate::index::Index; use crate::query::Query;
use crate::query::{Query, VectorQuery};
#[napi] #[napi]
pub struct Table { pub struct Table {
// We keep a duplicate of the table name so we can use it for error pub(crate) table: TableRef,
// messages even if the table has been closed
name: String,
pub(crate) inner: Option<LanceDbTable>,
}
impl Table {
fn inner_ref(&self) -> napi::Result<&LanceDbTable> {
self.inner
.as_ref()
.ok_or_else(|| napi::Error::from_reason(format!("Table {} is closed", self.name)))
}
} }
#[napi] #[napi]
impl Table { impl Table {
pub(crate) fn new(table: LanceDbTable) -> Self { pub(crate) fn new(table: TableRef) -> Self {
Self { Self { table }
name: table.name().to_string(),
inner: Some(table),
}
}
#[napi]
pub fn display(&self) -> String {
match &self.inner {
None => format!("ClosedTable({})", self.name),
Some(inner) => inner.to_string(),
}
}
#[napi]
pub fn is_open(&self) -> bool {
self.inner.is_some()
}
#[napi]
pub fn close(&mut self) {
self.inner.take();
} }
/// Return Schema as empty Arrow IPC file. /// Return Schema as empty Arrow IPC file.
#[napi] #[napi]
pub async fn schema(&self) -> napi::Result<Buffer> { pub async fn schema(&self) -> napi::Result<Buffer> {
let schema = let schema =
self.inner_ref()?.schema().await.map_err(|e| { self.table.schema().await.map_err(|e| {
napi::Error::from_reason(format!("Failed to create IPC file: {}", e)) napi::Error::from_reason(format!("Failed to create IPC file: {}", e))
})?; })?;
let mut writer = FileWriter::try_new(vec![], &schema) let mut writer = FileWriter::try_new(vec![], &schema)
@@ -86,94 +53,52 @@ impl Table {
} }
#[napi] #[napi]
pub async fn add(&self, buf: Buffer, mode: String) -> napi::Result<()> { pub async fn add(&self, buf: Buffer) -> napi::Result<()> {
let batches = ipc_file_to_batches(buf.to_vec()) let batches = ipc_file_to_batches(buf.to_vec())
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?; .map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
let mut op = self.inner_ref()?.add(batches); self.table
.add(Box::new(batches), AddDataOptions::default())
op = if mode == "append" { .await
op.mode(AddDataMode::Append) .map_err(|e| {
} else if mode == "overwrite" {
op.mode(AddDataMode::Overwrite)
} else {
return Err(napi::Error::from_reason(format!("Invalid mode: {}", mode)));
};
op.execute().await.map_err(|e| {
napi::Error::from_reason(format!( napi::Error::from_reason(format!(
"Failed to add batches to table {}: {}", "Failed to add batches to table {}: {}",
self.name, e self.table, e
)) ))
}) })
} }
#[napi] #[napi]
pub async fn count_rows(&self, filter: Option<String>) -> napi::Result<i64> { pub async fn count_rows(&self, filter: Option<String>) -> napi::Result<i64> {
self.inner_ref()? self.table
.count_rows(filter) .count_rows(filter)
.await .await
.map(|val| val as i64) .map(|val| val as i64)
.map_err(|e| { .map_err(|e| {
napi::Error::from_reason(format!( napi::Error::from_reason(format!(
"Failed to count rows in table {}: {}", "Failed to count rows in table {}: {}",
self.name, e self.table, e
)) ))
}) })
} }
#[napi] #[napi]
pub async fn delete(&self, predicate: String) -> napi::Result<()> { pub async fn delete(&self, predicate: String) -> napi::Result<()> {
self.inner_ref()?.delete(&predicate).await.map_err(|e| { self.table.delete(&predicate).await.map_err(|e| {
napi::Error::from_reason(format!( napi::Error::from_reason(format!(
"Failed to delete rows in table {}: predicate={}", "Failed to delete rows in table {}: predicate={}",
self.name, e self.table, e
)) ))
}) })
} }
#[napi] #[napi]
pub async fn create_index( pub fn create_index(&self) -> IndexBuilder {
&self, IndexBuilder::new(self.table.as_ref())
index: Option<&Index>,
column: String,
replace: Option<bool>,
) -> napi::Result<()> {
let lancedb_index = if let Some(index) = index {
index.consume()?
} else {
lancedb::index::Index::Auto
};
let mut builder = self.inner_ref()?.create_index(&[column], lancedb_index);
if let Some(replace) = replace {
builder = builder.replace(replace);
}
builder.execute().await.default_error()
} }
#[napi] #[napi]
pub async fn update( pub fn query(&self) -> Query {
&self, Query::new(self)
only_if: Option<String>,
columns: Vec<(String, String)>,
) -> napi::Result<()> {
let mut op = self.inner_ref()?.update();
if let Some(only_if) = only_if {
op = op.only_if(only_if);
}
for (column_name, value) in columns {
op = op.column(column_name, value);
}
op.execute().await.default_error()
}
#[napi]
pub fn query(&self) -> napi::Result<Query> {
Ok(Query::new(self.inner_ref()?.query()))
}
#[napi]
pub fn vector_search(&self, vector: Float32Array) -> napi::Result<VectorQuery> {
self.query()?.nearest_to(vector)
} }
#[napi] #[napi]
@@ -182,14 +107,14 @@ impl Table {
.into_iter() .into_iter()
.map(|sql| (sql.name, sql.value_sql)) .map(|sql| (sql.name, sql.value_sql))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let transforms = NewColumnTransform::SqlExpressions(transforms); let transforms = lance::dataset::NewColumnTransform::SqlExpressions(transforms);
self.inner_ref()? self.table
.add_columns(transforms, None) .add_columns(transforms, None)
.await .await
.map_err(|err| { .map_err(|err| {
napi::Error::from_reason(format!( napi::Error::from_reason(format!(
"Failed to add columns to table {}: {}", "Failed to add columns to table {}: {}",
self.name, err self.table, err
)) ))
})?; })?;
Ok(()) Ok(())
@@ -209,13 +134,13 @@ impl Table {
.map(LanceColumnAlteration::from) .map(LanceColumnAlteration::from)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
self.inner_ref()? self.table
.alter_columns(&alterations) .alter_columns(&alterations)
.await .await
.map_err(|err| { .map_err(|err| {
napi::Error::from_reason(format!( napi::Error::from_reason(format!(
"Failed to alter columns in table {}: {}", "Failed to alter columns in table {}: {}",
self.name, err self.table, err
)) ))
})?; })?;
Ok(()) Ok(())
@@ -224,78 +149,14 @@ impl Table {
#[napi] #[napi]
pub async fn drop_columns(&self, columns: Vec<String>) -> napi::Result<()> { pub async fn drop_columns(&self, columns: Vec<String>) -> napi::Result<()> {
let col_refs = columns.iter().map(String::as_str).collect::<Vec<_>>(); let col_refs = columns.iter().map(String::as_str).collect::<Vec<_>>();
self.inner_ref()? self.table.drop_columns(&col_refs).await.map_err(|err| {
.drop_columns(&col_refs)
.await
.map_err(|err| {
napi::Error::from_reason(format!( napi::Error::from_reason(format!(
"Failed to drop columns from table {}: {}", "Failed to drop columns from table {}: {}",
self.name, err self.table, err
)) ))
})?; })?;
Ok(()) Ok(())
} }
#[napi]
pub async fn version(&self) -> napi::Result<i64> {
self.inner_ref()?
.version()
.await
.map(|val| val as i64)
.default_error()
}
#[napi]
pub async fn checkout(&self, version: i64) -> napi::Result<()> {
self.inner_ref()?
.checkout(version as u64)
.await
.default_error()
}
#[napi]
pub async fn checkout_latest(&self) -> napi::Result<()> {
self.inner_ref()?.checkout_latest().await.default_error()
}
#[napi]
pub async fn restore(&self) -> napi::Result<()> {
self.inner_ref()?.restore().await.default_error()
}
#[napi]
pub async fn list_indices(&self) -> napi::Result<Vec<IndexConfig>> {
Ok(self
.inner_ref()?
.list_indices()
.await
.default_error()?
.into_iter()
.map(IndexConfig::from)
.collect::<Vec<_>>())
}
}
#[napi(object)]
/// A description of an index currently configured on a column
pub struct IndexConfig {
/// The type of the index
pub index_type: String,
/// The columns in the index
///
/// Currently this is always an array of size 1. In the future there may
/// be more columns to represent composite indices.
pub columns: Vec<String>,
}
impl From<lancedb::index::IndexConfig> for IndexConfig {
fn from(value: lancedb::index::IndexConfig) -> Self {
let index_type = format!("{:?}", value.index_type);
Self {
index_type,
columns: value.columns,
}
}
} }
/// A definition of a column alteration. The alteration changes the column at /// A definition of a column alteration. The alteration changes the column at

View File

@@ -1,13 +0,0 @@
use lancedb::DistanceType;
pub fn parse_distance_type(distance_type: impl AsRef<str>) -> napi::Result<DistanceType> {
match distance_type.as_ref().to_lowercase().as_str() {
"l2" => Ok(DistanceType::L2),
"cosine" => Ok(DistanceType::Cosine),
"dot" => Ok(DistanceType::Dot),
_ => Err(napi::Error::from_reason(format!(
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
distance_type.as_ref()
))),
}
}

View File

@@ -1,5 +1,9 @@
{ {
"include": ["lancedb/*.ts", "lancedb/**/*.ts", "lancedb/*.js"], "include": [
"lancedb/*.ts",
"lancedb/**/*.ts",
"lancedb/*.js",
],
"compilerOptions": { "compilerOptions": {
"target": "es2022", "target": "es2022",
"module": "commonjs", "module": "commonjs",
@@ -7,17 +11,21 @@
"outDir": "./dist", "outDir": "./dist",
"strict": true, "strict": true,
"allowJs": true, "allowJs": true,
"resolveJsonModule": true "resolveJsonModule": true,
}, },
"exclude": ["./dist/*"], "exclude": [
"./dist/*",
],
"typedocOptions": { "typedocOptions": {
"entryPoints": ["lancedb/index.ts"], "entryPoints": [
"lancedb/index.ts"
],
"out": "../docs/src/javascript/", "out": "../docs/src/javascript/",
"visibilityFilters": { "visibilityFilters": {
"protected": false, "protected": false,
"private": false, "private": false,
"inherited": true, "inherited": true,
"external": false "external": false,
} }
} }
} }

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 0.6.5 current_version = 0.6.1
commit = True commit = True
message = [python] Bump version: {current_version} → {new_version} message = [python] Bump version: {current_version} → {new_version}
tag = True tag = True

View File

@@ -15,24 +15,10 @@ need to use `await` to call these functions.
## Connection ## Connection
* The connection now has a `close` method. You can call this when No changes yet.
you are done with the connection to eagerly free resources. Currently
this is limited to freeing/closing the HTTP connection for remote
connections. In the future we may add caching or other resources to
native connections so this is probably a good practice even if you aren't using remote connections.
In addition, the connection can be used as a context manager which may
be a more convenient way to ensure the connection is closed.
It is not mandatory to call the `close` method. If you don't call it
the connection will be closed when the object is garbage collected.
## Table ## Table
* The table now has a `close` method, similar to the connection. This
can be used to eagerly free the cache used by a Table object. Similar
to the connection, it can be used as a context manager and it is not
mandatory to call the `close` method.
* Previously `Table.schema` was a property. Now it is an async method. * Previously `Table.schema` was a property. Now it is an async method.
* The method `Table.__len__` was removed and `len(table)` will no longer * The method `Table.__len__` was removed and `len(table)` will no longer
work. Use `Table.count_rows` instead. work. Use `Table.count_rows` instead.

View File

@@ -7,7 +7,7 @@ license.workspace = true
repository.workspace = true repository.workspace = true
keywords.workspace = true keywords.workspace = true
categories.workspace = true categories.workspace = true
rust-version = "1.75.0"
[lib] [lib]
name = "_lancedb" name = "_lancedb"
@@ -22,9 +22,6 @@ pyo3-asyncio = { version = "0.20", features = ["attributes", "tokio-runtime"] }
# Prevent dynamic linking of lzma, which comes from datafusion # Prevent dynamic linking of lzma, which comes from datafusion
lzma-sys = { version = "*", features = ["static"] } lzma-sys = { version = "*", features = ["static"] }
pin-project = "1.1.5"
futures.workspace = true
tokio = { version = "1.36.0", features = ["sync"] }
[build-dependencies] [build-dependencies]
pyo3-build-config = { version = "0.20.3", features = [ pyo3-build-config = { version = "0.20.3", features = [

View File

@@ -1,9 +1,9 @@
[project] [project]
name = "lancedb" name = "lancedb"
version = "0.6.5" version = "0.6.1"
dependencies = [ dependencies = [
"deprecation", "deprecation",
"pylance==0.10.5", "pylance==0.10.1",
"ratelimiter~=1.0", "ratelimiter~=1.0",
"retry>=0.9.2", "retry>=0.9.2",
"tqdm>=4.27.0", "tqdm>=4.27.0",
@@ -81,7 +81,6 @@ embeddings = [
"awscli>=1.29.57", "awscli>=1.29.57",
"botocore>=1.31.57", "botocore>=1.31.57",
] ]
azure = ["adlfs>=2024.2.0"]
[tool.maturin] [tool.maturin]
python-source = "python" python-source = "python"
@@ -94,11 +93,13 @@ lancedb = "lancedb.cli.cli:cli"
requires = ["maturin>=1.4"] requires = ["maturin>=1.4"]
build-backend = "maturin" build-backend = "maturin"
[tool.ruff.lint] [tool.ruff.lint]
select = ["F", "E", "W", "I", "G", "TCH", "PERF"] select = ["F", "E", "W", "I", "G", "TCH", "PERF"]
[tool.pytest.ini_options] [tool.pytest.ini_options]
addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py" addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py"
markers = [ markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')", "slow: marks tests as slow (deselect with '-m \"not slow\"')",
"asyncio", "asyncio",

View File

@@ -21,11 +21,10 @@ __version__ = importlib.metadata.version("lancedb")
from ._lancedb import connect as lancedb_connect from ._lancedb import connect as lancedb_connect
from .common import URI, sanitize_uri from .common import URI, sanitize_uri
from .db import AsyncConnection, DBConnection, LanceDBConnection from .db import AsyncConnection, AsyncLanceDBConnection, DBConnection, LanceDBConnection
from .remote.db import RemoteDBConnection from .remote.db import RemoteDBConnection
from .schema import vector from .schema import vector # noqa: F401
from .table import AsyncTable from .utils import sentry_log # noqa: F401
from .utils import sentry_log
def connect( def connect(
@@ -36,7 +35,6 @@ def connect(
host_override: Optional[str] = None, host_override: Optional[str] = None,
read_consistency_interval: Optional[timedelta] = None, read_consistency_interval: Optional[timedelta] = None,
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None, request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
**kwargs,
) -> DBConnection: ) -> DBConnection:
"""Connect to a LanceDB database. """Connect to a LanceDB database.
@@ -101,12 +99,7 @@ def connect(
if isinstance(request_thread_pool, int): if isinstance(request_thread_pool, int):
request_thread_pool = ThreadPoolExecutor(request_thread_pool) request_thread_pool = ThreadPoolExecutor(request_thread_pool)
return RemoteDBConnection( return RemoteDBConnection(
uri, uri, api_key, region, host_override, request_thread_pool=request_thread_pool
api_key,
region,
host_override,
request_thread_pool=request_thread_pool,
**kwargs,
) )
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval) return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)
@@ -175,33 +168,8 @@ async def connect_async(
conn : DBConnection conn : DBConnection
A connection to a LanceDB database. A connection to a LanceDB database.
""" """
if read_consistency_interval is not None: return AsyncLanceDBConnection(
read_consistency_interval_secs = read_consistency_interval.total_seconds()
else:
read_consistency_interval_secs = None
return AsyncConnection(
await lancedb_connect( await lancedb_connect(
sanitize_uri(uri), sanitize_uri(uri), api_key, region, host_override, read_consistency_interval
api_key,
region,
host_override,
read_consistency_interval_secs,
) )
) )
__all__ = [
"connect",
"connect_async",
"AsyncConnection",
"AsyncTable",
"URI",
"sanitize_uri",
"sentry_log",
"vector",
"DBConnection",
"LanceDBConnection",
"RemoteDBConnection",
"__version__",
]

View File

@@ -1,23 +1,9 @@
from typing import Dict, List, Optional, Tuple from typing import Optional
import pyarrow as pa import pyarrow as pa
class Index:
@staticmethod
def ivf_pq(
distance_type: Optional[str],
num_partitions: Optional[int],
num_sub_vectors: Optional[int],
max_iterations: Optional[int],
sample_rate: Optional[int],
) -> Index: ...
@staticmethod
def btree() -> Index: ...
class Connection(object): class Connection(object):
async def table_names( async def table_names(self) -> list[str]: ...
self, start_after: Optional[str], limit: Optional[int]
) -> list[str]: ...
async def create_table( async def create_table(
self, name: str, mode: str, data: pa.RecordBatchReader self, name: str, mode: str, data: pa.RecordBatchReader
) -> Table: ... ) -> Table: ...
@@ -25,27 +11,9 @@ class Connection(object):
self, name: str, mode: str, schema: pa.Schema self, name: str, mode: str, schema: pa.Schema
) -> Table: ... ) -> Table: ...
class Table: class Table(object):
def name(self) -> str: ... def name(self) -> str: ...
def __repr__(self) -> str: ...
async def schema(self) -> pa.Schema: ... async def schema(self) -> pa.Schema: ...
async def add(self, data: pa.RecordBatchReader, mode: str) -> None: ...
async def update(self, updates: Dict[str, str], where: Optional[str]) -> None: ...
async def count_rows(self, filter: Optional[str]) -> int: ...
async def create_index(
self, column: str, config: Optional[Index], replace: Optional[bool]
): ...
async def version(self) -> int: ...
async def checkout(self, version): ...
async def checkout_latest(self): ...
async def restore(self): ...
async def list_indices(self) -> List[IndexConfig]: ...
def query(self) -> Query: ...
def vector_search(self) -> VectorQuery: ...
class IndexConfig:
index_type: str
columns: List[str]
async def connect( async def connect(
uri: str, uri: str,
@@ -54,27 +22,3 @@ async def connect(
host_override: Optional[str], host_override: Optional[str],
read_consistency_interval: Optional[float], read_consistency_interval: Optional[float],
) -> Connection: ... ) -> Connection: ...
class RecordBatchStream:
def schema(self) -> pa.Schema: ...
async def next(self) -> Optional[pa.RecordBatch]: ...
class Query:
def where(self, filter: str): ...
def select(self, columns: Tuple[str, str]): ...
def limit(self, limit: int): ...
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
async def execute(self) -> RecordBatchStream: ...
class VectorQuery:
async def execute(self) -> RecordBatchStream: ...
def where(self, filter: str): ...
def select(self, columns: List[str]): ...
def select_with_projection(self, columns: Tuple[str, str]): ...
def limit(self, limit: int): ...
def column(self, column: str): ...
def distance_type(self, distance_type: str): ...
def postfilter(self): ...
def refine_factor(self, refine_factor: int): ...
def nprobes(self, nprobes: int): ...
def bypass_vector_index(self): ...

View File

@@ -1,44 +0,0 @@
from typing import List
import pyarrow as pa
from ._lancedb import RecordBatchStream
class AsyncRecordBatchReader:
"""
An async iterator over a stream of RecordBatches.
Also allows access to the schema of the stream
"""
def __init__(self, inner: RecordBatchStream):
self.inner_ = inner
@property
def schema(self) -> pa.Schema:
"""
Get the schema of the batches produced by the stream
Accessing the schema does not consume any data from the stream
"""
return self.inner_.schema()
async def read_all(self) -> List[pa.RecordBatch]:
"""
Read all the record batches from the stream
This consumes the entire stream and returns a list of record batches
If there are a lot of results this may consume a lot of memory
"""
return [batch async for batch in self]
def __aiter__(self):
return self
async def __anext__(self) -> pa.RecordBatch:
next = await self.inner_.next()
if next is None:
raise StopAsyncIteration
return next

View File

@@ -13,12 +13,11 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import inspect import inspect
import os import os
from abc import abstractmethod from abc import abstractmethod
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Iterable, List, Literal, Optional, Union from typing import TYPE_CHECKING, Iterable, List, Optional, Union
import pyarrow as pa import pyarrow as pa
from overrides import EnforceOverrides, override from overrides import EnforceOverrides, override
@@ -28,9 +27,8 @@ from lancedb.common import data_to_reader, validate_schema
from lancedb.embeddings.registry import EmbeddingFunctionRegistry from lancedb.embeddings.registry import EmbeddingFunctionRegistry
from lancedb.utils.events import register_event from lancedb.utils.events import register_event
from ._lancedb import connect as lancedb_connect
from .pydantic import LanceModel from .pydantic import LanceModel
from .table import AsyncTable, LanceTable, Table, _sanitize_data from .table import AsyncLanceTable, LanceTable, Table, _sanitize_data
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -319,10 +317,6 @@ class LanceDBConnection(DBConnection):
def uri(self) -> str: def uri(self) -> str:
return self._uri return self._uri
async def _async_get_table_names(self, start_after: Optional[str], limit: int):
conn = AsyncConnection(await lancedb_connect(self.uri))
return await conn.table_names(start_after=start_after, limit=limit)
@override @override
def table_names( def table_names(
self, page_token: Optional[str] = None, limit: int = 10 self, page_token: Optional[str] = None, limit: int = 10
@@ -334,10 +328,6 @@ class LanceDBConnection(DBConnection):
Iterator of str. Iterator of str.
A list of table names. A list of table names.
""" """
try:
asyncio.get_running_loop()
# User application is async. Soon we will just tell them to use the
# async version. Until then fallback to the old sync implementation.
try: try:
filesystem = fs_from_uri(self.uri)[0] filesystem = fs_from_uri(self.uri)[0]
except pa.ArrowInvalid: except pa.ArrowInvalid:
@@ -356,10 +346,6 @@ class LanceDBConnection(DBConnection):
] ]
tables.sort() tables.sort()
return tables return tables
except RuntimeError:
# User application is sync. It is safe to use the async implementation
# under the hood.
return asyncio.run(self._async_get_table_names(page_token, limit))
def __len__(self) -> int: def __len__(self) -> int:
return len(self.table_names()) return len(self.table_names())
@@ -441,95 +427,43 @@ class LanceDBConnection(DBConnection):
filesystem.delete_dir(path) filesystem.delete_dir(path)
class AsyncConnection(object): class AsyncConnection(EnforceOverrides):
"""An active LanceDB connection """An active LanceDB connection interface."""
To obtain a connection you can use the [connect] function.
This could be a native connection (using lance) or a remote connection (e.g. for
connecting to LanceDb Cloud)
Local connections do not currently hold any open resources but they may do so in the
future (for example, for shared cache or connections to catalog services) Remote
connections represent an open connection to the remote server. The [close] method
can be used to release any underlying resources eagerly. The connection can also
be used as a context manager:
Connections can be shared on multiple threads and are expected to be long lived.
Connections can also be used as a context manager, however, in many cases a single
connection can be used for the lifetime of the application and so this is often
not needed. Closing a connection is optional. If it is not closed then it will
be automatically closed when the connection object is deleted.
Examples
--------
>>> import asyncio
>>> import lancedb
>>> async def my_connect():
... with await lancedb.connect("/tmp/my_dataset") as conn:
... # do something with the connection
... pass
... # conn is closed here
"""
def __init__(self, connection: LanceDbConnection):
self._inner = connection
def __repr__(self):
return self._inner.__repr__()
def __enter__(self):
self
def __exit__(self, *_):
self.close()
def is_open(self):
"""Return True if the connection is open."""
return self._inner.is_open()
def close(self):
"""Close the connection, releasing any underlying resources.
It is safe to call this method multiple times.
Any attempt to use the connection after it is closed will result in an error."""
self._inner.close()
@abstractmethod
async def table_names( async def table_names(
self, *, start_after: Optional[str] = None, limit: Optional[int] = None self, *, page_token: Optional[str] = None, limit: int = 10
) -> Iterable[str]: ) -> Iterable[str]:
"""List all tables in this database, in sorted order """List all tables in this database, in sorted order
Parameters Parameters
---------- ----------
start_after: str, optional page_token: str, optional
If present, only return names that come lexicographically after the supplied The token to use for pagination. If not present, start from the beginning.
value. Typically, this token is last table name from the previous page.
Only supported by LanceDb Cloud.
This can be combined with limit to implement pagination by setting this to
the last table name from the previous page.
limit: int, default 10 limit: int, default 10
The number of results to return. The size of the page to return.
Only supported by LanceDb Cloud.
Returns Returns
------- -------
Iterable of str Iterable of str
""" """
return await self._inner.table_names(start_after=start_after, limit=limit) pass
@abstractmethod
async def create_table( async def create_table(
self, self,
name: str, name: str,
data: Optional[DATA] = None, data: Optional[DATA] = None,
schema: Optional[Union[pa.Schema, LanceModel]] = None, schema: Optional[Union[pa.Schema, LanceModel]] = None,
mode: Optional[Literal["create", "overwrite"]] = None, mode: str = "create",
exist_ok: Optional[bool] = None, exist_ok: bool = False,
on_bad_vectors: Optional[str] = None, on_bad_vectors: str = "error",
fill_value: Optional[float] = None, fill_value: float = 0.0,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None, embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
) -> AsyncTable: ) -> Table:
"""Create a [Table][lancedb.table.Table] in the database. """Create a [Table][lancedb.table.Table] in the database.
Parameters Parameters
@@ -551,7 +485,7 @@ class AsyncConnection(object):
- pyarrow.Schema - pyarrow.Schema
- [LanceModel][lancedb.pydantic.LanceModel] - [LanceModel][lancedb.pydantic.LanceModel]
mode: Literal["create", "overwrite"]; default "create" mode: str; default "create"
The mode to use when creating the table. The mode to use when creating the table.
Can be either "create" or "overwrite". Can be either "create" or "overwrite".
By default, if the table already exists, an exception is raised. By default, if the table already exists, an exception is raised.
@@ -667,6 +601,72 @@ class AsyncConnection(object):
LanceTable(connection=..., name="table4") LanceTable(connection=..., name="table4")
""" """
raise NotImplementedError
async def open_table(self, name: str) -> Table:
"""Open a Lance Table in the database.
Parameters
----------
name: str
The name of the table.
Returns
-------
A LanceTable object representing the table.
"""
raise NotImplementedError
async def drop_table(self, name: str):
"""Drop a table from the database.
Parameters
----------
name: str
The name of the table.
"""
raise NotImplementedError
async def drop_database(self):
"""
Drop database
This is the same thing as dropping all the tables
"""
raise NotImplementedError
class AsyncLanceDBConnection(AsyncConnection):
def __init__(self, connection: LanceDbConnection):
self._inner = connection
async def __repr__(self) -> str:
pass
@override
async def table_names(
self,
*,
page_token=None,
limit=None,
) -> Iterable[str]:
# TODO: hook in page_token and limit
return await self._inner.table_names()
@override
async def create_table(
self,
name: str,
data: Optional[DATA] = None,
schema: Optional[Union[pa.Schema, LanceModel]] = None,
mode: str = "create",
exist_ok: bool = False,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
) -> Table:
if mode.lower() not in ["create", "overwrite"]:
raise ValueError("mode must be either 'create' or 'overwrite'")
if inspect.isclass(schema) and issubclass(schema, LanceModel): if inspect.isclass(schema) and issubclass(schema, LanceModel):
# convert LanceModel to pyarrow schema # convert LanceModel to pyarrow schema
# note that it's possible this contains # note that it's possible this contains
@@ -681,14 +681,6 @@ class AsyncConnection(object):
registry = EmbeddingFunctionRegistry.get_instance() registry = EmbeddingFunctionRegistry.get_instance()
metadata = registry.get_table_metadata(embedding_functions) metadata = registry.get_table_metadata(embedding_functions)
# Defining defaults here and not in function prototype. In the future
# these defaults will move into rust so better to keep them as None.
if on_bad_vectors is None:
on_bad_vectors = "error"
if fill_value is None:
fill_value = 0.0
if data is not None: if data is not None:
data = _sanitize_data( data = _sanitize_data(
data, data,
@@ -716,10 +708,6 @@ class AsyncConnection(object):
schema = schema.with_metadata(metadata) schema = schema.with_metadata(metadata)
validate_schema(schema) validate_schema(schema)
if exist_ok is None:
exist_ok = False
if mode is None:
mode = "create"
if mode == "create" and exist_ok: if mode == "create" and exist_ok:
mode = "exist_ok" mode = "exist_ok"
@@ -734,37 +722,16 @@ class AsyncConnection(object):
) )
register_event("create_table") register_event("create_table")
return AsyncTable(new_table) return AsyncLanceTable(new_table)
async def open_table(self, name: str) -> Table: @override
"""Open a Lance Table in the database. async def open_table(self, name: str) -> LanceTable:
Parameters
----------
name: str
The name of the table.
Returns
-------
A LanceTable object representing the table.
"""
table = await self._inner.open_table(name)
register_event("open_table")
return AsyncTable(table)
async def drop_table(self, name: str):
"""Drop a table from the database.
Parameters
----------
name: str
The name of the table.
"""
raise NotImplementedError raise NotImplementedError
@override
async def drop_table(self, name: str, ignore_missing: bool = False):
raise NotImplementedError
@override
async def drop_database(self): async def drop_database(self):
"""
Drop database
This is the same thing as dropping all the tables
"""
raise NotImplementedError raise NotImplementedError

View File

@@ -31,7 +31,7 @@ class ImageBindEmbeddings(EmbeddingFunction):
six different modalities: images, text, audio, depth, thermal, and IMU data six different modalities: images, text, audio, depth, thermal, and IMU data
to download package, run : to download package, run :
`pip install imagebind-packaged==0.1.2` `pip install imagebind@git+https://github.com/raghavdixit99/ImageBind`
""" """
name: str = "imagebind_huge" name: str = "imagebind_huge"

View File

@@ -10,15 +10,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from functools import cached_property from functools import cached_property
from typing import TYPE_CHECKING, List, Optional, Union from typing import List, Optional, Union
import numpy as np
from ..util import attempt_import_or_raise from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction from .base import TextEmbeddingFunction
from .registry import register from .registry import register
from .utils import api_key_not_found_help
if TYPE_CHECKING:
import numpy as np
@register("openai") @register("openai")
@@ -27,46 +28,14 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
An embedding function that uses the OpenAI API An embedding function that uses the OpenAI API
https://platform.openai.com/docs/guides/embeddings https://platform.openai.com/docs/guides/embeddings
This can also be used for open source models that
are compatible with the OpenAI API.
Notes
-----
If you're running an Ollama server locally,
you can just override the `base_url` parameter
and provide the Ollama embedding model you want
to use (https://ollama.com/library):
```python
from lancedb.embeddings import get_registry
openai = get_registry().get("openai")
embedding_function = openai.create(
name="<ollama-embedding-model-name>",
base_url="http://localhost:11434",
)
```
""" """
name: str = "text-embedding-ada-002" name: str = "text-embedding-ada-002"
dim: Optional[int] = None dim: Optional[int] = None
base_url: Optional[str] = None
default_headers: Optional[dict] = None
organization: Optional[str] = None
api_key: Optional[str] = None
def ndims(self): def ndims(self):
return self._ndims return self._ndims
@staticmethod
def model_names():
return [
"text-embedding-ada-002",
"text-embedding-3-large",
"text-embedding-3-small",
]
@cached_property @cached_property
def _ndims(self): def _ndims(self):
if self.name == "text-embedding-ada-002": if self.name == "text-embedding-ada-002":
@@ -79,8 +48,8 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
raise ValueError(f"Unknown model name {self.name}") raise ValueError(f"Unknown model name {self.name}")
def generate_embeddings( def generate_embeddings(
self, texts: Union[List[str], "np.ndarray"] self, texts: Union[List[str], np.ndarray]
) -> List["np.array"]: ) -> List[np.array]:
""" """
Get the embeddings for the given texts Get the embeddings for the given texts
@@ -93,25 +62,15 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
if self.name == "text-embedding-ada-002": if self.name == "text-embedding-ada-002":
rs = self._openai_client.embeddings.create(input=texts, model=self.name) rs = self._openai_client.embeddings.create(input=texts, model=self.name)
else: else:
kwargs = { rs = self._openai_client.embeddings.create(
"input": texts, input=texts, model=self.name, dimensions=self.ndims()
"model": self.name, )
}
if self.dim:
kwargs["dimensions"] = self.dim
rs = self._openai_client.embeddings.create(**kwargs)
return [v.embedding for v in rs.data] return [v.embedding for v in rs.data]
@cached_property @cached_property
def _openai_client(self): def _openai_client(self):
openai = attempt_import_or_raise("openai") openai = attempt_import_or_raise("openai")
kwargs = {}
if self.base_url: if not os.environ.get("OPENAI_API_KEY"):
kwargs["base_url"] = self.base_url api_key_not_found_help("openai")
if self.default_headers: return openai.OpenAI()
kwargs["default_headers"] = self.default_headers
if self.organization:
kwargs["organization"] = self.organization
if self.api_key:
kwargs["api_key"] = self.api_key
return openai.OpenAI(**kwargs)

View File

@@ -22,15 +22,13 @@ try:
import tantivy import tantivy
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Please install tantivy-py `pip install tantivy` to use the full text search feature." # noqa: E501 "Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature." # noqa: E501
) )
from .table import LanceTable from .table import LanceTable
def create_index( def create_index(index_path: str, text_fields: List[str]) -> tantivy.Index:
index_path: str, text_fields: List[str], ordering_fields: List[str] = None
) -> tantivy.Index:
""" """
Create a new Index (not populated) Create a new Index (not populated)
@@ -40,16 +38,12 @@ def create_index(
Path to the index directory Path to the index directory
text_fields : List[str] text_fields : List[str]
List of text fields to index List of text fields to index
ordering_fields: List[str]
List of unsigned type fields to order by at search time
Returns Returns
------- -------
index : tantivy.Index index : tantivy.Index
The index object (not yet populated) The index object (not yet populated)
""" """
if ordering_fields is None:
ordering_fields = []
# Declaring our schema. # Declaring our schema.
schema_builder = tantivy.SchemaBuilder() schema_builder = tantivy.SchemaBuilder()
# special field that we'll populate with row_id # special field that we'll populate with row_id
@@ -57,9 +51,6 @@ def create_index(
# data fields # data fields
for name in text_fields: for name in text_fields:
schema_builder.add_text_field(name, stored=True) schema_builder.add_text_field(name, stored=True)
if ordering_fields:
for name in ordering_fields:
schema_builder.add_unsigned_field(name, fast=True)
schema = schema_builder.build() schema = schema_builder.build()
os.makedirs(index_path, exist_ok=True) os.makedirs(index_path, exist_ok=True)
index = tantivy.Index(schema, path=index_path) index = tantivy.Index(schema, path=index_path)
@@ -71,7 +62,6 @@ def populate_index(
table: LanceTable, table: LanceTable,
fields: List[str], fields: List[str],
writer_heap_size: int = 1024 * 1024 * 1024, writer_heap_size: int = 1024 * 1024 * 1024,
ordering_fields: List[str] = None,
) -> int: ) -> int:
""" """
Populate an index with data from a LanceTable Populate an index with data from a LanceTable
@@ -92,11 +82,8 @@ def populate_index(
int int
The number of rows indexed The number of rows indexed
""" """
if ordering_fields is None:
ordering_fields = []
# first check the fields exist and are string or large string type # first check the fields exist and are string or large string type
nested = [] nested = []
for name in fields: for name in fields:
try: try:
f = table.schema.field(name) # raises KeyError if not found f = table.schema.field(name) # raises KeyError if not found
@@ -117,7 +104,7 @@ def populate_index(
if len(nested) > 0: if len(nested) > 0:
max_nested_level = max([len(name.split(".")) for name in nested]) max_nested_level = max([len(name.split(".")) for name in nested])
for b in dataset.to_batches(columns=fields + ordering_fields): for b in dataset.to_batches(columns=fields):
if max_nested_level > 0: if max_nested_level > 0:
b = pa.Table.from_batches([b]) b = pa.Table.from_batches([b])
for _ in range(max_nested_level - 1): for _ in range(max_nested_level - 1):
@@ -128,10 +115,6 @@ def populate_index(
value = b[name][i].as_py() value = b[name][i].as_py()
if value is not None: if value is not None:
doc.add_text(name, value) doc.add_text(name, value)
for name in ordering_fields:
value = b[name][i].as_py()
if value is not None:
doc.add_unsigned(name, value)
if not doc.is_empty: if not doc.is_empty:
doc.add_integer("doc_id", row_id) doc.add_integer("doc_id", row_id)
writer.add_document(doc) writer.add_document(doc)
@@ -166,7 +149,7 @@ def resolve_path(schema, field_name: str) -> pa.Field:
def search_index( def search_index(
index: tantivy.Index, query: str, limit: int = 10, ordering_field=None index: tantivy.Index, query: str, limit: int = 10
) -> Tuple[Tuple[int], Tuple[float]]: ) -> Tuple[Tuple[int], Tuple[float]]:
""" """
Search an index for a query Search an index for a query
@@ -189,9 +172,6 @@ def search_index(
searcher = index.searcher() searcher = index.searcher()
query = index.parse_query(query) query = index.parse_query(query)
# get top results # get top results
if ordering_field:
results = searcher.search(query, limit, order_by_field=ordering_field)
else:
results = searcher.search(query, limit) results = searcher.search(query, limit)
if results.count == 0: if results.count == 0:
return tuple(), tuple() return tuple(), tuple()

View File

@@ -1,163 +0,0 @@
from typing import Optional
from ._lancedb import (
Index as LanceDbIndex,
)
from ._lancedb import (
IndexConfig,
)
class BTree(object):
"""Describes a btree index configuration
A btree index is an index on scalar columns. The index stores a copy of the
column in sorted order. A header entry is created for each block of rows
(currently the block size is fixed at 4096). These header entries are stored
in a separate cacheable structure (a btree). To search for data the header is
used to determine which blocks need to be read from disk.
For example, a btree index in a table with 1Bi rows requires
sizeof(Scalar) * 256Ki bytes of memory and will generally need to read
sizeof(Scalar) * 4096 bytes to find the correct row ids.
This index is good for scalar columns with mostly distinct values and does best
when the query is highly selective.
The btree index does not currently have any parameters though parameters such as
the block size may be added in the future.
"""
def __init__(self):
self._inner = LanceDbIndex.btree()
class IvfPq(object):
"""Describes an IVF PQ Index
This index stores a compressed (quantized) copy of every vector. These vectors
are grouped into partitions of similar vectors. Each partition keeps track of
a centroid which is the average value of all vectors in the group.
During a query the centroids are compared with the query vector to find the
closest partitions. The compressed vectors in these partitions are then
searched to find the closest vectors.
The compression scheme is called product quantization. Each vector is divide
into subvectors and then each subvector is quantized into a small number of
bits. the parameters `num_bits` and `num_subvectors` control this process,
providing a tradeoff between index size (and thus search speed) and index
accuracy.
The partitioning process is called IVF and the `num_partitions` parameter
controls how many groups to create.
Note that training an IVF PQ index on a large dataset is a slow operation and
currently is also a memory intensive operation.
"""
def __init__(
self,
*,
distance_type: Optional[str] = None,
num_partitions: Optional[int] = None,
num_sub_vectors: Optional[int] = None,
max_iterations: Optional[int] = None,
sample_rate: Optional[int] = None,
):
"""
Create an IVF PQ index config
Parameters
----------
distance_type: str, default "L2"
The distance metric used to train the index
This is used when training the index to calculate the IVF partitions
(vectors are grouped in partitions with similar vectors according to this
distance type) and to calculate a subvector's code during quantization.
The distance type used to train an index MUST match the distance type used
to search the index. Failure to do so will yield inaccurate results.
The following distance types are available:
"l2" - Euclidean distance. This is a very common distance metric that
accounts for both magnitude and direction when determining the distance
between vectors. L2 distance has a range of [0, ∞).
"cosine" - Cosine distance. Cosine distance is a distance metric
calculated from the cosine similarity between two vectors. Cosine
similarity is a measure of similarity between two non-zero vectors of an
inner product space. It is defined to equal the cosine of the angle
between them. Unlike L2, the cosine distance is not affected by the
magnitude of the vectors. Cosine distance has a range of [0, 2].
Note: the cosine distance is undefined when one (or both) of the vectors
are all zeros (there is no direction). These vectors are invalid and may
never be returned from a vector search.
"dot" - Dot product. Dot distance is the dot product of two vectors. Dot
distance has a range of (-∞, ∞). If the vectors are normalized (i.e. their
L2 norm is 1), then dot distance is equivalent to the cosine distance.
num_partitions: int, default sqrt(num_rows)
The number of IVF partitions to create.
This value should generally scale with the number of rows in the dataset.
By default the number of partitions is the square root of the number of
rows.
If this value is too large then the first part of the search (picking the
right partition) will be slow. If this value is too small then the second
part of the search (searching within a partition) will be slow.
num_sub_vectors: int, default is vector dimension / 16
Number of sub-vectors of PQ.
This value controls how much the vector is compressed during the
quantization step. The more sub vectors there are the less the vector is
compressed. The default is the dimension of the vector divided by 16. If
the dimension is not evenly divisible by 16 we use the dimension divded by
8.
The above two cases are highly preferred. Having 8 or 16 values per
subvector allows us to use efficient SIMD instructions.
If the dimension is not visible by 8 then we use 1 subvector. This is not
ideal and will likely result in poor performance.
max_iterations: int, default 50
Max iteration to train kmeans.
When training an IVF PQ index we use kmeans to calculate the partitions.
This parameter controls how many iterations of kmeans to run.
Increasing this might improve the quality of the index but in most cases
these extra iterations have diminishing returns.
The default value is 50.
sample_rate: int, default 256
The rate used to calculate the number of training vectors for kmeans.
When an IVF PQ index is trained, we need to calculate partitions. These
are groups of vectors that are similar to each other. To do this we use an
algorithm called kmeans.
Running kmeans on a large dataset can be slow. To speed this up we run
kmeans on a random sample of the data. This parameter controls the size of
the sample. The total number of vectors used to train the index is
`sample_rate * num_partitions`.
Increasing this value might improve the quality of the index but in most
cases the default should be sufficient.
The default value is 256.
"""
self._inner = LanceDbIndex.ivf_pq(
distance_type=distance_type,
num_partitions=num_partitions,
num_sub_vectors=num_sub_vectors,
max_iterations=max_iterations,
sample_rate=sample_rate,
)
__all__ = ["BTree", "IvfPq", "IndexConfig"]

View File

@@ -16,16 +16,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from pathlib import Path from pathlib import Path
from typing import ( from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Type, Union
TYPE_CHECKING,
Dict,
List,
Literal,
Optional,
Tuple,
Type,
Union,
)
import deprecation import deprecation
import numpy as np import numpy as np
@@ -33,7 +24,6 @@ import pyarrow as pa
import pydantic import pydantic
from . import __version__ from . import __version__
from .arrow import AsyncRecordBatchReader
from .common import VEC from .common import VEC
from .rerankers.base import Reranker from .rerankers.base import Reranker
from .rerankers.linear_combination import LinearCombinationReranker from .rerankers.linear_combination import LinearCombinationReranker
@@ -43,8 +33,6 @@ if TYPE_CHECKING:
import PIL import PIL
import polars as pl import polars as pl
from ._lancedb import Query as LanceQuery
from ._lancedb import VectorQuery as LanceVectorQuery
from .pydantic import LanceModel from .pydantic import LanceModel
from .table import Table from .table import Table
@@ -118,8 +106,8 @@ class Query(pydantic.BaseModel):
class LanceQueryBuilder(ABC): class LanceQueryBuilder(ABC):
"""An abstract query builder. Subclasses are defined for vector search, """Build LanceDB query based on specific query type:
full text search, hybrid, and plain SQL filtering. vector or full text search.
""" """
@classmethod @classmethod
@@ -129,48 +117,39 @@ class LanceQueryBuilder(ABC):
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]], query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
query_type: str, query_type: str,
vector_column_name: str, vector_column_name: str,
ordering_field_name: str = None, vector: Optional[VEC] = None,
text: Optional[str] = None,
) -> LanceQueryBuilder: ) -> LanceQueryBuilder:
""" if query is None and vector is None and text is None:
Create a query builder based on the given query and query type.
Parameters
----------
table: Table
The table to query.
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]]
The query to use. If None, an empty query builder is returned
which performs simple SQL filtering.
query_type: str
The type of query to perform. One of "vector", "fts", "hybrid", or "auto".
If "auto", the query type is inferred based on the query.
vector_column_name: str
The name of the vector column to use for vector search.
"""
if query is None:
return LanceEmptyQueryBuilder(table) return LanceEmptyQueryBuilder(table)
if query_type == "hybrid": if query_type == "hybrid":
# hybrid fts and vector query # hybrid fts and vector query
return LanceHybridQueryBuilder(table, query, vector_column_name) return LanceHybridQueryBuilder(
table, query, vector_column_name, vector, text
)
# remember the string query for reranking purpose # Resolve hybrid query with explicit vector and text params here to avoid
str_query = query if isinstance(query, str) else None # adding them as params in the BaseQueryBuilder class
if vector is not None or text is not None:
if query_type not in ["hybrid", "auto"]:
raise ValueError(
"If `vector` and `text` are provided, then `query_type`\
must be 'hybrid' or 'auto'"
)
return LanceHybridQueryBuilder(
table, query, vector_column_name, vector, text
)
# convert "auto" query_type to "vector", "fts" # convert "auto" query_type to "vector" or "fts"
# or "hybrid" and convert the query to vector if needed # and convert the query to vector if needed
query, query_type = cls._resolve_query( query, query_type = cls._resolve_query(
table, query, query_type, vector_column_name table, query, query_type, vector_column_name
) )
if query_type == "hybrid":
return LanceHybridQueryBuilder(table, query, vector_column_name)
if isinstance(query, str): if isinstance(query, str):
# fts # fts
return LanceFtsQueryBuilder( return LanceFtsQueryBuilder(table, query)
table, query, ordering_field_name=ordering_field_name
)
if isinstance(query, list): if isinstance(query, list):
query = np.array(query, dtype=np.float32) query = np.array(query, dtype=np.float32)
@@ -179,7 +158,7 @@ class LanceQueryBuilder(ABC):
else: else:
raise TypeError(f"Unsupported query type: {type(query)}") raise TypeError(f"Unsupported query type: {type(query)}")
return LanceVectorQueryBuilder(table, query, vector_column_name, str_query) return LanceVectorQueryBuilder(table, query, vector_column_name)
@classmethod @classmethod
def _resolve_query(cls, table, query, query_type, vector_column_name): def _resolve_query(cls, table, query, query_type, vector_column_name):
@@ -195,8 +174,6 @@ class LanceQueryBuilder(ABC):
elif query_type == "auto": elif query_type == "auto":
if isinstance(query, (list, np.ndarray)): if isinstance(query, (list, np.ndarray)):
return query, "vector" return query, "vector"
if isinstance(query, tuple):
return query, "hybrid"
else: else:
conf = table.embedding_functions.get(vector_column_name) conf = table.embedding_functions.get(vector_column_name)
if conf is not None: if conf is not None:
@@ -443,7 +420,6 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
table: "Table", table: "Table",
query: Union[np.ndarray, list, "PIL.Image.Image"], query: Union[np.ndarray, list, "PIL.Image.Image"],
vector_column: str, vector_column: str,
str_query: Optional[str] = None,
): ):
super().__init__(table) super().__init__(table)
self._query = query self._query = query
@@ -452,8 +428,6 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._refine_factor = None self._refine_factor = None
self._vector_column = vector_column self._vector_column = vector_column
self._prefilter = False self._prefilter = False
self._reranker = None
self._str_query = str_query
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder: def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
"""Set the distance metric to use. """Set the distance metric to use.
@@ -524,21 +498,6 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
and also the "_distance" column which is the distance between the query and also the "_distance" column which is the distance between the query
vector and the returned vectors. vector and the returned vectors.
""" """
return self.to_batches().read_all()
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
"""
Execute the query and return the result as a RecordBatchReader object.
Parameters
----------
batch_size: int
The maximum number of selected records in a RecordBatch object.
Returns
-------
pa.RecordBatchReader
"""
vector = self._query if isinstance(self._query, list) else self._query.tolist() vector = self._query if isinstance(self._query, list) else self._query.tolist()
if isinstance(vector[0], np.ndarray): if isinstance(vector[0], np.ndarray):
vector = [v.tolist() for v in vector] vector = [v.tolist() for v in vector]
@@ -554,16 +513,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
vector_column=self._vector_column, vector_column=self._vector_column,
with_row_id=self._with_row_id, with_row_id=self._with_row_id,
) )
result_set = self._table._execute_query(query, batch_size) return self._table._execute_query(query)
if self._reranker is not None:
rs_table = result_set.read_all()
result_set = self._reranker.rerank_vector(self._str_query, rs_table)
# convert result_set back to RecordBatchReader
result_set = pa.RecordBatchReader.from_batches(
result_set.schema, result_set.to_batches()
)
return result_set
def where(self, where: str, prefilter: bool = False) -> LanceVectorQueryBuilder: def where(self, where: str, prefilter: bool = False) -> LanceVectorQueryBuilder:
"""Set the where clause. """Set the where clause.
@@ -589,52 +539,14 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._prefilter = prefilter self._prefilter = prefilter
return self return self
def rerank(
self, reranker: Reranker, query_string: Optional[str] = None
) -> LanceVectorQueryBuilder:
"""Rerank the results using the specified reranker.
Parameters
----------
reranker: Reranker
The reranker to use.
query_string: Optional[str]
The query to use for reranking. This needs to be specified explicitly here
as the query used for vector search may already be vectorized and the
reranker requires a string query.
This is only required if the query used for vector search is not a string.
Note: This doesn't yet support the case where the query is multimodal or a
list of vectors.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._reranker = reranker
if self._str_query is None and query_string is None:
raise ValueError(
"""
The query used for vector search is not a string.
In this case, the reranker query needs to be specified explicitly.
"""
)
if query_string is not None and not isinstance(query_string, str):
raise ValueError("Reranking currently only supports string queries")
self._str_query = query_string if query_string is not None else self._str_query
return self
class LanceFtsQueryBuilder(LanceQueryBuilder): class LanceFtsQueryBuilder(LanceQueryBuilder):
"""A builder for full text search for LanceDB.""" """A builder for full text search for LanceDB."""
def __init__(self, table: "Table", query: str, ordering_field_name: str = None): def __init__(self, table: "Table", query: str):
super().__init__(table) super().__init__(table)
self._query = query self._query = query
self._phrase_query = False self._phrase_query = False
self.ordering_field_name = ordering_field_name
self._reranker = None
def phrase_query(self, phrase_query: bool = True) -> LanceFtsQueryBuilder: def phrase_query(self, phrase_query: bool = True) -> LanceFtsQueryBuilder:
"""Set whether to use phrase query. """Set whether to use phrase query.
@@ -658,7 +570,7 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
import tantivy import tantivy
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Please install tantivy-py `pip install tantivy` to use the full text search feature." # noqa: E501 "Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature." # noqa: E501
) )
from .fts import search_index from .fts import search_index
@@ -679,35 +591,26 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
if self._phrase_query: if self._phrase_query:
query = query.replace('"', "'") query = query.replace('"', "'")
query = f'"{query}"' query = f'"{query}"'
row_ids, scores = search_index( row_ids, scores = search_index(index, query, self._limit)
index, query, self._limit, ordering_field=self.ordering_field_name
)
if len(row_ids) == 0: if len(row_ids) == 0:
empty_schema = pa.schema([pa.field("score", pa.float32())]) empty_schema = pa.schema([pa.field("score", pa.float32())])
return pa.Table.from_pylist([], schema=empty_schema) return pa.Table.from_pylist([], schema=empty_schema)
scores = pa.array(scores) scores = pa.array(scores)
output_tbl = self._table.to_lance().take(row_ids, columns=self._columns) output_tbl = self._table.to_lance().take(row_ids, columns=self._columns)
output_tbl = output_tbl.append_column("score", scores) output_tbl = output_tbl.append_column("score", scores)
# this needs to match vector search results which are uint64
row_ids = pa.array(row_ids, type=pa.uint64())
if self._where is not None: if self._where is not None:
tmp_name = "__lancedb__duckdb__indexer__"
output_tbl = output_tbl.append_column(
tmp_name, pa.array(range(len(output_tbl)))
)
try: try:
# TODO would be great to have Substrait generate pyarrow compute # TODO would be great to have Substrait generate pyarrow compute
# expressions or conversely have pyarrow support SQL expressions # expressions or conversely have pyarrow support SQL expressions
# using Substrait # using Substrait
import duckdb import duckdb
indexer = duckdb.sql( output_tbl = (
f"SELECT {tmp_name} FROM output_tbl WHERE {self._where}" duckdb.sql("SELECT * FROM output_tbl")
).to_arrow_table()[tmp_name] .filter(self._where)
output_tbl = output_tbl.take(indexer).drop([tmp_name]) .to_arrow_table()
row_ids = row_ids.take(indexer) )
except ImportError: except ImportError:
import tempfile import tempfile
@@ -717,33 +620,13 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
ds = lance.write_dataset(output_tbl, tmp) ds = lance.write_dataset(output_tbl, tmp)
output_tbl = ds.to_table(filter=self._where) output_tbl = ds.to_table(filter=self._where)
indexer = output_tbl[tmp_name]
row_ids = row_ids.take(indexer)
output_tbl = output_tbl.drop([tmp_name])
if self._with_row_id: if self._with_row_id:
# Need to set this to uint explicitly as vector results are in uint64
row_ids = pa.array(row_ids, type=pa.uint64())
output_tbl = output_tbl.append_column("_rowid", row_ids) output_tbl = output_tbl.append_column("_rowid", row_ids)
if self._reranker is not None:
output_tbl = self._reranker.rerank_fts(self._query, output_tbl)
return output_tbl return output_tbl
def rerank(self, reranker: Reranker) -> LanceFtsQueryBuilder:
"""Rerank the results using the specified reranker.
Parameters
----------
reranker: Reranker
The reranker to use.
Returns
-------
LanceFtsQueryBuilder
The LanceQueryBuilder object.
"""
self._reranker = reranker
return self
class LanceEmptyQueryBuilder(LanceQueryBuilder): class LanceEmptyQueryBuilder(LanceQueryBuilder):
def to_arrow(self) -> pa.Table: def to_arrow(self) -> pa.Table:
@@ -756,22 +639,20 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
class LanceHybridQueryBuilder(LanceQueryBuilder): class LanceHybridQueryBuilder(LanceQueryBuilder):
""" def __init__(
A query builder that performs hybrid vector and full text search. self,
Results are combined and reranked based on the specified reranker. table: "Table",
By default, the results are reranked using the LinearCombinationReranker. query: str,
vector_column: str,
To make the vector and fts results comparable, the scores are normalized. vector: Optional[VEC] = None,
Instead of normalizing scores, the `normalize` parameter can be set to "rank" text: Optional[str] = None,
in the `rerank` method to convert the scores to ranks and then normalize them. ):
"""
def __init__(self, table: "Table", query: str, vector_column: str):
super().__init__(table) super().__init__(table)
self._validate_fts_index() self._validate_fts_index()
vector_query, fts_query = self._validate_query(query) vector_query, fts_query = self._validate_query(
query, vector_column, vector, text
)
self._fts_query = LanceFtsQueryBuilder(table, fts_query) self._fts_query = LanceFtsQueryBuilder(table, fts_query)
vector_query = self._query_to_vector(table, vector_query, vector_column)
self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column) self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column)
self._norm = "score" self._norm = "score"
self._reranker = LinearCombinationReranker(weight=0.7, fill=1.0) self._reranker = LinearCombinationReranker(weight=0.7, fill=1.0)
@@ -782,23 +663,31 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
"Please create a full-text search index " "to perform hybrid search." "Please create a full-text search index " "to perform hybrid search."
) )
def _validate_query(self, query): def _validate_query(self, query, vector_column, vector, text):
# Temp hack to support vectorized queries for hybrid search if query is not None:
if isinstance(query, str): if vector is not None or text is not None:
return query, query
elif isinstance(query, tuple):
if len(query) != 2:
raise ValueError( raise ValueError(
"The query must be a tuple of (vector_query, fts_query)." "Either pass `query` or `vector` and `text` separately, not both."
) )
if not isinstance(query[0], (list, np.ndarray, pa.Array, pa.ChunkedArray)): else:
if vector is None or text is None:
raise ValueError(
"Either pass `query` or `vector` and `text` separately, not both."
)
if vector is not None and text is not None:
if not isinstance(vector, (list, np.ndarray, pa.Array, pa.ChunkedArray)):
raise ValueError(f"The vector query must be one of {VEC}.") raise ValueError(f"The vector query must be one of {VEC}.")
if not isinstance(query[1], str): if not isinstance(text, str):
raise ValueError("The fts query must be a string.") raise ValueError("The fts query must be a string.")
return query[0], query[1] return vector, text
if isinstance(query, str):
vector = self._query_to_vector(self._table, query, vector_column)
return vector, query
else: else:
raise ValueError( raise ValueError(
"The query must be either a string or a tuple of (vector, string)." f"For hybrid search `query` must be a string or `vector` and `text` \
must be provided explicitly of types {VEC} and str respectively."
) )
def to_arrow(self) -> pa.Table: def to_arrow(self) -> pa.Table:
@@ -1025,334 +914,3 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
""" """
self._vector_query.refine_factor(refine_factor) self._vector_query.refine_factor(refine_factor)
return self return self
class AsyncQueryBase(object):
def __init__(self, inner: Union[LanceQuery | LanceVectorQuery]):
"""
Construct an AsyncQueryBase
This method is not intended to be called directly. Instead, use the
[Table.query][] method to create a query.
"""
self._inner = inner
def where(self, predicate: str) -> AsyncQuery:
"""
Only return rows matching the given predicate
The predicate should be supplied as an SQL query string. For example:
>>> predicate = "x > 10"
>>> predicate = "y > 0 AND y < 100"
>>> predicate = "x > 5 OR y = 'test'"
Filtering performance can often be improved by creating a scalar index
on the filter column(s).
"""
self._inner.where(predicate)
return self
def select(self, columns: Union[List[str], dict[str, str]]) -> AsyncQuery:
"""
Return only the specified columns.
By default a query will return all columns from the table. However, this can
have a very significant impact on latency. LanceDb stores data in a columnar
fashion. This
means we can finely tune our I/O to select exactly the columns we need.
As a best practice you should always limit queries to the columns that you need.
If you pass in a list of column names then only those columns will be
returned.
You can also use this method to create new "dynamic" columns based on your
existing columns. For example, you may not care about "a" or "b" but instead
simply want "a + b". This is often seen in the SELECT clause of an SQL query
(e.g. `SELECT a+b FROM my_table`).
To create dynamic columns you can pass in a dict[str, str]. A column will be
returned for each entry in the map. The key provides the name of the column.
The value is an SQL string used to specify how the column is calculated.
For example, an SQL query might state `SELECT a + b AS combined, c`. The
equivalent input to this method would be `{"combined": "a + b", "c": "c"}`.
Columns will always be returned in the order given, even if that order is
different than the order used when adding the data.
"""
if isinstance(columns, dict):
column_tuples = list(columns.items())
else:
try:
column_tuples = [(c, c) for c in columns]
except TypeError:
raise TypeError("columns must be a list of column names or a dict")
self._inner.select(column_tuples)
return self
def limit(self, limit: int) -> AsyncQuery:
"""
Set the maximum number of results to return.
By default, a plain search has no limit. If this method is not
called then every valid row from the table will be returned.
"""
self._inner.limit(limit)
return self
async def to_batches(self) -> AsyncRecordBatchReader:
"""
Execute the query and return the results as an Apache Arrow RecordBatchReader.
"""
return AsyncRecordBatchReader(await self._inner.execute())
async def to_arrow(self) -> pa.Table:
"""
Execute the query and collect the results into an Apache Arrow Table.
This method will collect all results into memory before returning. If
you expect a large number of results, you may want to use [to_batches][]
"""
batch_iter = await self.to_batches()
return pa.Table.from_batches(
await batch_iter.read_all(), schema=batch_iter.schema
)
async def to_pandas(self) -> "pd.DataFrame":
"""
Execute the query and collect the results into a pandas DataFrame.
This method will collect all results into memory before returning. If
you expect a large number of results, you may want to use [to_batches][]
and convert each batch to pandas separately.
Example
-------
>>> import asyncio
>>> from lancedb import connect_async
>>> async def doctest_example():
... conn = await connect_async("./.lancedb")
... table = await conn.create_table("my_table", data=[{"a": 1, "b": 2}])
... async for batch in await table.query().to_batches():
... batch_df = batch.to_pandas()
>>> asyncio.run(doctest_example())
"""
return (await self.to_arrow()).to_pandas()
class AsyncQuery(AsyncQueryBase):
def __init__(self, inner: LanceQuery):
"""
Construct an AsyncQuery
This method is not intended to be called directly. Instead, use the
[Table.query][] method to create a query.
"""
super().__init__(inner)
self._inner = inner
@classmethod
def _query_vec_to_array(self, vec: Union[VEC, Tuple]):
if isinstance(vec, list):
return pa.array(vec)
if isinstance(vec, np.ndarray):
return pa.array(vec)
if isinstance(vec, pa.Array):
return vec
if isinstance(vec, pa.ChunkedArray):
return vec.combine_chunks()
if isinstance(vec, tuple):
return pa.array(vec)
# We've checked everything we formally support in our typings
# but, as a fallback, let pyarrow try and convert it anyway.
# This can allow for some more exotic things like iterables
return pa.array(vec)
def nearest_to(
self, query_vector: Optional[Union[VEC, Tuple]] = None
) -> AsyncVectorQuery:
"""
Find the nearest vectors to the given query vector.
This converts the query from a plain query to a vector query.
This method will attempt to convert the input to the query vector
expected by the embedding model. If the input cannot be converted
then an error will be thrown.
By default, there is no embedding model, and the input should be
something that can be converted to a pyarrow array of floats. This
includes lists, numpy arrays, and tuples.
If there is only one vector column (a column whose data type is a
fixed size list of floats) then the column does not need to be specified.
If there is more than one vector column you must use
[AsyncVectorQuery::column][] to specify which column you would like to
compare with.
If no index has been created on the vector column then a vector query
will perform a distance comparison between the query vector and every
vector in the database and then sort the results. This is sometimes
called a "flat search"
For small databases, with tens of thousands of vectors or less, this can
be reasonably fast. In larger databases you should create a vector index
on the column. If there is a vector index then an "approximate" nearest
neighbor search (frequently called an ANN search) will be performed. This
search is much faster, but the results will be approximate.
The query can be further parameterized using the returned builder. There
are various ANN search parameters that will let you fine tune your recall
accuracy vs search latency.
Vector searches always have a [limit][]. If `limit` has not been called then
a default `limit` of 10 will be used.
"""
return AsyncVectorQuery(
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
)
class AsyncVectorQuery(AsyncQueryBase):
def __init__(self, inner: LanceVectorQuery):
"""
Construct an AsyncVectorQuery
This method is not intended to be called directly. Instead, create
a query first with [Table.query][] and then use [AsyncQuery.nearest_to][]
to convert to a vector query.
"""
super().__init__(inner)
self._inner = inner
def column(self, column: str) -> AsyncVectorQuery:
"""
Set the vector column to query
This controls which column is compared to the query vector supplied in
the call to [Query.nearest_to][].
This parameter must be specified if the table has more than one column
whose data type is a fixed-size-list of floats.
"""
self._inner.column(column)
return self
def nprobes(self, nprobes: int) -> AsyncVectorQuery:
"""
Set the number of partitions to search (probe)
This argument is only used when the vector column has an IVF PQ index.
If there is no index then this value is ignored.
The IVF stage of IVF PQ divides the input into partitions (clusters) of
related values.
The partition whose centroids are closest to the query vector will be
exhaustiely searched to find matches. This parameter controls how many
partitions should be searched.
Increasing this value will increase the recall of your query but will
also increase the latency of your query. The default value is 20. This
default is good for many cases but the best value to use will depend on
your data and the recall that you need to achieve.
For best results we recommend tuning this parameter with a benchmark against
your actual data to find the smallest possible value that will still give
you the desired recall.
"""
self._inner.nprobes(nprobes)
return self
def refine_factor(self, refine_factor: int) -> AsyncVectorQuery:
"""
A multiplier to control how many additional rows are taken during the refine
step
This argument is only used when the vector column has an IVF PQ index.
If there is no index then this value is ignored.
An IVF PQ index stores compressed (quantized) values. They query vector is
compared against these values and, since they are compressed, the comparison is
inaccurate.
This parameter can be used to refine the results. It can improve both improve
recall and correct the ordering of the nearest results.
To refine results LanceDb will first perform an ANN search to find the nearest
`limit` * `refine_factor` results. In other words, if `refine_factor` is 3 and
`limit` is the default (10) then the first 30 results will be selected. LanceDb
then fetches the full, uncompressed, values for these 30 results. The results
are then reordered by the true distance and only the nearest 10 are kept.
Note: there is a difference between calling this method with a value of 1 and
never calling this method at all. Calling this method with any value will have
an impact on your search latency. When you call this method with a
`refine_factor` of 1 then LanceDb still needs to fetch the full, uncompressed,
values so that it can potentially reorder the results.
Note: if this method is NOT called then the distances returned in the _distance
column will be approximate distances based on the comparison of the quantized
query vector and the quantized result vectors. This can be considerably
different than the true distance between the query vector and the actual
uncompressed vector.
"""
self._inner.refine_factor(refine_factor)
return self
def distance_type(self, distance_type: str) -> AsyncVectorQuery:
"""
Set the distance metric to use
When performing a vector search we try and find the "nearest" vectors according
to some kind of distance metric. This parameter controls which distance metric
to use. See @see {@link IvfPqOptions.distanceType} for more details on the
different distance metrics available.
Note: if there is a vector index then the distance type used MUST match the
distance type used to train the vector index. If this is not done then the
results will be invalid.
By default "l2" is used.
"""
self._inner.distance_type(distance_type)
return self
def postfilter(self) -> AsyncVectorQuery:
"""
If this is called then filtering will happen after the vector search instead of
before.
By default filtering will be performed before the vector search. This is how
filtering is typically understood to work. This prefilter step does add some
additional latency. Creating a scalar index on the filter column(s) can
often improve this latency. However, sometimes a filter is too complex or
scalar indices cannot be applied to the column. In these cases postfiltering
can be used instead of prefiltering to improve latency.
Post filtering applies the filter to the results of the vector search. This
means we only run the filter on a much smaller set of data. However, it can
cause the query to return fewer than `limit` results (or even no results) if
none of the nearest results match the filter.
Post filtering happens during the "refine stage" (described in more detail in
@see {@link VectorQuery#refineFactor}). This means that setting a higher refine
factor can often help restore some of the results lost by post filtering.
"""
self._inner.postfilter()
return self
def bypass_vector_index(self) -> AsyncVectorQuery:
"""
If this is called then any vector index is skipped
An exhaustive (flat) search will be performed. The query vector will
be compared to every vector in the table. At high scales this can be
expensive. However, this is often still useful. For example, skipping
the vector index can give you ground truth results which you can use to
calculate your recall to select an appropriate value for nprobes.
"""
self._inner.bypass_vector_index()
return self

View File

@@ -58,9 +58,6 @@ class RestfulLanceDBClient:
closed: bool = attrs.field(default=False, init=False) closed: bool = attrs.field(default=False, init=False)
connection_timeout: float = attrs.field(default=120.0, kw_only=True)
read_timeout: float = attrs.field(default=300.0, kw_only=True)
@functools.cached_property @functools.cached_property
def session(self) -> requests.Session: def session(self) -> requests.Session:
sess = requests.Session() sess = requests.Session()
@@ -120,7 +117,7 @@ class RestfulLanceDBClient:
urljoin(self.url, uri), urljoin(self.url, uri),
params=params, params=params,
headers=self.headers, headers=self.headers,
timeout=(self.connection_timeout, self.read_timeout), timeout=(120.0, 300.0),
) as resp: ) as resp:
self._check_status(resp) self._check_status(resp)
return resp.json() return resp.json()
@@ -162,7 +159,7 @@ class RestfulLanceDBClient:
urljoin(self.url, uri), urljoin(self.url, uri),
headers=headers, headers=headers,
params=params, params=params,
timeout=(self.connection_timeout, self.read_timeout), timeout=(120.0, 300.0),
**req_kwargs, **req_kwargs,
) as resp: ) as resp:
self._check_status(resp) self._check_status(resp)

View File

@@ -41,8 +41,6 @@ class RemoteDBConnection(DBConnection):
region: str, region: str,
host_override: Optional[str] = None, host_override: Optional[str] = None,
request_thread_pool: Optional[ThreadPoolExecutor] = None, request_thread_pool: Optional[ThreadPoolExecutor] = None,
connection_timeout: float = 120.0,
read_timeout: float = 300.0,
): ):
"""Connect to a remote LanceDB database.""" """Connect to a remote LanceDB database."""
parsed = urlparse(db_url) parsed = urlparse(db_url)
@@ -51,12 +49,7 @@ class RemoteDBConnection(DBConnection):
self.db_name = parsed.netloc self.db_name = parsed.netloc
self.api_key = api_key self.api_key = api_key
self._client = RestfulLanceDBClient( self._client = RestfulLanceDBClient(
self.db_name, self.db_name, region, api_key, host_override
region,
api_key,
host_override,
connection_timeout=connection_timeout,
read_timeout=read_timeout,
) )
self._request_thread_pool = request_thread_pool self._request_thread_pool = request_thread_pool

View File

@@ -68,16 +68,10 @@ class RemoteTable(Table):
def list_indices(self): def list_indices(self):
"""List all the indices on the table""" """List all the indices on the table"""
print(self._name)
resp = self._conn._client.post(f"/v1/table/{self._name}/index/list/") resp = self._conn._client.post(f"/v1/table/{self._name}/index/list/")
return resp return resp
def index_stats(self, index_uuid: str):
"""List all the indices on the table"""
resp = self._conn._client.post(
f"/v1/table/{self._name}/index/{index_uuid}/stats/"
)
return resp
def create_scalar_index( def create_scalar_index(
self, self,
column: str, column: str,
@@ -295,9 +289,7 @@ class RemoteTable(Table):
vector_column_name = inf_vector_column_query(self.schema) vector_column_name = inf_vector_column_query(self.schema)
return LanceVectorQueryBuilder(self, query, vector_column_name) return LanceVectorQueryBuilder(self, query, vector_column_name)
def _execute_query( def _execute_query(self, query: Query) -> pa.Table:
self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader:
if ( if (
query.vector is not None query.vector is not None
and len(query.vector) > 0 and len(query.vector) > 0
@@ -323,12 +315,13 @@ class RemoteTable(Table):
q = query.copy() q = query.copy()
q.vector = v q.vector = v
results.append(submit(self._name, q)) results.append(submit(self._name, q))
return pa.concat_tables( return pa.concat_tables(
[add_index(r.result().to_arrow(), i) for i, r in enumerate(results)] [add_index(r.result().to_arrow(), i) for i, r in enumerate(results)]
).to_reader() )
else: else:
result = self._conn._client.query(self._name, query) result = self._conn._client.query(self._name, query)
return result.to_arrow().to_reader() return result.to_arrow()
def _do_merge( def _do_merge(
self, self,

View File

@@ -24,59 +24,8 @@ class Reranker(ABC):
raise ValueError("score must be either 'relevance' or 'all'") raise ValueError("score must be either 'relevance' or 'all'")
self.score = return_score self.score = return_score
def rerank_vector(
self,
query: str,
vector_results: pa.Table,
):
"""
Rerank function receives the result from the vector search.
This isn't mandatory to implement
Parameters
----------
query : str
The input query
vector_results : pa.Table
The results from the vector search
Returns
-------
pa.Table
The reranked results
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement rerank_vector"
)
def rerank_fts(
self,
query: str,
fts_results: pa.Table,
):
"""
Rerank function receives the result from the FTS search.
This isn't mandatory to implement
Parameters
----------
query : str
The input query
fts_results : pa.Table
The results from the FTS search
Returns
-------
pa.Table
The reranked results
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement rerank_fts"
)
@abstractmethod @abstractmethod
def rerank_hybrid( def rerank_hybrid(
self,
query: str, query: str,
vector_results: pa.Table, vector_results: pa.Table,
fts_results: pa.Table, fts_results: pa.Table,
@@ -94,11 +43,6 @@ class Reranker(ABC):
The results from the vector search The results from the vector search
fts_results : pa.Table fts_results : pa.Table
The results from the FTS search The results from the FTS search
Returns
-------
pa.Table
The reranked results
""" """
pass pass

View File

@@ -49,8 +49,14 @@ class CohereReranker(Reranker):
) )
return cohere.Client(os.environ.get("COHERE_API_KEY") or self.api_key) return cohere.Client(os.environ.get("COHERE_API_KEY") or self.api_key)
def _rerank(self, result_set: pa.Table, query: str): def rerank_hybrid(
docs = result_set[self.column].to_pylist() self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
docs = combined_results[self.column].to_pylist()
results = self._client.rerank( results = self._client.rerank(
query=query, query=query,
documents=docs, documents=docs,
@@ -60,22 +66,12 @@ class CohereReranker(Reranker):
indices, scores = list( indices, scores = list(
zip(*[(result.index, result.relevance_score) for result in results]) zip(*[(result.index, result.relevance_score) for result in results])
) # tuples ) # tuples
result_set = result_set.take(list(indices)) combined_results = combined_results.take(list(indices))
# add the scores # add the scores
result_set = result_set.append_column( combined_results = combined_results.append_column(
"_relevance_score", pa.array(scores, type=pa.float32()) "_relevance_score", pa.array(scores, type=pa.float32())
) )
return result_set
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
combined_results = self._rerank(combined_results, query)
if self.score == "relevance": if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"]) combined_results = combined_results.drop_columns(["score", "_distance"])
elif self.score == "all": elif self.score == "all":
@@ -83,25 +79,3 @@ class CohereReranker(Reranker):
"return_score='all' not implemented for cohere reranker" "return_score='all' not implemented for cohere reranker"
) )
return combined_results return combined_results
def rerank_vector(
self,
query: str,
vector_results: pa.Table,
):
result_set = self._rerank(vector_results, query)
if self.score == "relevance":
result_set = result_set.drop_columns(["_distance"])
return result_set
def rerank_fts(
self,
query: str,
fts_results: pa.Table,
):
result_set = self._rerank(fts_results, query)
if self.score == "relevance":
result_set = result_set.drop_columns(["score"])
return result_set

View File

@@ -33,8 +33,14 @@ class ColbertReranker(Reranker):
"torch" "torch"
) # import here for faster ops later ) # import here for faster ops later
def _rerank(self, result_set: pa.Table, query: str): def rerank_hybrid(
docs = result_set[self.column].to_pylist() self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
docs = combined_results[self.column].to_pylist()
tokenizer, model = self._model tokenizer, model = self._model
@@ -53,25 +59,14 @@ class ColbertReranker(Reranker):
scores.append(score.item()) scores.append(score.item())
# replace the self.column column with the docs # replace the self.column column with the docs
result_set = result_set.drop(self.column) combined_results = combined_results.drop(self.column)
result_set = result_set.append_column( combined_results = combined_results.append_column(
self.column, pa.array(docs, type=pa.string()) self.column, pa.array(docs, type=pa.string())
) )
# add the scores # add the scores
result_set = result_set.append_column( combined_results = combined_results.append_column(
"_relevance_score", pa.array(scores, type=pa.float32()) "_relevance_score", pa.array(scores, type=pa.float32())
) )
return result_set
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
combined_results = self._rerank(combined_results, query)
if self.score == "relevance": if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"]) combined_results = combined_results.drop_columns(["score", "_distance"])
elif self.score == "all": elif self.score == "all":
@@ -85,32 +80,6 @@ class ColbertReranker(Reranker):
return combined_results return combined_results
def rerank_vector(
self,
query: str,
vector_results: pa.Table,
):
result_set = self._rerank(vector_results, query)
if self.score == "relevance":
result_set = result_set.drop_columns(["_distance"])
result_set = result_set.sort_by([("_relevance_score", "descending")])
return result_set
def rerank_fts(
self,
query: str,
fts_results: pa.Table,
):
result_set = self._rerank(fts_results, query)
if self.score == "relevance":
result_set = result_set.drop_columns(["score"])
result_set = result_set.sort_by([("_relevance_score", "descending")])
return result_set
@cached_property @cached_property
def _model(self): def _model(self):
transformers = attempt_import_or_raise("transformers") transformers = attempt_import_or_raise("transformers")

View File

@@ -46,16 +46,6 @@ class CrossEncoderReranker(Reranker):
return cross_encoder return cross_encoder
def _rerank(self, result_set: pa.Table, query: str):
passages = result_set[self.column].to_pylist()
cross_inp = [[query, passage] for passage in passages]
cross_scores = self.model.predict(cross_inp)
result_set = result_set.append_column(
"_relevance_score", pa.array(cross_scores, type=pa.float32())
)
return result_set
def rerank_hybrid( def rerank_hybrid(
self, self,
query: str, query: str,
@@ -63,7 +53,13 @@ class CrossEncoderReranker(Reranker):
fts_results: pa.Table, fts_results: pa.Table,
): ):
combined_results = self.merge_results(vector_results, fts_results) combined_results = self.merge_results(vector_results, fts_results)
combined_results = self._rerank(combined_results, query) passages = combined_results[self.column].to_pylist()
cross_inp = [[query, passage] for passage in passages]
cross_scores = self.model.predict(cross_inp)
combined_results = combined_results.append_column(
"_relevance_score", pa.array(cross_scores, type=pa.float32())
)
# sort the results by _score # sort the results by _score
if self.score == "relevance": if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"]) combined_results = combined_results.drop_columns(["score", "_distance"])
@@ -76,27 +72,3 @@ class CrossEncoderReranker(Reranker):
) )
return combined_results return combined_results
def rerank_vector(
self,
query: str,
vector_results: pa.Table,
):
vector_results = self._rerank(vector_results, query)
if self.score == "relevance":
vector_results = vector_results.drop_columns(["_distance"])
vector_results = vector_results.sort_by([("_relevance_score", "descending")])
return vector_results
def rerank_fts(
self,
query: str,
fts_results: pa.Table,
):
fts_results = self._rerank(fts_results, query)
if self.score == "relevance":
fts_results = fts_results.drop_columns(["score"])
fts_results = fts_results.sort_by([("_relevance_score", "descending")])
return fts_results

View File

@@ -39,8 +39,14 @@ class OpenaiReranker(Reranker):
self.column = column self.column = column
self.api_key = api_key self.api_key = api_key
def _rerank(self, result_set: pa.Table, query: str): def rerank_hybrid(
docs = result_set[self.column].to_pylist() self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
docs = combined_results[self.column].to_pylist()
response = self._client.chat.completions.create( response = self._client.chat.completions.create(
model=self.model_name, model=self.model_name,
response_format={"type": "json_object"}, response_format={"type": "json_object"},
@@ -64,25 +70,14 @@ class OpenaiReranker(Reranker):
zip(*[(result["content"], result["relevance_score"]) for result in results]) zip(*[(result["content"], result["relevance_score"]) for result in results])
) # tuples ) # tuples
# replace the self.column column with the docs # replace the self.column column with the docs
result_set = result_set.drop(self.column) combined_results = combined_results.drop(self.column)
result_set = result_set.append_column( combined_results = combined_results.append_column(
self.column, pa.array(docs, type=pa.string()) self.column, pa.array(docs, type=pa.string())
) )
# add the scores # add the scores
result_set = result_set.append_column( combined_results = combined_results.append_column(
"_relevance_score", pa.array(scores, type=pa.float32()) "_relevance_score", pa.array(scores, type=pa.float32())
) )
return result_set
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
combined_results = self._rerank(combined_results, query)
if self.score == "relevance": if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"]) combined_results = combined_results.drop_columns(["score", "_distance"])
elif self.score == "all": elif self.score == "all":
@@ -96,24 +91,6 @@ class OpenaiReranker(Reranker):
return combined_results return combined_results
def rerank_vector(self, query: str, vector_results: pa.Table):
vector_results = self._rerank(vector_results, query)
if self.score == "relevance":
vector_results = vector_results.drop_columns(["_distance"])
vector_results = vector_results.sort_by([("_relevance_score", "descending")])
return vector_results
def rerank_fts(self, query: str, fts_results: pa.Table):
fts_results = self._rerank(fts_results, query)
if self.score == "relevance":
fts_results = fts_results.drop_columns(["score"])
fts_results = fts_results.sort_by([("_relevance_score", "descending")])
return fts_results
@cached_property @cached_property
def _client(self): def _client(self):
openai = attempt_import_or_raise( openai = attempt_import_or_raise(

View File

@@ -19,17 +19,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
from functools import cached_property from functools import cached_property
from typing import ( from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Literal,
Optional,
Tuple,
Union,
)
import lance import lance
import numpy as np import numpy as np
@@ -37,14 +27,14 @@ import pyarrow as pa
import pyarrow.compute as pc import pyarrow.compute as pc
import pyarrow.fs as pa_fs import pyarrow.fs as pa_fs
from lance import LanceDataset from lance import LanceDataset
from lance.dependencies import _check_for_hugging_face
from lance.vector import vec_to_table from lance.vector import vec_to_table
from overrides import override
from .common import DATA, VEC, VECTOR_COLUMN_NAME from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from .merge import LanceMergeInsertBuilder from .merge import LanceMergeInsertBuilder
from .pydantic import LanceModel, model_to_dict from .pydantic import LanceModel, model_to_dict
from .query import AsyncQuery, AsyncVectorQuery, LanceQueryBuilder, Query from .query import LanceQueryBuilder, Query
from .util import ( from .util import (
fs_from_uri, fs_from_uri,
inf_vector_column_query, inf_vector_column_query,
@@ -61,7 +51,6 @@ if TYPE_CHECKING:
from ._lancedb import Table as LanceDBTable from ._lancedb import Table as LanceDBTable
from .db import LanceDBConnection from .db import LanceDBConnection
from .index import BTree, IndexConfig, IvfPq
pd = safe_import_pandas() pd = safe_import_pandas()
@@ -75,27 +64,6 @@ def _sanitize_data(
on_bad_vectors: str, on_bad_vectors: str,
fill_value: Any, fill_value: Any,
): ):
if _check_for_hugging_face(data):
# Huggingface datasets
from lance.dependencies import datasets
if isinstance(data, datasets.dataset_dict.DatasetDict):
if schema is None:
schema = _schema_from_hf(data, schema)
data = _to_record_batch_generator(
_to_batches_with_split(data),
schema,
metadata,
on_bad_vectors,
fill_value,
)
elif isinstance(data, datasets.Dataset):
if schema is None:
schema = data.features.arrow_schema
data = _to_record_batch_generator(
data.data.to_batches(), schema, metadata, on_bad_vectors, fill_value
)
if isinstance(data, list): if isinstance(data, list):
# convert to list of dict if data is a bunch of LanceModels # convert to list of dict if data is a bunch of LanceModels
if isinstance(data[0], LanceModel): if isinstance(data[0], LanceModel):
@@ -132,37 +100,6 @@ def _sanitize_data(
return data return data
def _schema_from_hf(data, schema):
"""
Extract pyarrow schema from HuggingFace DatasetDict
and validate that they're all the same schema between
splits
"""
for dataset in data.values():
if schema is None:
schema = dataset.features.arrow_schema
elif schema != dataset.features.arrow_schema:
msg = "All datasets in a HuggingFace DatasetDict must have the same schema"
raise TypeError(msg)
return schema
def _to_batches_with_split(data):
"""
Return a generator of RecordBatches from a HuggingFace DatasetDict
with an extra `split` column
"""
for key, dataset in data.items():
for batch in dataset.data.to_batches():
table = pa.Table.from_batches([batch])
if "split" not in table.column_names:
table = table.append_column(
"split", pa.array([key] * batch.num_rows, pa.string())
)
for b in table.to_batches():
yield b
def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schema]): def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schema]):
""" """
Use the embedding function to automatically embed the source column and add the Use the embedding function to automatically embed the source column and add the
@@ -171,8 +108,7 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata) functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
for vector_column, conf in functions.items(): for vector_column, conf in functions.items():
func = conf.function func = conf.function
no_vector_column = vector_column not in data.column_names if vector_column not in data.column_names:
if no_vector_column or pc.all(pc.is_null(data[vector_column])).as_py():
col_data = func.compute_source_embeddings_with_retry( col_data = func.compute_source_embeddings_with_retry(
data[conf.source_column] data[conf.source_column]
) )
@@ -180,16 +116,9 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
dtype = schema.field(vector_column).type dtype = schema.field(vector_column).type
else: else:
dtype = pa.list_(pa.float32(), len(col_data[0])) dtype = pa.list_(pa.float32(), len(col_data[0]))
if no_vector_column:
data = data.append_column( data = data.append_column(
pa.field(vector_column, type=dtype), pa.array(col_data, type=dtype) pa.field(vector_column, type=dtype), pa.array(col_data, type=dtype)
) )
else:
data = data.set_column(
data.column_names.index(vector_column),
pa.field(vector_column, type=dtype),
pa.array(col_data, type=dtype),
)
return data return data
@@ -197,13 +126,12 @@ def _to_record_batch_generator(
data: Iterable, schema, metadata, on_bad_vectors, fill_value data: Iterable, schema, metadata, on_bad_vectors, fill_value
): ):
for batch in data: for batch in data:
# always convert to table because we need to sanitize the data if not isinstance(batch, pa.RecordBatch):
# and do things like add the vector column etc table = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value)
if isinstance(batch, pa.RecordBatch): for batch in table.to_batches():
batch = pa.Table.from_batches([batch]) yield batch
batch = _sanitize_data(batch, schema, metadata, on_bad_vectors, fill_value) else:
for b in batch.to_batches(): yield batch
yield b
class Table(ABC): class Table(ABC):
@@ -490,6 +418,8 @@ class Table(ABC):
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None, vector_column_name: Optional[str] = None,
query_type: str = "auto", query_type: str = "auto",
vector: Optional[VEC] = None,
text: Optional[str] = None,
) -> LanceQueryBuilder: ) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors """Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search] of the given query vector. We currently support [vector search][search]
@@ -568,9 +498,7 @@ class Table(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def _execute_query( def _execute_query(self, query: Query) -> pa.Table:
self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader:
pass pass
@abstractmethod @abstractmethod
@@ -1161,7 +1089,6 @@ class LanceTable(Table):
def create_fts_index( def create_fts_index(
self, self,
field_names: Union[str, List[str]], field_names: Union[str, List[str]],
ordering_field_names: Union[str, List[str]] = None,
*, *,
replace: bool = False, replace: bool = False,
writer_heap_size: Optional[int] = 1024 * 1024 * 1024, writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
@@ -1180,18 +1107,12 @@ class LanceTable(Table):
not yet an atomic operation; the index will be temporarily not yet an atomic operation; the index will be temporarily
unavailable while the new index is being created. unavailable while the new index is being created.
writer_heap_size: int, default 1GB writer_heap_size: int, default 1GB
ordering_field_names:
A list of unsigned type fields to index to optionally order
results on at search time
""" """
from .fts import create_index, populate_index from .fts import create_index, populate_index
if isinstance(field_names, str): if isinstance(field_names, str):
field_names = [field_names] field_names = [field_names]
if isinstance(ordering_field_names, str):
ordering_field_names = [ordering_field_names]
fs, path = fs_from_uri(self._get_fts_index_path()) fs, path = fs_from_uri(self._get_fts_index_path())
index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound
if index_exists: if index_exists:
@@ -1199,18 +1120,8 @@ class LanceTable(Table):
raise ValueError("Index already exists. Use replace=True to overwrite.") raise ValueError("Index already exists. Use replace=True to overwrite.")
fs.delete_dir(path) fs.delete_dir(path)
index = create_index( index = create_index(self._get_fts_index_path(), field_names)
self._get_fts_index_path(), populate_index(index, self, field_names, writer_heap_size=writer_heap_size)
field_names,
ordering_fields=ordering_field_names,
)
populate_index(
index,
self,
field_names,
ordering_fields=ordering_field_names,
writer_heap_size=writer_heap_size,
)
register_event("create_fts_index") register_event("create_fts_index")
def _get_fts_index_path(self): def _get_fts_index_path(self):
@@ -1344,7 +1255,8 @@ class LanceTable(Table):
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None, vector_column_name: Optional[str] = None,
query_type: str = "auto", query_type: str = "auto",
ordering_field_name: Optional[str] = None, vector: Optional[VEC] = None,
text: Optional[str] = None,
) -> LanceQueryBuilder: ) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors """Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search] of the given query vector. We currently support [vector search][search]
@@ -1399,6 +1311,10 @@ class LanceTable(Table):
or raise an error if no corresponding embedding function is found. or raise an error if no corresponding embedding function is found.
If the `query` is a string, then the query type is "vector" if the If the `query` is a string, then the query type is "vector" if the
table has embedding functions, else the query type is "fts" table has embedding functions, else the query type is "fts"
vector: list/np.ndarray, default None
vector query for hybrid search
text: str, default None
text query for hybrid search
Returns Returns
------- -------
@@ -1408,7 +1324,8 @@ class LanceTable(Table):
and also the "_distance" column which is the distance between the query and also the "_distance" column which is the distance between the query
vector and the returned vector. vector and the returned vector.
""" """
if vector_column_name is None and query is not None: is_query_defined = query is not None or vector is not None or text is not None
if vector_column_name is None and is_query_defined:
vector_column_name = inf_vector_column_query(self.schema) vector_column_name = inf_vector_column_query(self.schema)
register_event("search_table") register_event("search_table")
return LanceQueryBuilder.create( return LanceQueryBuilder.create(
@@ -1416,7 +1333,8 @@ class LanceTable(Table):
query, query,
query_type, query_type,
vector_column_name=vector_column_name, vector_column_name=vector_column_name,
ordering_field_name=ordering_field_name, vector=vector,
text=text,
) )
@classmethod @classmethod
@@ -1598,11 +1516,10 @@ class LanceTable(Table):
self._dataset_mut.update(values_sql, where) self._dataset_mut.update(values_sql, where)
register_event("update") register_event("update")
def _execute_query( def _execute_query(self, query: Query) -> pa.Table:
self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader:
ds = self.to_lance() ds = self.to_lance()
return ds.scanner(
return ds.to_table(
columns=query.columns, columns=query.columns,
filter=query.filter, filter=query.filter,
prefilter=query.prefilter, prefilter=query.prefilter,
@@ -1615,8 +1532,7 @@ class LanceTable(Table):
"refine_factor": query.refine_factor, "refine_factor": query.refine_factor,
}, },
with_row_id=query.with_row_id, with_row_id=query.with_row_id,
batch_size=batch_size, )
).to_reader()
def _do_merge( def _do_merge(
self, self,
@@ -1882,23 +1798,9 @@ def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name
return data return data
class AsyncTable: class AsyncTable(ABC):
""" """
An AsyncTable is a collection of Records in a LanceDB Database. A Table is a collection of Records in a LanceDB Database.
An AsyncTable can be obtained from the
[AsyncConnection.create_table][lancedb.AsyncConnection.create_table] and
[AsyncConnection.open_table][lancedb.AsyncConnection.open_table] methods.
An AsyncTable object is expected to be long lived and reused for multiple
operations. AsyncTable objects will cache a certain amount of index data in memory.
This cache will be freed when the Table is garbage collected. To eagerly free the
cache you can call the [close][AsyncTable.close] method. Once the AsyncTable is
closed, it cannot be used for any further operations.
An AsyncTable can also be used as a context manager, and will automatically close
when the context is exited. Closing a table is optional. If you do not close the
table, it will be closed when the AsyncTable object is garbage collected.
Examples Examples
-------- --------
@@ -1933,49 +1835,21 @@ class AsyncTable:
[Table.create_index][lancedb.table.Table.create_index]. [Table.create_index][lancedb.table.Table.create_index].
""" """
def __init__(self, table: LanceDBTable):
"""Create a new Table object.
You should not create Table objects directly.
Use [AsyncConnection.create_table][lancedb.AsyncConnection.create_table] and
[AsyncConnection.open_table][lancedb.AsyncConnection.open_table] to obtain
Table objects."""
self._inner = table
def __repr__(self):
return self._inner.__repr__()
def __enter__(self):
return self
def __exit__(self, *_):
self.close()
def is_open(self) -> bool:
"""Return True if the table is closed."""
return self._inner.is_open()
def close(self):
"""Close the table and free any resources associated with it.
It is safe to call this method multiple times.
Any attempt to use the table after it has been closed will raise an error."""
return self._inner.close()
@property @property
@abstractmethod
def name(self) -> str: def name(self) -> str:
"""The name of the table.""" """The name of the table."""
return self._inner.name() raise NotImplementedError
@abstractmethod
async def schema(self) -> pa.Schema: async def schema(self) -> pa.Schema:
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#) """The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
of this Table of this Table
""" """
return await self._inner.schema() raise NotImplementedError
@abstractmethod
async def count_rows(self, filter: Optional[str] = None) -> int: async def count_rows(self, filter: Optional[str] = None) -> int:
""" """
Count the number of rows in the table. Count the number of rows in the table.
@@ -1985,10 +1859,7 @@ class AsyncTable:
filter: str, optional filter: str, optional
A SQL where clause to filter the rows to count. A SQL where clause to filter the rows to count.
""" """
return await self._inner.count_rows(filter) raise NotImplementedError
def query(self) -> AsyncQuery:
return AsyncQuery(self._inner.query())
async def to_pandas(self) -> "pd.DataFrame": async def to_pandas(self) -> "pd.DataFrame":
"""Return the table as a pandas DataFrame. """Return the table as a pandas DataFrame.
@@ -1997,8 +1868,9 @@ class AsyncTable:
------- -------
pd.DataFrame pd.DataFrame
""" """
return (await self.to_arrow()).to_pandas() return self.to_arrow().to_pandas()
@abstractmethod
async def to_arrow(self) -> pa.Table: async def to_arrow(self) -> pa.Table:
"""Return the table as a pyarrow Table. """Return the table as a pyarrow Table.
@@ -2006,59 +1878,124 @@ class AsyncTable:
------- -------
pa.Table pa.Table
""" """
return await self.query().to_arrow() raise NotImplementedError
async def create_index( async def create_index(
self, self,
column: str, metric="L2",
*, num_partitions=256,
replace: Optional[bool] = None, num_sub_vectors=96,
config: Optional[Union[IvfPq, BTree]] = None, vector_column_name: str = VECTOR_COLUMN_NAME,
replace: bool = True,
accelerator: Optional[str] = None,
index_cache_size: Optional[int] = None,
): ):
"""Create an index to speed up queries """Create an index on the table.
Indices can be created on vector columns or scalar columns.
Indices on vector columns will speed up vector searches.
Indices on scalar columns will speed up filtering (in both
vector and non-vector searches)
Parameters Parameters
---------- ----------
index: Index metric: str, default "L2"
The index to create. The distance metric to use when creating the index.
Valid values are "L2", "cosine", or "dot".
LanceDb supports multiple types of indices. See the static methods on L2 is euclidean distance.
the Index class for more details. num_partitions: int, default 256
column: str, default None The number of IVF partitions to use when creating the index.
The column to index. Default is 256.
num_sub_vectors: int, default 96
When building a scalar index this must be set. The number of PQ sub-vectors to use when creating the index.
Default is 96.
When building a vector index, this is optional. The default will look vector_column_name: str, default "vector"
for any columns of type fixed-size-list with floating point values. If The vector column name to create the index.
there is only one column of this type then it will be used. Otherwise
an error will be returned.
replace: bool, default True replace: bool, default True
Whether to replace the existing index - If True, replace the existing index if it exists.
If this is false, and another index already exists on the same columns - If False, raise an error if duplicate index exists.
and the same name, then an error will be returned. This is true even if accelerator: str, default None
that index is out of date. If set, use the given accelerator to create the index.
Only support "cuda" for now.
The default is True index_cache_size : int, optional
The size of the index cache in number of entries. Default value is 256.
""" """
index = None raise NotImplementedError
if config is not None:
index = config._inner
await self._inner.create_index(column, index=index, replace=replace)
@abstractmethod
async def create_scalar_index(
self,
column: str,
*,
replace: bool = True,
):
"""Create a scalar index on a column.
Scalar indices, like vector indices, can be used to speed up scans. A scalar
index can speed up scans that contain filter expressions on the indexed column.
For example, the following scan will be faster if the column ``my_col`` has
a scalar index:
.. code-block:: python
import lancedb
db = lancedb.connect("/data/lance")
img_table = db.open_table("images")
my_df = img_table.search().where("my_col = 7", prefilter=True).to_pandas()
Scalar indices can also speed up scans containing a vector search and a
prefilter:
.. code-block::python
import lancedb
db = lancedb.connect("/data/lance")
img_table = db.open_table("images")
img_table.search([1, 2, 3, 4], vector_column_name="vector")
.where("my_col != 7", prefilter=True)
.to_pandas()
Scalar indices can only speed up scans for basic filters using
equality, comparison, range (e.g. ``my_col BETWEEN 0 AND 100``), and set
membership (e.g. `my_col IN (0, 1, 2)`)
Scalar indices can be used if the filter contains multiple indexed columns and
the filter criteria are AND'd or OR'd together
(e.g. ``my_col < 0 AND other_col> 100``)
Scalar indices may be used if the filter contains non-indexed columns but,
depending on the structure of the filter, they may not be usable. For example,
if the column ``not_indexed`` does not have a scalar index then the filter
``my_col = 0 OR not_indexed = 1`` will not be able to use any scalar index on
``my_col``.
**Experimental API**
Parameters
----------
column : str
The column to be indexed. Must be a boolean, integer, float,
or string column.
replace : bool, default True
Replace the existing index if it exists.
Examples
--------
.. code-block:: python
import lance
dataset = lance.dataset("./images.lance")
dataset.create_scalar_index("category")
"""
raise NotImplementedError
@abstractmethod
async def add( async def add(
self, self,
data: DATA, data: DATA,
*, mode: str = "append",
mode: Optional[Literal["append", "overwrite"]] = "append", on_bad_vectors: str = "error",
on_bad_vectors: Optional[str] = None, fill_value: float = 0.0,
fill_value: Optional[float] = None,
): ):
"""Add more data to the [Table](Table). """Add more data to the [Table](Table).
@@ -2082,22 +2019,7 @@ class AsyncTable:
The value to use when filling vectors. Only used if on_bad_vectors="fill". The value to use when filling vectors. Only used if on_bad_vectors="fill".
""" """
schema = await self.schema() raise NotImplementedError
if on_bad_vectors is None:
on_bad_vectors = "error"
if fill_value is None:
fill_value = 0.0
data = _sanitize_data(
data,
schema,
metadata=schema.metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
if isinstance(data, pa.Table):
data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches())
await self._inner.add(data, mode)
register_event("add")
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
""" """
@@ -2159,23 +2081,94 @@ class AsyncTable:
return LanceMergeInsertBuilder(self, on) return LanceMergeInsertBuilder(self, on)
def vector_search( @abstractmethod
async def search(
self, self,
query_vector: Optional[Union[VEC, Tuple]] = None, query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
) -> AsyncVectorQuery: vector_column_name: Optional[str] = None,
""" query_type: str = "auto",
Search the table with a given query vector. ) -> LanceQueryBuilder:
This is a convenience method for preparing a vector query and """Create a search query to find the nearest neighbors
is the same thing as calling `nearestTo` on the builder returned of the given query vector. We currently support [vector search][search]
by `query`. Seer [nearest_to][AsyncQuery.nearest_to] for more details. and [full-text search][experimental-full-text-search].
"""
return self.query().nearest_to(query_vector)
async def _execute_query( All query options are defined in [Query][lancedb.query.Query].
self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader: Examples
--------
>>> import lancedb
>>> db = lancedb.connect("./.lancedb")
>>> data = [
... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]},
... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]},
... {"original_width": 3000, "caption": "test", "vector": [0.3, 6.2, 2.6]}
... ]
>>> table = db.create_table("my_table", data)
>>> query = [0.4, 1.4, 2.4]
>>> (table.search(query)
... .where("original_width > 1000", prefilter=True)
... .select(["caption", "original_width", "vector"])
... .limit(2)
... .to_pandas())
caption original_width vector _distance
0 foo 2000 [0.5, 3.4, 1.3] 5.220000
1 test 3000 [0.3, 6.2, 2.6] 23.089996
Parameters
----------
query: list/np.ndarray/str/PIL.Image.Image, default None
The targetted vector to search for.
- *default None*.
Acceptable types are: list, np.ndarray, PIL.Image.Image
- If None then the select/where/limit clauses are applied to filter
the table
vector_column_name: str, optional
The name of the vector column to search.
The vector column needs to be a pyarrow fixed size list type
- If not specified then the vector column is inferred from
the table schema
- If the table has multiple vector columns then the *vector_column_name*
needs to be specified. Otherwise, an error is raised.
query_type: str
*default "auto"*.
Acceptable types are: "vector", "fts", "hybrid", or "auto"
- If "auto" then the query type is inferred from the query;
- If `query` is a list/np.ndarray then the query type is
"vector";
- If `query` is a PIL.Image.Image then either do vector search,
or raise an error if no corresponding embedding function is found.
- If `query` is a string, then the query type is "vector" if the
table has embedding functions else the query type is "fts"
Returns
-------
LanceQueryBuilder
A query builder object representing the query.
Once executed, the query returns
- selected columns
- the vector
- and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
raise NotImplementedError
@abstractmethod
async def _execute_query(self, query: Query) -> pa.Table:
pass pass
@abstractmethod
async def _do_merge( async def _do_merge(
self, self,
merge: LanceMergeInsertBuilder, merge: LanceMergeInsertBuilder,
@@ -2185,6 +2178,7 @@ class AsyncTable:
): ):
pass pass
@abstractmethod
async def delete(self, where: str): async def delete(self, where: str):
"""Delete rows from the table. """Delete rows from the table.
@@ -2235,60 +2229,63 @@ class AsyncTable:
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
async def update( async def update(
self, self,
updates: Optional[Dict[str, Any]] = None,
*,
where: Optional[str] = None, where: Optional[str] = None,
updates_sql: Optional[Dict[str, str]] = None, values: Optional[dict] = None,
*,
values_sql: Optional[Dict[str, str]] = None,
): ):
""" """
This can be used to update zero to all rows in the table. This can be used to update zero to all rows depending on how many
rows match the where clause. If no where clause is provided, then
all rows will be updated.
If a filter is provided with `where` then only rows matching the Either `values` or `values_sql` must be provided. You cannot provide
filter will be updated. Otherwise all rows will be updated. both.
Parameters Parameters
---------- ----------
updates: dict, optional
The updates to apply. The keys should be the name of the column to
update. The values should be the new values to assign. This is
required unless updates_sql is supplied.
where: str, optional where: str, optional
An SQL filter that controls which rows are updated. For example, 'x = 2' The SQL where clause to use when updating rows. For example, 'x = 2'
or 'x IN (1, 2, 3)'. Only rows that satisfy this filter will be udpated. or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
updates_sql: dict, optional values: dict, optional
The updates to apply, expressed as SQL expression strings. The keys should The values to update. The keys are the column names and the values
be column names. The values should be SQL expressions. These can be SQL are the values to set.
literals (e.g. "7" or "'foo'") or they can be expressions based on the values_sql: dict, optional
previous value of the row (e.g. "x + 1" to increment the x column by 1) The values to update, expressed as SQL expression strings. These can
reference existing columns. For example, {"x": "x + 1"} will increment
the x column by 1.
Examples Examples
-------- --------
>>> import asyncio
>>> import lancedb >>> import lancedb
>>> import pandas as pd >>> import pandas as pd
>>> async def demo_update(): >>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
... data = pd.DataFrame({"x": [1, 2], "vector": [[1, 2], [3, 4]]}) >>> db = lancedb.connect("./.lancedb")
... db = await lancedb.connect_async("./.lancedb") >>> table = db.create_table("my_table", data)
... table = await db.create_table("my_table", data) >>> table.to_pandas()
... # x is [1, 2], vector is [[1, 2], [3, 4]] x vector
... await table.update({"vector": [10, 10]}, where="x = 2") 0 1 [1.0, 2.0]
... # x is [1, 2], vector is [[1, 2], [10, 10]] 1 2 [3.0, 4.0]
... await table.update(updates_sql={"x": "x + 1"}) 2 3 [5.0, 6.0]
... # x is [2, 3], vector is [[1, 2], [10, 10]] >>> table.update(where="x = 2", values={"vector": [10, 10]})
>>> asyncio.run(demo_update()) >>> table.to_pandas()
x vector
0 1 [1.0, 2.0]
1 3 [5.0, 6.0]
2 2 [10.0, 10.0]
>>> table.update(values_sql={"x": "x + 1"})
>>> table.to_pandas()
x vector
0 2 [1.0, 2.0]
1 4 [5.0, 6.0]
2 3 [10.0, 10.0]
""" """
if updates is not None and updates_sql is not None: raise NotImplementedError
raise ValueError("Only one of updates or updates_sql can be provided")
if updates is None and updates_sql is None:
raise ValueError("Either updates or updates_sql must be provided")
if updates is not None:
updates_sql = {k: value_to_sql(v) for k, v in updates.items()}
return await self._inner.update(updates_sql, where)
@abstractmethod
async def cleanup_old_versions( async def cleanup_old_versions(
self, self,
older_than: Optional[timedelta] = None, older_than: Optional[timedelta] = None,
@@ -2320,6 +2317,7 @@ class AsyncTable:
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
async def compact_files(self, *args, **kwargs): async def compact_files(self, *args, **kwargs):
""" """
Run the compaction process on the table. Run the compaction process on the table.
@@ -2335,6 +2333,7 @@ class AsyncTable:
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
async def add_columns(self, transforms: Dict[str, str]): async def add_columns(self, transforms: Dict[str, str]):
""" """
Add new columns with defined values. Add new columns with defined values.
@@ -2350,6 +2349,7 @@ class AsyncTable:
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
async def alter_columns(self, alterations: Iterable[Dict[str, str]]): async def alter_columns(self, alterations: Iterable[Dict[str, str]]):
""" """
Alter column names and nullability. Alter column names and nullability.
@@ -2372,6 +2372,7 @@ class AsyncTable:
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
async def drop_columns(self, columns: Iterable[str]): async def drop_columns(self, columns: Iterable[str]):
""" """
Drop columns from the table. Drop columns from the table.
@@ -2385,64 +2386,125 @@ class AsyncTable:
""" """
raise NotImplementedError raise NotImplementedError
async def version(self) -> int:
"""
Retrieve the version of the table
LanceDb supports versioning. Every operation that modifies the table increases class AsyncLanceTable(AsyncTable):
version. As long as a version hasn't been deleted you can `[Self::checkout]` def __init__(self, table: LanceDBTable):
that version to view the data at that point. In addition, you can self._inner = table
`[Self::restore]` the version to replace the current table with a previous
version.
"""
return await self._inner.version()
async def checkout(self, version): @property
""" @override
Checks out a specific version of the Table def name(self) -> str:
return self._inner.name()
Any read operation on the table will now access the data at the checked out @override
version. As a consequence, calling this method will disable any read consistency async def schema(self) -> pa.Schema:
interval that was previously set. return await self._inner.schema()
This is a read-only operation that turns the table into a sort of "view" @override
or "detached head". Other table instances will not be affected. To make the async def count_rows(self, filter: Optional[str] = None) -> int:
change permanent you can use the `[Self::restore]` method. raise NotImplementedError
Any operation that modifies the table will fail while the table is in a checked async def to_pandas(self) -> "pd.DataFrame":
out state. return self.to_arrow().to_pandas()
To return the table to a normal state use `[Self::checkout_latest]` @override
""" async def to_arrow(self) -> pa.Table:
await self._inner.checkout(version) raise NotImplementedError
async def checkout_latest(self): async def create_index(
""" self,
Ensures the table is pointing at the latest version metric="L2",
num_partitions=256,
num_sub_vectors=96,
vector_column_name: str = VECTOR_COLUMN_NAME,
replace: bool = True,
accelerator: Optional[str] = None,
index_cache_size: Optional[int] = None,
):
raise NotImplementedError
This can be used to manually update a table when the read_consistency_interval @override
is None async def create_scalar_index(
It can also be used to undo a `[Self::checkout]` operation self,
""" column: str,
await self._inner.checkout_latest() *,
replace: bool = True,
):
raise NotImplementedError
async def restore(self): @override
""" async def add(
Restore the table to the currently checked out version self,
data: DATA,
mode: str = "append",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
):
raise NotImplementedError
This operation will fail if checkout has not been called previously def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
on = [on] if isinstance(on, str) else list(on.iter())
This operation will overwrite the latest version of the table with a return LanceMergeInsertBuilder(self, on)
previous version. Any changes made since the checked out version will
no longer be visible.
Once the operation concludes the table will no longer be in a checked @override
out state and the read_consistency_interval, if any, will apply. async def search(
""" self,
await self._inner.restore() query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: str = "auto",
) -> LanceQueryBuilder:
raise NotImplementedError
async def list_indices(self) -> IndexConfig: @override
""" async def _execute_query(self, query: Query) -> pa.Table:
List all indices that have been created with Self::create_index pass
"""
return await self._inner.list_indices() @override
async def _do_merge(
self,
merge: LanceMergeInsertBuilder,
new_data: DATA,
on_bad_vectors: str,
fill_value: float,
):
pass
@override
async def delete(self, where: str):
raise NotImplementedError
@override
async def update(
self,
where: Optional[str] = None,
values: Optional[dict] = None,
*,
values_sql: Optional[Dict[str, str]] = None,
):
raise NotImplementedError
@override
async def cleanup_old_versions(
self,
older_than: Optional[timedelta] = None,
*,
delete_unverified: bool = False,
) -> CleanupStats:
raise NotImplementedError
@override
async def compact_files(self, *args, **kwargs):
raise NotImplementedError
@override
async def add_columns(self, transforms: Dict[str, str]):
raise NotImplementedError
@override
async def alter_columns(self, alterations: Iterable[Dict[str, str]]):
raise NotImplementedError
@override
async def drop_columns(self, columns: Iterable[str]):
raise NotImplementedError

View File

@@ -26,18 +26,6 @@ import pyarrow as pa
import pyarrow.fs as pa_fs import pyarrow.fs as pa_fs
def safe_import_adlfs():
try:
import adlfs
return adlfs
except ImportError:
return None
adlfs = safe_import_adlfs()
def get_uri_scheme(uri: str) -> str: def get_uri_scheme(uri: str) -> str:
""" """
Get the scheme of a URI. If the URI does not have a scheme, assume it is a file URI. Get the scheme of a URI. If the URI does not have a scheme, assume it is a file URI.
@@ -104,17 +92,6 @@ def fs_from_uri(uri: str) -> Tuple[pa_fs.FileSystem, str]:
path = get_uri_location(uri) path = get_uri_location(uri)
return fs, path return fs, path
elif get_uri_scheme(uri) == "az" and adlfs is not None:
az_blob_fs = adlfs.AzureBlobFileSystem(
account_name=os.environ.get("AZURE_STORAGE_ACCOUNT_NAME"),
account_key=os.environ.get("AZURE_STORAGE_ACCOUNT_KEY"),
)
fs = pa_fs.PyFileSystem(pa_fs.FSSpecHandler(az_blob_fs))
path = get_uri_location(uri)
return fs, path
return pa_fs.FileSystem.from_uri(uri) return pa_fs.FileSystem.from_uri(uri)

View File

@@ -69,7 +69,7 @@ class _Events:
self.throttled_event_names = ["search_table"] self.throttled_event_names = ["search_table"]
self.throttled_events = set() self.throttled_events = set()
self.max_events = 5 # max events to store in memory self.max_events = 5 # max events to store in memory
self.rate_limit = 60.0 * 60.0 # rate limit (seconds) self.rate_limit = 60.0 * 5 # rate limit (seconds)
self.time = 0.0 self.time = 0.0
if is_git_dir(): if is_git_dir():

View File

@@ -11,9 +11,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
from datetime import timedelta
import lancedb import lancedb
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@@ -185,10 +182,6 @@ async def test_table_names_async(tmp_path):
db = await lancedb.connect_async(tmp_path) db = await lancedb.connect_async(tmp_path)
assert await db.table_names() == ["test1", "test2", "test3"] assert await db.table_names() == ["test1", "test2", "test3"]
assert await db.table_names(limit=1) == ["test1"]
assert await db.table_names(start_after="test1", limit=1) == ["test2"]
assert await db.table_names(start_after="test1") == ["test2", "test3"]
def test_create_mode(tmp_path): def test_create_mode(tmp_path):
db = lancedb.connect(tmp_path) db = lancedb.connect(tmp_path)
@@ -257,28 +250,6 @@ def test_create_exist_ok(tmp_path):
db.create_table("test", schema=bad_schema, exist_ok=True) db.create_table("test", schema=bad_schema, exist_ok=True)
@pytest.mark.asyncio
async def test_connect(tmp_path):
db = await lancedb.connect_async(tmp_path)
assert str(db) == f"NativeDatabase(uri={tmp_path}, read_consistency_interval=None)"
db = await lancedb.connect_async(
tmp_path, read_consistency_interval=timedelta(seconds=5)
)
assert str(db) == f"NativeDatabase(uri={tmp_path}, read_consistency_interval=5s)"
@pytest.mark.asyncio
async def test_close(tmp_path):
db = await lancedb.connect_async(tmp_path)
assert db.is_open()
db.close()
assert not db.is_open()
with pytest.raises(RuntimeError, match="is closed"):
await db.table_names()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_mode_async(tmp_path): async def test_create_mode_async(tmp_path):
db = await lancedb.connect_async(tmp_path) db = await lancedb.connect_async(tmp_path)
@@ -351,39 +322,6 @@ async def test_create_exist_ok_async(tmp_path):
# await db.create_table("test", schema=bad_schema, exist_ok=True) # await db.create_table("test", schema=bad_schema, exist_ok=True)
@pytest.mark.asyncio
async def test_open_table(tmp_path):
db = await lancedb.connect_async(tmp_path)
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["foo", "bar"],
"price": [10.0, 20.0],
}
)
await db.create_table("test", data=data)
tbl = await db.open_table("test")
assert tbl.name == "test"
assert (
re.search(
r"NativeTable\(test, uri=.*test\.lance, read_consistency_interval=None\)",
str(tbl),
)
is not None
)
assert await tbl.schema() == pa.schema(
{
"vector": pa.list_(pa.float32(), list_size=2),
"item": pa.utf8(),
"price": pa.float64(),
}
)
with pytest.raises(ValueError, match="was not found"):
await db.open_table("does_not_exist")
def test_delete_table(tmp_path): def test_delete_table(tmp_path):
db = lancedb.connect(tmp_path) db = lancedb.connect(tmp_path)
data = pd.DataFrame( data = pd.DataFrame(

View File

@@ -11,7 +11,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys import sys
from typing import List, Union
import lance import lance
import lancedb import lancedb
@@ -24,8 +23,6 @@ from lancedb.embeddings import (
EmbeddingFunctionRegistry, EmbeddingFunctionRegistry,
with_embeddings, with_embeddings,
) )
from lancedb.embeddings.base import TextEmbeddingFunction
from lancedb.embeddings.registry import get_registry, register
from lancedb.pydantic import LanceModel, Vector from lancedb.pydantic import LanceModel, Vector
@@ -115,34 +112,3 @@ def test_embedding_function_rate_limit(tmp_path):
table.add([{"text": "hello world"}]) table.add([{"text": "hello world"}])
table.add([{"text": "hello world"}]) table.add([{"text": "hello world"}])
assert len(table) == 2 assert len(table) == 2
def test_add_optional_vector(tmp_path):
@register("mock-embedding")
class MockEmbeddingFunction(TextEmbeddingFunction):
def ndims(self):
return 128
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
"""
Generate the embeddings for the given texts
"""
return [np.random.randn(self.ndims()).tolist() for _ in range(len(texts))]
registry = get_registry()
model = registry.get("mock-embedding").create()
class LanceSchema(LanceModel):
id: str
vector: Vector(model.ndims()) = model.VectorField(default=None)
text: str = model.SourceField()
db = lancedb.connect(tmp_path)
tbl = db.create_table("optional_vector", schema=LanceSchema)
# add works
expected = LanceSchema(id="id", text="text")
tbl.add([expected])
assert not (np.abs(tbl.to_pandas()["vector"][0]) < 1e-6).all()

View File

@@ -43,7 +43,6 @@ def table(tmp_path) -> ldb.table.LanceTable:
) )
for _ in range(100) for _ in range(100)
] ]
count = [random.randint(1, 10000) for _ in range(100)]
table = db.create_table( table = db.create_table(
"test", "test",
data=pd.DataFrame( data=pd.DataFrame(
@@ -53,7 +52,6 @@ def table(tmp_path) -> ldb.table.LanceTable:
"text": text, "text": text,
"text2": text, "text2": text,
"nested": [{"text": t} for t in text], "nested": [{"text": t} for t in text],
"count": count,
} }
), ),
) )
@@ -81,39 +79,6 @@ def test_search_index(tmp_path, table):
assert len(results[1]) == 10 # _distance assert len(results[1]) == 10 # _distance
def test_search_ordering_field_index_table(tmp_path, table):
table.create_fts_index("text", ordering_field_names=["count"])
rows = (
table.search("puppy", ordering_field_name="count")
.limit(20)
.select(["text", "count"])
.to_list()
)
for r in rows:
assert "puppy" in r["text"]
assert sorted(rows, key=lambda x: x["count"], reverse=True) == rows
def test_search_ordering_field_index(tmp_path, table):
index = ldb.fts.create_index(
str(tmp_path / "index"), ["text"], ordering_fields=["count"]
)
ldb.fts.populate_index(index, table, ["text"], ordering_fields=["count"])
index.reload()
results = ldb.fts.search_index(
index, query="puppy", limit=10, ordering_field="count"
)
assert len(results) == 2
assert len(results[0]) == 10 # row_ids
assert len(results[1]) == 10 # _distance
rows = table.to_lance().take(results[0]).to_pylist()
for r in rows:
assert "puppy" in r["text"]
assert sorted(rows, key=lambda x: x["count"], reverse=True) == rows
def test_create_index_from_table(tmp_path, table): def test_create_index_from_table(tmp_path, table):
table.create_fts_index("text") table.create_fts_index("text")
df = table.search("puppy").limit(10).select(["text"]).to_pandas() df = table.search("puppy").limit(10).select(["text"]).to_pandas()
@@ -129,7 +94,6 @@ def test_create_index_from_table(tmp_path, table):
"text": "gorilla", "text": "gorilla",
"text2": "gorilla", "text2": "gorilla",
"nested": {"text": "gorilla"}, "nested": {"text": "gorilla"},
"count": 10,
} }
] ]
) )
@@ -173,11 +137,7 @@ def test_search_index_with_filter(table):
# no duckdb # no duckdb
with mock.patch("builtins.__import__", side_effect=import_mock): with mock.patch("builtins.__import__", side_effect=import_mock):
rs = table.search("puppy").where("id=1").limit(10) rs = table.search("puppy").where("id=1").limit(10).to_list()
# test schema
assert rs.to_arrow().drop("score").schema.equals(table.schema)
rs = rs.to_list()
for r in rs: for r in rs:
assert r["id"] == 1 assert r["id"] == 1
@@ -187,10 +147,6 @@ def test_search_index_with_filter(table):
assert r["id"] == 1 assert r["id"] == 1
assert rs == rs2 assert rs == rs2
rs = table.search("puppy").where("id=1").with_row_id(True).limit(10).to_list()
for r in rs:
assert r["id"] == 1
assert r["_rowid"] is not None
def test_null_input(table): def test_null_input(table):
@@ -202,7 +158,6 @@ def test_null_input(table):
"text": None, "text": None,
"text2": None, "text2": None,
"nested": {"text": None}, "nested": {"text": None},
"count": 7,
} }
] ]
) )
@@ -214,18 +169,10 @@ def test_syntax(table):
table.create_fts_index("text") table.create_fts_index("text")
with pytest.raises(ValueError, match="Syntax Error"): with pytest.raises(ValueError, match="Syntax Error"):
table.search("they could have been dogs OR cats").limit(10).to_list() table.search("they could have been dogs OR cats").limit(10).to_list()
# these should work
# terms queries
table.search('"they could have been dogs" OR cats').limit(10).to_list()
table.search("(they AND could) OR (have AND been AND dogs) OR cats").limit(
10
).to_list()
# phrase queries
table.search("they could have been dogs OR cats").phrase_query().limit(10).to_list() table.search("they could have been dogs OR cats").phrase_query().limit(10).to_list()
# this should work
table.search('"they could have been dogs OR cats"').limit(10).to_list() table.search('"they could have been dogs OR cats"').limit(10).to_list()
# this should work too
table.search('''"the cats OR dogs were not really 'pets' at all"''').limit( table.search('''"the cats OR dogs were not really 'pets' at all"''').limit(
10 10
).to_list() ).to_list()

View File

@@ -1,126 +0,0 @@
# Copyright 2024 Lance Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import lancedb
import numpy as np
import pyarrow as pa
import pytest
from lancedb.embeddings import get_registry
from lancedb.embeddings.base import TextEmbeddingFunction
from lancedb.embeddings.registry import register
from lancedb.pydantic import LanceModel, Vector
datasets = pytest.importorskip("datasets")
@pytest.fixture(scope="session")
def mock_embedding_function():
@register("random")
class MockTextEmbeddingFunction(TextEmbeddingFunction):
def generate_embeddings(self, texts):
return [np.random.randn(128).tolist() for _ in range(len(texts))]
def ndims(self):
return 128
@pytest.fixture
def mock_hf_dataset():
# Create pyarrow table with `text` and `label` columns
train = datasets.Dataset(
pa.table(
{
"text": ["foo", "bar"],
"label": [0, 1],
}
),
split="train",
)
test = datasets.Dataset(
pa.table(
{
"text": ["fizz", "buzz"],
"label": [0, 1],
}
),
split="test",
)
return datasets.DatasetDict({"train": train, "test": test})
@pytest.fixture
def hf_dataset_with_split():
# Create pyarrow table with `text` and `label` columns
train = datasets.Dataset(
pa.table(
{"text": ["foo", "bar"], "label": [0, 1], "split": ["train", "train"]}
),
split="train",
)
test = datasets.Dataset(
pa.table(
{"text": ["fizz", "buzz"], "label": [0, 1], "split": ["test", "test"]}
),
split="test",
)
return datasets.DatasetDict({"train": train, "test": test})
def test_write_hf_dataset(tmp_path: Path, mock_embedding_function, mock_hf_dataset):
db = lancedb.connect(tmp_path)
emb = get_registry().get("random").create()
class Schema(LanceModel):
text: str = emb.SourceField()
label: int
vector: Vector(emb.ndims()) = emb.VectorField()
train_table = db.create_table("train", schema=Schema)
train_table.add(mock_hf_dataset["train"])
class WithSplit(LanceModel):
text: str = emb.SourceField()
label: int
vector: Vector(emb.ndims()) = emb.VectorField()
split: str
full_table = db.create_table("full", schema=WithSplit)
full_table.add(mock_hf_dataset)
assert len(train_table) == mock_hf_dataset["train"].num_rows
assert len(full_table) == sum(ds.num_rows for ds in mock_hf_dataset.values())
rt_train_table = full_table.to_lance().to_table(
columns=["text", "label"], filter="split='train'"
)
assert rt_train_table.to_pylist() == mock_hf_dataset["train"].data.to_pylist()
def test_bad_hf_dataset(tmp_path: Path, mock_embedding_function, hf_dataset_with_split):
db = lancedb.connect(tmp_path)
emb = get_registry().get("random").create()
class Schema(LanceModel):
text: str = emb.SourceField()
label: int
vector: Vector(emb.ndims()) = emb.VectorField()
split: str
train_table = db.create_table("train", schema=Schema)
# this should still work because we don't add the split column
# if it already exists
train_table.add(hf_dataset_with_split)

View File

@@ -1,69 +0,0 @@
from datetime import timedelta
import pyarrow as pa
import pytest
import pytest_asyncio
from lancedb import AsyncConnection, AsyncTable, connect_async
from lancedb.index import BTree, IvfPq
@pytest_asyncio.fixture
async def db_async(tmp_path) -> AsyncConnection:
return await connect_async(tmp_path, read_consistency_interval=timedelta(seconds=0))
def sample_fixed_size_list_array(nrows, dim):
vector_data = pa.array([float(i) for i in range(dim * nrows)], pa.float32())
return pa.FixedSizeListArray.from_arrays(vector_data, dim)
DIM = 8
NROWS = 256
@pytest_asyncio.fixture
async def some_table(db_async):
data = pa.Table.from_pydict(
{
"id": list(range(256)),
"vector": sample_fixed_size_list_array(NROWS, DIM),
}
)
return await db_async.create_table(
"some_table",
data,
)
@pytest.mark.asyncio
async def test_create_scalar_index(some_table: AsyncTable):
# Can create
await some_table.create_index("id")
# Can recreate if replace=True
await some_table.create_index("id", replace=True)
indices = await some_table.list_indices()
assert len(indices) == 1
assert indices[0].index_type == "BTree"
assert indices[0].columns == ["id"]
# Can't recreate if replace=False
with pytest.raises(RuntimeError, match="already exists"):
await some_table.create_index("id", replace=False)
# can also specify index type
await some_table.create_index("id", config=BTree())
@pytest.mark.asyncio
async def test_create_vector_index(some_table: AsyncTable):
# Can create
await some_table.create_index("vector")
# Can recreate if replace=True
await some_table.create_index("vector", replace=True)
# Can't recreate if replace=False
with pytest.raises(RuntimeError, match="already exists"):
await some_table.create_index("vector", replace=False)
# Can also specify index type
await some_table.create_index("vector", config=IvfPq(num_partitions=100))
indices = await some_table.list_indices()
assert len(indices) == 1
assert indices[0].index_type == "IvfPq"
assert indices[0].columns == ["vector"]

View File

@@ -16,35 +16,16 @@ import os
import lancedb import lancedb
import pytest import pytest
# AWS:
# You need to setup AWS credentials an a base path to run this test. Example # You need to setup AWS credentials an a base path to run this test. Example
# AWS_PROFILE=default TEST_S3_BASE_URL=s3://my_bucket/dataset pytest tests/test_io.py # AWS_PROFILE=default TEST_S3_BASE_URL=s3://my_bucket/dataset pytest tests/test_io.py
#
# Azure:
# You need to setup Azure credentials an a base path to run this test. Example
# export AZURE_STORAGE_ACCOUNT_NAME="<account>"
# export AZURE_STORAGE_ACCOUNT_KEY="<key>"
# export REMOTE_BASE_URL=az://my_blob/dataset
# pytest tests/test_io.py
@pytest.fixture(autouse=True, scope="module")
def setup():
yield
if remote_url := os.environ.get("REMOTE_BASE_URL"):
db = lancedb.connect(remote_url)
for table in db.table_names():
db.drop_table(table)
@pytest.mark.skipif( @pytest.mark.skipif(
(os.environ.get("REMOTE_BASE_URL") is None), (os.environ.get("TEST_S3_BASE_URL") is None),
reason="please setup remote base url", reason="please setup s3 base url",
) )
def test_remote_io(): def test_s3_io():
db = lancedb.connect(os.environ.get("REMOTE_BASE_URL")) db = lancedb.connect(os.environ.get("TEST_S3_BASE_URL"))
assert db.table_names() == [] assert db.table_names() == []
table = db.create_table( table = db.create_table(

View File

@@ -12,20 +12,16 @@
# limitations under the License. # limitations under the License.
import unittest.mock as mock import unittest.mock as mock
from datetime import timedelta
from typing import Optional
import lance import lance
import lancedb
import numpy as np import numpy as np
import pandas.testing as tm import pandas.testing as tm
import pyarrow as pa import pyarrow as pa
import pytest import pytest
import pytest_asyncio
from lancedb.db import LanceDBConnection from lancedb.db import LanceDBConnection
from lancedb.pydantic import LanceModel, Vector from lancedb.pydantic import LanceModel, Vector
from lancedb.query import AsyncQueryBase, LanceVectorQueryBuilder, Query from lancedb.query import LanceVectorQueryBuilder, Query
from lancedb.table import AsyncTable, LanceTable from lancedb.table import LanceTable
class MockTable: class MockTable:
@@ -36,9 +32,9 @@ class MockTable:
def to_lance(self): def to_lance(self):
return lance.dataset(self.uri) return lance.dataset(self.uri)
def _execute_query(self, query, batch_size: Optional[int] = None): def _execute_query(self, query):
ds = self.to_lance() ds = self.to_lance()
return ds.scanner( return ds.to_table(
columns=query.columns, columns=query.columns,
filter=query.filter, filter=query.filter,
prefilter=query.prefilter, prefilter=query.prefilter,
@@ -50,8 +46,7 @@ class MockTable:
"nprobes": query.nprobes, "nprobes": query.nprobes,
"refine_factor": query.refine_factor, "refine_factor": query.refine_factor,
}, },
batch_size=batch_size, )
).to_reader()
@pytest.fixture @pytest.fixture
@@ -70,24 +65,6 @@ def table(tmp_path) -> MockTable:
return MockTable(tmp_path) return MockTable(tmp_path)
@pytest_asyncio.fixture
async def table_async(tmp_path) -> AsyncTable:
conn = await lancedb.connect_async(
tmp_path, read_consistency_interval=timedelta(seconds=0)
)
data = pa.table(
{
"vector": pa.array(
[[1, 2], [3, 4]], type=pa.list_(pa.float32(), list_size=2)
),
"id": pa.array([1, 2]),
"str_field": pa.array(["a", "b"]),
"float_field": pa.array([1.0, 2.0]),
}
)
return await conn.create_table("test", data)
def test_cast(table): def test_cast(table):
class TestModel(LanceModel): class TestModel(LanceModel):
vector: Vector(2) vector: Vector(2)
@@ -117,25 +94,6 @@ def test_query_builder(table):
assert all(np.array(rs[0]["vector"]) == [1, 2]) assert all(np.array(rs[0]["vector"]) == [1, 2])
def test_query_builder_batches(table):
rs = (
LanceVectorQueryBuilder(table, [0, 0], "vector")
.limit(2)
.select(["id", "vector"])
.to_batches(1)
)
rs_list = []
for item in rs:
rs_list.append(item)
assert isinstance(item, pa.RecordBatch)
assert len(rs_list) == 1
assert len(rs_list[0]["id"]) == 2
assert all(rs_list[0].to_pandas()["vector"][0] == [1.0, 2.0])
assert rs_list[0].to_pandas()["id"][0] == 1
assert all(rs_list[0].to_pandas()["vector"][1] == [3.0, 4.0])
assert rs_list[0].to_pandas()["id"][1] == 2
def test_dynamic_projection(table): def test_dynamic_projection(table):
rs = ( rs = (
LanceVectorQueryBuilder(table, [0, 0], "vector") LanceVectorQueryBuilder(table, [0, 0], "vector")
@@ -220,116 +178,9 @@ def test_query_builder_with_different_vector_column():
nprobes=20, nprobes=20,
refine_factor=None, refine_factor=None,
vector_column="foo_vector", vector_column="foo_vector",
), )
None,
) )
def cosine_distance(vec1, vec2): def cosine_distance(vec1, vec2):
return 1 - np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) return 1 - np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
async def check_query(
query: AsyncQueryBase, *, expected_num_rows=None, expected_columns=None
):
num_rows = 0
results = await query.to_batches()
async for batch in results:
if expected_columns is not None:
assert batch.schema.names == expected_columns
num_rows += batch.num_rows
if expected_num_rows is not None:
assert num_rows == expected_num_rows
@pytest.mark.asyncio
async def test_query_async(table_async: AsyncTable):
await check_query(
table_async.query(),
expected_num_rows=2,
expected_columns=["vector", "id", "str_field", "float_field"],
)
await check_query(table_async.query().where("id = 2"), expected_num_rows=1)
await check_query(
table_async.query().select(["id", "vector"]), expected_columns=["id", "vector"]
)
await check_query(
table_async.query().select({"foo": "id", "bar": "id + 1"}),
expected_columns=["foo", "bar"],
)
await check_query(table_async.query().limit(1), expected_num_rows=1)
await check_query(
table_async.query().nearest_to(pa.array([1, 2])), expected_num_rows=2
)
# Support different types of inputs for the vector query
for vector_query in [
[1, 2],
[1.0, 2.0],
np.array([1, 2]),
(1, 2),
]:
await check_query(
table_async.query().nearest_to(vector_query), expected_num_rows=2
)
# No easy way to check these vector query parameters are doing what they say. We
# just check that they don't raise exceptions and assume this is tested at a lower
# level.
await check_query(
table_async.query().where("id = 2").nearest_to(pa.array([1, 2])).postfilter(),
expected_num_rows=1,
)
await check_query(
table_async.query().nearest_to(pa.array([1, 2])).refine_factor(1),
expected_num_rows=2,
)
await check_query(
table_async.query().nearest_to(pa.array([1, 2])).nprobes(10),
expected_num_rows=2,
)
await check_query(
table_async.query().nearest_to(pa.array([1, 2])).bypass_vector_index(),
expected_num_rows=2,
)
await check_query(
table_async.query().nearest_to(pa.array([1, 2])).distance_type("dot"),
expected_num_rows=2,
)
await check_query(
table_async.query().nearest_to(pa.array([1, 2])).distance_type("DoT"),
expected_num_rows=2,
)
# Make sure we can use a vector query as a base query (e.g. call limit on it)
# Also make sure `vector_search` works
await check_query(table_async.vector_search([1, 2]).limit(1), expected_num_rows=1)
# Also check an empty query
await check_query(table_async.query().where("id < 0"), expected_num_rows=0)
@pytest.mark.asyncio
async def test_query_to_arrow_async(table_async: AsyncTable):
table = await table_async.to_arrow()
assert table.num_rows == 2
assert table.num_columns == 4
table = await table_async.query().to_arrow()
assert table.num_rows == 2
assert table.num_columns == 4
table = await table_async.query().where("id < 0").to_arrow()
assert table.num_rows == 0
assert table.num_columns == 4
@pytest.mark.asyncio
async def test_query_to_pandas_async(table_async: AsyncTable):
df = await table_async.to_pandas()
assert df.shape == (2, 4)
df = await table_async.query().to_pandas()
assert df.shape == (2, 4)
df = await table_async.query().where("id < 0").to_pandas()
assert df.shape == (0, 4)

Some files were not shown because too many files have changed in this diff Show More