Compare commits

..

77 Commits

Author SHA1 Message Date
Chang She
fbd0bc7740 bump version for v0.1.5-python 2023-06-02 09:18:26 -07:00
gsilvestrin
f765a453cf Use fsspec to implement table_names with cloud storage support (#117)
Co-authored-by: Will Jones <willjones127@gmail.com>
2023-06-01 16:56:26 -07:00
gsilvestrin
45b3a14f26 Bumping vectordb to v0.1.3 (#124) 2023-06-01 16:36:11 -07:00
Lei Xu
9965b4564d [Python] Support drop table (#123)
Closes #86
2023-06-01 15:58:45 -07:00
gsilvestrin
0719e4b3fb Revert "refactor: pull node binaries into separate packages (#88)" (#122)
This reverts commit e50b642d80.
2023-06-01 13:53:07 -07:00
Jai
091fb9b665 add existence check (#112) 2023-06-01 11:45:26 -07:00
Chang She
03013a4434 Multimodal search demo (#118)
Slow roasted over 12 hours, Pairs well with #111

---------

Co-authored-by: Chang She <chang@lancedb.com>
2023-06-01 10:34:08 -07:00
gsilvestrin
3e14b357e7 add openai embedding function to nodejs client (#107)
- openai is an optional dependency for lancedb
- added an example to show how to use it
2023-06-01 10:25:00 -07:00
Lei Xu
99cbda8b07 Generate diffusiondb embeddings (#111) 2023-06-01 10:23:29 -07:00
Will Jones
e50b642d80 refactor: pull node binaries into separate packages (#88)
Changes:

* Refactors the Node module to load the shared library from a separate
package. When a user does `npm install vectordb`, the correct optional
dependency is automatically downloaded by npm.
* Brings Rust and Node versions in alignment at 0.1.2.
* Add scripts and instructions to build Linux and MacOS node artifacts
locally.
* Add instructions for publishing the npm module and crates.
2023-06-01 09:17:19 -07:00
gsilvestrin
6d8cf52e01 Better error granularity for table operations (#113) 2023-06-01 09:04:42 -07:00
Akash
53f3882d6e Fixed documentation link for the Youtube Transcripts Jupyter Notebook (#105)
Changed the link to the Youtube Transcripts jupyter notebook path on the
documentation.

Previously it went inside docs/notebooks (which does not exist). I've
modified it to go inside the notebooks folder instead.
2023-06-01 09:00:40 -07:00
Chang She
2b26775ed1 python v0.1.4 2023-05-31 20:11:25 -07:00
Lei Xu
306ada5cb8 Support S3 and GCS from typescript SDK (#106) 2023-05-30 21:32:17 -07:00
gsilvestrin
d3aa8bfbc5 add embedding functions to the nodejs client (#95) 2023-05-26 18:09:20 -07:00
Chang She
04d97347d7 move tantivy-py installation to be separate from wheel (#97)
pypi does not allow packages to be uploaded that has a direct reference

for now we'll just ask the user to install tantivy separately

---------

Co-authored-by: Chang She <chang@lancedb.com>
2023-05-25 17:57:26 -06:00
Chang She
22aa8a93c2 bump version for v0.1.3 2023-05-25 17:01:52 -06:00
Chang She
f485378ea4 Basic full text search capabilities (#62)
This is v1 of integrating full text search index into LanceDB.

# API
The query API is roughly the same as before, except if the input is text
instead of a vector we assume that its fts search.

## Example
If `table` is a LanceDB LanceTable, then:

Build index: `table.create_fts_index("text")`

Query: `df = table.search("puppy").limit(10).select(["text"]).to_df()`

# Implementation
Here we use the tantivy-py package to build the index. We then use the
row id's as the full-text-search index's doc id then we just do a Take
operation to fetch the rows.

# Limitations

1. don't support incremental row appends yet. New data won't show up in
search
2. local filesystem only 
3. requires building tantivy explicitly

---------

Co-authored-by: Chang She <chang@lancedb.com>
2023-05-24 22:25:31 -06:00
gsilvestrin
f923cfe47f add create index to nodejs client (#89) 2023-05-24 16:45:58 -06:00
gsilvestrin
06cb7b6458 add query params to to nodejs client (#87) 2023-05-24 15:48:31 -06:00
gsilvestrin
bdef634954 bugfix: string columns should be converted to Utf8Array (#94) 2023-05-23 14:58:49 -07:00
Will Jones
aac2ffa4b3 Lint and test vectordb node in CI (#92)
Closes #90.
2023-05-22 14:26:06 -07:00
gsilvestrin
e28fe7b468 nodejs append records api (#85) 2023-05-18 15:13:57 -07:00
gsilvestrin
61b9479bd9 JavaScript client initial linux support (#84) 2023-05-16 17:04:06 -07:00
gsilvestrin
961d892c89 Added TypeScript example (#82) 2023-05-16 13:40:52 -07:00
Jai
0b35e6dfa9 node quickstart (#83) 2023-05-16 09:53:04 -07:00
Jai
ca96fc55f6 add link to node quickstart to readme (#81) 2023-05-16 09:24:12 -07:00
gsilvestrin
395c7460d5 nodejs create_table (#75) 2023-05-15 19:00:17 -07:00
Jai
92d810eac4 docs build (#78) 2023-05-14 10:18:28 -07:00
Jai
a55a579b7f nodejs read only example (#77) 2023-05-12 15:50:59 -07:00
gsilvestrin
202924f832 updated node example (#74) 2023-05-11 12:55:02 -07:00
gsilvestrin
648f8123ca Exposing limit parameter (#73) 2023-05-11 09:12:06 -07:00
gsilvestrin
5bb5b0a685 javascript example improvements (#72) 2023-05-10 22:06:44 -07:00
gsilvestrin
c2e73262ef bump version and skipping building the native lib during install (#71) 2023-05-10 15:10:46 -07:00
gsilvestrin
f5bf6181e3 Merge pull request #70 from lancedb/gsilvestrin/nodejs_client-merge
JavaScript / Node.js library for LanceDB
2023-05-10 13:44:52 -07:00
gsilvestrin
c2dc1da509 Removing sample db 2023-05-10 13:40:17 -07:00
gsilvestrin
38e6efc185 JavaScript / Node.js library for LanceDB
- Core rust library
- ffi bridge that exposes rust functionality to javascript
- npm package that provides a TypeScript / JavaScript library
- limitations: it only supports reading for now
2023-05-10 12:51:49 -07:00
Chang She
636a6d3761 Merge pull request #65 from lancedb/jaichopra/add-youtube-transcript-example 2023-05-08 17:45:35 -07:00
Jai Chopra
2a855c9f6a update image url 2023-05-08 17:39:52 -07:00
Jai Chopra
5c47b0c6a2 add youtube transcript example 2023-05-08 17:38:08 -07:00
Jai
d12bc24091 Merge pull request #63 from lancedb/jaichopra/update-readme-ecosystem
update ecosystem in readme
2023-05-07 09:12:25 -07:00
Jai Chopra
c4261b23e6 update blog url 2023-05-07 08:18:24 -07:00
Jai Chopra
ab0abbbfab update ecosystem in readme 2023-05-07 08:17:02 -07:00
Chang She
13c9a2e1c9 Merge pull request #61 from lancedb/jaichopra/langchain-example-doc
add langchain example
2023-05-05 16:06:40 -07:00
Jai Chopra
7e3db16225 add langchain example 2023-05-05 16:00:14 -07:00
Jai
62abe2d96f Merge pull request #57 from lancedb/jaichopra/s3-lambda-docs
S3 Lambda example
2023-05-05 14:08:24 -07:00
Chang She
59014a01e0 bump version for v0.1.2 2023-05-05 11:27:09 -07:00
Jai Chopra
11f423ccf5 clean up 2023-05-04 17:21:53 -07:00
Chang She
47ae17ea05 Merge pull request #58 from lancedb/changhiskhan/parse-schema
Add method to get the URI scheme to support cloud storage
2023-05-04 14:36:45 -07:00
Chang She
b6739f3f66 windows paths 2023-05-04 11:41:05 -07:00
Jai Chopra
6ff3c60cd1 clean up example 2023-05-04 10:14:31 -07:00
Chang She
3a2df0ce45 Add method to get the URI scheme to support cloud storage 2023-05-04 09:47:03 -07:00
Jai Chopra
6556e42e6d update lambda example to lancedb 2023-05-04 08:17:13 -07:00
Jai Chopra
c3d90b2c78 update tagline 2023-05-04 08:17:13 -07:00
Jai Chopra
66f7d5cec9 also update docs index 2023-05-04 08:17:13 -07:00
Jai Chopra
4336ed050d add new feature to readme.md 2023-05-04 08:17:13 -07:00
Lei Xu
976344257c add cargo metadata 2023-05-04 08:17:13 -07:00
Lei Xu
906551b001 initialize the rust core 2023-05-04 08:17:13 -07:00
Chang She
33ac42a51c bump version for v0.1.1 2023-05-04 08:17:13 -07:00
Chang She
c0bc65cdfa Merge pull request #55 from lancedb/jaichopra/update-tagline
update tagline
2023-05-03 21:06:41 -07:00
Jai Chopra
298b81f0b0 update tagline 2023-05-03 19:55:10 -07:00
Jai
fe7a3ccd60 Merge pull request #53 from lancedb/jaichopra/update-major-features-readme
also update docs index
2023-05-03 07:51:54 -07:00
Jai Chopra
baf8d7c1a1 also update docs index 2023-05-03 07:50:44 -07:00
Chang She
2021e1bf6d Merge pull request #52 from lancedb/jaichopra/update-major-features-readme 2023-05-03 07:36:09 -07:00
Jai Chopra
2dbe71cf88 add new feature to readme.md 2023-05-03 07:30:46 -07:00
Jai
7cd36196b4 Update langchain.md 2023-04-27 11:08:29 -07:00
Lei Xu
afe19ade7f Merge pull request #49 from lancedb/lei/rust_core
Rust core directory
2023-04-27 10:40:21 -07:00
Lei Xu
118efdce73 add cargo metadata 2023-04-27 10:36:01 -07:00
Lei Xu
b0426387e7 initialize the rust core 2023-04-27 10:31:50 -07:00
Jai
87fb4d0645 Update langchain.md 2023-04-27 07:13:18 -07:00
Jai
c930b94917 Update s3_lambda.md 2023-04-27 07:12:52 -07:00
Jai
aa23d911f5 Update langchain.md 2023-04-26 14:50:09 -07:00
Jai Chopra
ca8d8e82b7 add simple langchain example 2023-04-26 14:44:20 -07:00
Jai
3d3ba913ed Update s3_lambda.md 2023-04-26 10:19:27 -07:00
Jai
0346d5319e Update s3_lambda.md 2023-04-26 10:18:47 -07:00
Jai
41eadf6fd9 Update s3_lambda.md 2023-04-26 10:18:31 -07:00
Jai Chopra
e784c6311d tree github build script from remote 2023-04-25 21:40:28 -07:00
71 changed files with 15581 additions and 50 deletions

View File

@@ -40,9 +40,8 @@ jobs:
python -m pip install -e .
python -m pip install -r ../docs/requirements.txt
- name: Build docs
working-directory: docs
run: |
mkdocs build
PYTHONPATH=. mkdocs build -f docs/mkdocs.yml
- name: Setup Pages
uses: actions/configure-pages@v2
- name: Upload artifact

101
.github/workflows/node.yml vendored Normal file
View File

@@ -0,0 +1,101 @@
name: Node
on:
push:
branches:
- main
pull_request:
paths:
- node/**
- rust/ffi/node/**
- .github/workflows/node.yml
env:
# Disable full debug symbol generation to speed up CI build and keep memory down
# "1" means line tables only, which is useful for panic tracebacks.
RUSTFLAGS: "-C debuginfo=1"
RUST_BACKTRACE: "1"
jobs:
lint:
name: Lint
runs-on: ubuntu-22.04
defaults:
run:
shell: bash
working-directory: node
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
lfs: true
- uses: actions/setup-node@v3
with:
node-version: 18
cache: 'npm'
cache-dependency-path: node/package-lock.json
- name: Lint
run: |
npm ci
npm run lint
linux:
name: Linux (Node ${{ matrix.node-version }})
timeout-minutes: 30
strategy:
matrix:
node-version: [ "16", "18" ]
runs-on: "ubuntu-22.04"
defaults:
run:
shell: bash
working-directory: node
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
lfs: true
- uses: actions/setup-node@v3
with:
node-version: ${{ matrix.node-version }}
cache: 'npm'
cache-dependency-path: node/package-lock.json
- uses: Swatinem/rust-cache@v2
- name: Install dependencies
run: |
sudo apt update
sudo apt install -y protobuf-compiler libssl-dev
- name: Build
run: |
npm ci
npm run build
npm run tsc
- name: Test
run: npm run test
macos:
timeout-minutes: 30
runs-on: "macos-13"
defaults:
run:
shell: bash
working-directory: node
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
lfs: true
- uses: actions/setup-node@v3
with:
node-version: 18
cache: 'npm'
cache-dependency-path: node/package-lock.json
- uses: Swatinem/rust-cache@v2
- name: Install dependencies
run: brew install protobuf
- name: Build
run: |
npm ci
npm run build
npm run tsc
- name: Test
run: |
npm run test

View File

@@ -31,6 +31,7 @@ jobs:
- name: Install lancedb
run: |
pip install -e .
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
pip install pytest
- name: Run tests
run: pytest -x -v --durations=30 tests
@@ -49,10 +50,11 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.11"
- name: Install lancedb
run: |
pip install -e .
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
pip install pytest
- name: Run tests
run: pytest -x -v --durations=30 tests

15
.gitignore vendored
View File

@@ -2,6 +2,7 @@
**/*.whl
*.egg-info
**/__pycache__
.DS_Store
rust/target
rust/Cargo.lock
@@ -15,3 +16,17 @@ python/build
python/dist
notebooks/.ipynb_checkpoints
**/.hypothesis
## Javascript
*.node
**/node_modules
**/.DS_Store
node/dist
node/examples/**/package-lock.json
node/examples/**/dist
## Rust
target

3797
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

6
Cargo.toml Normal file
View File

@@ -0,0 +1,6 @@
[workspace]
members = [
"rust/vectordb",
"rust/ffi/node"
]
resolver = "2"

View File

@@ -3,10 +3,10 @@
<img width="275" alt="LanceDB Logo" src="https://user-images.githubusercontent.com/917119/226205734-6063d87a-1ecc-45fe-85be-1dea6383a3d8.png">
**Serverless, low-latency vector database for AI applications**
**Developer-friendly, serverless vector database for AI applications**
<a href="https://lancedb.github.io/lancedb/">Documentation</a>
<a href="https://blog.eto.ai/">Blog</a>
<a href="https://blog.lancedb.com/">Blog</a>
<a href="https://discord.gg/zMM32dvNtd">Discord</a>
<a href="https://twitter.com/lancedb">Twitter</a>
@@ -21,23 +21,41 @@ The key features of LanceDB include:
* Production-scale vector search with no servers to manage.
* Combine attribute-based information with vectors and store them as a single source-of-truth.
* Store, query and filter vectors, metadata and multi-modal data (text, images, videos, point clouds, and more).
* Native Python and Javascript/Typescript support.
* Zero-copy, automatic versioning, manage versions of your data without needing extra infrastructure.
* Ecosystem integrations: Apache-Arrow, Pandas, Polars, DuckDB and more on the way.
* Ecosystem integrations with [LangChain 🦜️🔗](https://python.langchain.com/en/latest/modules/indexes/vectorstores/examples/lanecdb.html), [LlamaIndex 🦙](https://gpt-index.readthedocs.io/en/latest/examples/vector_stores/LanceDBIndexDemo.html), Apache-Arrow, Pandas, Polars, DuckDB and more on the way.
LanceDB's core is written in Rust 🦀 and is built using <a href="https://github.com/eto-ai/lance">Lance</a>, an open-source columnar format designed for performant ML workloads.
## Quick Start
**Installation**
**Javascript**
```shell
npm install vectordb
```
```javascript
const lancedb = require('vectordb');
const db = await lancedb.connect('data/sample-lancedb');
const table = await db.createTable('vectors',
[{ id: 1, vector: [0.1, 0.2], item: "foo", price: 10 },
{ id: 2, vector: [1.1, 1.2], item: "bar", price: 50 }])
const query = table.search([0.1, 0.3]);
query.limit = 20;
const results = await query.execute();
```
**Python**
```shell
pip install lancedb
```
**Quickstart**
```python
import lancedb

View File

@@ -8,7 +8,10 @@ theme:
plugins:
- search
- mkdocstrings
- mkdocstrings:
handlers:
python:
paths: [../python]
- mkdocs-jupyter
nav:
@@ -16,6 +19,7 @@ nav:
- Basics: basic.md
- Embeddings: embedding.md
- Indexing: ann_indexes.md
- Full-text search: fts.md
- Integrations: integrations.md
- Python API: python.md

View File

@@ -0,0 +1,7 @@
# Code documentation Q&A bot with LangChain
## use LanceDB's LangChain integration to build a Q&A bot for your documentation
<img id="splash" width="400" alt="langchain" src="https://user-images.githubusercontent.com/917119/236580868-61a246a9-e587-4c2b-8ae5-6fe5f7b7e81e.png">
This example is in a [notebook](https://github.com/lancedb/lancedb/blob/main/notebooks/code_qa_bot.ipynb)

View File

@@ -0,0 +1,99 @@
# YouTube transcript QA bot with NodeJS
## use LanceDB's Javascript API and OpenAI to build a QA bot for YouTube transcripts
<img id="splash" width="400" alt="nodejs" src="https://github.com/lancedb/lancedb/assets/917119/3a140e75-bf8e-438a-a1e4-af14a72bcf98">
This Q&A bot will allow you to search through youtube transcripts using natural language! We'll introduce how you can use LanceDB's Javascript API to store and manage your data easily.
For this example we're using a HuggingFace dataset that contains YouTube transcriptions: `jamescalam/youtube-transcriptions`, to make it easier, we've converted it to a LanceDB `db` already, which you can download and put in a working directory:
```wget -c https://eto-public.s3.us-west-2.amazonaws.com/lancedb_demo.tar.gz -O - | tar -xz -C .```
Now, we'll create a simple app that can:
1. Take a text based query and search for contexts in our corpus, using embeddings generated from the OpenAI Embedding API.
2. Create a prompt with the contexts, and call the OpenAI Completion API to answer the text based query.
Dependencies and setup of OpenAI API:
```javascript
const lancedb = require("vectordb");
const { Configuration, OpenAIApi } = require("openai");
const configuration = new Configuration({
apiKey: process.env.OPENAI_API_KEY,
});
const openai = new OpenAIApi(configuration);
```
First, let's set our question and the context amount. The context amount will be used to query similar documents in our corpus.
```javascript
const QUESTION = "who was the 12th person on the moon and when did they land?";
const CONTEXT_AMOUNT = 3;
```
Now, let's generate an embedding from this question:
```javascript
const embeddingResponse = await openai.createEmbedding({
model: "text-embedding-ada-002",
input: QUESTION,
});
const embedding = embeddingResponse.data["data"][0]["embedding"];
```
Once we have the embedding, we can connect to LanceDB (using the database we downloaded earlier), and search through the chatbot table.
We'll extract 3 similar documents found.
```javascript
const db = await lancedb.connect('./lancedb');
const tbl = await db.openTable('chatbot');
const query = tbl.search(embedding);
query.limit = CONTEXT_AMOUNT;
const context = await query.execute();
```
Let's combine the context together so we can pass it into our prompt:
```javascript
for (let i = 1; i < context.length; i++) {
context[0]["text"] += " " + context[i]["text"];
}
```
Lastly, let's construct the prompt. You could play around with this to create more accurate/better prompts to yield results.
```javascript
const prompt = "Answer the question based on the context below.\n\n" +
"Context:\n" +
`${context[0]["text"]}\n` +
`\n\nQuestion: ${QUESTION}\nAnswer:`;
```
We pass the prompt, along with the context, to the completion API.
```javascript
const completion = await openai.createCompletion({
model: "text-davinci-003",
prompt,
temperature: 0,
max_tokens: 400,
top_p: 1,
frequency_penalty: 0,
presence_penalty: 0,
});
```
And that's it!
```javascript
console.log(completion.data.choices[0].text);
```
The response is (which is non deterministic):
```
The 12th person on the moon was Harrison Schmitt and he landed on December 11, 1972.
```

View File

@@ -0,0 +1,106 @@
# Serverless LanceDB
## Store your data on S3 and use Lambda to compute embeddings and retrieve queries in production easily.
<img id="splash" width="400" alt="s3-lambda" src="https://user-images.githubusercontent.com/917119/234653050-305a1e90-9305-40ab-b014-c823172a948c.png">
This is a great option if you're wanting to scale with your use case and save effort and costs of maintenance.
Let's walk through how to get a simple Lambda function that queries the SIFT dataset on S3.
Before we start, you'll need to ensure you create a secure account access to AWS. We recommend using user policies, as this way AWS can share credentials securely without you having to pass around environment variables into Lambda.
We'll also use a container to ship our Lambda code. This is a good option for Lambda as you don't have the space limits that you would otherwise by building a package yourself.
# Initial setup: creating a LanceDB Table and storing it remotely on S3
We'll use the SIFT vector dataset as an example. To make it easier, we've already made a Lance-format SIFT dataset publicly available, which we can access and use to populate our LanceDB Table.
To do this, download the Lance files locally first from:
```
s3://eto-public/datasets/sift/vec_data.lance
```
Then, we can write a quick Python script to populate our LanceDB Table:
```python
import pylance
sift_dataset = pylance.dataset("/path/to/local/vec_data.lance")
df = sift_dataset.to_table().to_pandas()
import lancedb
db = lancedb.connect(".")
table = db.create_table("vector_example", df)
```
Once we've created our Table, we are free to move this data over to S3 so we can remotely host it.
# Building our Lambda app: a simple event handler for vector search
Now that we've got a remotely hosted LanceDB Table, we'll want to be able to query it from Lambda. To do so, let's create a new `Dockerfile` using the AWS python container base:
```docker
FROM public.ecr.aws/lambda/python:3.10
RUN pip3 install --upgrade pip
RUN pip3 install --no-cache-dir -U numpy --target "${LAMBDA_TASK_ROOT}"
RUN pip3 install --no-cache-dir -U lancedb --target "${LAMBDA_TASK_ROOT}"
COPY app.py ${LAMBDA_TASK_ROOT}
CMD [ "app.handler" ]
```
Now let's make a simple Lambda function that queries the SIFT dataset in `app.py`.
```python
import json
import numpy as np
import lancedb
db = lancedb.connect("s3://eto-public/tables")
table = db.open_table("vector_example")
def handler(event, context):
status_code = 200
if event['query_vector'] is None:
status_code = 404
return {
"statusCode": status_code,
"headers": {
"Content-Type": "application/json"
},
"body": json.dumps({
"Error ": "No vector to query was issued"
})
}
# Shape of SIFT is (128,1M), d=float32
query_vector = np.array(event['query_vector'], dtype=np.float32)
rs = table.search(query_vector).limit(2).to_df()
return {
"statusCode": status_code,
"headers": {
"Content-Type": "application/json"
},
"body": rs.to_json()
}
```
# Deploying the container to ECR
The next step is to build and push the container to ECR, where it can then be used to create a new Lambda function.
It's best to follow the official AWS documentation for how to do this, which you can view here:
```
https://docs.aws.amazon.com/lambda/latest/dg/images-create.html#images-upload
```
# Final step: setting up your Lambda function
Once the container is pushed, you can create a Lambda function by selecting the container.

View File

@@ -0,0 +1,7 @@
# YouTube transcript search
## Search through youtube transcripts using natural language with LanceDB
<img id="splash" width="400" alt="youtube transcript search" src="https://user-images.githubusercontent.com/917119/236965568-def7394d-171c-45f2-939d-8edfeaadd88c.png">
This example is in a [notebook](https://github.com/lancedb/lancedb/blob/main/notebooks/youtube_transcript_search.ipynb)

51
docs/src/fts.md Normal file
View File

@@ -0,0 +1,51 @@
# [EXPERIMENTAL] Full text search
LanceDB now provides experimental support for full text search.
This is currently Python only. We plan to push the integration down to Rust in the future
to make this available for JS as well.
## Installation
To use full text search, you must install optional dependency tantivy-py:
# tantivy 0.19.2
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
## Quickstart
Assume:
1. `table` is a LanceDB Table
2. `text` is the name of the Table column that we want to index
To create the index:
```python
table.create_fts_index("text")
```
To search:
```python
df = table.search("puppy").limit(10).select(["text"]).to_df()
```
LanceDB automatically looks for an FTS index if the input is str.
## Multiple text columns
If you have multiple columns to index, pass them all as a list to `create_fts_index`:
```python
table.create_fts_index(["text1", "text2"])
```
Note that the search API call does not change - you can search over all indexed columns at once.
## Current limitations
1. Currently we do not yet support incremental writes.
If you add data after fts index creation, it won't be reflected
in search results until you do a full reindex.
2. We currently only support local filesystem paths for the fts index.

View File

@@ -6,11 +6,13 @@ The key features of LanceDB include:
* Production-scale vector search with no servers to manage.
* Combine attribute-based information with vectors and store them as a single source-of-truth.
* Store, query and filter vectors, metadata and multi-modal data (text, images, videos, point clouds, and more).
* Native Python and Javascript/Typescript support (coming soon).
* Zero-copy, automatic versioning, manage versions of your data without needing extra infrastructure.
* Ecosystem integrations: Apache-Arrow, Pandas, Polars, DuckDB and more on the way.
* Ecosystem integrations with [LangChain 🦜️🔗](https://python.langchain.com/en/latest/modules/indexes/vectorstores/examples/lanecdb.html), [LlamaIndex 🦙](https://gpt-index.readthedocs.io/en/latest/examples/vector_stores/LanceDBIndexDemo.html), Apache-Arrow, Pandas, Polars, DuckDB and more on the way.
LanceDB's core is written in Rust 🦀 and is built using Lance, an open-source columnar format designed for performant ML workloads.
@@ -36,12 +38,13 @@ result = table.search([100, 100]).limit(2).to_df()
## Complete Demos
We will be adding completed demo apps built using LanceDB.
- [YouTube Transcript Search](../notebooks/youtube_transcript_search.ipynb)
- [YouTube Transcript Search](../../notebooks/youtube_transcript_search.ipynb)
## Documentation Quick Links
* [`Basic Operations`](basic.md) - basic functionality of LanceDB.
* [`Embedding Functions`](embedding.md) - functions for working with embeddings.
* [`Indexing`](ann_indexes.md) - create vector indexes to speed up queries.
* [`Full text search`](fts.md) - [EXPERIMENTAL] full-text search API
* [`Ecosystem Integrations`](integrations.md) - integrating LanceDB with python data tooling ecosystem.
* [`API Reference`](python.md) - detailed documentation for the LanceDB Python SDK.

View File

@@ -6,9 +6,9 @@
pip install lancedb
```
::: lancedb
::: lancedb.db
::: lancedb.table
::: lancedb.query
::: lancedb.embeddings
::: lancedb.context
## ::: lancedb
## ::: lancedb.db
## ::: lancedb.table
## ::: lancedb.query
## ::: lancedb.embeddings
## ::: lancedb.context

16
node/.eslintrc.js Normal file
View File

@@ -0,0 +1,16 @@
module.exports = {
env: {
browser: true,
es2021: true
},
extends: 'standard-with-typescript',
overrides: [
],
parserOptions: {
project: './tsconfig.json',
ecmaVersion: 'latest',
sourceType: 'module'
},
rules: {
}
}

43
node/README.md Normal file
View File

@@ -0,0 +1,43 @@
# LanceDB
A JavaScript / Node.js library for [LanceDB](https://github.com/lancedb/lancedb).
## Installation
```bash
npm install vectordb
```
## Usage
### Basic Example
```javascript
const lancedb = require('vectordb');
const db = lancedb.connect('<PATH_TO_LANCEDB_DATASET>');
const table = await db.openTable('my_table');
const query = await table.search([0.1, 0.3]).setLimit(20).execute();
console.log(results);
```
The [examples](./examples) folder contains complete examples.
## Development
The LanceDB javascript is built with npm:
```bash
npm run tsc
```
Run the tests with
```bash
npm test
```
To run the linter and have it automatically fix all errors
```bash
npm run lint -- --fix
```

View File

@@ -0,0 +1,41 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
'use strict'
async function example () {
const lancedb = require('vectordb')
// You need to provide an OpenAI API key, here we read it from the OPENAI_API_KEY environment variable
const apiKey = process.env.OPENAI_API_KEY
// The embedding function will create embeddings for the 'text' column(text in this case)
const embedding = new lancedb.OpenAIEmbeddingFunction('text', apiKey)
const db = await lancedb.connect('data/sample-lancedb')
const data = [
{ id: 1, text: 'Black T-Shirt', price: 10 },
{ id: 2, text: 'Leather Jacket', price: 50 }
]
const table = await db.createTable('vectors', data, embedding)
console.log(await db.tableNames())
const results = await table
.search('keeps me warm')
.limit(1)
.execute()
console.log(results[0].text)
}
example().then(_ => { console.log('All done!') })

View File

@@ -0,0 +1,15 @@
{
"name": "vectordb-example-js-openai",
"version": "1.0.0",
"description": "",
"main": "index.js",
"scripts": {
"test": "echo \"Error: no test specified\" && exit 1"
},
"author": "Lance Devs",
"license": "Apache-2.0",
"dependencies": {
"vectordb": "file:../..",
"openai": "^3.2.1"
}
}

36
node/examples/js/index.js Normal file
View File

@@ -0,0 +1,36 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
'use strict'
async function example () {
const lancedb = require('vectordb')
const db = await lancedb.connect('data/sample-lancedb')
const data = [
{ id: 1, vector: [0.1, 0.2], price: 10 },
{ id: 2, vector: [1.1, 1.2], price: 50 }
]
const table = await db.createTable('vectors', data)
console.log(await db.tableNames())
const results = await table
.search([0.1, 0.3])
.limit(20)
.execute()
console.log(results)
}
example()

View File

@@ -0,0 +1,14 @@
{
"name": "vectordb-example-js",
"version": "1.0.0",
"description": "",
"main": "index.js",
"scripts": {
"test": "echo \"Error: no test specified\" && exit 1"
},
"author": "Lance Devs",
"license": "Apache-2.0",
"dependencies": {
"vectordb": "file:../.."
}
}

View File

@@ -0,0 +1,22 @@
{
"name": "vectordb-example-ts",
"version": "1.0.0",
"description": "",
"main": "dist/index.js",
"types": "dist/index.d.ts",
"scripts": {
"tsc": "tsc -b",
"build": "tsc"
},
"author": "Lance Devs",
"license": "Apache-2.0",
"devDependencies": {
"@types/node": "^18.16.2",
"ts-node": "^10.9.1",
"ts-node-dev": "^2.0.0",
"typescript": "*"
},
"dependencies": {
"vectordb": "file:../.."
}
}

View File

@@ -0,0 +1,35 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import * as vectordb from 'vectordb';
async function example () {
const db = await vectordb.connect('data/sample-lancedb')
const data = [
{ id: 1, vector: [0.1, 0.2], price: 10 },
{ id: 2, vector: [1.1, 1.2], price: 50 }
]
const table = await db.createTable('vectors', data)
console.log(await db.tableNames())
const results = await table
.search([0.1, 0.3])
.limit(20)
.execute()
console.log(results)
}
example().then(_ => { console.log ("All done!") })

View File

@@ -0,0 +1,10 @@
{
"include": ["src/**/*.ts"],
"compilerOptions": {
"target": "es2016",
"module": "commonjs",
"declaration": true,
"outDir": "./dist",
"strict": true
}
}

8
node/gen_test_data.py Normal file
View File

@@ -0,0 +1,8 @@
import lancedb
uri = "sample-lancedb"
db = lancedb.connect(uri)
table = db.create_table("my_table",
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}])

40
node/native.js Normal file
View File

@@ -0,0 +1,40 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
let nativeLib;
function getPlatformLibrary() {
if (process.platform === "darwin" && process.arch == "arm64") {
return require('./aarch64-apple-darwin.node');
} else if (process.platform === "darwin" && process.arch == "x64") {
return require('./x86_64-apple-darwin.node');
} else if (process.platform === "linux" && process.arch == "x64") {
return require('./x86_64-unknown-linux-gnu.node');
} else {
throw new Error(`vectordb: unsupported platform ${process.platform}_${process.arch}. Please file a bug report at https://github.com/lancedb/lancedb/issues`)
}
}
try {
nativeLib = require('./index.node')
} catch (e) {
if (e.code === "MODULE_NOT_FOUND") {
nativeLib = getPlatformLibrary();
} else {
throw new Error('vectordb: failed to load native library. Please file a bug report at https://github.com/lancedb/lancedb/issues');
}
}
module.exports = nativeLib

7422
node/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

52
node/package.json Normal file
View File

@@ -0,0 +1,52 @@
{
"name": "vectordb",
"version": "0.1.3",
"description": " Serverless, low-latency vector database for AI applications",
"main": "dist/index.js",
"types": "dist/index.d.ts",
"scripts": {
"tsc": "tsc -b",
"build": "cargo-cp-artifact --artifact cdylib vectordb-node index.node -- cargo build --message-format=json-render-diagnostics",
"build-release": "npm run build -- --release",
"test": "mocha -recursive dist/test",
"lint": "eslint src --ext .js,.ts"
},
"repository": {
"type": "git",
"url": "https://github.com/lancedb/lancedb/node"
},
"keywords": [
"data-format",
"data-science",
"machine-learning",
"data-analytics"
],
"author": "Lance Devs",
"license": "Apache-2.0",
"devDependencies": {
"@types/chai": "^4.3.4",
"@types/mocha": "^10.0.1",
"@types/node": "^18.16.2",
"@types/sinon": "^10.0.15",
"@types/temp": "^0.9.1",
"@typescript-eslint/eslint-plugin": "^5.59.1",
"cargo-cp-artifact": "^0.1",
"chai": "^4.3.7",
"eslint": "^8.39.0",
"eslint-config-standard-with-typescript": "^34.0.1",
"eslint-plugin-import": "^2.27.5",
"eslint-plugin-n": "^15.7.0",
"eslint-plugin-promise": "^6.1.1",
"mocha": "^10.2.0",
"sinon": "^15.1.0",
"openai": "^3.2.1",
"temp": "^0.9.4",
"ts-node": "^10.9.1",
"ts-node-dev": "^2.0.0",
"typescript": "*"
},
"dependencies": {
"@apache-arrow/ts": "^12.0.0",
"apache-arrow": "^12.0.0"
}
}

85
node/src/arrow.ts Normal file
View File

@@ -0,0 +1,85 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import {
Field,
Float32,
List, type ListBuilder,
makeBuilder,
RecordBatchFileWriter,
Table, Utf8,
type Vector,
vectorFromArray
} from 'apache-arrow'
import { type EmbeddingFunction } from './index'
export async function convertToTable<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Table> {
if (data.length === 0) {
throw new Error('At least one record needs to be provided')
}
const columns = Object.keys(data[0])
const records: Record<string, Vector> = {}
for (const columnsKey of columns) {
if (columnsKey === 'vector') {
const listBuilder = newVectorListBuilder()
const vectorSize = (data[0].vector as any[]).length
for (const datum of data) {
if ((datum[columnsKey] as any[]).length !== vectorSize) {
throw new Error(`Invalid vector size, expected ${vectorSize}`)
}
listBuilder.append(datum[columnsKey])
}
records[columnsKey] = listBuilder.finish().toVector()
} else {
const values = []
for (const datum of data) {
values.push(datum[columnsKey])
}
if (columnsKey === embeddings?.sourceColumn) {
const vectors = await embeddings.embed(values as T[])
const listBuilder = newVectorListBuilder()
vectors.map(v => listBuilder.append(v))
records.vector = listBuilder.finish().toVector()
}
if (typeof values[0] === 'string') {
// `vectorFromArray` converts strings into dictionary vectors, forcing it back to a string column
records[columnsKey] = vectorFromArray(values, new Utf8())
} else {
records[columnsKey] = vectorFromArray(values)
}
}
}
return new Table(records)
}
// Creates a new Arrow ListBuilder that stores a Vector column
function newVectorListBuilder (): ListBuilder<Float32, any> {
const children = new Field<Float32>('item', new Float32())
const list = new List(children)
return makeBuilder({
type: list
})
}
export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
const table = await convertToTable(data, embeddings)
const writer = RecordBatchFileWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array())
}

View File

@@ -0,0 +1,28 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* An embedding function that automatically creates vector representation for a given column.
*/
export interface EmbeddingFunction<T> {
/**
* The name of the column that will be used as input for the Embedding Function.
*/
sourceColumn: string
/**
* Creates a vector representation for the given values.
*/
embed: (data: T[]) => Promise<number[][]>
}

View File

@@ -0,0 +1,51 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { type EmbeddingFunction } from '../index'
export class OpenAIEmbeddingFunction implements EmbeddingFunction<string> {
private readonly _openai: any
private readonly _modelName: string
constructor (sourceColumn: string, openAIKey: string, modelName: string = 'text-embedding-ada-002') {
let openai
try {
// eslint-disable-next-line @typescript-eslint/no-var-requires
openai = require('openai')
} catch {
throw new Error('please install openai using npm install openai')
}
this.sourceColumn = sourceColumn
const configuration = new openai.Configuration({
apiKey: openAIKey
})
this._openai = new openai.OpenAIApi(configuration)
this._modelName = modelName
}
async embed (data: string[]): Promise<number[][]> {
const response = await this._openai.createEmbedding({
model: this._modelName,
input: data
})
const embeddings: number[][] = []
for (let i = 0; i < response.data.data.length; i++) {
embeddings.push(response.data.data[i].embedding as number[])
}
return embeddings
}
sourceColumn: string
}

342
node/src/index.ts Normal file
View File

@@ -0,0 +1,342 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import {
RecordBatchFileWriter,
type Table as ArrowTable,
tableFromIPC,
Vector
} from 'apache-arrow'
import { fromRecordsToBuffer } from './arrow'
import type { EmbeddingFunction } from './embedding/embedding_function'
// eslint-disable-next-line @typescript-eslint/no-var-requires
const { databaseNew, databaseTableNames, databaseOpenTable, tableCreate, tableSearch, tableAdd, tableCreateVectorIndex } = require('../native.js')
export type { EmbeddingFunction }
export { OpenAIEmbeddingFunction } from './embedding/openai'
/**
* Connect to a LanceDB instance at the given URI
* @param uri The uri of the database.
*/
export async function connect (uri: string): Promise<Connection> {
const db = await databaseNew(uri)
return new Connection(db, uri)
}
/**
* A connection to a LanceDB database.
*/
export class Connection {
private readonly _uri: string
private readonly _db: any
constructor (db: any, uri: string) {
this._uri = uri
this._db = db
}
get uri (): string {
return this._uri
}
/**
* Get the names of all tables in the database.
*/
async tableNames (): Promise<string[]> {
return databaseTableNames.call(this._db)
}
/**
* Open a table in the database.
*
* @param name The name of the table.
*/
async openTable (name: string): Promise<Table>
/**
* Open a table in the database.
*
* @param name The name of the table.
* @param embeddings An embedding function to use on this Table
*/
async openTable<T> (name: string, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
const tbl = await databaseOpenTable.call(this._db, name)
if (embeddings !== undefined) {
return new Table(tbl, name, embeddings)
} else {
return new Table(tbl, name)
}
}
/**
* Creates a new Table and initialize it with new data.
*
* @param name The name of the table.
* @param data Non-empty Array of Records to be inserted into the Table
*/
async createTable (name: string, data: Array<Record<string, unknown>>): Promise<Table>
/**
* Creates a new Table and initialize it with new data.
*
* @param name The name of the table.
* @param data Non-empty Array of Records to be inserted into the Table
* @param embeddings An embedding function to use on this Table
*/
async createTable<T> (name: string, data: Array<Record<string, unknown>>, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
async createTable<T> (name: string, data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
const tbl = await tableCreate.call(this._db, name, await fromRecordsToBuffer(data, embeddings))
if (embeddings !== undefined) {
return new Table(tbl, name, embeddings)
} else {
return new Table(tbl, name)
}
}
async createTableArrow (name: string, table: ArrowTable): Promise<Table> {
const writer = RecordBatchFileWriter.writeAll(table)
await tableCreate.call(this._db, name, Buffer.from(await writer.toUint8Array()))
return await this.openTable(name)
}
}
export class Table<T = number[]> {
private readonly _tbl: any
private readonly _name: string
private readonly _embeddings?: EmbeddingFunction<T>
constructor (tbl: any, name: string)
/**
* @param tbl
* @param name
* @param embeddings An embedding function to use when interacting with this table
*/
constructor (tbl: any, name: string, embeddings: EmbeddingFunction<T>)
constructor (tbl: any, name: string, embeddings?: EmbeddingFunction<T>) {
this._tbl = tbl
this._name = name
this._embeddings = embeddings
}
get name (): string {
return this._name
}
/**
* Creates a search query to find the nearest neighbors of the given search term
* @param query The query search term
*/
search (query: T): Query<T> {
return new Query(this._tbl, query, this._embeddings)
}
/**
* Insert records into this Table.
*
* @param data Records to be inserted into the Table
* @return The number of rows added to the table
*/
async add (data: Array<Record<string, unknown>>): Promise<number> {
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Append.toString())
}
/**
* Insert records into this Table, replacing its contents.
*
* @param data Records to be inserted into the Table
* @return The number of rows added to the table
*/
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString())
}
/**
* Create an ANN index on this Table vector index.
*
* @param indexParams The parameters of this Index, @see VectorIndexParams.
*/
async create_index (indexParams: VectorIndexParams): Promise<any> {
return tableCreateVectorIndex.call(this._tbl, indexParams)
}
}
interface IvfPQIndexConfig {
/**
* The column to be indexed
*/
column?: string
/**
* A unique name for the index
*/
index_name?: string
/**
* Metric type, L2 or Cosine
*/
metric_type?: MetricType
/**
* The number of partitions this index
*/
num_partitions?: number
/**
* The max number of iterations for kmeans training.
*/
max_iters?: number
/**
* Train as optimized product quantization.
*/
use_opq?: boolean
/**
* Number of subvectors to build PQ code
*/
num_sub_vectors?: number
/**
* The number of bits to present one PQ centroid.
*/
num_bits?: number
/**
* Max number of iterations to train OPQ, if `use_opq` is true.
*/
max_opq_iters?: number
type: 'ivf_pq'
}
export type VectorIndexParams = IvfPQIndexConfig
/**
* A builder for nearest neighbor queries for LanceDB.
*/
export class Query<T = number[]> {
private readonly _tbl: any
private readonly _query: T
private _queryVector?: number[]
private _limit: number
private _refineFactor?: number
private _nprobes: number
private readonly _columns?: string[]
private _filter?: string
private _metricType?: MetricType
private readonly _embeddings?: EmbeddingFunction<T>
constructor (tbl: any, query: T, embeddings?: EmbeddingFunction<T>) {
this._tbl = tbl
this._query = query
this._limit = 10
this._nprobes = 20
this._refineFactor = undefined
this._columns = undefined
this._filter = undefined
this._metricType = undefined
this._embeddings = embeddings
}
/***
* Sets the number of results that will be returned
* @param value number of results
*/
limit (value: number): Query<T> {
this._limit = value
return this
}
/**
* Refine the results by reading extra elements and re-ranking them in memory.
* @param value refine factor to use in this query.
*/
refineFactor (value: number): Query<T> {
this._refineFactor = value
return this
}
/**
* The number of probes used. A higher number makes search more accurate but also slower.
* @param value The number of probes used.
*/
nprobes (value: number): Query<T> {
this._nprobes = value
return this
}
/**
* A filter statement to be applied to this query.
* @param value A filter in the same format used by a sql WHERE clause.
*/
filter (value: string): Query<T> {
this._filter = value
return this
}
/**
* The MetricType used for this Query.
* @param value The metric to the. @see MetricType for the different options
*/
metricType (value: MetricType): Query<T> {
this._metricType = value
return this
}
/**
* Execute the query and return the results as an Array of Objects
*/
async execute<T = Record<string, unknown>> (): Promise<T[]> {
if (this._embeddings !== undefined) {
this._queryVector = (await this._embeddings.embed([this._query]))[0]
} else {
this._queryVector = this._query as number[]
}
const buffer = await tableSearch.call(this._tbl, this)
const data = tableFromIPC(buffer)
return data.toArray().map((entry: Record<string, unknown>) => {
const newObject: Record<string, unknown> = {}
Object.keys(entry).forEach((key: string) => {
if (entry[key] instanceof Vector) {
newObject[key] = (entry[key] as Vector).toArray()
} else {
newObject[key] = entry[key]
}
})
return newObject as unknown as T
})
}
}
export enum WriteMode {
Overwrite = 'overwrite',
Append = 'append'
}
/**
* Distance metrics type.
*/
export enum MetricType {
/**
* Euclidean distance
*/
L2 = 'l2',
/**
* Cosine distance
*/
Cosine = 'cosine'
}

View File

@@ -0,0 +1,50 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { describe } from 'mocha'
import { assert } from 'chai'
import { OpenAIEmbeddingFunction } from '../../embedding/openai'
// eslint-disable-next-line @typescript-eslint/no-var-requires
const { OpenAIApi } = require('openai')
// eslint-disable-next-line @typescript-eslint/no-var-requires
const { stub } = require('sinon')
describe('OpenAPIEmbeddings', function () {
const stubValue = {
data: {
data: [
{
embedding: Array(1536).fill(1.0)
},
{
embedding: Array(1536).fill(2.0)
}
]
}
}
describe('#embed', function () {
it('should create vector embeddings', async function () {
const openAIStub = stub(OpenAIApi.prototype, 'createEmbedding').returns(stubValue)
const f = new OpenAIEmbeddingFunction('text', 'sk-key')
const vectors = await f.embed(['abc', 'def'])
assert.isTrue(openAIStub.calledOnce)
assert.equal(vectors.length, 2)
assert.deepEqual(vectors[0], stubValue.data.data[0].embedding)
assert.deepEqual(vectors[1], stubValue.data.data[1].embedding)
})
})
})

52
node/src/test/io.ts Normal file
View File

@@ -0,0 +1,52 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// IO tests
import { describe } from 'mocha'
import { assert } from 'chai'
import * as lancedb from '../index'
describe('LanceDB S3 client', function () {
if (process.env.TEST_S3_BASE_URL != null) {
const baseUri = process.env.TEST_S3_BASE_URL
it('should have a valid url', async function () {
const uri = `${baseUri}/valid_url`
const table = await createTestDB(uri, 2, 20)
const con = await lancedb.connect(uri)
assert.equal(con.uri, uri)
const results = await table.search([0.1, 0.3]).limit(5).execute()
assert.equal(results.length, 5)
})
} else {
describe.skip('Skip S3 test', function () {})
}
})
async function createTestDB (uri: string, numDimensions: number = 2, numRows: number = 2): Promise<lancedb.Table> {
const con = await lancedb.connect(uri)
const data = []
for (let i = 0; i < numRows; i++) {
const vector = []
for (let j = 0; j < numDimensions; j++) {
vector.push(i + (j * 0.1))
}
data.push({ id: i + 1, name: `name_${i}`, price: i + 10, is_active: (i % 2 === 0), vector })
}
return await con.createTable('vectors', data)
}

207
node/src/test/test.ts Normal file
View File

@@ -0,0 +1,207 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { describe } from 'mocha'
import { assert } from 'chai'
import { track } from 'temp'
import * as lancedb from '../index'
import { type EmbeddingFunction, MetricType, Query } from '../index'
describe('LanceDB client', function () {
describe('when creating a connection to lancedb', function () {
it('should have a valid url', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)
assert.equal(con.uri, uri)
})
it('should return the existing table names', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)
assert.deepEqual(await con.tableNames(), ['vectors'])
})
})
describe('when querying an existing dataset', function () {
it('should open a table', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
assert.equal(table.name, 'vectors')
})
it('execute a query', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
const results = await table.search([0.1, 0.3]).execute()
assert.equal(results.length, 2)
assert.equal(results[0].price, 10)
const vector = results[0].vector as Float32Array
assert.approximately(vector[0], 0.0, 0.2)
assert.approximately(vector[0], 0.1, 0.3)
})
it('limits # of results', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
const results = await table.search([0.1, 0.3]).limit(1).execute()
assert.equal(results.length, 1)
assert.equal(results[0].id, 1)
})
it('uses a filter', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
const results = await table.search([0.1, 0.1]).filter('id == 2').execute()
assert.equal(results.length, 1)
assert.equal(results[0].id, 2)
})
})
describe('when creating a new dataset', function () {
it('creates a new table from javascript objects', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const data = [
{ id: 1, vector: [0.1, 0.2], price: 10 },
{ id: 2, vector: [1.1, 1.2], price: 50 }
]
const tableName = `vectors_${Math.floor(Math.random() * 100)}`
const table = await con.createTable(tableName, data)
assert.equal(table.name, tableName)
const results = await table.search([0.1, 0.3]).execute()
assert.equal(results.length, 2)
})
it('appends records to an existing table ', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const data = [
{ id: 1, vector: [0.1, 0.2], price: 10, name: 'a' },
{ id: 2, vector: [1.1, 1.2], price: 50, name: 'b' }
]
const table = await con.createTable('vectors', data)
const results = await table.search([0.1, 0.3]).execute()
assert.equal(results.length, 2)
const dataAdd = [
{ id: 3, vector: [2.1, 2.2], price: 10, name: 'c' },
{ id: 4, vector: [3.1, 3.2], price: 50, name: 'd' }
]
await table.add(dataAdd)
const resultsAdd = await table.search([0.1, 0.3]).execute()
assert.equal(resultsAdd.length, 4)
})
it('overwrite all records in a table', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
const results = await table.search([0.1, 0.3]).execute()
assert.equal(results.length, 2)
const dataOver = [
{ vector: [2.1, 2.2], price: 10, name: 'foo' },
{ vector: [3.1, 3.2], price: 50, name: 'bar' }
]
await table.overwrite(dataOver)
const resultsAdd = await table.search([0.1, 0.3]).execute()
assert.equal(resultsAdd.length, 2)
})
})
describe('when creating a vector index', function () {
it('overwrite all records in a table', async function () {
const uri = await createTestDB(32, 300)
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
await table.create_index({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2 })
}).timeout(10_000) // Timeout is high partially because GH macos runner is pretty slow
})
describe('when using a custom embedding function', function () {
class TextEmbedding implements EmbeddingFunction<string> {
sourceColumn: string
constructor (targetColumn: string) {
this.sourceColumn = targetColumn
}
_embedding_map = new Map<string, number[]>([
['foo', [2.1, 2.2]],
['bar', [3.1, 3.2]]
])
async embed (data: string[]): Promise<number[][]> {
return data.map(datum => this._embedding_map.get(datum) ?? [0.0, 0.0])
}
}
it('should encode the original data into embeddings', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const embeddings = new TextEmbedding('name')
const data = [
{ price: 10, name: 'foo' },
{ price: 50, name: 'bar' }
]
const table = await con.createTable('vectors', data, embeddings)
const results = await table.search('foo').execute()
assert.equal(results.length, 2)
})
})
})
describe('Query object', function () {
it('sets custom parameters', async function () {
const query = new Query(undefined, [0.1, 0.3])
.limit(1)
.metricType(MetricType.Cosine)
.refineFactor(100)
.nprobes(20) as Record<string, any>
assert.equal(query._limit, 1)
assert.equal(query._metricType, MetricType.Cosine)
assert.equal(query._refineFactor, 100)
assert.equal(query._nprobes, 20)
})
})
async function createTestDB (numDimensions: number = 2, numRows: number = 2): Promise<string> {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const data = []
for (let i = 0; i < numRows; i++) {
const vector = []
for (let j = 0; j < numDimensions; j++) {
vector.push(i + (j * 0.1))
}
data.push({ id: i + 1, name: `name_${i}`, price: i + 10, is_active: (i % 2 === 0), vector })
}
await con.createTable('vectors', data)
return dir
}

10
node/tsconfig.json Normal file
View File

@@ -0,0 +1,10 @@
{
"include": ["src/**/*.ts"],
"compilerOptions": {
"target": "es2016",
"module": "commonjs",
"declaration": true,
"outDir": "./dist",
"strict": true
}
}

357
notebooks/code_qa_bot.ipynb Normal file
View File

@@ -0,0 +1,357 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "13cb272e",
"metadata": {},
"source": [
"# Code documentation Q&A bot example with LangChain\n",
"\n",
"This Q&A bot will allow you to query your own documentation easily using questions. We'll also demonstrate the use of LangChain and LanceDB using the OpenAI API. \n",
"\n",
"In this example we'll use Pandas 2.0 documentation, but, this could be replaced for your own docs as well"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "66638d6c",
"metadata": {},
"outputs": [],
"source": [
"!pip install --quiet openai langchain\n",
"!pip install --quiet -U lancedb"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "d1cdcac3",
"metadata": {},
"source": [
"First, let's get some setup out of the way. As we're using the OpenAI API, ensure that you've set your key (and organization if needed):"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "58ee1868",
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"import os\n",
"\n",
"# Configuring the environment variable OPENAI_API_KEY\n",
"if \"OPENAI_API_KEY\" not in os.environ:\n",
" # OR set the key here as a variable\n",
" openai.api_key = \"sk-...\"\n",
" \n",
"assert len(openai.Model.list()[\"data\"]) > 0"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "34f524d3",
"metadata": {},
"source": [
"# Loading in our code documentation, generating embeddings and storing our documents in LanceDB\n",
"\n",
"We're going to use the power of LangChain to help us create our Q&A bot. It comes with several APIs that can make our development much easier as well as a LanceDB integration for vectorstore."
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "b55d22f1",
"metadata": {},
"outputs": [],
"source": [
"import lancedb\n",
"import re\n",
"import pickle\n",
"from pathlib import Path\n",
"\n",
"from langchain.document_loaders import UnstructuredHTMLLoader\n",
"from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"from langchain.vectorstores import LanceDB\n",
"from langchain.llms import OpenAI\n",
"from langchain.chains import RetrievalQA"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "6ccf9b2b",
"metadata": {},
"source": [
"You can download the Pandas documentation from https://pandas.pydata.org/docs/. To make sure we're not littering our repo with docs, we won't include it in the LanceDB repo, so download this and store it locally first."
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "ae42496c",
"metadata": {},
"source": [
"We'll create a simple helper function that can help to extract metadata, so we can use this downstream when we're wanting to query with filters. In this case, we want to keep the lineage of the uri or path for each document that we process:"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "d171d062",
"metadata": {},
"outputs": [],
"source": [
"def get_document_title(document):\n",
" m = str(document.metadata[\"source\"])\n",
" title = re.findall(\"pandas.documentation(.*).html\", m)\n",
" if title[0] is not None:\n",
" return(title[0])\n",
" return ''"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "130162ad",
"metadata": {},
"source": [
"# Pre-processing and loading the documentation\n",
"\n",
"Next, let's pre-process and load the documentation. To make sure we don't need to do this repeatedly if we were updating code, we're caching it using pickle so we can retrieve it again (this could take a few minutes to run the first time yyou do it). We'll also add some more metadata to the docs here such as the title and version of the code:"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "33bfe7d8",
"metadata": {},
"outputs": [],
"source": [
"docs_path = Path(\"docs.pkl\")\n",
"docs = []\n",
"\n",
"if not docs_path.exists():\n",
" for p in Path(\"./pandas.documentation\").rglob(\"*.html\"):\n",
" if p.is_dir():\n",
" continue\n",
" loader = UnstructuredHTMLLoader(p)\n",
" raw_document = loader.load()\n",
" \n",
" m = {}\n",
" m[\"title\"] = get_document_title(raw_document[0])\n",
" m[\"version\"] = \"2.0rc0\"\n",
" raw_document[0].metadata = raw_document[0].metadata | m\n",
" raw_document[0].metadata[\"source\"] = str(raw_document[0].metadata[\"source\"])\n",
" docs = docs + raw_document\n",
"\n",
" with docs_path.open(\"wb\") as fh:\n",
" pickle.dump(docs, fh)\n",
"else:\n",
" with docs_path.open(\"rb\") as fh:\n",
" docs = pickle.load(fh)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "c3852dd3",
"metadata": {},
"source": [
"# Generating emebeddings from our docs\n",
"\n",
"Now that we have our raw documents loaded, we need to pre-process them to generate embeddings:"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "82230563",
"metadata": {},
"outputs": [],
"source": [
"text_splitter = RecursiveCharacterTextSplitter(\n",
" chunk_size=1000,\n",
" chunk_overlap=200,\n",
")\n",
"documents = text_splitter.split_documents(docs)\n",
"embeddings = OpenAIEmbeddings()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "43e68215",
"metadata": {},
"source": [
"# Storing and querying with LanceDB\n",
"\n",
"Let's connect to LanceDB so we can store our documents. We'll create a Table to store them in:"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "74780a58",
"metadata": {},
"outputs": [],
"source": [
"db = lancedb.connect('/tmp/lancedb')\n",
"table = db.create_table(\"pandas_docs\", data=[\n",
" {\"vector\": embeddings.embed_query(\"Hello World\"), \"text\": \"Hello World\", \"id\": \"1\"}\n",
"], mode=\"overwrite\")\n",
"docsearch = LanceDB.from_documents(documents, embeddings, connection=table)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "3cb1dc5d",
"metadata": {},
"source": [
"Now let's create our RetrievalQA chain using the LanceDB vector store:"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "6a5891ad",
"metadata": {},
"outputs": [],
"source": [
"qa = RetrievalQA.from_chain_type(llm=OpenAI(), chain_type=\"stuff\", retriever=docsearch.as_retriever())"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "28d93b85",
"metadata": {},
"source": [
"And thats it! We're all setup. The next step is to run some queries, let's try a few:"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "70d88316",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"' The major differences in pandas 2.0 include installing optional dependencies with pip extras, the ability to use any numpy numeric dtype in an Index, and enhancements, notable bug fixes, backwards incompatible API changes, deprecations, and performance improvements.'"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"query = \"What are the major differences in pandas 2.0?\"\n",
"qa.run(query)"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "85a0397c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"' 2.0.0rc0'"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"query = \"What's the current version of pandas?\"\n",
"qa.run(query)"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "923f86c6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"' Optional dependencies can be installed with pip install \"pandas[all]\" or \"pandas[performance]\". This will install all recommended performance dependencies such as numexpr, bottleneck and numba.'"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"query = \"How do I make use of installing optional dependencies?\"\n",
"qa.run(query)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "02082f83",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\" \\n\\nPandas 2.0 includes a number of API breaking changes, such as increased minimum versions for dependencies, the use of os.linesep for DataFrame.to_csv's line_terminator, and reorganization of the library. See the release notes for a full list of changes.\""
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"query = \"What are the backwards incompatible API changes in Pandas 2.0?\"\n",
"qa.run(query)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "75cea547",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

108
notebooks/diffusiondb/datagen.py Executable file
View File

@@ -0,0 +1,108 @@
#!/usr/bin/env python
#
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataset hf://poloclub/diffusiondb
"""
import io
from argparse import ArgumentParser
from multiprocessing import Pool
import lance
import lancedb
import pyarrow as pa
from datasets import load_dataset
from PIL import Image
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast
MODEL_ID = "openai/clip-vit-base-patch32"
device = "cuda"
tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
schema = pa.schema(
[
pa.field("prompt", pa.string()),
pa.field("seed", pa.uint32()),
pa.field("step", pa.uint16()),
pa.field("cfg", pa.float32()),
pa.field("sampler", pa.string()),
pa.field("width", pa.uint16()),
pa.field("height", pa.uint16()),
pa.field("timestamp", pa.timestamp("s")),
pa.field("image_nsfw", pa.float32()),
pa.field("prompt_nsfw", pa.float32()),
pa.field("vector", pa.list_(pa.float32(), 512)),
pa.field("image", pa.binary()),
]
)
def pil_to_bytes(img) -> list[bytes]:
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
def generate_clip_embeddings(batch) -> pa.RecordBatch:
image = processor(text=None, images=batch["image"], return_tensors="pt")[
"pixel_values"
].to(device)
img_emb = model.get_image_features(image)
batch["vector"] = img_emb.cpu().tolist()
with Pool() as p:
batch["image_bytes"] = p.map(pil_to_bytes, batch["image"])
return batch
def datagen(args):
"""Generate DiffusionDB dataset, and use CLIP model to generate image embeddings."""
dataset = load_dataset("poloclub/diffusiondb", args.subset)
data = []
for b in dataset.map(
generate_clip_embeddings, batched=True, batch_size=256, remove_columns=["image"]
)["train"]:
b["image"] = b["image_bytes"]
del b["image_bytes"]
data.append(b)
tbl = pa.Table.from_pylist(data, schema=schema)
return tbl
def main():
parser = ArgumentParser()
parser.add_argument(
"-o", "--output", metavar="DIR", help="Output lance directory", required=True
)
parser.add_argument(
"-s",
"--subset",
choices=["2m_all", "2m_first_10k", "2m_first_100k"],
default="2m_first_10k",
help="subset of the hg dataset",
)
args = parser.parse_args()
batches = datagen(args)
lance.write_dataset(batches, args.output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,9 @@
datasets
Pillow
lancedb
isort
black
transformers
--index-url https://download.pytorch.org/whl/cu118
torch
torchvision

View File

@@ -0,0 +1,240 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1.2\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
"\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1.2\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
]
}
],
"source": [
"!pip install --quiet -U lancedb\n",
"!pip install --quiet gradio transformers torch torchvision"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"import io\n",
"import PIL\n",
"import duckdb\n",
"import lancedb"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## First run setup: Download data and pre-process"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<lance.dataset.LanceDataset at 0x3045db590>"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# remove null prompts\n",
"import lance\n",
"import pyarrow.compute as pc\n",
"\n",
"# download s3://eto-public/datasets/diffusiondb/small_10k.lance to this uri\n",
"data = lance.dataset(\"~/datasets/rawdata.lance\").to_table()\n",
"\n",
"# First data processing and full-text-search index\n",
"db = lancedb.connect(\"~/datasets/demo\")\n",
"tbl = db.create_table(\"diffusiondb\", data.filter(~pc.field(\"prompt\").is_null()))\n",
"tbl = tbl.create_fts_index([\"prompt\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create / Open LanceDB Table"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"db = lancedb.connect(\"~/datasets/demo\")\n",
"tbl = db.open_table(\"diffusiondb\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create CLIP embedding function for the text"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
"from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast\n",
"\n",
"MODEL_ID = \"openai/clip-vit-base-patch32\"\n",
"\n",
"tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)\n",
"model = CLIPModel.from_pretrained(MODEL_ID)\n",
"processor = CLIPProcessor.from_pretrained(MODEL_ID)\n",
"\n",
"def embed_func(query):\n",
" inputs = tokenizer([query], padding=True, return_tensors=\"pt\")\n",
" text_features = model.get_text_features(**inputs)\n",
" return text_features.detach().numpy()[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Search functions for Gradio"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
"def find_image_vectors(query):\n",
" emb = embed_func(query)\n",
" return _extract(tbl.search(emb).limit(9).to_df())\n",
"\n",
"def find_image_keywords(query):\n",
" return _extract(tbl.search(query).limit(9).to_df())\n",
"\n",
"def find_image_sql(query):\n",
" diffusiondb = tbl.to_lance()\n",
" return _extract(duckdb.query(query).to_df())\n",
"\n",
"def _extract(df):\n",
" image_col = \"image\"\n",
" return [(PIL.Image.open(io.BytesIO(row[image_col])), row[\"prompt\"]) for _, row in df.iterrows()]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup Gradio interface"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running on local URL: http://127.0.0.1:7867\n",
"\n",
"To create a public link, set `share=True` in `launch()`.\n"
]
},
{
"data": {
"text/html": [
"<div><iframe src=\"http://127.0.0.1:7867/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": []
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import gradio as gr\n",
"\n",
"\n",
"with gr.Blocks() as demo:\n",
"\n",
" with gr.Row():\n",
" with gr.Tab(\"Embeddings\"):\n",
" vector_query = gr.Textbox(value=\"portraits of a person\", show_label=False)\n",
" b1 = gr.Button(\"Submit\")\n",
" with gr.Tab(\"Keywords\"):\n",
" keyword_query = gr.Textbox(value=\"ninja turtle\", show_label=False)\n",
" b2 = gr.Button(\"Submit\")\n",
" with gr.Tab(\"SQL\"):\n",
" sql_query = gr.Textbox(value=\"SELECT * from diffusiondb WHERE image_nsfw >= 2 LIMIT 9\", show_label=False)\n",
" b3 = gr.Button(\"Submit\")\n",
" with gr.Row():\n",
" gallery = gr.Gallery(\n",
" label=\"Found images\", show_label=False, elem_id=\"gallery\"\n",
" ).style(columns=[3], rows=[3], object_fit=\"contain\", height=\"auto\") \n",
" \n",
" b1.click(find_image_vectors, inputs=vector_query, outputs=gallery)\n",
" b2.click(find_image_keywords, inputs=keyword_query, outputs=gallery)\n",
" b3.click(find_image_sql, inputs=sql_query, outputs=gallery)\n",
" \n",
"demo.launch()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
}
},
"nbformat": 4,
"nbformat_minor": 1
}

View File

@@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .db import LanceDBConnection, URI
from .db import URI, LanceDBConnection
def connect(uri: URI) -> LanceDBConnection:

View File

@@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Union, List
from typing import List, Union
import numpy as np
import pandas as pd

View File

@@ -13,11 +13,16 @@
from __future__ import annotations
import os
from pathlib import Path
import pyarrow as pa
import os
from .common import URI, DATA
import pyarrow as pa
from pyarrow import fs
from .common import DATA, URI
from .table import LanceTable
from .util import get_uri_scheme, get_uri_location
class LanceDBConnection:
@@ -26,6 +31,8 @@ class LanceDBConnection:
"""
def __init__(self, uri: URI):
is_local = isinstance(uri, Path) or get_uri_scheme(uri) == "file"
if is_local:
if isinstance(uri, str):
uri = Path(uri)
uri = uri.expanduser().absolute()
@@ -43,7 +50,20 @@ class LanceDBConnection:
-------
A list of table names.
"""
return [p.stem for p in Path(self.uri).glob("*.lance")]
try:
filesystem, path = fs.FileSystem.from_uri(self.uri)
except pa.ArrowInvalid:
raise NotImplementedError(
"Unsupported scheme: " + self.uri
)
try:
paths = filesystem.get_file_info(fs.FileSelector(get_uri_location(self.uri)))
except FileNotFoundError:
# It is ok if the file does not exist since it will be created
paths = []
tables = [os.path.splitext(file_info.base_name)[0] for file_info in paths if file_info.extension == 'lance']
return tables
def __len__(self) -> int:
return len(self.table_names())
@@ -104,3 +124,15 @@ class LanceDBConnection:
A LanceTable object representing the table.
"""
return LanceTable(self, name)
def drop_table(self, name: str):
"""Drop a table from the database.
Parameters
----------
name: str
The name of the table.
"""
filesystem, path = pa.fs.FileSystem.from_uri(self.uri)
table_path = os.path.join(path, name + ".lance")
filesystem.delete_dir(table_path)

View File

@@ -13,14 +13,13 @@
import math
import sys
from retry import retry
from typing import Callable, Union
from lance.vector import vec_to_table
import numpy as np
import pandas as pd
import pyarrow as pa
from lance.vector import vec_to_table
from retry import retry
def with_embeddings(
@@ -68,7 +67,9 @@ class EmbeddingFunction:
if len(self.rate_limiter_kwargs) > 0:
v = int(sys.version_info.minor)
if v >= 11:
print("WARNING: rate limit only support up to 3.10, proceeding without rate limiter")
print(
"WARNING: rate limit only support up to 3.10, proceeding without rate limiter"
)
else:
import ratelimiter

128
python/lancedb/fts.py Normal file
View File

@@ -0,0 +1,128 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Full text search index using tantivy-py"""
import os
from typing import List, Tuple
import pyarrow as pa
try:
import tantivy
except ImportError:
raise ImportError(
"Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature."
)
from .table import LanceTable
def create_index(index_path: str, text_fields: List[str]) -> tantivy.Index:
"""
Create a new Index (not populated)
Parameters
----------
index_path : str
Path to the index directory
text_fields : List[str]
List of text fields to index
Returns
-------
index : tantivy.Index
The index object (not yet populated)
"""
# Declaring our schema.
schema_builder = tantivy.SchemaBuilder()
# special field that we'll populate with row_id
schema_builder.add_integer_field("doc_id", stored=True)
# data fields
for name in text_fields:
schema_builder.add_text_field(name, stored=True)
schema = schema_builder.build()
os.makedirs(index_path, exist_ok=True)
index = tantivy.Index(schema, path=index_path)
return index
def populate_index(index: tantivy.Index, table: LanceTable, fields: List[str]) -> int:
"""
Populate an index with data from a LanceTable
Parameters
----------
index : tantivy.Index
The index object
table : LanceTable
The table to index
fields : List[str]
List of fields to index
"""
# first check the fields exist and are string or large string type
for name in fields:
f = table.schema.field(name) # raises KeyError if not found
if not pa.types.is_string(f.type) and not pa.types.is_large_string(f.type):
raise TypeError(f"Field {name} is not a string type")
# create a tantivy writer
writer = index.writer()
# write data into index
dataset = table.to_lance()
row_id = 0
for b in dataset.to_batches(columns=fields):
for i in range(b.num_rows):
doc = tantivy.Document()
doc.add_integer("doc_id", row_id)
for name in fields:
doc.add_text(name, b[name][i].as_py())
writer.add_document(doc)
row_id += 1
# commit changes
writer.commit()
return row_id
def search_index(
index: tantivy.Index, query: str, limit: int = 10
) -> Tuple[Tuple[int], Tuple[float]]:
"""
Search an index for a query
Parameters
----------
index : tantivy.Index
The index object
query : str
The query string
limit : int
The maximum number of results to return
Returns
-------
ids_and_score: list[tuple[int], tuple[float]]
A tuple of two tuples, the first containing the document ids
and the second containing the scores
"""
searcher = index.searcher()
query = index.parse_query(query)
# get top results
results = searcher.search(query, limit)
return tuple(
zip(
*[
(searcher.doc(doc_address)["doc_id"][0], score)
for score, doc_address in results.hits
]
)
)

View File

@@ -14,6 +14,7 @@ from __future__ import annotations
import numpy as np
import pandas as pd
import pyarrow as pa
from .common import VECTOR_COLUMN_NAME
@@ -131,7 +132,6 @@ class LanceQueryBuilder:
vector and the returned vector.
"""
ds = self._table.to_lance()
# TODO indexed search
tbl = ds.to_table(
columns=self._columns,
filter=self._where,
@@ -145,3 +145,26 @@ class LanceQueryBuilder:
},
)
return tbl.to_pandas()
class LanceFtsQueryBuilder(LanceQueryBuilder):
def to_df(self) -> pd.DataFrame:
try:
import tantivy
except ImportError:
raise ImportError(
"Please install tantivy-py `pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985` to use the full text search feature."
)
from .fts import search_index
# get the index path
index_path = self._table._get_fts_index_path()
# open the index
index = tantivy.Index.open(index_path)
# get the scores and doc ids
row_ids, scores = search_index(index, self._query, self._limit)
scores = pa.array(scores)
output_tbl = self._table.to_lance().take(row_ids, columns=self._columns)
output_tbl = output_tbl.append_column("score", scores)
return output_tbl.to_pandas()

View File

@@ -14,17 +14,20 @@
from __future__ import annotations
import os
import shutil
from functools import cached_property
from typing import List, Union
import lance
import numpy as np
import pandas as pd
from lance import LanceDataset
import pyarrow as pa
from lance import LanceDataset
from lance.vector import vec_to_table
from .query import LanceQueryBuilder
from .common import DATA, VECTOR_COLUMN_NAME, VEC
from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .query import LanceFtsQueryBuilder, LanceQueryBuilder
from .util import get_uri_scheme
def _sanitize_data(data, schema):
@@ -130,6 +133,27 @@ class LanceTable:
)
self._reset_dataset()
def create_fts_index(self, field_names: Union[str, List[str]]):
"""Create a full-text search index on the table.
Warning - this API is highly experimental and is highly likely to change
in the future.
Parameters
----------
field_names: str or list of str
The name(s) of the field to index.
"""
from .fts import create_index, populate_index
if isinstance(field_names, str):
field_names = [field_names]
index = create_index(self._get_fts_index_path(), field_names)
populate_index(index, self, field_names)
def _get_fts_index_path(self):
return os.path.join(self._dataset_uri, "_indices", "tantivy")
@cached_property
def _dataset(self) -> LanceDataset:
return lance.dataset(self._dataset_uri, version=self._version)
@@ -158,7 +182,7 @@ class LanceTable:
self._reset_dataset()
return len(self)
def search(self, query: VEC) -> LanceQueryBuilder:
def search(self, query: Union[VEC, str]) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector.
@@ -174,6 +198,10 @@ class LanceTable:
and also the "score" column which is the distance between the query
vector and the returned vector.
"""
if isinstance(query, str):
# fts
return LanceFtsQueryBuilder(self, query)
if isinstance(query, list):
query = np.array(query)
if isinstance(query, np.ndarray):
@@ -225,8 +253,7 @@ def _sanitize_vector_column(data: pa.Table, vector_column_name: str) -> pa.Table
vector_column_name: str
The name of the vector column.
"""
i = data.column_names.index(vector_column_name)
if i < 0:
if vector_column_name not in data.column_names:
raise ValueError(f"Missing vector column: {vector_column_name}")
vec_arr = data[vector_column_name].combine_chunks()
if pa.types.is_fixed_size_list(vec_arr.type):
@@ -238,4 +265,4 @@ def _sanitize_vector_column(data: pa.Table, vector_column_name: str) -> pa.Table
values = values.cast(pa.float32())
list_size = len(values) / len(data)
vec_arr = pa.FixedSizeListArray.from_arrays(values, list_size)
return data.set_column(i, vector_column_name, vec_arr)
return data.set_column(data.column_names.index(vector_column_name), vector_column_name, vec_arr)

63
python/lancedb/util.py Normal file
View File

@@ -0,0 +1,63 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from urllib.parse import ParseResult, urlparse
from pyarrow import fs
def get_uri_scheme(uri: str) -> str:
"""
Get the scheme of a URI. If the URI does not have a scheme, assume it is a file URI.
Parameters
----------
uri : str
The URI to parse.
Returns
-------
str: The scheme of the URI.
"""
parsed = urlparse(uri)
scheme = parsed.scheme
if not scheme:
scheme = "file"
elif scheme in ["s3a", "s3n"]:
scheme = "s3"
elif len(scheme) == 1:
# Windows drive names are parsed as the scheme
# e.g. "c:\path" -> ParseResult(scheme="c", netloc="", path="/path", ...)
# So we add special handling here for schemes that are a single character
scheme = "file"
return scheme
def get_uri_location(uri: str) -> str:
"""
Get the location of a URI. If the parameter is not a url, assumes it is just a path
Parameters
----------
uri : str
The URI to parse.
Returns
-------
str: Location part of the URL, without scheme
"""
parsed = urlparse(uri)
if not parsed.netloc:
return parsed.path
else:
return parsed.netloc + parsed.path

View File

@@ -1,10 +1,10 @@
[project]
name = "lancedb"
version = "0.1.1"
dependencies = ["pylance>=0.4.4", "ratelimiter", "retry", "tqdm"]
version = "0.1.5"
dependencies = ["pylance>=0.4.17", "ratelimiter", "retry", "tqdm"]
description = "lancedb"
authors = [
{ name = "Lance Devs", email = "dev@eto.ai" },
{ name = "LanceDB Devs", email = "dev@lancedb.com" },
]
license = { file = "LICENSE" }
readme = "README.md"

View File

@@ -11,10 +11,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import lancedb
import pandas as pd
import pytest
import lancedb
def test_basic(tmp_path):
db = lancedb.connect(tmp_path)
@@ -96,3 +97,26 @@ def test_create_mode(tmp_path):
)
tbl = db.create_table("test", data=new_data, mode="overwrite")
assert tbl.to_pandas().item.tolist() == ["fizz", "buzz"]
def test_delete_table(tmp_path):
db = lancedb.connect(tmp_path)
data = pd.DataFrame(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
"item": ["foo", "bar"],
"price": [10.0, 20.0],
}
)
db.create_table("test", data=data)
with pytest.raises(Exception):
db.create_table("test", data=data)
assert db.table_names() == ["test"]
db.drop_table("test")
assert db.table_names() == []
db.create_table("test", data=data)
assert db.table_names() == ["test"]

View File

@@ -14,7 +14,6 @@ import sys
import numpy as np
import pyarrow as pa
from lancedb.embeddings import with_embeddings

84
python/tests/test_fts.py Normal file
View File

@@ -0,0 +1,84 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import lancedb.fts
import numpy as np
import pandas as pd
import pytest
import tantivy
import lancedb as ldb
@pytest.fixture
def table(tmp_path) -> ldb.table.LanceTable:
db = ldb.connect(tmp_path)
vectors = [np.random.randn(128) for _ in range(100)]
nouns = ("puppy", "car", "rabbit", "girl", "monkey")
verbs = ("runs", "hits", "jumps", "drives", "barfs")
adv = ("crazily.", "dutifully.", "foolishly.", "merrily.", "occasionally.")
adj = ("adorable", "clueless", "dirty", "odd", "stupid")
text = [
" ".join(
[
nouns[random.randrange(0, 5)],
verbs[random.randrange(0, 5)],
adv[random.randrange(0, 5)],
adj[random.randrange(0, 5)],
]
)
for _ in range(100)
]
table = db.create_table(
"test", data=pd.DataFrame({"vector": vectors, "text": text, "text2": text})
)
return table
def test_create_index(tmp_path):
index = ldb.fts.create_index(str(tmp_path / "index"), ["text"])
assert isinstance(index, tantivy.Index)
assert os.path.exists(str(tmp_path / "index"))
def test_populate_index(tmp_path, table):
index = ldb.fts.create_index(str(tmp_path / "index"), ["text"])
assert ldb.fts.populate_index(index, table, ["text"]) == len(table)
def test_search_index(tmp_path, table):
index = ldb.fts.create_index(str(tmp_path / "index"), ["text"])
ldb.fts.populate_index(index, table, ["text"])
index.reload()
results = ldb.fts.search_index(index, query="puppy", limit=10)
assert len(results) == 2
assert len(results[0]) == 10 # row_ids
assert len(results[1]) == 10 # scores
def test_create_index_from_table(tmp_path, table):
table.create_fts_index("text")
df = table.search("puppy").limit(10).select(["text"]).to_df()
assert len(df) == 10
assert "text" in df.columns
def test_create_index_multiple_columns(tmp_path, table):
table.create_fts_index(["text", "text2"])
df = table.search("puppy").limit(10).to_df()
assert len(df) == 10
assert "text" in df.columns
assert "text2" in df.columns

49
python/tests/test_io.py Normal file
View File

@@ -0,0 +1,49 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import lancedb
# You need to setup AWS credentials an a base path to run this test. Example
# AWS_PROFILE=default TEST_S3_BASE_URL=s3://my_bucket/dataset pytest tests/test_io.py
@pytest.mark.skipif(
(os.environ.get("TEST_S3_BASE_URL") is None),
reason="please setup s3 base url",
)
def test_s3_io():
db = lancedb.connect(os.environ.get("TEST_S3_BASE_URL"))
assert db.table_names() == []
table = db.create_table(
"test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
],
)
rs = table.search([100, 100]).limit(1).to_df()
assert len(rs) == 1
assert rs["item"].iloc[0] == "bar"
rs = table.search([100, 100]).where("price < 15").limit(2).to_df()
assert len(rs) == 1
assert rs["item"].iloc[0] == "foo"
assert db.table_names() == ["test"]
assert "test" in db
assert len(db) == 1
assert db.open_table("test").name == db["test"].name

View File

@@ -12,14 +12,12 @@
# limitations under the License.
import lance
from lancedb.query import LanceQueryBuilder
import numpy as np
import pandas as pd
import pandas.testing as tm
import pyarrow as pa
import pytest
from lancedb.query import LanceQueryBuilder
class MockTable:

View File

@@ -16,7 +16,6 @@ from pathlib import Path
import pandas as pd
import pyarrow as pa
import pytest
from lancedb.table import LanceTable

30
python/tests/test_util.py Normal file
View File

@@ -0,0 +1,30 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lancedb.util import get_uri_scheme
def test_normalize_uri():
uris = [
"relative/path",
"/absolute/path",
"file:///absolute/path",
"s3://bucket/path",
"gs://bucket/path",
"c:\\windows\\path",
]
schemes = ["file", "file", "file", "s3", "gs", "file"]
for uri, expected_scheme in zip(uris, schemes):
parsed_scheme = get_uri_scheme(uri)
assert parsed_scheme == expected_scheme

21
rust/ffi/node/Cargo.toml Normal file
View File

@@ -0,0 +1,21 @@
[package]
name = "vectordb-node"
version = "0.1.0"
description = "Serverless, low-latency vector database for AI applications"
license = "Apache-2.0"
edition = "2018"
exclude = ["index.node"]
[lib]
crate-type = ["cdylib"]
[dependencies]
arrow-array = "37.0"
arrow-ipc = "37.0"
arrow-schema = "37.0"
once_cell = "1"
futures = "0.3"
lance = "0.4.17"
vectordb = { path = "../../vectordb" }
tokio = { version = "1.23", features = ["rt-multi-thread"] }
neon = {version = "0.10.1", default-features = false, features = ["channel-api", "napi-6", "promise-api", "task-api"] }

3
rust/ffi/node/README.md Normal file
View File

@@ -0,0 +1,3 @@
The LanceDB node bridge (vectordb-node) allows javascript applications to access LanceDB datasets.
It is build using [Neon](https://neon-bindings.com). See the node project for an example of how it is used / tests

View File

@@ -0,0 +1,60 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::io::Cursor;
use std::ops::Deref;
use std::sync::Arc;
use arrow_array::cast::as_list_array;
use arrow_array::{Array, FixedSizeListArray, RecordBatch};
use arrow_ipc::reader::FileReader;
use arrow_schema::{DataType, Field, Schema};
use lance::arrow::{FixedSizeListArrayExt, RecordBatchExt};
pub(crate) fn convert_record_batch(record_batch: RecordBatch) -> RecordBatch {
let column = record_batch
.column_by_name("vector")
.expect("vector column is missing");
let arr = as_list_array(column.deref());
let list_size = arr.values().len() / record_batch.num_rows();
let r = FixedSizeListArray::try_new(arr.values(), list_size as i32).unwrap();
let schema = Arc::new(Schema::new(vec![Field::new(
"vector",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
list_size as i32,
),
true,
)]));
let mut new_batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(r)]).unwrap();
if record_batch.num_columns() > 1 {
let rb = record_batch.drop_column("vector").unwrap();
new_batch = new_batch.merge(&rb).unwrap();
}
new_batch
}
pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Vec<RecordBatch> {
let mut batches: Vec<RecordBatch> = Vec::new();
let fr = FileReader::try_new(Cursor::new(slice), None);
let file_reader = fr.unwrap();
for b in file_reader {
let record_batch = convert_record_batch(b.unwrap());
batches.push(record_batch);
}
batches
}

View File

@@ -0,0 +1,36 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use neon::prelude::*;
pub(crate) fn vec_str_to_array<'a, C: Context<'a>>(
vec: &Vec<String>,
cx: &mut C,
) -> JsResult<'a, JsArray> {
let a = JsArray::new(cx, vec.len() as u32);
for (i, s) in vec.iter().enumerate() {
let v = cx.string(s);
a.set(cx, i as u32, v)?;
}
Ok(a)
}
pub(crate) fn js_array_to_vec(array: &JsArray, cx: &mut FunctionContext) -> Vec<f32> {
let mut query_vec: Vec<f32> = Vec::new();
for i in 0..array.len(cx) {
let entry: Handle<JsNumber> = array.get(cx, i).unwrap();
query_vec.push(entry.value(cx) as f32);
}
query_vec
}

View File

@@ -0,0 +1,15 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod vector;

View File

@@ -0,0 +1,128 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::convert::TryFrom;
use lance::index::vector::ivf::IvfBuildParams;
use lance::index::vector::pq::PQBuildParams;
use lance::index::vector::MetricType;
use neon::context::FunctionContext;
use neon::prelude::*;
use vectordb::index::vector::{IvfPQIndexBuilder, VectorIndexBuilder};
use crate::{runtime, JsTable};
pub(crate) fn table_create_vector_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let index_params = cx.argument::<JsObject>(0)?;
let index_params_builder = get_index_params_builder(&mut cx, index_params).unwrap();
let rt = runtime(&mut cx)?;
let channel = cx.channel();
let (deferred, promise) = cx.promise();
let table = js_table.table.clone();
rt.block_on(async move {
let add_result = table
.lock()
.unwrap()
.create_index(&index_params_builder)
.await;
deferred.settle_with(&channel, move |mut cx| {
add_result
.map(|_| cx.undefined())
.or_else(|err| cx.throw_error(err.to_string()))
});
});
Ok(promise)
}
fn get_index_params_builder(
cx: &mut FunctionContext,
obj: Handle<JsObject>,
) -> Result<impl VectorIndexBuilder, String> {
let idx_type = obj
.get::<JsString, _, _>(cx, "type")
.map_err(|t| t.to_string())?
.value(cx);
match idx_type.as_str() {
"ivf_pq" => {
let mut index_builder: IvfPQIndexBuilder = IvfPQIndexBuilder::new();
let mut pq_params = PQBuildParams::default();
obj.get_opt::<JsString, _, _>(cx, "column")
.map_err(|t| t.to_string())?
.map(|s| index_builder.column(s.value(cx)));
obj.get_opt::<JsString, _, _>(cx, "index_name")
.map_err(|t| t.to_string())?
.map(|s| index_builder.index_name(s.value(cx)));
obj.get_opt::<JsString, _, _>(cx, "metric_type")
.map_err(|t| t.to_string())?
.map(|s| MetricType::try_from(s.value(cx).as_str()))
.map(|mt| {
let metric_type = mt.unwrap();
index_builder.metric_type(metric_type);
pq_params.metric_type = metric_type;
});
let num_partitions = obj
.get_opt::<JsNumber, _, _>(cx, "num_partitions")
.map_err(|t| t.to_string())?
.map(|s| s.value(cx) as usize);
let max_iters = obj
.get_opt::<JsNumber, _, _>(cx, "max_iters")
.map_err(|t| t.to_string())?
.map(|s| s.value(cx) as usize);
num_partitions.map(|np| {
let max_iters = max_iters.unwrap_or(50);
let ivf_params = IvfBuildParams {
num_partitions: np,
max_iters,
};
index_builder.ivf_params(ivf_params)
});
obj.get_opt::<JsBoolean, _, _>(cx, "use_opq")
.map_err(|t| t.to_string())?
.map(|s| pq_params.use_opq = s.value(cx));
obj.get_opt::<JsNumber, _, _>(cx, "num_sub_vectors")
.map_err(|t| t.to_string())?
.map(|s| pq_params.num_sub_vectors = s.value(cx) as usize);
obj.get_opt::<JsNumber, _, _>(cx, "num_bits")
.map_err(|t| t.to_string())?
.map(|s| pq_params.num_bits = s.value(cx) as usize);
obj.get_opt::<JsNumber, _, _>(cx, "max_iters")
.map_err(|t| t.to_string())?
.map(|s| pq_params.max_iters = s.value(cx) as usize);
obj.get_opt::<JsNumber, _, _>(cx, "max_opq_iters")
.map_err(|t| t.to_string())?
.map(|s| pq_params.max_opq_iters = s.value(cx) as usize);
Ok(index_builder)
}
t => Err(format!("{} is not a valid index type", t).to_string()),
}
}

268
rust/ffi/node/src/lib.rs Normal file
View File

@@ -0,0 +1,268 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::convert::TryFrom;
use std::ops::Deref;
use std::sync::{Arc, Mutex};
use arrow_array::{Float32Array, RecordBatchReader};
use arrow_ipc::writer::FileWriter;
use futures::{TryFutureExt, TryStreamExt};
use lance::arrow::RecordBatchBuffer;
use lance::dataset::WriteMode;
use lance::index::vector::MetricType;
use neon::prelude::*;
use neon::types::buffer::TypedArray;
use once_cell::sync::OnceCell;
use tokio::runtime::Runtime;
use vectordb::database::Database;
use vectordb::error::Error;
use vectordb::table::Table;
use crate::arrow::arrow_buffer_to_record_batch;
mod arrow;
mod convert;
mod index;
struct JsDatabase {
database: Arc<Database>,
}
impl Finalize for JsDatabase {}
struct JsTable {
table: Arc<Mutex<Table>>,
}
impl Finalize for JsTable {}
fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> {
static RUNTIME: OnceCell<Runtime> = OnceCell::new();
RUNTIME.get_or_try_init(|| Runtime::new().or_else(|err| cx.throw_error(err.to_string())))
}
fn database_new(mut cx: FunctionContext) -> JsResult<JsPromise> {
let path = cx.argument::<JsString>(0)?.value(&mut cx);
let rt = runtime(&mut cx)?;
let channel = cx.channel();
let (deferred, promise) = cx.promise();
rt.spawn(async move {
let database = Database::connect(&path).await;
deferred.settle_with(&channel, move |mut cx| {
let db = JsDatabase {
database: Arc::new(database.or_else(|err| cx.throw_error(err.to_string()))?),
};
Ok(cx.boxed(db))
});
});
Ok(promise)
}
fn database_table_names(mut cx: FunctionContext) -> JsResult<JsPromise> {
let db = cx
.this()
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise();
let channel = cx.channel();
let database = db.database.clone();
rt.spawn(async move {
let tables_rst = database.table_names().await;
deferred.settle_with(&channel, move |mut cx| {
let tables = tables_rst.or_else(|err| cx.throw_error(err.to_string()))?;
let table_names = convert::vec_str_to_array(&tables, &mut cx);
table_names
});
});
Ok(promise)
}
fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
let db = cx
.this()
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
let rt = runtime(&mut cx)?;
let channel = cx.channel();
let database = db.database.clone();
let (deferred, promise) = cx.promise();
rt.spawn(async move {
let table_rst = database.open_table(&table_name).await;
deferred.settle_with(&channel, move |mut cx| {
let table = Arc::new(Mutex::new(
table_rst.or_else(|err| cx.throw_error(err.to_string()))?,
));
Ok(cx.boxed(JsTable { table }))
});
});
Ok(promise)
}
fn table_search(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let query_obj = cx.argument::<JsObject>(0)?;
let limit = query_obj
.get::<JsNumber, _, _>(&mut cx, "_limit")?
.value(&mut cx);
let filter = query_obj
.get_opt::<JsString, _, _>(&mut cx, "_filter")?
.map(|s| s.value(&mut cx));
let refine_factor = query_obj
.get_opt::<JsNumber, _, _>(&mut cx, "_refineFactor")?
.map(|s| s.value(&mut cx))
.map(|i| i as u32);
let nprobes = query_obj
.get::<JsNumber, _, _>(&mut cx, "_nprobes")?
.value(&mut cx) as usize;
let metric_type = query_obj
.get_opt::<JsString, _, _>(&mut cx, "_metricType")?
.map(|s| s.value(&mut cx))
.map(|s| MetricType::try_from(s.as_str()).unwrap());
let rt = runtime(&mut cx)?;
let channel = cx.channel();
let (deferred, promise) = cx.promise();
let table = js_table.table.clone();
let query_vector = query_obj.get::<JsArray, _, _>(&mut cx, "_queryVector")?;
let query = convert::js_array_to_vec(query_vector.deref(), &mut cx);
rt.spawn(async move {
let builder = table
.lock()
.unwrap()
.search(Float32Array::from(query))
.limit(limit as usize)
.refine_factor(refine_factor)
.nprobes(nprobes)
.filter(filter)
.metric_type(metric_type);
let record_batch_stream = builder.execute();
let results = record_batch_stream
.and_then(|stream| stream.try_collect::<Vec<_>>().map_err(Error::from))
.await;
deferred.settle_with(&channel, move |mut cx| {
let results = results.or_else(|err| cx.throw_error(err.to_string()))?;
let vector: Vec<u8> = Vec::new();
if results.is_empty() {
return cx.buffer(0);
}
let schema = results.get(0).unwrap().schema();
let mut fr = FileWriter::try_new(vector, schema.deref())
.or_else(|err| cx.throw_error(err.to_string()))?;
for batch in results.iter() {
fr.write(batch)
.or_else(|err| cx.throw_error(err.to_string()))?;
}
fr.finish().or_else(|err| cx.throw_error(err.to_string()))?;
let buf = fr
.into_inner()
.or_else(|err| cx.throw_error(err.to_string()))?;
Ok(JsBuffer::external(&mut cx, buf))
});
});
Ok(promise)
}
fn table_create(mut cx: FunctionContext) -> JsResult<JsPromise> {
let db = cx
.this()
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
let buffer = cx.argument::<JsBuffer>(1)?;
let batches = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx));
let rt = runtime(&mut cx)?;
let channel = cx.channel();
let (deferred, promise) = cx.promise();
let database = db.database.clone();
rt.block_on(async move {
let batch_reader: Box<dyn RecordBatchReader> = Box::new(RecordBatchBuffer::new(batches));
let table_rst = database.create_table(&table_name, batch_reader).await;
deferred.settle_with(&channel, move |mut cx| {
let table = Arc::new(Mutex::new(
table_rst.or_else(|err| cx.throw_error(err.to_string()))?,
));
Ok(cx.boxed(JsTable { table }))
});
});
Ok(promise)
}
fn table_add(mut cx: FunctionContext) -> JsResult<JsPromise> {
let write_mode_map: HashMap<&str, WriteMode> = HashMap::from([
("create", WriteMode::Create),
("append", WriteMode::Append),
("overwrite", WriteMode::Overwrite),
]);
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let buffer = cx.argument::<JsBuffer>(0)?;
let write_mode = cx.argument::<JsString>(1)?.value(&mut cx);
let batches = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx));
let rt = runtime(&mut cx)?;
let channel = cx.channel();
let (deferred, promise) = cx.promise();
let table = js_table.table.clone();
let write_mode = write_mode_map.get(write_mode.as_str()).cloned();
rt.block_on(async move {
let batch_reader: Box<dyn RecordBatchReader> = Box::new(RecordBatchBuffer::new(batches));
let add_result = table.lock().unwrap().add(batch_reader, write_mode).await;
deferred.settle_with(&channel, move |mut cx| {
let added = add_result.or_else(|err| cx.throw_error(err.to_string()))?;
Ok(cx.number(added as f64))
});
});
Ok(promise)
}
#[neon::main]
fn main(mut cx: ModuleContext) -> NeonResult<()> {
cx.export_function("databaseNew", database_new)?;
cx.export_function("databaseTableNames", database_table_names)?;
cx.export_function("databaseOpenTable", database_open_table)?;
cx.export_function("tableSearch", table_search)?;
cx.export_function("tableCreate", table_create)?;
cx.export_function("tableAdd", table_add)?;
cx.export_function(
"tableCreateVectorIndex",
index::vector::table_create_vector_index,
)?;
Ok(())
}

22
rust/vectordb/Cargo.toml Normal file
View File

@@ -0,0 +1,22 @@
[package]
name = "vectordb"
version = "0.0.1"
edition = "2021"
description = "Serverless, low-latency vector database for AI applications"
license = "Apache-2.0"
repository = "https://github.com/lancedb/lancedb"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
arrow-array = "37.0"
arrow-data = "37.0"
arrow-schema = "37.0"
object_store = "0.5.6"
snafu = "0.7.4"
lance = "0.4.17"
tokio = { version = "1.23", features = ["rt-multi-thread"] }
[dev-dependencies]
tempfile = "3.5.0"
rand = { version = "0.8.3", features = ["small_rng"] }

View File

@@ -0,0 +1,149 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::fs::create_dir_all;
use std::path::Path;
use arrow_array::RecordBatchReader;
use lance::io::object_store::ObjectStore;
use snafu::prelude::*;
use crate::error::{CreateDirSnafu, Result};
use crate::table::Table;
pub struct Database {
object_store: ObjectStore,
pub(crate) uri: String,
}
const LANCE_EXTENSION: &str = "lance";
/// A connection to LanceDB
impl Database {
/// Connects to LanceDB
///
/// # Arguments
///
/// * `path` - URI where the database is located, can be a local file or a supported remote cloud storage
///
/// # Returns
///
/// * A [Database] object.
pub async fn connect(uri: &str) -> Result<Database> {
let object_store = ObjectStore::new(uri).await?;
if object_store.is_local() {
Self::try_create_dir(uri).context(CreateDirSnafu { path: uri })?;
}
Ok(Database {
uri: uri.to_string(),
object_store,
})
}
/// Try to create a local directory to store the lancedb dataset
fn try_create_dir(path: &str) -> core::result::Result<(), std::io::Error> {
let path = Path::new(path);
if !path.try_exists()? {
create_dir_all(&path)?;
}
Ok(())
}
/// Get the names of all tables in the database.
///
/// # Returns
///
/// * A [Vec<String>] with all table names.
pub async fn table_names(&self) -> Result<Vec<String>> {
let f = self
.object_store
.read_dir("/")
.await?
.iter()
.map(|fname| Path::new(fname))
.filter(|path| {
let is_lance = path
.extension()
.map(|e| e.to_str().map(|e| e == LANCE_EXTENSION))
.flatten();
is_lance.unwrap_or(false)
})
.map(|p| {
p.file_stem()
.map(|s| s.to_str().map(|s| String::from(s)))
.flatten()
})
.flatten()
.collect();
Ok(f)
}
pub async fn create_table(
&self,
name: &str,
batches: Box<dyn RecordBatchReader>,
) -> Result<Table> {
Table::create(&self.uri, name, batches).await
}
/// Open a table in the database.
///
/// # Arguments
/// * `name` - The name of the table.
///
/// # Returns
///
/// * A [Table] object.
pub async fn open_table(&self, name: &str) -> Result<Table> {
Table::open(&self.uri, name).await
}
}
#[cfg(test)]
mod tests {
use std::fs::create_dir_all;
use tempfile::tempdir;
use crate::database::Database;
#[tokio::test]
async fn test_connect() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = Database::connect(uri).await.unwrap();
assert_eq!(db.uri, uri);
}
#[tokio::test]
async fn test_table_names() {
let tmp_dir = tempdir().unwrap();
create_dir_all(tmp_dir.path().join("table1.lance")).unwrap();
create_dir_all(tmp_dir.path().join("table2.lance")).unwrap();
create_dir_all(tmp_dir.path().join("invalidlance")).unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = Database::connect(uri).await.unwrap();
let tables = db.table_names().await.unwrap();
assert_eq!(tables.len(), 2);
assert!(tables.contains(&String::from("table1")));
assert!(tables.contains(&String::from("table2")));
}
#[tokio::test]
async fn test_connect_s3() {
// let db = Database::connect("s3://bucket/path/to/database").await.unwrap();
}
}

View File

@@ -0,0 +1,61 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use snafu::Snafu;
#[derive(Debug, Snafu)]
#[snafu(visibility(pub(crate)))]
pub enum Error {
#[snafu(display("LanceDBError: Invalid table name: {name}"))]
InvalidTableName { name: String },
#[snafu(display("LanceDBError: Table '{name}' was not found"))]
TableNotFound { name: String },
#[snafu(display("LanceDBError: Table '{name}' already exists"))]
TableAlreadyExists { name: String },
#[snafu(display("LanceDBError: Unable to created lance dataset at {path}: {source}"))]
CreateDir {
path: String,
source: std::io::Error,
},
#[snafu(display("LanceDBError: {message}"))]
Store { message: String },
#[snafu(display("LanceDBError: {message}"))]
Lance { message: String },
}
pub type Result<T> = std::result::Result<T, Error>;
impl From<lance::Error> for Error {
fn from(e: lance::Error) -> Self {
Self::Lance {
message: e.to_string(),
}
}
}
impl From<object_store::Error> for Error {
fn from(e: object_store::Error) -> Self {
Self::Store {
message: e.to_string(),
}
}
}
impl From<object_store::path::Error> for Error {
fn from(e: object_store::path::Error) -> Self {
Self::Store {
message: e.to_string(),
}
}
}

View File

@@ -0,0 +1,15 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod vector;

View File

@@ -0,0 +1,163 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use lance::index::vector::ivf::IvfBuildParams;
use lance::index::vector::pq::PQBuildParams;
use lance::index::vector::{MetricType, VectorIndexParams};
pub trait VectorIndexBuilder {
fn get_column(&self) -> Option<String>;
fn get_index_name(&self) -> Option<String>;
fn build(&self) -> VectorIndexParams;
}
pub struct IvfPQIndexBuilder {
column: Option<String>,
index_name: Option<String>,
metric_type: Option<MetricType>,
ivf_params: Option<IvfBuildParams>,
pq_params: Option<PQBuildParams>,
}
impl IvfPQIndexBuilder {
pub fn new() -> IvfPQIndexBuilder {
IvfPQIndexBuilder {
column: None,
index_name: None,
metric_type: None,
ivf_params: None,
pq_params: None,
}
}
}
impl IvfPQIndexBuilder {
pub fn column(&mut self, column: String) -> &mut IvfPQIndexBuilder {
self.column = Some(column);
self
}
pub fn index_name(&mut self, index_name: String) -> &mut IvfPQIndexBuilder {
self.index_name = Some(index_name);
self
}
pub fn metric_type(&mut self, metric_type: MetricType) -> &mut IvfPQIndexBuilder {
self.metric_type = Some(metric_type);
self
}
pub fn ivf_params(&mut self, ivf_params: IvfBuildParams) -> &mut IvfPQIndexBuilder {
self.ivf_params = Some(ivf_params);
self
}
pub fn pq_params(&mut self, pq_params: PQBuildParams) -> &mut IvfPQIndexBuilder {
self.pq_params = Some(pq_params);
self
}
}
impl VectorIndexBuilder for IvfPQIndexBuilder {
fn get_column(&self) -> Option<String> {
self.column.clone()
}
fn get_index_name(&self) -> Option<String> {
self.index_name.clone()
}
fn build(&self) -> VectorIndexParams {
let ivf_params = self.ivf_params.clone().unwrap_or(IvfBuildParams::default());
let pq_params = self.pq_params.clone().unwrap_or(PQBuildParams::default());
VectorIndexParams::with_ivf_pq_params(pq_params.metric_type, ivf_params, pq_params)
}
}
#[cfg(test)]
mod tests {
use lance::index::vector::ivf::IvfBuildParams;
use lance::index::vector::pq::PQBuildParams;
use lance::index::vector::{MetricType, StageParams};
use crate::index::vector::{IvfPQIndexBuilder, VectorIndexBuilder};
#[test]
fn test_builder_no_params() {
let index_builder = IvfPQIndexBuilder::new();
assert!(index_builder.get_column().is_none());
assert!(index_builder.get_index_name().is_none());
let index_params = index_builder.build();
assert_eq!(index_params.stages.len(), 2);
if let StageParams::Ivf(ivf_params) = index_params.stages.get(0).unwrap() {
let default = IvfBuildParams::default();
assert_eq!(ivf_params.num_partitions, default.num_partitions);
assert_eq!(ivf_params.max_iters, default.max_iters);
} else {
panic!("Expected first stage to be ivf")
}
if let StageParams::PQ(pq_params) = index_params.stages.get(1).unwrap() {
assert_eq!(pq_params.use_opq, false);
} else {
panic!("Expected second stage to be pq")
}
}
#[test]
fn test_builder_all_params() {
let mut index_builder = IvfPQIndexBuilder::new();
index_builder
.column("c".to_owned())
.metric_type(MetricType::Cosine)
.index_name("index".to_owned());
assert_eq!(index_builder.column.clone().unwrap(), "c");
assert_eq!(index_builder.metric_type.unwrap(), MetricType::Cosine);
assert_eq!(index_builder.index_name.clone().unwrap(), "index");
let ivf_params = IvfBuildParams::new(500);
let mut pq_params = PQBuildParams::default();
pq_params.use_opq = true;
pq_params.max_iters = 1;
pq_params.num_bits = 8;
pq_params.num_sub_vectors = 50;
pq_params.metric_type = MetricType::Cosine;
pq_params.max_opq_iters = 2;
index_builder.ivf_params(ivf_params);
index_builder.pq_params(pq_params);
let index_params = index_builder.build();
assert_eq!(index_params.stages.len(), 2);
if let StageParams::Ivf(ivf_params) = index_params.stages.get(0).unwrap() {
assert_eq!(ivf_params.num_partitions, 500);
} else {
assert!(false, "Expected first stage to be ivf")
}
if let StageParams::PQ(pq_params) = index_params.stages.get(1).unwrap() {
assert_eq!(pq_params.use_opq, true);
assert_eq!(pq_params.max_iters, 1);
assert_eq!(pq_params.num_bits, 8);
assert_eq!(pq_params.num_sub_vectors, 50);
assert_eq!(pq_params.metric_type, MetricType::Cosine);
assert_eq!(pq_params.max_opq_iters, 2);
} else {
assert!(false, "Expected second stage to be pq")
}
}
}

19
rust/vectordb/src/lib.rs Normal file
View File

@@ -0,0 +1,19 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod database;
pub mod error;
pub mod index;
pub mod query;
pub mod table;

218
rust/vectordb/src/query.rs Normal file
View File

@@ -0,0 +1,218 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use arrow_array::Float32Array;
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
use lance::dataset::Dataset;
use lance::index::vector::MetricType;
use crate::error::Result;
/// A builder for nearest neighbor queries for LanceDB.
pub struct Query {
pub dataset: Arc<Dataset>,
pub query_vector: Float32Array,
pub limit: usize,
pub filter: Option<String>,
pub nprobes: usize,
pub refine_factor: Option<u32>,
pub metric_type: Option<MetricType>,
pub use_index: bool,
}
impl Query {
/// Creates a new Query object
///
/// # Arguments
///
/// * `dataset` - The table / dataset the query will be run against.
/// * `vector` The vector used for this query.
///
/// # Returns
///
/// * A [Query] object.
pub(crate) fn new(dataset: Arc<Dataset>, vector: Float32Array) -> Self {
Query {
dataset,
query_vector: vector,
limit: 10,
nprobes: 20,
refine_factor: None,
metric_type: None,
use_index: false,
filter: None,
}
}
/// Execute the queries and return its results.
///
/// # Returns
///
/// * A [DatasetRecordBatchStream] with the query's results.
pub async fn execute(&self) -> Result<DatasetRecordBatchStream> {
let mut scanner: Scanner = self.dataset.scan();
scanner.nearest(
crate::table::VECTOR_COLUMN_NAME,
&self.query_vector,
self.limit,
)?;
scanner.nprobs(self.nprobes);
scanner.use_index(self.use_index);
self.filter.as_ref().map(|f| scanner.filter(f));
self.refine_factor.map(|rf| scanner.refine(rf));
self.metric_type.map(|mt| scanner.distance_metric(mt));
Ok(scanner.try_into_stream().await?)
}
/// Set the maximum number of results to return.
///
/// # Arguments
///
/// * `limit` - The maximum number of results to return.
pub fn limit(mut self, limit: usize) -> Query {
self.limit = limit;
self
}
/// Set the vector used for this query.
///
/// # Arguments
///
/// * `vector` - The vector that will be used for search.
pub fn query_vector(mut self, query_vector: Float32Array) -> Query {
self.query_vector = query_vector;
self
}
/// Set the number of probes to use.
///
/// # Arguments
///
/// * `nprobes` - The number of probes to use.
pub fn nprobes(mut self, nprobes: usize) -> Query {
self.nprobes = nprobes;
self
}
/// Set the refine factor to use.
///
/// # Arguments
///
/// * `refine_factor` - The refine factor to use.
pub fn refine_factor(mut self, refine_factor: Option<u32>) -> Query {
self.refine_factor = refine_factor;
self
}
/// Set the distance metric to use.
///
/// # Arguments
///
/// * `metric_type` - The distance metric to use. By default [MetricType::L2] is used.
pub fn metric_type(mut self, metric_type: Option<MetricType>) -> Query {
self.metric_type = metric_type;
self
}
/// Whether to use an ANN index if available
///
/// # Arguments
///
/// * `use_index` - Sets Whether to use an ANN index if available
pub fn use_index(mut self, use_index: bool) -> Query {
self.use_index = use_index;
self
}
pub fn filter(mut self, filter: Option<String>) -> Query {
self.filter = filter;
self
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow_array::{Float32Array, RecordBatch, RecordBatchReader};
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
use lance::arrow::RecordBatchBuffer;
use lance::dataset::Dataset;
use lance::index::vector::MetricType;
use crate::query::Query;
#[tokio::test]
async fn test_setters_getters() {
let mut batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
let ds = Dataset::write(&mut batches, ":memory:", None)
.await
.unwrap();
let vector = Float32Array::from_iter_values([0.1, 0.2]);
let query = Query::new(Arc::new(ds), vector.clone());
assert_eq!(query.query_vector, vector);
let new_vector = Float32Array::from_iter_values([9.8, 8.7]);
let query = query
.query_vector(new_vector.clone())
.limit(100)
.nprobes(1000)
.use_index(true)
.metric_type(Some(MetricType::Cosine))
.refine_factor(Some(999));
assert_eq!(query.query_vector, new_vector);
assert_eq!(query.limit, 100);
assert_eq!(query.nprobes, 1000);
assert_eq!(query.use_index, true);
assert_eq!(query.metric_type, Some(MetricType::Cosine));
assert_eq!(query.refine_factor, Some(999));
}
#[tokio::test]
async fn test_execute() {
let mut batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
let ds = Dataset::write(&mut batches, ":memory:", None)
.await
.unwrap();
let vector = Float32Array::from_iter_values([0.1; 128]);
let query = Query::new(Arc::new(ds), vector.clone());
let result = query.execute().await;
assert_eq!(result.is_ok(), true);
}
fn make_test_batches() -> RecordBatchBuffer {
let dim: usize = 128;
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("key", DataType::Int32, false),
ArrowField::new(
"vector",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
dim as i32,
),
true,
),
ArrowField::new("uri", DataType::Utf8, true),
]));
RecordBatchBuffer::new(vec![RecordBatch::new_empty(schema.clone())])
}
}

388
rust/vectordb/src/table.rs Normal file
View File

@@ -0,0 +1,388 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::path::Path;
use std::sync::Arc;
use arrow_array::{Float32Array, RecordBatchReader};
use lance::dataset::{Dataset, WriteMode, WriteParams};
use lance::index::IndexType;
use snafu::prelude::*;
use crate::error::{Error, InvalidTableNameSnafu, Result};
use crate::index::vector::VectorIndexBuilder;
use crate::query::Query;
pub const VECTOR_COLUMN_NAME: &str = "vector";
pub const LANCE_FILE_EXTENSION: &str = "lance";
/// A table in a LanceDB database.
#[derive(Debug)]
pub struct Table {
name: String,
uri: String,
dataset: Arc<Dataset>,
}
impl std::fmt::Display for Table {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Table({})", self.name)
}
}
impl Table {
/// Opens an existing Table
///
/// # Arguments
///
/// * `base_path` - The base path where the table is located
/// * `name` The Table name
///
/// # Returns
///
/// * A [Table] object.
pub async fn open(base_uri: &str, name: &str) -> Result<Self> {
let path = Path::new(base_uri);
let table_uri = path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION));
let uri = table_uri
.as_path()
.to_str()
.context(InvalidTableNameSnafu { name })?;
let dataset = Dataset::open(&uri).await.map_err(|e| match e {
lance::Error::DatasetNotFound { .. } => Error::TableNotFound {
name: name.to_string(),
},
e => Error::Lance {
message: e.to_string(),
},
})?;
Ok(Table {
name: name.to_string(),
uri: uri.to_string(),
dataset: Arc::new(dataset),
})
}
/// Creates a new Table
///
/// # Arguments
///
/// * `base_path` - The base path where the table is located
/// * `name` The Table name
/// * `batches` RecordBatch to be saved in the database
///
/// # Returns
///
/// * A [Table] object.
pub async fn create(
base_uri: &str,
name: &str,
mut batches: Box<dyn RecordBatchReader>,
) -> Result<Self> {
let base_path = Path::new(base_uri);
let table_uri = base_path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION));
let uri = table_uri
.as_path()
.to_str()
.context(InvalidTableNameSnafu { name })?
.to_string();
let dataset = Dataset::write(&mut batches, &uri, Some(WriteParams::default()))
.await
.map_err(|e| match e {
lance::Error::DatasetAlreadyExists { .. } => Error::TableAlreadyExists {
name: name.to_string(),
},
e => Error::Lance {
message: e.to_string(),
},
})?;
Ok(Table {
name: name.to_string(),
uri,
dataset: Arc::new(dataset),
})
}
/// Create index on the table.
pub async fn create_index(&mut self, index_builder: &impl VectorIndexBuilder) -> Result<()> {
use lance::index::DatasetIndexExt;
let dataset = self
.dataset
.create_index(
&[index_builder
.get_column()
.unwrap_or(VECTOR_COLUMN_NAME.to_string())
.as_str()],
IndexType::Vector,
index_builder.get_index_name(),
&index_builder.build(),
)
.await?;
self.dataset = Arc::new(dataset);
Ok(())
}
/// Insert records into this Table
///
/// # Arguments
///
/// * `batches` RecordBatch to be saved in the Table
/// * `write_mode` Append / Overwrite existing records. Default: Append
/// # Returns
///
/// * The number of rows added
pub async fn add(
&mut self,
mut batches: Box<dyn RecordBatchReader>,
write_mode: Option<WriteMode>,
) -> Result<usize> {
let mut params = WriteParams::default();
params.mode = write_mode.unwrap_or(WriteMode::Append);
self.dataset = Arc::new(Dataset::write(&mut batches, &self.uri, Some(params)).await?);
Ok(batches.count())
}
/// Creates a new Query object that can be executed.
///
/// # Arguments
///
/// * `vector` The vector used for this query.
///
/// # Returns
///
/// * A [Query] object.
pub fn search(&self, query_vector: Float32Array) -> Query {
Query::new(self.dataset.clone(), query_vector)
}
/// Returns the number of rows in this Table
pub async fn count_rows(&self) -> Result<usize> {
Ok(self.dataset.count_rows().await?)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow_array::{
Array, FixedSizeListArray, Float32Array, Int32Array, RecordBatch, RecordBatchReader,
};
use arrow_data::ArrayDataBuilder;
use arrow_schema::{DataType, Field, Schema};
use lance::arrow::RecordBatchBuffer;
use lance::dataset::{Dataset, WriteMode};
use lance::index::vector::ivf::IvfBuildParams;
use lance::index::vector::pq::PQBuildParams;
use rand::Rng;
use tempfile::tempdir;
use super::*;
use crate::index::vector::IvfPQIndexBuilder;
#[tokio::test]
async fn test_open() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path().join("test.lance");
let uri = tmp_dir.path().to_str().unwrap();
let mut batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
Dataset::write(&mut batches, dataset_path.to_str().unwrap(), None)
.await
.unwrap();
let table = Table::open(uri, "test").await.unwrap();
assert_eq!(table.name, "test")
}
#[tokio::test]
async fn test_open_not_found() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let table = Table::open(uri, "test").await;
assert!(matches!(table.unwrap_err(), Error::TableNotFound { .. }));
}
#[test]
fn test_object_store_path() {
use std::path::Path as StdPath;
let p = StdPath::new("s3://bucket/path/to/file");
let c = p.join("subfile");
assert_eq!(c.to_str().unwrap(), "s3://bucket/path/to/file/subfile");
}
#[tokio::test]
async fn test_create_already_exists() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
let schema = batches.schema().clone();
Table::create(&uri, "test", batches).await.unwrap();
let batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
let result = Table::create(&uri, "test", batches).await;
assert!(matches!(
result.unwrap_err(),
Error::TableAlreadyExists { .. }
));
}
#[tokio::test]
async fn test_add() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
let schema = batches.schema().clone();
let mut table = Table::create(&uri, "test", batches).await.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
let new_batches: Box<dyn RecordBatchReader> =
Box::new(RecordBatchBuffer::new(vec![RecordBatch::try_new(
schema,
vec![Arc::new(Int32Array::from_iter_values(100..110))],
)
.unwrap()]));
table.add(new_batches, None).await.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 20);
assert_eq!(table.name, "test");
}
#[tokio::test]
async fn test_add_overwrite() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
let schema = batches.schema().clone();
let mut table = Table::create(uri, "test", batches).await.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
let new_batches: Box<dyn RecordBatchReader> =
Box::new(RecordBatchBuffer::new(vec![RecordBatch::try_new(
schema,
vec![Arc::new(Int32Array::from_iter_values(100..110))],
)
.unwrap()]));
table
.add(new_batches, Some(WriteMode::Overwrite))
.await
.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
assert_eq!(table.name, "test");
}
#[tokio::test]
async fn test_search() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path().join("test.lance");
let uri = tmp_dir.path().to_str().unwrap();
let mut batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
Dataset::write(&mut batches, dataset_path.to_str().unwrap(), None)
.await
.unwrap();
let table = Table::open(uri, "test").await.unwrap();
let vector = Float32Array::from_iter_values([0.1, 0.2]);
let query = table.search(vector.clone());
assert_eq!(vector, query.query_vector);
}
fn make_test_batches() -> RecordBatchBuffer {
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
RecordBatchBuffer::new(vec![RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..10))],
)
.unwrap()])
}
#[tokio::test]
async fn test_create_index() {
use arrow_array::RecordBatch;
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use rand;
use std::iter::repeat_with;
use arrow_array::Float32Array;
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let dimension = 16;
let schema = Arc::new(ArrowSchema::new(vec![Field::new(
"embeddings",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
dimension,
),
false,
)]));
let mut rng = rand::thread_rng();
let float_arr = Float32Array::from(
repeat_with(|| rng.gen::<f32>())
.take(512 * dimension as usize)
.collect::<Vec<f32>>(),
);
let vectors = Arc::new(create_fixed_size_list(float_arr, dimension).unwrap());
let batches = RecordBatchBuffer::new(vec![RecordBatch::try_new(
schema.clone(),
vec![vectors.clone()],
)
.unwrap()]);
let reader: Box<dyn RecordBatchReader + Send> = Box::new(batches);
let mut table = Table::create(uri, "test", reader).await.unwrap();
let mut i = IvfPQIndexBuilder::new();
let index_builder = i
.column("embeddings".to_string())
.index_name("my_index".to_string())
.ivf_params(IvfBuildParams::new(256))
.pq_params(PQBuildParams::default());
table.create_index(index_builder).await.unwrap();
assert_eq!(table.dataset.load_indices().await.unwrap().len(), 1);
assert_eq!(table.count_rows().await.unwrap(), 512);
assert_eq!(table.name, "test");
}
fn create_fixed_size_list<T: Array>(values: T, list_size: i32) -> Result<FixedSizeListArray> {
let list_type = DataType::FixedSizeList(
Arc::new(Field::new("item", values.data_type().clone(), true)),
list_size,
);
let data = ArrayDataBuilder::new(list_type)
.len(values.len() / list_size as usize)
.add_child_data(values.into_data())
.build()
.unwrap();
Ok(FixedSizeListArray::from(data))
}
}