Compare commits

...

44 Commits

Author SHA1 Message Date
Lance Release
ce2242e06d [python] Bump version: 0.5.1 → 0.5.2 2024-02-02 21:33:02 +00:00
Weston Pace
778339388a chore: bump pylance version to latest in pyproject.toml (#918) 2024-02-02 13:32:12 -08:00
Weston Pace
7f8637a0b4 feat: add merge_insert to the node and rust APIs (#915) 2024-02-02 13:16:51 -08:00
QianZhu
09cd08222d make it explicit about the vector column data type (#916)
<img width="837" alt="Screenshot 2024-02-01 at 4 23 34 PM"
src="https://github.com/lancedb/lancedb/assets/1305083/4f0f5c5a-2a24-4b00-aad1-ef80a593d964">
[
<img width="838" alt="Screenshot 2024-02-01 at 4 26 03 PM"
src="https://github.com/lancedb/lancedb/assets/1305083/ca073bc8-b518-4be3-811d-8a7184416f07">
](url)

---------

Co-authored-by: Weston Pace <weston.pace@gmail.com>
2024-02-02 09:02:02 -08:00
Bert
a248d7feec fix: add request retry to python client (#917)
Adds capability to the remote python SDK to retry requests (fixes #911)

This can be configured through environment:
- `LANCE_CLIENT_MAX_RETRIES`= total number of retries. Set to 0 to
disable retries. default = 3
- `LANCE_CLIENT_CONNECT_RETRIES` = number of times to retry request in
case of TCP connect failure. default = 3
- `LANCE_CLIENT_READ_RETRIES` = number of times to retry request in case
of HTTP request failure. default = 3
- `LANCE_CLIENT_RETRY_STATUSES` = http statuses for which the request
will be retried. passed as comma separated list of ints. default `500,
502, 503`
- `LANCE_CLIENT_RETRY_BACKOFF_FACTOR` = controls time between retry
requests. see
[here](23f2287eb5/src/urllib3/util/retry.py (L141-L146)).
default = 0.25

Only read requests will be retried:
- list table names
- query
- describe table
- list table indices

This does not add retry capabilities for writes as it could possibly
cause issues in the case where the retried write isn't idempotent. For
example, in the case where the LB times-out the request but the server
completes the request anyway, we might not want to blindly retry an
insert request.
2024-02-02 11:27:29 -05:00
Weston Pace
cc9473a94a docs: add cleanup_old_versions and compact_files to Table for documentation purposes (#900)
Closes #819
2024-02-01 15:06:00 -08:00
Weston Pace
d77e95a4f4 feat: upgrade to lance 0.9.11 and expose merge_insert (#906)
This adds the python bindings requested in #870 The javascript/rust
bindings will be added in a future PR.
2024-02-01 11:36:29 -08:00
Lei Xu
62f053ac92 ci: bump to new version of python action to use node 20 gIthub action runtime (#909)
Github action is deprecating old node-16 runtime.
2024-02-01 11:36:03 -08:00
JacobLinCool
34e10caad2 fix the repo link on npm, add links for homepage and bug report (#910)
- fix the repo link on npm
- add links for homepage and bug report
2024-01-31 21:07:11 -08:00
QianZhu
f5726e2d0c arrow table/f16 example (#907) 2024-01-31 14:41:28 -08:00
Lance Release
12b4fb42fc Updating package-lock.json 2024-01-31 21:18:24 +00:00
Lance Release
1328cd46f1 Updating package-lock.json 2024-01-31 20:29:38 +00:00
Lance Release
0c940ed9f8 Bump version: 0.4.6 → 0.4.7 2024-01-31 20:29:28 +00:00
Lei Xu
5f59e51583 fix(node): pass AWS credentials to db level operations (#908)
Passed the following tests

```ts
const keyId = process.env.AWS_ACCESS_KEY_ID;
const secretKey = process.env.AWS_SECRET_ACCESS_KEY;
const sessionToken = process.env.AWS_SESSION_TOKEN;
const region = process.env.AWS_REGION;

const db = await lancedb.connect({
  uri: "s3://bucket/path",
  awsCredentials: {
    accessKeyId: keyId,
    secretKey: secretKey,
    sessionToken: sessionToken,
  },
  awsRegion: region,
} as lancedb.ConnectionOptions);

  console.log(await db.createTable("test", [{ vector: [1, 2, 3] }]));
  console.log(await db.tableNames());
  console.log(await db.dropTable("test"))
```
2024-01-31 12:05:01 -08:00
Will Jones
8d0ea29f89 docs: provide AWS S3 cleanup and permissions advice (#903)
Adding some more quick advice for how to setup AWS S3 with LanceDB.

---------

Co-authored-by: Prashanth Rao <35005448+prrao87@users.noreply.github.com>
2024-01-31 09:24:54 -08:00
Abraham Lopez
b9468bb980 chore: update JS/TS example in README (#898)
- The JS/TS library actually expects named parameters via an object in
`.createTable()` rather than individual arguments
- Added example on how to search rows by criteria without a vector
search. TS type of `.search()` currently has the `query` parameter as
non-optional so we have to pass undefined for now.
2024-01-30 11:09:45 -08:00
Lei Xu
a42df158a3 ci: change apple silicon runner to free OSS macos-14 target (#901) 2024-01-30 11:05:42 -08:00
Raghav Dixit
9df6905d86 chore(python): GTE embedding function model name update (#902)
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
2024-01-30 23:56:29 +05:30
Ayush Chaurasia
3ffed89793 feat(python): Hybrid search & Reranker API (#824)
based on https://github.com/lancedb/lancedb/pull/713
- The Reranker api can be plugged into vector only or fts only search
but this PR doesn't do that (see example -
https://txt.cohere.com/rerank/)


### Default reranker -- `LinearCombinationReranker(weight=0.7,
fill=1.0)`

```
table.search("hello", query_type="hybrid").rerank(normalize="score").to_pandas()
```
### Available rerankers
LinearCombinationReranker
```
from lancedb.rerankers import LinearCombinationReranker

# Same as default 
table.search("hello", query_type="hybrid").rerank(
                                      normalize="score", 
                                      reranker=LinearCombinationReranker()
                                     ).to_pandas()

# with custom params
reranker = LinearCombinationReranker(weight=0.3, fill=1.0)
table.search("hello", query_type="hybrid").rerank(
                                      normalize="score", 
                                      reranker=reranker
                                     ).to_pandas()
```

Cohere Reranker
```
from lancedb.rerankers import CohereReranker

# default model.. English and multi-lingual supported. See docstring for available custom params
table.search("hello", query_type="hybrid").rerank(
                                      normalize="rank",  # score or rank
                                      reranker=CohereReranker()
                                     ).to_pandas()

```

CrossEncoderReranker

```
from lancedb.rerankers import CrossEncoderReranker

table.search("hello", query_type="hybrid").rerank(
                                      normalize="rank", 
                                      reranker=CrossEncoderReranker()
                                     ).to_pandas()

```

## Using custom Reranker
```
from lancedb.reranker import Reranker

class CustomReranker(Reranker):
    def rerank_hybrid(self, vector_result, fts_result):
           combined_res = self.merge_results(vector_results, fts_results) # or use custom combination logic
           # Custom rerank logic here
           
           return combined_res
```

- [x] Expand testing
- [x] Make sure usage makes sense
- [x] Run simple benchmarks for correctness (Seeing weird result from
cohere reranker in the toy example)
- Support diverse rerankers by default:
- [x] Cross encoding
- [x] Cohere
- [x] Reciprocal Rank Fusion

---------

Co-authored-by: Chang She <759245+changhiskhan@users.noreply.github.com>
Co-authored-by: Prashanth Rao <35005448+prrao87@users.noreply.github.com>
2024-01-30 19:10:33 +05:30
Prashanth Rao
f150768739 Fix image bgcolor (#891)
Minor fix to change the background color for an image in the docs. It's
now readable in both light and dark modes (earlier version made it
impossible to read in dark mode).
2024-01-30 16:50:29 +05:30
Ayush Chaurasia
b432ecf2f6 doc: Add documentation chatbot for LanceDB (#890)
<img width="1258" alt="Screenshot 2024-01-29 at 10 05 52 PM"
src="https://github.com/lancedb/lancedb/assets/15766192/7c108fde-e993-415c-ad01-72010fd5fe31">
2024-01-30 11:24:57 +05:30
Raghav Dixit
d1a7257810 feat(python): Embedding fn support for gte-mlx/gte-large (#873)
have added testing and an example in the docstring, will be pushing a
separate PR in recipe repo for rag example

---------

Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
2024-01-30 11:21:57 +05:30
Ayush Chaurasia
5c5e23bbb9 chore(python): Temporarily extend remote connection timeout (#888)
Context - https://etoai.slack.com/archives/C05NC5YSW5V/p1706371205883149
2024-01-29 17:34:33 +05:30
Lei Xu
e5796a4836 doc: fix js example of create index (#886) 2024-01-28 17:02:36 -08:00
Lei Xu
b9c5323265 doc: use snippet for rust code example and make sure rust examples run through CI (#885) 2024-01-28 14:30:30 -08:00
Lei Xu
e41a52863a fix: fix doc build to include the source snippet correctly (#883) 2024-01-28 11:55:58 -08:00
Chang She
13acc8a480 doc(rust): minor fixes for Rust quick start. (#878) 2024-01-28 11:40:52 -08:00
Lei Xu
22b9eceb12 chore: convert all js doc test to use snippet. (#881) 2024-01-28 11:39:25 -08:00
Lei Xu
5f62302614 doc: use code snippet for typescript examples (#880)
The typescript code is in a fully function file, that will be run via the CI.
2024-01-27 22:52:37 -08:00
Ayush Chaurasia
d84e0d1db8 feat(python): Aws Bedrock embeddings integration (#822)
Supports amazon titan, cohere english & cohere multi-lingual base
models.
2024-01-28 02:04:15 +05:30
Lei Xu
ac94b2a420 chore: upgrade lance, pylance and datafusion (#879) 2024-01-27 12:31:38 -08:00
Lei Xu
b49bc113c4 chore: add one rust SDK e2e example (#876)
Co-authored-by: Chang She <759245+changhiskhan@users.noreply.github.com>
2024-01-26 22:41:20 -08:00
Lei Xu
77b5b1cf0e doc: update quick start for full rust example (#872) 2024-01-26 16:19:43 -08:00
Lei Xu
e910809de0 chore: bump github actions to v4 due to GHA warnings of node version deprecation (#874) 2024-01-26 15:52:47 -08:00
Lance Release
90b5b55126 Updating package-lock.json 2024-01-26 23:35:58 +00:00
Lance Release
488e4f8452 Updating package-lock.json 2024-01-26 22:40:46 +00:00
Lance Release
ba6f949515 Bump version: 0.4.5 → 0.4.6 2024-01-26 22:40:36 +00:00
Lei Xu
3dd8522bc9 feat(rust): provide connect and connect_with_options in Rust SDK (#871)
* Bring the feature parity of Rust connect methods.
* A global connect method that can connect to local and remote / cloud
table, as the same as in js/python today.
2024-01-26 11:40:11 -08:00
Lei Xu
e01ef63488 chore(rust): simplified version of optimize (#869)
Consolidate various optimize() into one method, similar to postgres
VACCUM in the process of preparing Rust API for public use
2024-01-26 11:36:04 -08:00
Lei Xu
a6cf24b359 feat(napi): Issue queries as node SDK (#868)
* Query as a fluent API and `AsyncIterator<RecordBatch>`
* Much more docs
* Add tests for auto infer vector search columns with different
dimensions.
2024-01-25 22:14:14 -08:00
Lance Release
9a07c9aad8 Updating package-lock.json 2024-01-25 21:49:36 +00:00
Lance Release
d405798952 Updating package-lock.json 2024-01-25 20:54:55 +00:00
Lance Release
e8a8b92b2a Bump version: 0.4.4 → 0.4.5 2024-01-25 20:54:44 +00:00
Lei Xu
66362c6506 fix: release build for node sdk (#861) 2024-01-25 12:51:32 -08:00
94 changed files with 4665 additions and 703 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 0.4.4 current_version = 0.4.7
commit = True commit = True
message = Bump version: {current_version} → {new_version} message = Bump version: {current_version} → {new_version}
tag = True tag = True

View File

@@ -16,7 +16,7 @@ jobs:
# Only runs on tags that matches the make-release action # Only runs on tags that matches the make-release action
if: startsWith(github.ref, 'refs/tags/v') if: startsWith(github.ref, 'refs/tags/v')
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- uses: Swatinem/rust-cache@v2 - uses: Swatinem/rust-cache@v2
with: with:
workspaces: rust workspaces: rust

View File

@@ -27,9 +27,9 @@ jobs:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v5
with: with:
python-version: "3.10" python-version: "3.10"
cache: "pip" cache: "pip"
@@ -42,7 +42,7 @@ jobs:
- name: Set up node - name: Set up node
uses: actions/setup-node@v3 uses: actions/setup-node@v3
with: with:
node-version: ${{ matrix.node-version }} node-version: 20
cache: 'npm' cache: 'npm'
cache-dependency-path: node/package-lock.json cache-dependency-path: node/package-lock.json
- uses: Swatinem/rust-cache@v2 - uses: Swatinem/rust-cache@v2
@@ -62,8 +62,9 @@ jobs:
run: | run: |
npx typedoc --plugin typedoc-plugin-markdown --out ../docs/src/javascript src/index.ts npx typedoc --plugin typedoc-plugin-markdown --out ../docs/src/javascript src/index.ts
- name: Build docs - name: Build docs
working-directory: docs
run: | run: |
PYTHONPATH=. mkdocs build -f docs/mkdocs.yml PYTHONPATH=. mkdocs build
- name: Setup Pages - name: Setup Pages
uses: actions/configure-pages@v2 uses: actions/configure-pages@v2
- name: Upload artifact - name: Upload artifact
@@ -72,4 +73,4 @@ jobs:
path: "docs/site" path: "docs/site"
- name: Deploy to GitHub Pages - name: Deploy to GitHub Pages
id: deployment id: deployment
uses: actions/deploy-pages@v1 uses: actions/deploy-pages@v1

View File

@@ -18,24 +18,20 @@ on:
env: env:
# Disable full debug symbol generation to speed up CI build and keep memory down # Disable full debug symbol generation to speed up CI build and keep memory down
# "1" means line tables only, which is useful for panic tracebacks. # "1" means line tables only, which is useful for panic tracebacks.
RUSTFLAGS: "-C debuginfo=1" RUSTFLAGS: "-C debuginfo=1 -C target-cpu=native -C target-feature=+f16c,+avx2,+fma"
RUST_BACKTRACE: "1" RUST_BACKTRACE: "1"
jobs: jobs:
test-python: test-python:
name: Test doc python code name: Test doc python code
runs-on: ${{ matrix.os }} runs-on: "ubuntu-latest"
strategy:
matrix:
python-minor-version: [ "11" ]
os: ["ubuntu-22.04"]
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v5
with: with:
python-version: 3.${{ matrix.python-minor-version }} python-version: 3.11
cache: "pip" cache: "pip"
cache-dependency-path: "docs/test/requirements.txt" cache-dependency-path: "docs/test/requirements.txt"
- name: Build Python - name: Build Python
@@ -52,45 +48,33 @@ jobs:
for d in *; do cd "$d"; echo "$d".py; python "$d".py; cd ..; done for d in *; do cd "$d"; echo "$d".py; python "$d".py; cd ..; done
test-node: test-node:
name: Test doc nodejs code name: Test doc nodejs code
runs-on: ${{ matrix.os }} runs-on: "ubuntu-latest"
strategy:
matrix:
node-version: [ "18" ]
os: ["ubuntu-22.04"]
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
lfs: true lfs: true
- name: Set up Node - name: Set up Node
uses: actions/setup-node@v3 uses: actions/setup-node@v4
with: with:
node-version: ${{ matrix.node-version }} node-version: 20
- name: Install dependecies needed for ubuntu - name: Install dependecies needed for ubuntu
if: ${{ matrix.os == 'ubuntu-22.04' }}
run: | run: |
sudo apt install -y protobuf-compiler libssl-dev sudo apt install -y protobuf-compiler libssl-dev
- name: Install node dependencies
run: |
cd docs/test
npm install
- name: Rust cache - name: Rust cache
uses: swatinem/rust-cache@v2 uses: swatinem/rust-cache@v2
- name: Install LanceDB - name: Install node dependencies
run: | run: |
cd docs/test/node_modules/vectordb cd node
npm ci npm ci
npm run build-release npm run build-release
npm run tsc cd ../docs
- name: Create test files npm install
run: |
cd docs/test
node md_testing.js
- name: Test - name: Test
env: env:
LANCEDB_URI: ${{ secrets.LANCEDB_URI }} LANCEDB_URI: ${{ secrets.LANCEDB_URI }}
LANCEDB_DEV_API_KEY: ${{ secrets.LANCEDB_DEV_API_KEY }} LANCEDB_DEV_API_KEY: ${{ secrets.LANCEDB_DEV_API_KEY }}
run: | run: |
cd docs/test/node cd docs
for d in *; do cd "$d"; echo "$d".js; node "$d".js; cd ..; done npm t

View File

@@ -26,7 +26,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Check out main - name: Check out main
uses: actions/checkout@v3 uses: actions/checkout@v4
with: with:
ref: main ref: main
persist-credentials: false persist-credentials: false
@@ -37,10 +37,10 @@ jobs:
run: | run: |
git config user.name 'Lance Release' git config user.name 'Lance Release'
git config user.email 'lance-dev@lancedb.com' git config user.email 'lance-dev@lancedb.com'
- name: Set up Python 3.10 - name: Set up Python 3.11
uses: actions/setup-python@v4 uses: actions/setup-python@v5
with: with:
python-version: "3.10" python-version: "3.11"
- name: Bump version, create tag and commit - name: Bump version, create tag and commit
run: | run: |
pip install bump2version pip install bump2version

View File

@@ -32,7 +32,7 @@ jobs:
shell: bash shell: bash
working-directory: node working-directory: node
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
lfs: true lfs: true
@@ -57,7 +57,7 @@ jobs:
shell: bash shell: bash
working-directory: node working-directory: node
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
lfs: true lfs: true
@@ -89,7 +89,7 @@ jobs:
shell: bash shell: bash
working-directory: node working-directory: node
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
lfs: true lfs: true
@@ -128,7 +128,7 @@ jobs:
# this one is for dynamodb # this one is for dynamodb
DYNAMODB_ENDPOINT: http://localhost:4566 DYNAMODB_ENDPOINT: http://localhost:4566
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
lfs: true lfs: true

View File

@@ -29,7 +29,7 @@ jobs:
shell: bash shell: bash
working-directory: nodejs working-directory: nodejs
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
lfs: true lfs: true
@@ -61,7 +61,7 @@ jobs:
shell: bash shell: bash
working-directory: nodejs working-directory: nodejs
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
lfs: true lfs: true
@@ -84,13 +84,13 @@ jobs:
run: npm run test run: npm run test
macos: macos:
timeout-minutes: 30 timeout-minutes: 30
runs-on: "macos-13" runs-on: "macos-14"
defaults: defaults:
run: run:
shell: bash shell: bash
working-directory: nodejs working-directory: nodejs
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
lfs: true lfs: true

View File

@@ -15,7 +15,7 @@ jobs:
working-directory: node working-directory: node
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v4
- uses: actions/setup-node@v3 - uses: actions/setup-node@v3
with: with:
node-version: 20 node-version: 20
@@ -45,13 +45,13 @@ jobs:
runner: macos-13 runner: macos-13
- arch: aarch64-apple-darwin - arch: aarch64-apple-darwin
# xlarge is implicitly arm64. # xlarge is implicitly arm64.
runner: macos-13-xlarge runner: macos-14
runs-on: ${{ matrix.config.runner }} runs-on: ${{ matrix.config.runner }}
# Only runs on tags that matches the make-release action # Only runs on tags that matches the make-release action
if: startsWith(github.ref, 'refs/tags/v') if: startsWith(github.ref, 'refs/tags/v')
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install system dependencies - name: Install system dependencies
run: brew install protobuf run: brew install protobuf
- name: Install npm dependencies - name: Install npm dependencies
@@ -66,7 +66,7 @@ jobs:
name: native-darwin name: native-darwin
path: | path: |
node/dist/lancedb-vectordb-darwin*.tgz node/dist/lancedb-vectordb-darwin*.tgz
node-linux: node-linux:
name: node-linux (${{ matrix.config.arch}}-unknown-linux-gnu name: node-linux (${{ matrix.config.arch}}-unknown-linux-gnu
@@ -83,7 +83,7 @@ jobs:
runner: buildjet-4vcpu-ubuntu-2204-arm runner: buildjet-4vcpu-ubuntu-2204-arm
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Build Linux Artifacts - name: Build Linux Artifacts
run: | run: |
bash ci/build_linux_artifacts.sh ${{ matrix.config.arch }} bash ci/build_linux_artifacts.sh ${{ matrix.config.arch }}
@@ -104,7 +104,7 @@ jobs:
target: [x86_64-pc-windows-msvc] target: [x86_64-pc-windows-msvc]
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Install Protoc v21.12 - name: Install Protoc v21.12
working-directory: C:\ working-directory: C:\
run: | run: |
@@ -154,7 +154,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v4
with: with:
ref: main ref: main
persist-credentials: false persist-credentials: false

View File

@@ -14,9 +14,9 @@ jobs:
shell: bash shell: bash
working-directory: python working-directory: python
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v5
with: with:
python-version: "3.8" python-version: "3.8"
- name: Build distribution - name: Build distribution

View File

@@ -26,7 +26,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Check out main - name: Check out main
uses: actions/checkout@v3 uses: actions/checkout@v4
with: with:
ref: main ref: main
persist-credentials: false persist-credentials: false
@@ -37,10 +37,10 @@ jobs:
run: | run: |
git config user.name 'Lance Release' git config user.name 'Lance Release'
git config user.email 'lance-dev@lancedb.com' git config user.email 'lance-dev@lancedb.com'
- name: Set up Python 3.10 - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v5
with: with:
python-version: "3.10" python-version: "3.11"
- name: Bump version, create tag and commit - name: Bump version, create tag and commit
working-directory: python working-directory: python
run: | run: |

View File

@@ -18,19 +18,19 @@ jobs:
timeout-minutes: 30 timeout-minutes: 30
strategy: strategy:
matrix: matrix:
python-minor-version: [ "8", "9", "10", "11" ] python-minor-version: [ "8", "11" ]
runs-on: "ubuntu-22.04" runs-on: "ubuntu-22.04"
defaults: defaults:
run: run:
shell: bash shell: bash
working-directory: python working-directory: python
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
lfs: true lfs: true
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v5
with: with:
python-version: 3.${{ matrix.python-minor-version }} python-version: 3.${{ matrix.python-minor-version }}
- name: Install lancedb - name: Install lancedb
@@ -55,7 +55,7 @@ jobs:
- name: x86 Mac - name: x86 Mac
runner: macos-13 runner: macos-13
- name: Arm Mac - name: Arm Mac
runner: macos-13-xlarge runner: macos-14
- name: x86 Windows - name: x86 Windows
runner: windows-latest runner: windows-latest
runs-on: "${{ matrix.config.runner }}" runs-on: "${{ matrix.config.runner }}"
@@ -64,12 +64,12 @@ jobs:
shell: bash shell: bash
working-directory: python working-directory: python
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
lfs: true lfs: true
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v5
with: with:
python-version: "3.11" python-version: "3.11"
- name: Install lancedb - name: Install lancedb
@@ -87,12 +87,12 @@ jobs:
shell: bash shell: bash
working-directory: python working-directory: python
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
lfs: true lfs: true
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 uses: actions/setup-python@v5
with: with:
python-version: 3.9 python-version: 3.9
- name: Install lancedb - name: Install lancedb

View File

@@ -32,7 +32,7 @@ jobs:
shell: bash shell: bash
working-directory: rust working-directory: rust
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
lfs: true lfs: true
@@ -55,7 +55,7 @@ jobs:
shell: bash shell: bash
working-directory: rust working-directory: rust
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
lfs: true lfs: true
@@ -70,18 +70,20 @@ jobs:
run: cargo build --all-features run: cargo build --all-features
- name: Run tests - name: Run tests
run: cargo test --all-features run: cargo test --all-features
- name: Run examples
run: cargo run --example simple
macos: macos:
timeout-minutes: 30 timeout-minutes: 30
strategy: strategy:
matrix: matrix:
mac-runner: [ "macos-13", "macos-13-xlarge" ] mac-runner: [ "macos-13", "macos-14" ]
runs-on: "${{ matrix.mac-runner }}" runs-on: "${{ matrix.mac-runner }}"
defaults: defaults:
run: run:
shell: bash shell: bash
working-directory: rust working-directory: rust
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
lfs: true lfs: true
@@ -99,7 +101,7 @@ jobs:
windows: windows:
runs-on: windows-2022 runs-on: windows-2022
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- uses: Swatinem/rust-cache@v2 - uses: Swatinem/rust-cache@v2
with: with:
workspaces: rust workspaces: rust

View File

@@ -8,7 +8,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v4
with: with:
ref: main ref: main
persist-credentials: false persist-credentials: false

View File

@@ -11,19 +11,19 @@ license = "Apache-2.0"
repository = "https://github.com/lancedb/lancedb" repository = "https://github.com/lancedb/lancedb"
[workspace.dependencies] [workspace.dependencies]
lance = { "version" = "=0.9.9", "features" = ["dynamodb"] } lance = { "version" = "=0.9.12", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.9.9" } lance-index = { "version" = "=0.9.12" }
lance-linalg = { "version" = "=0.9.9" } lance-linalg = { "version" = "=0.9.12" }
lance-testing = { "version" = "=0.9.9" } lance-testing = { "version" = "=0.9.12" }
# Note that this one does not include pyarrow # Note that this one does not include pyarrow
arrow = { version = "49.0.0", optional = false } arrow = { version = "50.0", optional = false }
arrow-array = "49.0" arrow-array = "50.0"
arrow-data = "49.0" arrow-data = "50.0"
arrow-ipc = "49.0" arrow-ipc = "50.0"
arrow-ord = "49.0" arrow-ord = "50.0"
arrow-schema = "49.0" arrow-schema = "50.0"
arrow-arith = "49.0" arrow-arith = "50.0"
arrow-cast = "49.0" arrow-cast = "50.0"
async-trait = "0" async-trait = "0"
chrono = "0.4.23" chrono = "0.4.23"
half = { "version" = "=2.3.1", default-features = false, features = [ half = { "version" = "=2.3.1", default-features = false, features = [

View File

@@ -51,12 +51,19 @@ npm install vectordb
const lancedb = require('vectordb'); const lancedb = require('vectordb');
const db = await lancedb.connect('data/sample-lancedb'); const db = await lancedb.connect('data/sample-lancedb');
const table = await db.createTable('vectors', const table = await db.createTable({
[{ id: 1, vector: [0.1, 0.2], item: "foo", price: 10 }, name: 'vectors',
{ id: 2, vector: [1.1, 1.2], item: "bar", price: 50 }]) data: [
{ id: 1, vector: [0.1, 0.2], item: "foo", price: 10 },
{ id: 2, vector: [1.1, 1.2], item: "bar", price: 50 }
]
})
const query = table.search([0.1, 0.3]).limit(2); const query = table.search([0.1, 0.3]).limit(2);
const results = await query.execute(); const results = await query.execute();
// You can also search for rows by specific criteria without involving a vector search.
const rowsByCriteria = await table.search(undefined).where("price >= 10").execute();
``` ```
**Python** **Python**

View File

@@ -33,3 +33,12 @@ You can run a local server to test the docs prior to deployment by navigating to
cd docs cd docs
mkdocs serve mkdocs serve
``` ```
### Run doctest for typescript example
```bash
cd lancedb/docs
npm i
npm run build
npm run all
```

View File

@@ -67,7 +67,9 @@ markdown_extensions:
line_spans: __span line_spans: __span
pygments_lang_class: true pygments_lang_class: true
- pymdownx.inlinehilite - pymdownx.inlinehilite
- pymdownx.snippets - pymdownx.snippets:
base_path: ..
dedent_subsections: true
- pymdownx.superfences - pymdownx.superfences
- pymdownx.tabbed: - pymdownx.tabbed:
alternate_style: true alternate_style: true
@@ -88,6 +90,7 @@ nav:
- Building an ANN index: ann_indexes.md - Building an ANN index: ann_indexes.md
- Vector Search: search.md - Vector Search: search.md
- Full-text search: fts.md - Full-text search: fts.md
- Hybrid search: hybrid_search.md
- Filtering: sql.md - Filtering: sql.md
- Versioning & Reproducibility: notebooks/reproducibility.ipynb - Versioning & Reproducibility: notebooks/reproducibility.ipynb
- Configuring Storage: guides/storage.md - Configuring Storage: guides/storage.md
@@ -130,6 +133,7 @@ nav:
- ⚙️ API reference: - ⚙️ API reference:
- 🐍 Python: python/python.md - 🐍 Python: python/python.md
- 👾 JavaScript: javascript/modules.md - 👾 JavaScript: javascript/modules.md
- 🦀 Rust: https://docs.rs/vectordb/latest/vectordb/
- ☁️ LanceDB Cloud: - ☁️ LanceDB Cloud:
- Overview: cloud/index.md - Overview: cloud/index.md
- API reference: - API reference:
@@ -148,6 +152,7 @@ nav:
- Building an ANN index: ann_indexes.md - Building an ANN index: ann_indexes.md
- Vector Search: search.md - Vector Search: search.md
- Full-text search: fts.md - Full-text search: fts.md
- Hybrid search: hybrid_search.md
- Filtering: sql.md - Filtering: sql.md
- Versioning & Reproducibility: notebooks/reproducibility.ipynb - Versioning & Reproducibility: notebooks/reproducibility.ipynb
- Configuring Storage: guides/storage.md - Configuring Storage: guides/storage.md
@@ -195,6 +200,9 @@ extra_css:
- styles/global.css - styles/global.css
- styles/extra.css - styles/extra.css
extra_javascript:
- "extra_js/init_ask_ai_widget.js"
extra: extra:
analytics: analytics:
provider: google provider: google

132
docs/package-lock.json generated Normal file
View File

@@ -0,0 +1,132 @@
{
"name": "lancedb-docs-test",
"version": "1.0.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "lancedb-docs-test",
"version": "1.0.0",
"license": "Apache 2",
"dependencies": {
"apache-arrow": "file:../node/node_modules/apache-arrow",
"vectordb": "file:../node"
},
"devDependencies": {
"@types/node": "^20.11.8",
"typescript": "^5.3.3"
}
},
"../node": {
"name": "vectordb",
"version": "0.4.6",
"cpu": [
"x64",
"arm64"
],
"license": "Apache-2.0",
"os": [
"darwin",
"linux",
"win32"
],
"dependencies": {
"@apache-arrow/ts": "^14.0.2",
"@neon-rs/load": "^0.0.74",
"apache-arrow": "^14.0.2",
"axios": "^1.4.0"
},
"devDependencies": {
"@neon-rs/cli": "^0.0.160",
"@types/chai": "^4.3.4",
"@types/chai-as-promised": "^7.1.5",
"@types/mocha": "^10.0.1",
"@types/node": "^18.16.2",
"@types/sinon": "^10.0.15",
"@types/temp": "^0.9.1",
"@types/uuid": "^9.0.3",
"@typescript-eslint/eslint-plugin": "^5.59.1",
"cargo-cp-artifact": "^0.1",
"chai": "^4.3.7",
"chai-as-promised": "^7.1.1",
"eslint": "^8.39.0",
"eslint-config-standard-with-typescript": "^34.0.1",
"eslint-plugin-import": "^2.26.0",
"eslint-plugin-n": "^15.7.0",
"eslint-plugin-promise": "^6.1.1",
"mocha": "^10.2.0",
"openai": "^4.24.1",
"sinon": "^15.1.0",
"temp": "^0.9.4",
"ts-node": "^10.9.1",
"ts-node-dev": "^2.0.0",
"typedoc": "^0.24.7",
"typedoc-plugin-markdown": "^3.15.3",
"typescript": "*",
"uuid": "^9.0.0"
},
"optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.4.6",
"@lancedb/vectordb-darwin-x64": "0.4.6",
"@lancedb/vectordb-linux-arm64-gnu": "0.4.6",
"@lancedb/vectordb-linux-x64-gnu": "0.4.6",
"@lancedb/vectordb-win32-x64-msvc": "0.4.6"
}
},
"../node/node_modules/apache-arrow": {
"version": "14.0.2",
"license": "Apache-2.0",
"dependencies": {
"@types/command-line-args": "5.2.0",
"@types/command-line-usage": "5.0.2",
"@types/node": "20.3.0",
"@types/pad-left": "2.1.1",
"command-line-args": "5.2.1",
"command-line-usage": "7.0.1",
"flatbuffers": "23.5.26",
"json-bignum": "^0.0.3",
"pad-left": "^2.1.0",
"tslib": "^2.5.3"
},
"bin": {
"arrow2csv": "bin/arrow2csv.js"
}
},
"node_modules/@types/node": {
"version": "20.11.8",
"resolved": "https://registry.npmjs.org/@types/node/-/node-20.11.8.tgz",
"integrity": "sha512-i7omyekpPTNdv4Jb/Rgqg0RU8YqLcNsI12quKSDkRXNfx7Wxdm6HhK1awT3xTgEkgxPn3bvnSpiEAc7a7Lpyow==",
"dev": true,
"dependencies": {
"undici-types": "~5.26.4"
}
},
"node_modules/apache-arrow": {
"resolved": "../node/node_modules/apache-arrow",
"link": true
},
"node_modules/typescript": {
"version": "5.3.3",
"resolved": "https://registry.npmjs.org/typescript/-/typescript-5.3.3.tgz",
"integrity": "sha512-pXWcraxM0uxAS+tN0AG/BF2TyqmHO014Z070UsJ+pFvYuRSq8KH8DmWpnbXe0pEPDHXZV3FcAbJkijJ5oNEnWw==",
"dev": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
},
"engines": {
"node": ">=14.17"
}
},
"node_modules/undici-types": {
"version": "5.26.5",
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz",
"integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==",
"dev": true
},
"node_modules/vectordb": {
"resolved": "../node",
"link": true
}
}
}

20
docs/package.json Normal file
View File

@@ -0,0 +1,20 @@
{
"name": "lancedb-docs-test",
"version": "1.0.0",
"description": "auto-generated tests from doc",
"author": "dev@lancedb.com",
"license": "Apache 2",
"dependencies": {
"apache-arrow": "file:../node/node_modules/apache-arrow",
"vectordb": "file:../node"
},
"scripts": {
"build": "tsc -b && cd ../node && npm run build-release",
"example": "npm run build && node",
"test": "npm run build && ls dist/*.js | xargs -n 1 node"
},
"devDependencies": {
"@types/node": "^20.11.8",
"typescript": "^5.3.3"
}
}

View File

@@ -7,7 +7,7 @@ for brute-force scanning of the entire vector space.
A vector index is faster but less accurate than exhaustive search (kNN or flat search). A vector index is faster but less accurate than exhaustive search (kNN or flat search).
LanceDB provides many parameters to fine-tune the index's size, the speed of queries, and the accuracy of results. LanceDB provides many parameters to fine-tune the index's size, the speed of queries, and the accuracy of results.
Currently, LanceDB does *not* automatically create the ANN index. Currently, LanceDB does _not_ automatically create the ANN index.
LanceDB has optimized code for kNN as well. For many use-cases, datasets under 100K vectors won't require index creation at all. LanceDB has optimized code for kNN as well. For many use-cases, datasets under 100K vectors won't require index creation at all.
If you can live with <100ms latency, skipping index creation is a simpler workflow while guaranteeing 100% recall. If you can live with <100ms latency, skipping index creation is a simpler workflow while guaranteeing 100% recall.
@@ -17,16 +17,17 @@ In the future we will look to automatically create and configure the ANN index a
Lance can support multiple index types, the most widely used one is `IVF_PQ`. Lance can support multiple index types, the most widely used one is `IVF_PQ`.
* `IVF_PQ`: use **Inverted File Index (IVF)** to first divide the dataset into `N` partitions, - `IVF_PQ`: use **Inverted File Index (IVF)** to first divide the dataset into `N` partitions,
and then use **Product Quantization** to compress vectors in each partition. and then use **Product Quantization** to compress vectors in each partition.
* `DiskANN` (**Experimental**): organize the vector as a on-disk graph, where the vertices approximately - `DiskANN` (**Experimental**): organize the vector as a on-disk graph, where the vertices approximately
represent the nearest neighbors of each vector. represent the nearest neighbors of each vector.
## Creating an IVF_PQ Index ## Creating an IVF_PQ Index
Lance supports `IVF_PQ` index type by default. Lance supports `IVF_PQ` index type by default.
=== "Python" === "Python"
Creating indexes is done via the [create_index](https://lancedb.github.io/lancedb/python/#lancedb.table.LanceTable.create_index) method. Creating indexes is done via the [create_index](https://lancedb.github.io/lancedb/python/#lancedb.table.LanceTable.create_index) method.
```python ```python
@@ -46,25 +47,20 @@ Lance supports `IVF_PQ` index type by default.
tbl.create_index(num_partitions=256, num_sub_vectors=96) tbl.create_index(num_partitions=256, num_sub_vectors=96)
``` ```
=== "Javascript" === "Typescript"
```javascript
const vectordb = require('vectordb')
const db = await vectordb.connect('data/sample-lancedb')
let data = [] ```typescript
for (let i = 0; i < 10_000; i++) { --8<--- "docs/src/ann_indexes.ts:import"
data.push({vector: Array(1536).fill(i), id: `${i}`, content: "", longId: `${i}`},)
} --8<-- "docs/src/ann_indexes.ts:ingest"
const table = await db.createTable('my_vectors', data)
await table.createIndex({ type: 'ivf_pq', column: 'vector', num_partitions: 256, num_sub_vectors: 96 })
``` ```
- **metric** (default: "L2"): The distance metric to use. By default it uses euclidean distance "`L2`". - **metric** (default: "L2"): The distance metric to use. By default it uses euclidean distance "`L2`".
We also support "cosine" and "dot" distance as well. We also support "cosine" and "dot" distance as well.
- **num_partitions** (default: 256): The number of partitions of the index. - **num_partitions** (default: 256): The number of partitions of the index.
- **num_sub_vectors** (default: 96): The number of sub-vectors (M) that will be created during Product Quantization (PQ). - **num_sub_vectors** (default: 96): The number of sub-vectors (M) that will be created during Product Quantization (PQ).
For D dimensional vector, it will be divided into `M` of `D/M` sub-vectors, each of which is presented by For D dimensional vector, it will be divided into `M` of `D/M` sub-vectors, each of which is presented by
a single PQ code. a single PQ code.
<figure markdown> <figure markdown>
![IVF PQ](./assets/ivf_pq.png) ![IVF PQ](./assets/ivf_pq.png)
@@ -78,7 +74,7 @@ Using GPU for index creation requires [PyTorch>2.0](https://pytorch.org/) being
You can specify the GPU device to train IVF partitions via You can specify the GPU device to train IVF partitions via
- **accelerator**: Specify to ``cuda`` or ``mps`` (on Apple Silicon) to enable GPU training. - **accelerator**: Specify to `cuda` or `mps` (on Apple Silicon) to enable GPU training.
=== "Linux" === "Linux"
@@ -106,10 +102,9 @@ You can specify the GPU device to train IVF partitions via
Trouble shootings: Trouble shootings:
If you see ``AssertionError: Torch not compiled with CUDA enabled``, you need to [install If you see `AssertionError: Torch not compiled with CUDA enabled`, you need to [install
PyTorch with CUDA support](https://pytorch.org/get-started/locally/). PyTorch with CUDA support](https://pytorch.org/get-started/locally/).
## Querying an ANN Index ## Querying an ANN Index
Querying vector indexes is done via the [search](https://lancedb.github.io/lancedb/python/#lancedb.table.LanceTable.search) function. Querying vector indexes is done via the [search](https://lancedb.github.io/lancedb/python/#lancedb.table.LanceTable.search) function.
@@ -127,6 +122,7 @@ There are a couple of parameters that can be used to fine-tune the search:
Note: refine_factor is only applicable if an ANN index is present. If specified on a table without an ANN index, it is ignored. Note: refine_factor is only applicable if an ANN index is present. If specified on a table without an ANN index, it is ignored.
=== "Python" === "Python"
```python ```python
tbl.search(np.random.random((1536))) \ tbl.search(np.random.random((1536))) \
.limit(2) \ .limit(2) \
@@ -134,41 +130,35 @@ There are a couple of parameters that can be used to fine-tune the search:
.refine_factor(10) \ .refine_factor(10) \
.to_pandas() .to_pandas()
``` ```
```
```text
vector item _distance vector item _distance
0 [0.44949695, 0.8444449, 0.06281311, 0.23338133... item 1141 103.575333 0 [0.44949695, 0.8444449, 0.06281311, 0.23338133... item 1141 103.575333
1 [0.48587373, 0.269207, 0.15095535, 0.65531915,... item 3953 108.393867 1 [0.48587373, 0.269207, 0.15095535, 0.65531915,... item 3953 108.393867
``` ```
=== "Javascript" === "Typescript"
```javascript
const results_1 = await table ```typescript
.search(Array(1536).fill(1.2)) --8<-- "docs/src/ann_indexes.ts:search1"
.limit(2)
.nprobes(20)
.refineFactor(10)
.execute()
``` ```
The search will return the data requested in addition to the distance of each item. The search will return the data requested in addition to the distance of each item.
### Filtering (where clause) ### Filtering (where clause)
You can further filter the elements returned by a search using a where clause. You can further filter the elements returned by a search using a where clause.
=== "Python" === "Python"
```python ```python
tbl.search(np.random.random((1536))).where("item != 'item 1141'").to_pandas() tbl.search(np.random.random((1536))).where("item != 'item 1141'").to_pandas()
``` ```
=== "Javascript" === "Typescript"
```javascript ```javascript
const results_2 = await table --8<-- "docs/src/ann_indexes.ts:search2"
.search(Array(1536).fill(1.2))
.where("id != '1141'")
.limit(2)
.execute()
``` ```
### Projections (select clause) ### Projections (select clause)
@@ -176,23 +166,23 @@ You can further filter the elements returned by a search using a where clause.
You can select the columns returned by the query using a select clause. You can select the columns returned by the query using a select clause.
=== "Python" === "Python"
```python ```python
tbl.search(np.random.random((1536))).select(["vector"]).to_pandas() tbl.search(np.random.random((1536))).select(["vector"]).to_pandas()
``` ```
```
vector _distance
```text
vector _distance
0 [0.30928212, 0.022668175, 0.1756372, 0.4911822... 93.971092 0 [0.30928212, 0.022668175, 0.1756372, 0.4911822... 93.971092
1 [0.2525465, 0.01723831, 0.261568, 0.002007689,... 95.173485 1 [0.2525465, 0.01723831, 0.261568, 0.002007689,... 95.173485
... ...
``` ```
=== "Javascript" === "Typescript"
```javascript
const results_3 = await table ```typescript
.search(Array(1536).fill(1.2)) --8<-- "docs/src/ann_indexes.ts:search3"
.select(["id"])
.limit(2)
.execute()
``` ```
## FAQ ## FAQ
@@ -221,4 +211,4 @@ On `SIFT-1M` dataset, our benchmark shows that keeping each partition 1K-4K rows
`num_sub_vectors` specifies how many Product Quantization (PQ) short codes to generate on each vector. Because `num_sub_vectors` specifies how many Product Quantization (PQ) short codes to generate on each vector. Because
PQ is a lossy compression of the original vector, a higher `num_sub_vectors` usually results in PQ is a lossy compression of the original vector, a higher `num_sub_vectors` usually results in
less space distortion, and thus yields better accuracy. However, a higher `num_sub_vectors` also causes heavier I/O and less space distortion, and thus yields better accuracy. However, a higher `num_sub_vectors` also causes heavier I/O and
more PQ computation, and thus, higher latency. `dimension / num_sub_vectors` should be a multiple of 8 for optimum SIMD efficiency. more PQ computation, and thus, higher latency. `dimension / num_sub_vectors` should be a multiple of 8 for optimum SIMD efficiency.

53
docs/src/ann_indexes.ts Normal file
View File

@@ -0,0 +1,53 @@
// --8<-- [start:import]
import * as vectordb from "vectordb";
// --8<-- [end:import]
(async () => {
// --8<-- [start:ingest]
const db = await vectordb.connect("data/sample-lancedb");
let data = [];
for (let i = 0; i < 10_000; i++) {
data.push({
vector: Array(1536).fill(i),
id: `${i}`,
content: "",
longId: `${i}`,
});
}
const table = await db.createTable("my_vectors", data);
await table.createIndex({
type: "ivf_pq",
column: "vector",
num_partitions: 16,
num_sub_vectors: 48,
});
// --8<-- [end:ingest]
// --8<-- [start:search1]
const results_1 = await table
.search(Array(1536).fill(1.2))
.limit(2)
.nprobes(20)
.refineFactor(10)
.execute();
// --8<-- [end:search1]
// --8<-- [start:search2]
const results_2 = await table
.search(Array(1536).fill(1.2))
.where("id != '1141'")
.limit(2)
.execute();
// --8<-- [end:search2]
// --8<-- [start:search3]
const results_3 = await table
.search(Array(1536).fill(1.2))
.select(["id"])
.limit(2)
.execute();
// --8<-- [end:search3]
console.log("Ann indexes: done");
})();

Binary file not shown.

Before

Width:  |  Height:  |  Size: 266 KiB

After

Width:  |  Height:  |  Size: 107 KiB

View File

@@ -11,43 +11,78 @@
## Installation ## Installation
=== "Python" === "Python"
```shell ```shell
pip install lancedb pip install lancedb
``` ```
=== "Javascript" === "Typescript"
```shell ```shell
npm install vectordb npm install vectordb
``` ```
=== "Rust"
!!! warning "Rust SDK is experimental, might introduce breaking changes in the near future"
```shell
cargo add vectordb
```
!!! info "To use the vectordb create, you first need to install protobuf."
=== "macOS"
```shell
brew install protobuf
```
=== "Ubuntu/Debian"
```shell
sudo apt install -y protobuf-compiler libssl-dev
```
!!! info "Please also make sure you're using the same version of Arrow as in the [vectordb crate](https://github.com/lancedb/lancedb/blob/main/Cargo.toml)"
## How to connect to a database ## How to connect to a database
=== "Python" === "Python"
```python ```python
import lancedb import lancedb
uri = "data/sample-lancedb" uri = "data/sample-lancedb"
db = lancedb.connect(uri) db = lancedb.connect(uri)
``` ```
LanceDB will create the directory if it doesn't exist (including parent directories). === "Typescript"
If you need a reminder of the uri, use the `db.uri` property. ```typescript
--8<-- "docs/src/basic_legacy.ts:import"
=== "Javascript" --8<-- "docs/src/basic_legacy.ts:open_db"
```javascript ```
const lancedb = require("vectordb");
const uri = "data/sample-lancedb"; === "Rust"
const db = await lancedb.connect(uri);
```
LanceDB will create the directory if it doesn't exist (including parent directories).
If you need a reminder of the uri, you can call `db.uri()`. ```rust
#[tokio::main]
async fn main() -> Result<()> {
--8<-- "rust/vectordb/examples/simple.rs:connect"
}
```
!!! info "See [examples/simple.rs](https://github.com/lancedb/lancedb/tree/main/rust/vectordb/examples/simple.rs) for a full working example."
LanceDB will create the directory if it doesn't exist (including parent directories).
If you need a reminder of the uri, you can call `db.uri()`.
## How to create a table ## How to create a table
=== "Python" === "Python"
```python ```python
tbl = db.create_table("my_table", tbl = db.create_table("my_table",
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
@@ -59,6 +94,7 @@
to the `create_table` method. to the `create_table` method.
You can also pass in a pandas DataFrame directly: You can also pass in a pandas DataFrame directly:
```python ```python
import pandas as pd import pandas as pd
df = pd.DataFrame([{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, df = pd.DataFrame([{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
@@ -66,19 +102,26 @@
tbl = db.create_table("table_from_df", data=df) tbl = db.create_table("table_from_df", data=df)
``` ```
=== "Javascript" === "Typescript"
```javascript
const tb = await db.createTable( ```typescript
"myTable", --8<-- "docs/src/basic_legacy.ts:create_table"
[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}]
)
``` ```
If the table already exists, LanceDB will raise an error by default. If the table already exists, LanceDB will raise an error by default.
If you want to overwrite the table, you can pass in `mode="overwrite"` If you want to overwrite the table, you can pass in `mode="overwrite"`
to the `createTable` function. to the `createTable` function.
=== "Rust"
```rust
use arrow_schema::{DataType, Schema, Field};
use arrow_array::{RecordBatch, RecordBatchIterator};
--8<-- "rust/vectordb/examples/simple.rs:create_table"
```
If the table already exists, LanceDB will raise an error by default.
!!! info "Under the hood, LanceDB is converting the input data into an Apache Arrow table and persisting it to disk in [Lance format](https://www.github.com/lancedb/lance)." !!! info "Under the hood, LanceDB is converting the input data into an Apache Arrow table and persisting it to disk in [Lance format](https://www.github.com/lancedb/lance)."
@@ -88,76 +131,145 @@ Sometimes you may not have the data to insert into the table at creation time.
In this case, you can create an empty table and specify the schema. In this case, you can create an empty table and specify the schema.
=== "Python" === "Python"
```python ```python
import pyarrow as pa import pyarrow as pa
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), list_size=2))]) schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), list_size=2))])
tbl = db.create_table("empty_table", schema=schema) tbl = db.create_table("empty_table", schema=schema)
``` ```
=== "Typescript"
```typescript
--8<-- "docs/src/basic_legacy.ts:create_empty_table"
```
=== "Rust"
```rust
--8<-- "rust/vectordb/examples/simple.rs:create_empty_table"
```
## How to open an existing table ## How to open an existing table
Once created, you can open a table using the following code: Once created, you can open a table using the following code:
=== "Python" === "Python"
```python
tbl = db.open_table("my_table")
```
If you forget the name of your table, you can always get a listing of all table names: ```python
tbl = db.open_table("my_table")
```
```python === "Typescript"
print(db.table_names())
``` ```typescript
const tbl = await db.openTable("myTable");
```
=== "Rust"
```rust
--8<-- "rust/vectordb/examples/simple.rs:open_with_existing_file"
```
If you forget the name of your table, you can always get a listing of all table names:
=== "Python"
```python
print(db.table_names())
```
=== "Javascript" === "Javascript"
```javascript
const tbl = await db.openTable("myTable");
```
If you forget the name of your table, you can always get a listing of all table names: ```javascript
console.log(await db.tableNames());
```
```javascript === "Rust"
console.log(await db.tableNames());
``` ```rust
--8<-- "rust/vectordb/examples/simple.rs:list_names"
```
## How to add data to a table ## How to add data to a table
After a table has been created, you can always add more data to it using After a table has been created, you can always add more data to it using
=== "Python" === "Python"
```python
# Option 1: Add a list of dicts to a table ```python
data = [{"vector": [1.3, 1.4], "item": "fizz", "price": 100.0},
{"vector": [9.5, 56.2], "item": "buzz", "price": 200.0}]
tbl.add(data)
# Option 2: Add a pandas DataFrame to a table # Option 1: Add a list of dicts to a table
df = pd.DataFrame(data) data = [{"vector": [1.3, 1.4], "item": "fizz", "price": 100.0},
tbl.add(data) {"vector": [9.5, 56.2], "item": "buzz", "price": 200.0}]
``` tbl.add(data)
=== "Javascript" # Option 2: Add a pandas DataFrame to a table
```javascript df = pd.DataFrame(data)
await tbl.add([{vector: [1.3, 1.4], item: "fizz", price: 100.0}, tbl.add(data)
{vector: [9.5, 56.2], item: "buzz", price: 200.0}]) ```
```
=== "Typescript"
```typescript
--8<-- "docs/src/basic_legacy.ts:add"
```
=== "Rust"
```rust
--8<-- "rust/vectordb/examples/simple.rs:add"
```
## How to search for (approximate) nearest neighbors ## How to search for (approximate) nearest neighbors
Once you've embedded the query, you can find its nearest neighbors using the following code: Once you've embedded the query, you can find its nearest neighbors using the following code:
=== "Python" === "Python"
```python
tbl.search([100, 100]).limit(2).to_pandas()
```
This returns a pandas DataFrame with the results. ```python
tbl.search([100, 100]).limit(2).to_pandas()
```
=== "Javascript" This returns a pandas DataFrame with the results.
```javascript
const query = await tbl.search([100, 100]).limit(2).execute(); === "Typescript"
```
```typescript
--8<-- "docs/src/basic_legacy.ts:search"
```
=== "Rust"
```rust
use futures::TryStreamExt;
--8<-- "rust/vectordb/examples/simple.rs:search"
```
By default, LanceDB runs a brute-force scan over dataset to find the K nearest neighbours (KNN).
For tables with more than 50K vectors, creating an ANN index is recommended to speed up search performance.
=== "Python"
```py
tbl.create_index()
```
=== "Typescript"
```{.typescript .ignore}
--8<-- "docs/src/basic_legacy.ts:create_index"
```
=== "Rust"
```rust
--8<-- "rust/vectordb/examples/simple.rs:create_index"
```
Check [Approximate Nearest Neighbor (ANN) Indexes](/ann_indices.md) section for more details.
## How to delete rows from a table ## How to delete rows from a table
@@ -166,20 +278,27 @@ which rows to delete, provide a filter that matches on the metadata columns.
This can delete any number of rows that match the filter. This can delete any number of rows that match the filter.
=== "Python" === "Python"
```python
tbl.delete('item = "fizz"')
```
=== "Javascript" ```python
```javascript tbl.delete('item = "fizz"')
await tbl.delete('item = "fizz"') ```
```
=== "Typescript"
```typescript
--8<-- "docs/src/basic_legacy.ts:delete"
```
=== "Rust"
```rust
--8<-- "rust/vectordb/examples/simple.rs:delete"
```
The deletion predicate is a SQL expression that supports the same expressions The deletion predicate is a SQL expression that supports the same expressions
as the `where()` clause on a search. They can be as simple or complex as needed. as the `where()` clause on a search. They can be as simple or complex as needed.
To see what expressions are supported, see the [SQL filters](sql.md) section. To see what expressions are supported, see the [SQL filters](sql.md) section.
=== "Python" === "Python"
Read more: [lancedb.table.Table.delete][] Read more: [lancedb.table.Table.delete][]
@@ -193,6 +312,7 @@ To see what expressions are supported, see the [SQL filters](sql.md) section.
Use the `drop_table()` method on the database to remove a table. Use the `drop_table()` method on the database to remove a table.
=== "Python" === "Python"
```python ```python
db.drop_table("my_table") db.drop_table("my_table")
``` ```
@@ -201,13 +321,20 @@ Use the `drop_table()` method on the database to remove a table.
By default, if the table does not exist an exception is raised. To suppress this, By default, if the table does not exist an exception is raised. To suppress this,
you can pass in `ignore_missing=True`. you can pass in `ignore_missing=True`.
=== "JavaScript" === "Typescript"
```javascript
await db.dropTable('myTable') ```typescript
--8<-- "docs/src/basic_legacy.ts:drop_table"
``` ```
This permanently removes the table and is not recoverable, unlike deleting rows. This permanently removes the table and is not recoverable, unlike deleting rows.
If the table does not exist an exception is raised. If the table does not exist an exception is raised.
=== "Rust"
```rust
--8<-- "rust/vectordb/examples/simple.rs:drop_table"
```
!!! note "Bundling `vectordb` apps with Webpack" !!! note "Bundling `vectordb` apps with Webpack"

92
docs/src/basic_legacy.ts Normal file
View File

@@ -0,0 +1,92 @@
// --8<-- [start:import]
import * as lancedb from "vectordb";
import { Schema, Field, Float32, FixedSizeList, Int32, Float16 } from "apache-arrow";
// --8<-- [end:import]
import * as fs from "fs";
import { Table as ArrowTable, Utf8 } from "apache-arrow";
const example = async () => {
fs.rmSync("data/sample-lancedb", { recursive: true, force: true });
// --8<-- [start:open_db]
const lancedb = require("vectordb");
const uri = "data/sample-lancedb";
const db = await lancedb.connect(uri);
// --8<-- [end:open_db]
// --8<-- [start:create_table]
const tbl = await db.createTable(
"myTable",
[
{ vector: [3.1, 4.1], item: "foo", price: 10.0 },
{ vector: [5.9, 26.5], item: "bar", price: 20.0 },
],
{ writeMode: lancedb.WriteMode.Overwrite }
);
// --8<-- [end:create_table]
// --8<-- [start:add]
const newData = Array.from({ length: 500 }, (_, i) => ({
vector: [i, i + 1],
item: "fizz",
price: i * 0.1,
}));
await tbl.add(newData);
// --8<-- [end:add]
// --8<-- [start:create_index]
await tbl.createIndex({
type: "ivf_pq",
num_partitions: 2,
num_sub_vectors: 2,
});
// --8<-- [end:create_index]
// --8<-- [start:create_empty_table]
const schema = new Schema([
new Field("id", new Int32()),
new Field("name", new Utf8()),
]);
const empty_tbl = await db.createTable({ name: "empty_table", schema });
// --8<-- [end:create_empty_table]
// --8<-- [start:create_f16_table]
const dim = 16
const total = 10
const f16_schema = new Schema([
new Field('id', new Int32()),
new Field(
'vector',
new FixedSizeList(dim, new Field('item', new Float16(), true)),
false
)
])
const data = lancedb.makeArrowTable(
Array.from(Array(total), (_, i) => ({
id: i,
vector: Array.from(Array(dim), Math.random)
})),
{ f16_schema }
)
const table = await db.createTable('f16_tbl', data)
// --8<-- [end:create_f16_table]
// --8<-- [start:search]
const query = await tbl.search([100, 100]).limit(2).execute();
// --8<-- [end:search]
console.log(query);
// --8<-- [start:delete]
await tbl.delete('item = "fizz"');
// --8<-- [end:delete]
// --8<-- [start:drop_table]
await db.dropTable("myTable");
// --8<-- [end:drop_table]
};
async function main() {
await example();
console.log("Basic example: done");
}
main();

View File

@@ -119,7 +119,7 @@ texts = [{"text": "Capitalism has been dominant in the Western world since the e
tbl.add(texts) tbl.add(texts)
``` ```
## Gemini Embedding Function ### Gemini Embeddings
With Google's Gemini, you can represent text (words, sentences, and blocks of text) in a vectorized form, making it easier to compare and contrast embeddings. For example, two texts that share a similar subject matter or sentiment should have similar embeddings, which can be identified through mathematical comparison techniques such as cosine similarity. For more on how and why you should use embeddings, refer to the Embeddings guide. With Google's Gemini, you can represent text (words, sentences, and blocks of text) in a vectorized form, making it easier to compare and contrast embeddings. For example, two texts that share a similar subject matter or sentiment should have similar embeddings, which can be identified through mathematical comparison techniques such as cosine similarity. For more on how and why you should use embeddings, refer to the Embeddings guide.
The Gemini Embedding Model API supports various task types: The Gemini Embedding Model API supports various task types:
@@ -155,6 +155,51 @@ tbl.add(df)
rs = tbl.search("hello").limit(1).to_pandas() rs = tbl.search("hello").limit(1).to_pandas()
``` ```
### AWS Bedrock Text Embedding Functions
AWS Bedrock supports multiple base models for generating text embeddings. You need to setup the AWS credentials to use this embedding function.
You can do so by using `awscli` and also add your session_token:
```shell
aws configure
aws configure set aws_session_token "<your_session_token>"
```
to ensure that the credentials are set up correctly, you can run the following command:
```shell
aws sts get-caller-identity
```
Supported Embedding modelIDs are:
* `amazon.titan-embed-text-v1`
* `cohere.embed-english-v3`
* `cohere.embed-multilingual-v3`
Supported paramters (to be passed in `create` method) are:
| Parameter | Type | Default Value | Description |
|---|---|---|---|
| **name** | str | "amazon.titan-embed-text-v1" | The model ID of the bedrock model to use. Supported base models for Text Embeddings: amazon.titan-embed-text-v1, cohere.embed-english-v3, cohere.embed-multilingual-v3 |
| **region** | str | "us-east-1" | Optional name of the AWS Region in which the service should be called (e.g., "us-east-1"). |
| **profile_name** | str | None | Optional name of the AWS profile to use for calling the Bedrock service. If not specified, the default profile will be used. |
| **assumed_role** | str | None | Optional ARN of an AWS IAM role to assume for calling the Bedrock service. If not specified, the current active credentials will be used. |
| **role_session_name** | str | "lancedb-embeddings" | Optional name of the AWS IAM role session to use for calling the Bedrock service. If not specified, a "lancedb-embeddings" name will be used. |
| **runtime** | bool | True | Optional choice of getting different client to perform operations with the Amazon Bedrock service. |
| **max_retries** | int | 7 | Optional number of retries to perform when a request fails. |
Usage Example:
```python
model = get_registry().get("bedrock-text").create()
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect("tmp_path")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
rs = tbl.search("hello").limit(1).to_pandas()
```
## Multi-modal embedding functions ## Multi-modal embedding functions
Multi-modal embedding functions allow you to query your table using both images and text. Multi-modal embedding functions allow you to query your table using both images and text.

View File

@@ -79,7 +79,10 @@ def qanda_langchain(query):
download_docs() download_docs()
docs = store_docs() docs = store_docs()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200,) text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
)
documents = text_splitter.split_documents(docs) documents = text_splitter.split_documents(docs)
embeddings = OpenAIEmbeddings() embeddings = OpenAIEmbeddings()

View File

@@ -0,0 +1,11 @@
document.addEventListener("DOMContentLoaded", function () {
var script = document.createElement("script");
script.src = "https://widget.kapa.ai/kapa-widget.bundle.js";
script.setAttribute("data-website-id", "c5881fae-cec0-490b-b45e-d83d131d4f25");
script.setAttribute("data-project-name", "LanceDB");
script.setAttribute("data-project-color", "#000000");
script.setAttribute("data-project-logo", "https://avatars.githubusercontent.com/u/108903835?s=200&v=4");
script.setAttribute("data-modal-example-questions","Help me create an IVF_PQ index,How do I do an exhaustive search?,How do I create a LanceDB table?,Can I use my own embedding function?");
script.async = true;
document.head.appendChild(script);
});

View File

@@ -68,6 +68,82 @@ Alternatively, if you are using AWS SSO, you can use the `AWS_PROFILE` and `AWS_
You can see a full list of environment variables [here](https://docs.rs/object_store/latest/object_store/aws/struct.AmazonS3Builder.html#method.from_env). You can see a full list of environment variables [here](https://docs.rs/object_store/latest/object_store/aws/struct.AmazonS3Builder.html#method.from_env).
!!! tip "Automatic cleanup for failed writes"
LanceDB uses [multi-part uploads](https://docs.aws.amazon.com/AmazonS3/latest/userguide/mpuoverview.html) when writing data to S3 in order to maximize write speed. LanceDB will abort these uploads when it shuts down gracefully, such as when cancelled by keyboard interrupt. However, in the rare case that LanceDB crashes, it is possible that some data will be left lingering in your account. To cleanup this data, we recommend (as AWS themselves do) that you setup a lifecycle rule to delete in-progress uploads after 7 days. See the AWS guide:
**[Configuring a bucket lifecycle configuration to delete incomplete multipart uploads](https://docs.aws.amazon.com/AmazonS3/latest/userguide/mpu-abort-incomplete-mpu-lifecycle-config.html)**
#### AWS IAM Permissions
If a bucket is private, then an IAM policy must be specified to allow access to it. For many development scenarios, using broad permissions such as a PowerUser account is more than sufficient for working with LanceDB. However, in many production scenarios, you may wish to have as narrow as possible permissions.
For **read and write access**, LanceDB will need a policy such as:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"s3:PutObject",
"s3:GetObject",
"s3:DeleteObject",
],
"Resource": "arn:aws:s3:::<bucket>/<prefix>/*"
},
{
"Effect": "Allow",
"Action": [
"s3:ListBucket",
"s3:GetBucketLocation"
],
"Resource": "arn:aws:s3:::<bucket>",
"Condition": {
"StringLike": {
"s3:prefix": [
"<prefix>/*"
]
}
}
}
]
}
```
For **read-only access**, LanceDB will need a policy such as:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"s3:GetObject",
],
"Resource": "arn:aws:s3:::<bucket>/<prefix>/*"
},
{
"Effect": "Allow",
"Action": [
"s3:ListBucket",
"s3:GetBucketLocation"
],
"Resource": "arn:aws:s3:::<bucket>",
"Condition": {
"StringLike": {
"s3:prefix": [
"<prefix>/*"
]
}
}
}
]
}
```
#### S3-compatible stores #### S3-compatible stores
LanceDB can also connect to S3-compatible stores, such as MinIO. To do so, you must specify two environment variables: `AWS_ENDPOINT` and `AWS_DEFAULT_REGION`. `AWS_ENDPOINT` should be the URL of the S3-compatible store, and `AWS_DEFAULT_REGION` should be the region to use. LanceDB can also connect to S3-compatible stores, such as MinIO. To do so, you must specify two environment variables: `AWS_ENDPOINT` and `AWS_DEFAULT_REGION`. `AWS_ENDPOINT` should be the URL of the S3-compatible store, and `AWS_DEFAULT_REGION` should be the region to use.

View File

@@ -16,9 +16,22 @@ This guide will show how to create tables, insert data into them, and update the
db = lancedb.connect("./.lancedb") db = lancedb.connect("./.lancedb")
``` ```
=== "Javascript"
Initialize a VectorDB connection and create a table using one of the many methods listed below.
```javascript
const lancedb = require("vectordb");
const uri = "data/sample-lancedb";
const db = await lancedb.connect(uri);
```
LanceDB allows ingesting data from various sources - `dict`, `list[dict]`, `pd.DataFrame`, `pa.Table` or a `Iterator[pa.RecordBatch]`. Let's take a look at some of the these. LanceDB allows ingesting data from various sources - `dict`, `list[dict]`, `pd.DataFrame`, `pa.Table` or a `Iterator[pa.RecordBatch]`. Let's take a look at some of the these.
### From list of tuples or dictionaries ### From list of tuples or dictionaries
=== "Python"
```python ```python
import lancedb import lancedb
@@ -32,7 +45,6 @@ This guide will show how to create tables, insert data into them, and update the
db["my_table"].head() db["my_table"].head()
``` ```
!!! info "Note" !!! info "Note"
If the table already exists, LanceDB will raise an error by default. If the table already exists, LanceDB will raise an error by default.
@@ -51,6 +63,27 @@ This guide will show how to create tables, insert data into them, and update the
db.create_table("name", data, mode="overwrite") db.create_table("name", data, mode="overwrite")
``` ```
=== "Javascript"
You can create a LanceDB table in JavaScript using an array of JSON records as follows.
```javascript
const tb = await db.createTable("my_table", [{
"vector": [3.1, 4.1],
"item": "foo",
"price": 10.0
}, {
"vector": [5.9, 26.5],
"item": "bar",
"price": 20.0
}]);
```
!!! info "Note"
If the table already exists, LanceDB will raise an error by default. If you want to overwrite the table, you need to specify the `WriteMode` in the createTable function.
```javascript
const table = await con.createTable(tableName, data, { writeMode: WriteMode.Overwrite })
```
### From a Pandas DataFrame ### From a Pandas DataFrame
```python ```python
@@ -67,7 +100,9 @@ This guide will show how to create tables, insert data into them, and update the
db["my_table"].head() db["my_table"].head()
``` ```
!!! info "Note" !!! info "Note"
Data is converted to Arrow before being written to disk. For maximum control over how data is saved, either provide the PyArrow schema to convert to or else provide a PyArrow Table directly. Data is converted to Arrow before being written to disk. For maximum control over how data is saved, either provide the PyArrow schema to convert to or else provide a PyArrow Table directly.
The **`vector`** column needs to be a [Vector](../python/pydantic.md#vector-field) (defined as [pyarrow.FixedSizeList](https://arrow.apache.org/docs/python/generated/pyarrow.list_.html)) type.
```python ```python
custom_schema = pa.schema([ custom_schema = pa.schema([
@@ -79,7 +114,7 @@ This guide will show how to create tables, insert data into them, and update the
table = db.create_table("my_table", data, schema=custom_schema) table = db.create_table("my_table", data, schema=custom_schema)
``` ```
### From a Polars DataFrame ### From a Polars DataFrame
LanceDB supports [Polars](https://pola.rs/), a modern, fast DataFrame library LanceDB supports [Polars](https://pola.rs/), a modern, fast DataFrame library
written in Rust. Just like in Pandas, the Polars integration is enabled by PyArrow written in Rust. Just like in Pandas, the Polars integration is enabled by PyArrow
@@ -97,26 +132,44 @@ This guide will show how to create tables, insert data into them, and update the
table = db.create_table("pl_table", data=data) table = db.create_table("pl_table", data=data)
``` ```
### From PyArrow Tables ### From an Arrow Table
You can also create LanceDB tables directly from PyArrow tables === "Python"
You can also create LanceDB tables directly from Arrow tables.
LanceDB supports float16 data type!
```python ```python
table = pa.Table.from_arrays( import pyarrows as pa
[ import numpy as np
pa.array([[3.1, 4.1, 5.1, 6.1], [5.9, 26.5, 4.7, 32.8]],
pa.list_(pa.float32(), 4)), dim = 16
pa.array(["foo", "bar"]), total = 2
pa.array([10.0, 20.0]), schema = pa.schema(
], [
["vector", "item", "price"], pa.field("vector", pa.list_(pa.float16(), dim)),
) pa.field("text", pa.string())
]
)
data = pa.Table.from_arrays(
[
pa.array([np.random.randn(dim).astype(np.float16) for _ in range(total)],
pa.list_(pa.float16(), dim)),
pa.array(["foo", "bar"])
],
["vector", "text"],
)
tbl = db.create_table("f16_tbl", data, schema=schema)
```
db = lancedb.connect("db") === "Javascript"
You can also create LanceDB tables directly from Arrow tables.
LanceDB supports Float16 data type!
tbl = db.create_table("my_table", table) ```javascript
--8<-- "docs/src/basic_legacy.ts:create_f16_table"
``` ```
### From Pydantic Models ### From Pydantic Models
When you create an empty table without data, you must specify the table schema. When you create an empty table without data, you must specify the table schema.
LanceDB supports creating tables by specifying a PyArrow schema or a specialized LanceDB supports creating tables by specifying a PyArrow schema or a specialized
Pydantic model called `LanceModel`. Pydantic model called `LanceModel`.
@@ -261,37 +314,6 @@ This guide will show how to create tables, insert data into them, and update the
You can also use iterators of other types like Pandas DataFrame or Pylists directly in the above example. You can also use iterators of other types like Pandas DataFrame or Pylists directly in the above example.
=== "JavaScript"
Initialize a VectorDB connection and create a table using one of the many methods listed below.
```javascript
const lancedb = require("vectordb");
const uri = "data/sample-lancedb";
const db = await lancedb.connect(uri);
```
You can create a LanceDB table in JavaScript using an array of JSON records as follows.
```javascript
const tb = await db.createTable("my_table", [{
"vector": [3.1, 4.1],
"item": "foo",
"price": 10.0
}, {
"vector": [5.9, 26.5],
"item": "bar",
"price": 20.0
}]);
```
!!! info "Note"
If the table already exists, LanceDB will raise an error by default. If you want to overwrite the table, you need to specify the `WriteMode` in the createTable function.
```javascript
const table = await con.createTable(tableName, data, { writeMode: WriteMode.Overwrite })
```
## Open existing tables ## Open existing tables
=== "Python" === "Python"

172
docs/src/hybrid_search.md Normal file
View File

@@ -0,0 +1,172 @@
# Hybrid Search
LanceDB supports both semantic and keyword-based search. In real world applications, it is often useful to combine these two approaches to get the best best results. For example, you may want to search for a document that is semantically similar to a query document, but also contains a specific keyword. This is an example of *hybrid search*, a search algorithm that combines multiple search techniques.
## Hybrid search in LanceDB
You can perform hybrid search in LanceDB by combining the results of semantic and full-text search via a reranking algorithm of your choice. LanceDB provides multiple rerankers out of the box. However, you can always write a custom reranker if your use case need more sophisticated logic .
```python
import lancedb
from lancedb.embeddings import get_registry
from lancedb.pydanatic import LanceModel, Vector
db = lancedb.connect("~/.lancedb")
# Ingest embedding function in LanceDB table
embeddings = get_registry().get("openai").create()
class Documents(LanceModel):
vector: Vector(embeddings.ndims) = embeddings.VectorField()
text: str = embeddings.SourceField()
table = db.create_table("documents", schema=Documents)
data = [
{ "text": "rebel spaceships striking from a hidden base"},
{ "text": "have won their first victory against the evil Galactic Empire"},
{ "text": "during the battle rebel spies managed to steal secret plans"},
{ "text": "to the Empire's ultimate weapon the Death Star"}
]
# ingest docs with auto-vectorization
table.add(data)
# hybrid search with default re-ranker
results = table.search("flower moon", query_type="hybrid").to_pandas()
```
By default, LanceDB uses `LinearCombinationReranker(weights=0.7)` to combine and rerank the results of semantic and full-text search. You can customize the hyperparameters as needed or write your own custom reranker. Here's how you can use any of the available rerankers:
### `rerank()` arguments
* `normalize`: `str`, default `"score"`:
The method to normalize the scores. Can be "rank" or "score". If "rank", the scores are converted to ranks and then normalized. If "score", the scores are normalized directly.
* `reranker`: `Reranker`, default `LinearCombinationReranker(weights=0.7)`.
The reranker to use. If not specified, the default reranker is used.
## Available Rerankers
LanceDB provides a number of re-rankers out of the box. You can use any of these re-rankers by passing them to the `rerank()` method. Here's a list of available re-rankers:
### Linear Combination Reranker
This is the default re-ranker used by LanceDB. It combines the results of semantic and full-text search using a linear combination of the scores. The weights for the linear combination can be specified. It defaults to 0.7, i.e, 70% weight for semantic search and 30% weight for full-text search.
```python
from lancedb.rerankers import LinearCombinationReranker
reranker = LinearCombinationReranker(weights=0.3) # Use 0.3 as the weight for vector search
results = table.search("rebel", query_type="hybrid").rerank(reranker=reranker).to_pandas()
```
Arguments
----------------
* `weight`: `float`, default `0.7`:
The weight to use for the semantic search score. The weight for the full-text search score is `1 - weights`.
* `fill`: `float`, default `1.0`:
The score to give to results that are only in one of the two result sets.This is treated as penalty, so a higher value means a lower score.
TODO: We should just hardcode this-- its pretty confusing as we invert scores to calculate final score
* `return_score` : str, default `"relevance"`
options are "relevance" or "all"
The type of score to return. If "relevance", will return only the `_relevance_score. If "all", will return all scores from the vector and FTS search along with the relevance score.
### Cohere Reranker
This re-ranker uses the [Cohere](https://cohere.ai/) API to combine the results of semantic and full-text search. You can use this re-ranker by passing `CohereReranker()` to the `rerank()` method. Note that you'll need to set the `COHERE_API_KEY` environment variable to use this re-ranker.
```python
from lancedb.rerankers import CohereReranker
reranker = CohereReranker()
results = table.search("vampire weekend", query_type="hybrid").rerank(reranker=reranker).to_pandas()
```
Arguments
----------------
* `model_name`` : str, default `"rerank-english-v2.0"``
The name of the cross encoder model to use. Available cohere models are:
- rerank-english-v2.0
- rerank-multilingual-v2.0
* `column` : str, default `"text"`
The name of the column to use as input to the cross encoder model.
* `top_n` : str, default `None`
The number of results to return. If None, will return all results.
!!! Note
Only returns `_relevance_score`. Does not support `return_score = "all"`.
### Cross Encoder Reranker
This reranker uses the [Sentence Transformers](https://www.sbert.net/) library to combine the results of semantic and full-text search. You can use it by passing `CrossEncoderReranker()` to the `rerank()` method.
```python
from lancedb.rerankers import CrossEncoderReranker
reranker = CrossEncoderReranker()
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
```
Arguments
----------------
* `model` : str, default `"cross-encoder/ms-marco-TinyBERT-L-6"`
The name of the cross encoder model to use. Available cross encoder models can be found [here](https://www.sbert.net/docs/pretrained_cross-encoders.html)
* `column` : str, default `"text"`
The name of the column to use as input to the cross encoder model.
* `device` : str, default `None`
The device to use for the cross encoder model. If None, will use "cuda" if available, otherwise "cpu".
!!! Note
Only returns `_relevance_score`. Does not support `return_score = "all"`.
## Building Custom Rerankers
You can build your own custom reranker by subclassing the `Reranker` class and implementing the `rerank_hybrid()` method. Here's an example of a custom reranker that combines the results of semantic and full-text search using a linear combination of the scores.
The `Reranker` base interface comes with a `merge_results()` method that can be used to combine the results of semantic and full-text search. This is a vanilla merging algorithm that simply concatenates the results and removes the duplicates without taking the scores into consideration. It only keeps the first copy of the row encountered. This works well in cases that don't require the scores of semantic and full-text search to combine the results. If you want to use the scores or want to support `return_score="all"`, you'll need to implement your own merging algorithm.
```python
from lancedb.rerankers import Reranker
import pyarrow as pa
class MyReranker(Reranker):
def __init__(self, param1, param2, ..., return_score="relevance"):
super().__init__(return_score)
self.param1 = param1
self.param2 = param2
def rerank_hybrid(self, vector_results: pa.Table, fts_results: pa.Table):
# Use the built-in merging function
combined_result = self.merge_results(vector_results, fts_results)
# Do something with the combined results
# ...
# Return the combined results
return combined_result
```
You can also accept additional arguments like a filter along with fts and vector search results
```python
from lancedb.rerankers import Reranker
import pyarrow as pa
class MyReranker(Reranker):
...
def rerank_hybrid(self, vector_results: pa.Table, fts_results: pa.Table, filter: str):
# Use the built-in merging function
combined_result = self.merge_results(vector_results, fts_results)
# Do something with the combined results & filter
# ...
# Return the combined results
return combined_result
```

View File

@@ -13,7 +13,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 50, "execution_count": 2,
"id": "c1b4e34b-a49c-471d-a343-a5940bb5138a", "id": "c1b4e34b-a49c-471d-a343-a5940bb5138a",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -23,7 +23,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 3,
"id": "4e5a8d07-d9a1-48c1-913a-8e0629289579", "id": "4e5a8d07-d9a1-48c1-913a-8e0629289579",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -44,7 +44,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 4,
"id": "5df12f66-8d99-43ad-8d0b-22189ec0a6b9", "id": "5df12f66-8d99-43ad-8d0b-22189ec0a6b9",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -62,7 +62,7 @@
"long: [[-122.7,-74.1]]" "long: [[-122.7,-74.1]]"
] ]
}, },
"execution_count": 2, "execution_count": 4,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -90,7 +90,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 5,
"id": "f4d87ae9-0ccb-48eb-b31d-bb8f2370e47e", "id": "f4d87ae9-0ccb-48eb-b31d-bb8f2370e47e",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -108,7 +108,7 @@
"long: [[-122.7,-74.1]]" "long: [[-122.7,-74.1]]"
] ]
}, },
"execution_count": 3, "execution_count": 5,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -135,10 +135,17 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 6,
"id": "25f34bcf-fca0-4431-8601-eac95d1bd347", "id": "25f34bcf-fca0-4431-8601-eac95d1bd347",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2024-01-31T18:59:33Z WARN lance::dataset] No existing dataset at /Users/qian/Work/LanceDB/lancedb/docs/src/notebooks/.lancedb/table3.lance, it will be created\n"
]
},
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
@@ -148,7 +155,7 @@
"long: float" "long: float"
] ]
}, },
"execution_count": 8, "execution_count": 6,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -171,45 +178,51 @@
"id": "4df51925-7ca2-4005-9c72-38b3d26240c6", "id": "4df51925-7ca2-4005-9c72-38b3d26240c6",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### From PyArrow Tables\n", "### From an Arrow Table\n",
"\n", "\n",
"You can also create LanceDB tables directly from pyarrow tables" "You can also create LanceDB tables directly from pyarrow tables. LanceDB supports float16 type."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 7,
"id": "90a880f6-be43-4c9d-ba65-0b05197c0f6f", "id": "90a880f6-be43-4c9d-ba65-0b05197c0f6f",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"vector: fixed_size_list<item: float>[2]\n", "vector: fixed_size_list<item: halffloat>[16]\n",
" child 0, item: float\n", " child 0, item: halffloat\n",
"item: string\n", "text: string"
"price: double"
] ]
}, },
"execution_count": 12, "execution_count": 7,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
], ],
"source": [ "source": [
"table = pa.Table.from_arrays(\n", "import numpy as np\n",
" [\n",
" pa.array([[3.1, 4.1], [5.9, 26.5]],\n",
" pa.list_(pa.float32(), 2)),\n",
" pa.array([\"foo\", \"bar\"]),\n",
" pa.array([10.0, 20.0]),\n",
" ],\n",
" [\"vector\", \"item\", \"price\"],\n",
" )\n",
"\n", "\n",
"db = lancedb.connect(\"db\")\n", "dim = 16\n",
"total = 2\n",
"schema = pa.schema(\n",
" [\n",
" pa.field(\"vector\", pa.list_(pa.float16(), dim)),\n",
" pa.field(\"text\", pa.string())\n",
" ]\n",
")\n",
"data = pa.Table.from_arrays(\n",
" [\n",
" pa.array([np.random.randn(dim).astype(np.float16) for _ in range(total)],\n",
" pa.list_(pa.float16(), dim)),\n",
" pa.array([\"foo\", \"bar\"])\n",
" ],\n",
" [\"vector\", \"text\"],\n",
")\n",
"\n", "\n",
"tbl = db.create_table(\"test1\", table, mode=\"overwrite\")\n", "tbl = db.create_table(\"f16_tbl\", data, schema=schema)\n",
"tbl.schema" "tbl.schema"
] ]
}, },
@@ -225,7 +238,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 8,
"id": "d81121d7-e4b7-447c-a48c-974b6ebb464a", "id": "d81121d7-e4b7-447c-a48c-974b6ebb464a",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -240,7 +253,7 @@
"imdb_id: int64 not null" "imdb_id: int64 not null"
] ]
}, },
"execution_count": 13, "execution_count": 8,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -282,7 +295,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 9,
"id": "bc247142-4e3c-41a2-b94c-8e00d2c2a508", "id": "bc247142-4e3c-41a2-b94c-8e00d2c2a508",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -292,7 +305,7 @@
"LanceTable(table4)" "LanceTable(table4)"
] ]
}, },
"execution_count": 14, "execution_count": 9,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -333,7 +346,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 10,
"id": "25ad3523-e0c9-4c28-b3df-38189c4e0e5f", "id": "25ad3523-e0c9-4c28-b3df-38189c4e0e5f",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -346,7 +359,7 @@
"price: double not null" "price: double not null"
] ]
}, },
"execution_count": 16, "execution_count": 10,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -385,7 +398,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 11,
"id": "2814173a-eacc-4dd8-a64d-6312b44582cc", "id": "2814173a-eacc-4dd8-a64d-6312b44582cc",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -411,7 +424,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": 12,
"id": "df9e13c0-41f6-437f-9dfa-2fd71d3d9c45", "id": "df9e13c0-41f6-437f-9dfa-2fd71d3d9c45",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -421,7 +434,7 @@
"['table6', 'table4', 'table5', 'movielens_small']" "['table6', 'table4', 'table5', 'movielens_small']"
] ]
}, },
"execution_count": 18, "execution_count": 12,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -432,7 +445,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 20, "execution_count": 13,
"id": "9343f5ad-6024-42ee-ac2f-6c1471df8679", "id": "9343f5ad-6024-42ee-ac2f-6c1471df8679",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -541,7 +554,7 @@
"9 [5.9, 26.5] bar 20.0" "9 [5.9, 26.5] bar 20.0"
] ]
}, },
"execution_count": 20, "execution_count": 13,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -564,7 +577,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 21, "execution_count": 14,
"id": "8a56250f-73a1-4c26-a6ad-5c7a0ce3a9ab", "id": "8a56250f-73a1-4c26-a6ad-5c7a0ce3a9ab",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -590,7 +603,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 22, "execution_count": 15,
"id": "030c7057-b98e-4e2f-be14-b8c1f927f83c", "id": "030c7057-b98e-4e2f-be14-b8c1f927f83c",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -621,7 +634,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 24, "execution_count": 16,
"id": "e7a17de2-08d2-41b7-bd05-f63d1045ab1f", "id": "e7a17de2-08d2-41b7-bd05-f63d1045ab1f",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -629,16 +642,16 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"32\n" "22\n"
] ]
}, },
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"17" "12"
] ]
}, },
"execution_count": 24, "execution_count": 16,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -661,7 +674,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 30, "execution_count": 17,
"id": "fe3310bd-08f4-4a22-a63b-b3127d22f9f7", "id": "fe3310bd-08f4-4a22-a63b-b3127d22f9f7",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -681,25 +694,20 @@
"8 [3.1, 4.1] foo 10.0\n", "8 [3.1, 4.1] foo 10.0\n",
"9 [3.1, 4.1] foo 10.0\n", "9 [3.1, 4.1] foo 10.0\n",
"10 [3.1, 4.1] foo 10.0\n", "10 [3.1, 4.1] foo 10.0\n",
"11 [3.1, 4.1] foo 10.0\n", "11 [3.1, 4.1] foo 10.0\n"
"12 [3.1, 4.1] foo 10.0\n",
"13 [3.1, 4.1] foo 10.0\n",
"14 [3.1, 4.1] foo 10.0\n",
"15 [3.1, 4.1] foo 10.0\n",
"16 [3.1, 4.1] foo 10.0\n"
] ]
}, },
{ {
"ename": "OSError", "ename": "OSError",
"evalue": "LanceError(IO): Error during planning: column foo does not exist", "evalue": "LanceError(IO): Error during planning: column foo does not exist, /Users/runner/work/lance/lance/rust/lance-core/src/error.rs:212:23",
"output_type": "error", "output_type": "error",
"traceback": [ "traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)", "\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[30], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m to_remove \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;28mstr\u001b[39m(v) \u001b[38;5;28;01mfor\u001b[39;00m v \u001b[38;5;129;01min\u001b[39;00m to_remove)\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(tbl\u001b[38;5;241m.\u001b[39mto_pandas())\n\u001b[0;32m----> 4\u001b[0m \u001b[43mtbl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdelete\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mitem IN (\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mto_remove\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m)\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m tbl\u001b[38;5;241m.\u001b[39mto_pandas()\n", "Cell \u001b[0;32mIn[17], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m to_remove \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;28mstr\u001b[39m(v) \u001b[38;5;28;01mfor\u001b[39;00m v \u001b[38;5;129;01min\u001b[39;00m to_remove)\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(tbl\u001b[38;5;241m.\u001b[39mto_pandas())\n\u001b[0;32m----> 4\u001b[0m \u001b[43mtbl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdelete\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mitem IN (\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mto_remove\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m)\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Documents/lancedb/lancedb/python/lancedb/table.py:610\u001b[0m, in \u001b[0;36mLanceTable.delete\u001b[0;34m(self, where)\u001b[0m\n\u001b[1;32m 609\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdelete\u001b[39m(\u001b[38;5;28mself\u001b[39m, where: \u001b[38;5;28mstr\u001b[39m):\n\u001b[0;32m--> 610\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdelete\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwhere\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Work/LanceDB/lancedb/docs/doc-venv/lib/python3.11/site-packages/lancedb/table.py:872\u001b[0m, in \u001b[0;36mLanceTable.delete\u001b[0;34m(self, where)\u001b[0m\n\u001b[1;32m 871\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdelete\u001b[39m(\u001b[38;5;28mself\u001b[39m, where: \u001b[38;5;28mstr\u001b[39m):\n\u001b[0;32m--> 872\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdelete\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwhere\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Documents/lancedb/lancedb/env/lib/python3.11/site-packages/lance/dataset.py:489\u001b[0m, in \u001b[0;36mLanceDataset.delete\u001b[0;34m(self, predicate)\u001b[0m\n\u001b[1;32m 487\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(predicate, pa\u001b[38;5;241m.\u001b[39mcompute\u001b[38;5;241m.\u001b[39mExpression):\n\u001b[1;32m 488\u001b[0m predicate \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mstr\u001b[39m(predicate)\n\u001b[0;32m--> 489\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_ds\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdelete\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpredicate\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Work/LanceDB/lancedb/docs/doc-venv/lib/python3.11/site-packages/lance/dataset.py:596\u001b[0m, in \u001b[0;36mLanceDataset.delete\u001b[0;34m(self, predicate)\u001b[0m\n\u001b[1;32m 594\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(predicate, pa\u001b[38;5;241m.\u001b[39mcompute\u001b[38;5;241m.\u001b[39mExpression):\n\u001b[1;32m 595\u001b[0m predicate \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mstr\u001b[39m(predicate)\n\u001b[0;32m--> 596\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_ds\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdelete\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpredicate\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mOSError\u001b[0m: LanceError(IO): Error during planning: column foo does not exist" "\u001b[0;31mOSError\u001b[0m: LanceError(IO): Error during planning: column foo does not exist, /Users/runner/work/lance/lance/rust/lance-core/src/error.rs:212:23"
] ]
} }
], ],
@@ -712,7 +720,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 43, "execution_count": null,
"id": "87d5bc21-847f-4c81-b56e-f6dbe5d05aac", "id": "87d5bc21-847f-4c81-b56e-f6dbe5d05aac",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -729,7 +737,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 44, "execution_count": null,
"id": "9cba4519-eb3a-4941-ab7e-873d762e750f", "id": "9cba4519-eb3a-4941-ab7e-873d762e750f",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -742,7 +750,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 46, "execution_count": null,
"id": "5bdc9801-d5ed-4871-92d0-88b27108e788", "id": "5bdc9801-d5ed-4871-92d0-88b27108e788",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -817,7 +825,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.4" "version": "3.11.7"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@@ -58,6 +58,8 @@ pip install lancedb
::: lancedb.schema.vector ::: lancedb.schema.vector
::: lancedb.merge.LanceMergeInsertBuilder
## Integrations ## Integrations
### Pydantic ### Pydantic

View File

@@ -2,27 +2,26 @@
A vector search finds the approximate or exact nearest neighbors to a given query vector. A vector search finds the approximate or exact nearest neighbors to a given query vector.
* In a recommendation system or search engine, you can find similar records to - In a recommendation system or search engine, you can find similar records to
the one you searched. the one you searched.
* In LLM and other AI applications, - In LLM and other AI applications,
each data point can be represented by [embeddings generated from existing models](embeddings/index.md), each data point can be represented by [embeddings generated from existing models](embeddings/index.md),
following which the search returns the most relevant features. following which the search returns the most relevant features.
## Distance metrics ## Distance metrics
Distance metrics are a measure of the similarity between a pair of vectors. Distance metrics are a measure of the similarity between a pair of vectors.
Currently, LanceDB supports the following metrics: Currently, LanceDB supports the following metrics:
| Metric | Description | | Metric | Description |
| ----------- | ------------------------------------ | | -------- | --------------------------------------------------------------------------- |
| `l2` | [Euclidean / L2 distance](https://en.wikipedia.org/wiki/Euclidean_distance) | | `l2` | [Euclidean / L2 distance](https://en.wikipedia.org/wiki/Euclidean_distance) |
| `cosine` | [Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity)| | `cosine` | [Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity) |
| `dot` | [Dot Production](https://en.wikipedia.org/wiki/Dot_product) | | `dot` | [Dot Production](https://en.wikipedia.org/wiki/Dot_product) |
## Exhaustive search (kNN) ## Exhaustive search (kNN)
If you do not create a vector index, LanceDB exhaustively scans the *entire* vector space If you do not create a vector index, LanceDB exhaustively scans the _entire_ vector space
and compute the distance to every vector in order to find the exact nearest neighbors. This is effectively a kNN search. and compute the distance to every vector in order to find the exact nearest neighbors. This is effectively a kNN search.
<!-- Setup Code <!-- Setup Code
@@ -38,22 +37,9 @@ data = [{"vector": row, "item": f"item {i}"}
db.create_table("my_vectors", data=data) db.create_table("my_vectors", data=data)
``` ```
--> -->
<!-- Setup Code
```javascript
const vectordb_setup = require('vectordb')
const db_setup = await vectordb_setup.connect('data/sample-lancedb')
let data = []
for (let i = 0; i < 10_000; i++) {
data.push({vector: Array(1536).fill(i), id: `${i}`, content: "", longId: `${i}`},)
}
await db_setup.createTable('my_vectors', data)
```
-->
=== "Python" === "Python"
```python ```python
import lancedb import lancedb
import numpy as np import numpy as np
@@ -70,17 +56,12 @@ await db_setup.createTable('my_vectors', data)
=== "JavaScript" === "JavaScript"
```javascript ```javascript
const vectordb = require('vectordb') --8<-- "docs/src/search_legacy.ts:import"
const db = await vectordb.connect('data/sample-lancedb')
const tbl = await db.openTable("my_vectors") --8<-- "docs/src/search_legacy.ts:search1"
const results_1 = await tbl.search(Array(1536).fill(1.2))
.limit(10)
.execute()
``` ```
By default, `l2` will be used as metric type. You can specify the metric type as By default, `l2` will be used as metric type. You can specify the metric type as
`cosine` or `dot` if required. `cosine` or `dot` if required.
=== "Python" === "Python"
@@ -92,20 +73,16 @@ By default, `l2` will be used as metric type. You can specify the metric type as
.to_list() .to_list()
``` ```
=== "JavaScript" === "JavaScript"
```javascript ```javascript
const results_2 = await tbl.search(Array(1536).fill(1.2)) --8<-- "docs/src/search_legacy.ts:search2"
.metricType("cosine")
.limit(10)
.execute()
``` ```
## Approximate nearest neighbor (ANN) search ## Approximate nearest neighbor (ANN) search
To perform scalable vector retrieval with acceptable latencies, it's common to build a vector index. To perform scalable vector retrieval with acceptable latencies, it's common to build a vector index.
While the exhaustive search is guaranteed to always return 100% recall, the approximate nature of While the exhaustive search is guaranteed to always return 100% recall, the approximate nature of
an ANN search means that using an index often involves a trade-off between recall and latency. an ANN search means that using an index often involves a trade-off between recall and latency.
See the [IVF_PQ index](./concepts/index_ivfpq.md.md) for a deeper description of how `IVF_PQ` See the [IVF_PQ index](./concepts/index_ivfpq.md.md) for a deeper description of how `IVF_PQ`
@@ -117,7 +94,9 @@ LanceDB returns vector search results via different formats commonly used in pyt
Let's create a LanceDB table with a nested schema: Let's create a LanceDB table with a nested schema:
=== "Python" === "Python"
```python ```python
from datetime import datetime from datetime import datetime
import lancedb import lancedb
from lancedb.pydantic import LanceModel, Vector from lancedb.pydantic import LanceModel, Vector
@@ -153,7 +132,7 @@ Let's create a LanceDB table with a nested schema:
### As a PyArrow table ### As a PyArrow table
Using `to_arrow()` we can get the results back as a pyarrow Table. Using `to_arrow()` we can get the results back as a pyarrow Table.
This result table has the same columns as the LanceDB table, with This result table has the same columns as the LanceDB table, with
the addition of an `_distance` column for vector search or a `score` the addition of an `_distance` column for vector search or a `score`
column for full text search. column for full text search.
@@ -169,11 +148,11 @@ Let's create a LanceDB table with a nested schema:
tbl.search(np.random.randn(1536)).to_pandas() tbl.search(np.random.randn(1536)).to_pandas()
``` ```
While other formats like Arrow/Pydantic/Python dicts have a natural While other formats like Arrow/Pydantic/Python dicts have a natural
way to handle nested schemas, pandas can only store nested data as a way to handle nested schemas, pandas can only store nested data as a
python dict column, which makes it difficult to support nested references. python dict column, which makes it difficult to support nested references.
So for convenience, you can also tell LanceDB to flatten a nested schema So for convenience, you can also tell LanceDB to flatten a nested schema
when creating the pandas dataframe. when creating the pandas dataframe.
```python ```python
tbl.search(np.random.randn(1536)).to_pandas(flatten=True) tbl.search(np.random.randn(1536)).to_pandas(flatten=True)

41
docs/src/search_legacy.ts Normal file
View File

@@ -0,0 +1,41 @@
// --8<-- [start:import]
import * as lancedb from "vectordb";
// --8<-- [end:import]
import * as fs from "fs";
async function setup() {
fs.rmSync("data/sample-lancedb", { recursive: true, force: true });
const db = await lancedb.connect("data/sample-lancedb");
let data = [];
for (let i = 0; i < 10_000; i++) {
data.push({
vector: Array(1536).fill(i),
id: `${i}`,
content: "",
longId: `${i}`,
});
}
await db.createTable("my_vectors", data);
}
async () => {
await setup();
// --8<-- [start:search1]
const db = await lancedb.connect("data/sample-lancedb");
const tbl = await db.openTable("my_vectors");
const results_1 = await tbl.search(Array(1536).fill(1.2)).limit(10).execute();
// --8<-- [end:search1]
// --8<-- [start:search2]
const results_2 = await tbl
.search(Array(1536).fill(1.2))
.metricType(lancedb.MetricType.Cosine)
.limit(10)
.execute();
// --8<-- [end:search2]
console.log("search: done");
};

View File

@@ -8,7 +8,7 @@ option that performs the filter prior to vector search. This can be useful to na
the search space on a very large dataset to reduce query latency. the search space on a very large dataset to reduce query latency.
<!-- Setup Code <!-- Setup Code
```python ```python
import lancedb import lancedb
import numpy as np import numpy as np
uri = "data/sample-lancedb" uri = "data/sample-lancedb"
@@ -21,7 +21,7 @@ tbl = db.create_table("my_vectors", data=data)
``` ```
--> -->
<!-- Setup Code <!-- Setup Code
```javascript ```javascript
const vectordb = require('vectordb') const vectordb = require('vectordb')
const db = await vectordb.connect('data/sample-lancedb') const db = await vectordb.connect('data/sample-lancedb')
@@ -34,6 +34,7 @@ const tbl = await db.createTable('myVectors', data)
--> -->
=== "Python" === "Python"
```py ```py
result = ( result = (
tbl.search([0.5, 0.2]) tbl.search([0.5, 0.2])
@@ -44,12 +45,9 @@ const tbl = await db.createTable('myVectors', data)
``` ```
=== "JavaScript" === "JavaScript"
```javascript ```javascript
let result = await tbl.search(Array(1536).fill(0.5)) --8<-- "docs/src/sql_legacy.ts:search"
.limit(1)
.filter("id = 10")
.prefilter(true)
.execute()
``` ```
## SQL filters ## SQL filters
@@ -60,14 +58,14 @@ It can be used during vector search, update, and deletion operations.
Currently, Lance supports a growing list of SQL expressions. Currently, Lance supports a growing list of SQL expressions.
* ``>``, ``>=``, ``<``, ``<=``, ``=`` - `>`, `>=`, `<`, `<=`, `=`
* ``AND``, ``OR``, ``NOT`` - `AND`, `OR`, `NOT`
* ``IS NULL``, ``IS NOT NULL`` - `IS NULL`, `IS NOT NULL`
* ``IS TRUE``, ``IS NOT TRUE``, ``IS FALSE``, ``IS NOT FALSE`` - `IS TRUE`, `IS NOT TRUE`, `IS FALSE`, `IS NOT FALSE`
* ``IN`` - `IN`
* ``LIKE``, ``NOT LIKE`` - `LIKE`, `NOT LIKE`
* ``CAST`` - `CAST`
* ``regexp_match(column, pattern)`` - `regexp_match(column, pattern)`
For example, the following filter string is acceptable: For example, the following filter string is acceptable:
@@ -82,29 +80,27 @@ For example, the following filter string is acceptable:
=== "Javascript" === "Javascript"
```javascript ```javascript
await tbl.search(Array(1536).fill(0)) --8<-- "docs/src/sql_legacy.ts:vec_search"
.where("(item IN ('item 0', 'item 2')) AND (id > 10)")
.execute()
``` ```
If your column name contains special characters or is a [SQL Keyword](https://docs.rs/sqlparser/latest/sqlparser/keywords/index.html), If your column name contains special characters or is a [SQL Keyword](https://docs.rs/sqlparser/latest/sqlparser/keywords/index.html),
you can use backtick (`` ` ``) to escape it. For nested fields, each segment of the you can use backtick (`` ` ``) to escape it. For nested fields, each segment of the
path must be wrapped in backticks. path must be wrapped in backticks.
=== "SQL" === "SQL"
```sql ```sql
`CUBE` = 10 AND `column name with space` IS NOT NULL `CUBE` = 10 AND `column name with space` IS NOT NULL
AND `nested with space`.`inner with space` < 2 AND `nested with space`.`inner with space` < 2
``` ```
!!! warning !!!warning "Field names containing periods (`.`) are not supported."
Field names containing periods (``.``) are not supported.
Literals for dates, timestamps, and decimals can be written by writing the string Literals for dates, timestamps, and decimals can be written by writing the string
value after the type name. For example value after the type name. For example
=== "SQL" === "SQL"
```sql ```sql
date_col = date '2021-01-01' date_col = date '2021-01-01'
and timestamp_col = timestamp '2021-01-01 00:00:00' and timestamp_col = timestamp '2021-01-01 00:00:00'
@@ -114,49 +110,47 @@ value after the type name. For example
For timestamp columns, the precision can be specified as a number in the type For timestamp columns, the precision can be specified as a number in the type
parameter. Microsecond precision (6) is the default. parameter. Microsecond precision (6) is the default.
| SQL | Time unit | | SQL | Time unit |
|------------------|--------------| | -------------- | ------------ |
| ``timestamp(0)`` | Seconds | | `timestamp(0)` | Seconds |
| ``timestamp(3)`` | Milliseconds | | `timestamp(3)` | Milliseconds |
| ``timestamp(6)`` | Microseconds | | `timestamp(6)` | Microseconds |
| ``timestamp(9)`` | Nanoseconds | | `timestamp(9)` | Nanoseconds |
LanceDB internally stores data in [Apache Arrow](https://arrow.apache.org/) format. LanceDB internally stores data in [Apache Arrow](https://arrow.apache.org/) format.
The mapping from SQL types to Arrow types is: The mapping from SQL types to Arrow types is:
| SQL type | Arrow type | | SQL type | Arrow type |
|----------|------------| | --------------------------------------------------------- | ------------------ |
| ``boolean`` | ``Boolean`` | | `boolean` | `Boolean` |
| ``tinyint`` / ``tinyint unsigned`` | ``Int8`` / ``UInt8`` | | `tinyint` / `tinyint unsigned` | `Int8` / `UInt8` |
| ``smallint`` / ``smallint unsigned`` | ``Int16`` / ``UInt16`` | | `smallint` / `smallint unsigned` | `Int16` / `UInt16` |
| ``int`` or ``integer`` / ``int unsigned`` or ``integer unsigned`` | ``Int32`` / ``UInt32`` | | `int` or `integer` / `int unsigned` or `integer unsigned` | `Int32` / `UInt32` |
| ``bigint`` / ``bigint unsigned`` | ``Int64`` / ``UInt64`` | | `bigint` / `bigint unsigned` | `Int64` / `UInt64` |
| ``float`` | ``Float32`` | | `float` | `Float32` |
| ``double`` | ``Float64`` | | `double` | `Float64` |
| ``decimal(precision, scale)`` | ``Decimal128`` | | `decimal(precision, scale)` | `Decimal128` |
| ``date`` | ``Date32`` | | `date` | `Date32` |
| ``timestamp`` | ``Timestamp`` [^1] | | `timestamp` | `Timestamp` [^1] |
| ``string`` | ``Utf8`` | | `string` | `Utf8` |
| ``binary`` | ``Binary`` | | `binary` | `Binary` |
[^1]: See precision mapping in previous table. [^1]: See precision mapping in previous table.
## Filtering without Vector Search ## Filtering without Vector Search
You can also filter your data without search. You can also filter your data without search.
=== "Python" === "Python"
```python
tbl.search().where("id = 10").limit(10).to_arrow() ```python
``` tbl.search().where("id = 10").limit(10).to_arrow()
```
=== "JavaScript" === "JavaScript"
```javascript
await tbl.where('id = 10').limit(10).execute()
```
!!! warning ```javascript
If your table is large, this could potentially return a very large --8<---- "docs/src/sql_legacy.ts:sql_search"
amount of data. Please be sure to use a `limit` clause unless ```
you're sure you want to return the whole result set.
!!!warning "If your table is large, this could potentially return a very large amount of data. Please be sure to use a `limit` clause unless you're sure you want to return the whole result set."

38
docs/src/sql_legacy.ts Normal file
View File

@@ -0,0 +1,38 @@
import * as vectordb from "vectordb";
(async () => {
const db = await vectordb.connect("data/sample-lancedb");
let data = [];
for (let i = 0; i < 10_000; i++) {
data.push({
vector: Array(1536).fill(i),
id: i,
item: `item ${i}`,
strId: `${i}`,
});
}
const tbl = await db.createTable("myVectors", data);
// --8<-- [start:search]
let result = await tbl
.search(Array(1536).fill(0.5))
.limit(1)
.filter("id = 10")
.prefilter(true)
.execute();
// --8<-- [end:search]
// --8<-- [start:vec_search]
await tbl
.search(Array(1536).fill(0))
.where("(item IN ('item 0', 'item 2')) AND (id > 10)")
.execute();
// --8<-- [end:vec_search]
// --8<-- [start:sql_search]
await tbl.filter("id = 10").limit(10).execute();
// --8<-- [end:sql_search]
console.log("SQL search: done");
})();

View File

@@ -1,54 +0,0 @@
const glob = require("glob");
const fs = require("fs");
const path = require("path");
const globString = "../src/**/*.md";
const excludedGlobs = [
"../src/fts.md",
"../src/embedding.md",
"../src/examples/*.md",
"../src/guides/tables.md",
"../src/embeddings/*.md",
];
const nodePrefix = "javascript";
const nodeFile = ".js";
const nodeFolder = "node";
const asyncPrefix = "(async () => {\n";
const asyncSuffix = "})();";
function* yieldLines(lines, prefix, suffix) {
let inCodeBlock = false;
for (const line of lines) {
if (line.trim().startsWith(prefix + nodePrefix)) {
inCodeBlock = true;
} else if (inCodeBlock && line.trim().startsWith(suffix)) {
inCodeBlock = false;
yield "\n";
} else if (inCodeBlock) {
yield line;
}
}
}
const files = glob.sync(globString, { recursive: true });
const excludedFiles = glob.sync(excludedGlobs, { recursive: true });
for (const file of files.filter((file) => !excludedFiles.includes(file))) {
const lines = [];
const data = fs.readFileSync(file, "utf-8");
const fileLines = data.split("\n");
for (const line of yieldLines(fileLines, "```", "```")) {
lines.push(line);
}
if (lines.length > 0) {
const fileName = path.basename(file, ".md");
const outPath = path.join(nodeFolder, fileName, `${fileName}${nodeFile}`);
console.log(outPath)
fs.mkdirSync(path.dirname(outPath), { recursive: true });
fs.writeFileSync(outPath, asyncPrefix + "\n" + lines.join("\n") + asyncSuffix);
}
}

View File

@@ -14,6 +14,7 @@ excluded_globs = [
"../src/concepts/*.md", "../src/concepts/*.md",
"../src/ann_indexes.md", "../src/ann_indexes.md",
"../src/basic.md", "../src/basic.md",
"../src/hybrid_search.md",
] ]
python_prefix = "py" python_prefix = "py"
@@ -48,6 +49,7 @@ def yield_lines(lines: Iterator[str], prefix: str, suffix: str):
if not skip_test: if not skip_test:
yield line[strip_length:] yield line[strip_length:]
for file in filter(lambda file: file not in excluded_files, files): for file in filter(lambda file: file not in excluded_files, files):
with open(file, "r") as f: with open(file, "r") as f:
lines = list(yield_lines(iter(f), "```", "```")) lines = list(yield_lines(iter(f), "```", "```"))

View File

@@ -1,13 +0,0 @@
{
"name": "lancedb-docs-test",
"version": "1.0.0",
"description": "",
"author": "",
"license": "ISC",
"dependencies": {
"fs": "^0.0.1-security",
"glob": "^10.2.7",
"path": "^0.12.7",
"vectordb": "https://gitpkg.now.sh/lancedb/lancedb/node?main"
}
}

17
docs/tsconfig.json Normal file
View File

@@ -0,0 +1,17 @@
{
"include": [
"src/*.ts",
],
"compilerOptions": {
"target": "es2022",
"module": "nodenext",
"declaration": true,
"outDir": "./dist",
"strict": true,
"allowJs": true,
"resolveJsonModule": true,
},
"exclude": [
"./dist/*",
]
}

74
node/package-lock.json generated
View File

@@ -1,12 +1,12 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.4.4", "version": "0.4.7",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "vectordb", "name": "vectordb",
"version": "0.4.4", "version": "0.4.7",
"cpu": [ "cpu": [
"x64", "x64",
"arm64" "arm64"
@@ -53,11 +53,11 @@
"uuid": "^9.0.0" "uuid": "^9.0.0"
}, },
"optionalDependencies": { "optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.4.4", "@lancedb/vectordb-darwin-arm64": "0.4.7",
"@lancedb/vectordb-darwin-x64": "0.4.4", "@lancedb/vectordb-darwin-x64": "0.4.7",
"@lancedb/vectordb-linux-arm64-gnu": "0.4.4", "@lancedb/vectordb-linux-arm64-gnu": "0.4.7",
"@lancedb/vectordb-linux-x64-gnu": "0.4.4", "@lancedb/vectordb-linux-x64-gnu": "0.4.7",
"@lancedb/vectordb-win32-x64-msvc": "0.4.4" "@lancedb/vectordb-win32-x64-msvc": "0.4.7"
} }
}, },
"node_modules/@75lb/deep-merge": { "node_modules/@75lb/deep-merge": {
@@ -328,6 +328,66 @@
"@jridgewell/sourcemap-codec": "^1.4.10" "@jridgewell/sourcemap-codec": "^1.4.10"
} }
}, },
"node_modules/@lancedb/vectordb-darwin-arm64": {
"version": "0.4.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.4.7.tgz",
"integrity": "sha512-kACOIytgjBfX8NRwjPKe311XRN3lbSN13B7avT5htMd3kYm3AnnMag9tZhlwoO7lIuvGaXhy7mApygJrjhfJ4g==",
"cpu": [
"arm64"
],
"optional": true,
"os": [
"darwin"
]
},
"node_modules/@lancedb/vectordb-darwin-x64": {
"version": "0.4.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.4.7.tgz",
"integrity": "sha512-vb74iK5uPWCwz5E60r3yWp/R/HSg54/Z9AZWYckYXqsPv4w/nfbkM5iZhfRqqR/9uE6JClWJKOtjbk7b8CFRFg==",
"cpu": [
"x64"
],
"optional": true,
"os": [
"darwin"
]
},
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.4.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.4.7.tgz",
"integrity": "sha512-jHp7THm6S9sB8RaCxGoZXLAwGAUHnawUUilB1K3mvQsRdfB2bBs0f7wDehW+PDhr+Iog4LshaWbcnoQEUJWR+Q==",
"cpu": [
"arm64"
],
"optional": true,
"os": [
"linux"
]
},
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
"version": "0.4.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.4.7.tgz",
"integrity": "sha512-LKbVe6Wrp/AGqCCjKliNDmYoeTNgY/wfb2DTLjrx41Jko/04ywLrJ6xSEAn3XD5RDCO5u3fyUdXHHHv5a3VAAQ==",
"cpu": [
"x64"
],
"optional": true,
"os": [
"linux"
]
},
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
"version": "0.4.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.4.7.tgz",
"integrity": "sha512-C5ln4+wafeY1Sm4PeV0Ios9lUaQVVip5Mjl9XU7ngioSEMEuXI/XMVfIdVfDPppVNXPeQxg33wLA272uw88D1Q==",
"cpu": [
"x64"
],
"optional": true,
"os": [
"win32"
]
},
"node_modules/@neon-rs/cli": { "node_modules/@neon-rs/cli": {
"version": "0.0.160", "version": "0.0.160",
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz", "resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",

View File

@@ -1,12 +1,12 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.4.4", "version": "0.4.7",
"description": " Serverless, low-latency vector database for AI applications", "description": " Serverless, low-latency vector database for AI applications",
"main": "dist/index.js", "main": "dist/index.js",
"types": "dist/index.d.ts", "types": "dist/index.d.ts",
"scripts": { "scripts": {
"tsc": "tsc -b", "tsc": "tsc -b",
"build": "cargo-cp-artifact --artifact cdylib vectordb-node index.node -- cargo build --message-format=json && tsc -b", "build": "npm run tsc && cargo-cp-artifact --artifact cdylib vectordb-node index.node -- cargo build --message-format=json",
"build-release": "npm run build -- --release", "build-release": "npm run build -- --release",
"test": "npm run tsc && mocha -recursive dist/test", "test": "npm run tsc && mocha -recursive dist/test",
"integration-test": "npm run tsc && mocha -recursive dist/integration_test", "integration-test": "npm run tsc && mocha -recursive dist/integration_test",
@@ -17,7 +17,11 @@
}, },
"repository": { "repository": {
"type": "git", "type": "git",
"url": "https://github.com/lancedb/lancedb/node" "url": "https://github.com/lancedb/lancedb.git"
},
"homepage": "https://lancedb.github.io/lancedb/",
"bugs": {
"url": "https://github.com/lancedb/lancedb/issues"
}, },
"keywords": [ "keywords": [
"data-format", "data-format",
@@ -81,10 +85,10 @@
} }
}, },
"optionalDependencies": { "optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.4.4", "@lancedb/vectordb-darwin-arm64": "0.4.7",
"@lancedb/vectordb-darwin-x64": "0.4.4", "@lancedb/vectordb-darwin-x64": "0.4.7",
"@lancedb/vectordb-linux-arm64-gnu": "0.4.4", "@lancedb/vectordb-linux-arm64-gnu": "0.4.7",
"@lancedb/vectordb-linux-x64-gnu": "0.4.4", "@lancedb/vectordb-linux-x64-gnu": "0.4.7",
"@lancedb/vectordb-win32-x64-msvc": "0.4.4" "@lancedb/vectordb-win32-x64-msvc": "0.4.7"
} }
} }

View File

@@ -37,6 +37,7 @@ const {
tableCountRows, tableCountRows,
tableDelete, tableDelete,
tableUpdate, tableUpdate,
tableMergeInsert,
tableCleanupOldVersions, tableCleanupOldVersions,
tableCompactFiles, tableCompactFiles,
tableListIndices, tableListIndices,
@@ -163,6 +164,7 @@ export async function connect (
{ {
uri: '', uri: '',
awsCredentials: undefined, awsCredentials: undefined,
awsRegion: defaultAwsRegion,
apiKey: undefined, apiKey: undefined,
region: defaultAwsRegion region: defaultAwsRegion
}, },
@@ -174,7 +176,13 @@ export async function connect (
// Remote connection // Remote connection
return new RemoteConnection(opts) return new RemoteConnection(opts)
} }
const db = await databaseNew(opts.uri) const db = await databaseNew(
opts.uri,
opts.awsCredentials?.accessKeyId,
opts.awsCredentials?.secretKey,
opts.awsCredentials?.sessionToken,
opts.awsRegion
)
return new LocalConnection(db, opts) return new LocalConnection(db, opts)
} }
@@ -433,6 +441,38 @@ export interface Table<T = number[]> {
*/ */
update: (args: UpdateArgs | UpdateSqlArgs) => Promise<void> update: (args: UpdateArgs | UpdateSqlArgs) => Promise<void>
/**
* Runs a "merge insert" operation on the table
*
* This operation can add rows, update rows, and remove rows all in a single
* transaction. It is a very generic tool that can be used to create
* behaviors like "insert if not exists", "update or insert (i.e. upsert)",
* or even replace a portion of existing data with new data (e.g. replace
* all data where month="january")
*
* The merge insert operation works by combining new data from a
* **source table** with existing data in a **target table** by using a
* join. There are three categories of records.
*
* "Matched" records are records that exist in both the source table and
* the target table. "Not matched" records exist only in the source table
* (e.g. these are new data) "Not matched by source" records exist only
* in the target table (this is old data)
*
* The MergeInsertArgs can be used to customize what should happen for
* each category of data.
*
* Please note that the data may appear to be reordered as part of this
* operation. This is because updated rows will be deleted from the
* dataset and then reinserted at the end with the new values.
*
* @param on a column to join on. This is how records from the source
* table and target table are matched.
* @param data the new data to insert
* @param args parameters controlling how the operation should behave
*/
mergeInsert: (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs) => Promise<void>
/** /**
* List the indicies on this table. * List the indicies on this table.
*/ */
@@ -443,6 +483,8 @@ export interface Table<T = number[]> {
*/ */
indexStats: (indexUuid: string) => Promise<IndexStats> indexStats: (indexUuid: string) => Promise<IndexStats>
filter(value: string): Query<T>
schema: Promise<Schema> schema: Promise<Schema>
} }
@@ -474,6 +516,36 @@ export interface UpdateSqlArgs {
valuesSql: Record<string, string> valuesSql: Record<string, string>
} }
export interface MergeInsertArgs {
/**
* If true then rows that exist in both the source table (new data) and
* the target table (old data) will be updated, replacing the old row
* with the corresponding matching row.
*
* If there are multiple matches then the behavior is undefined.
* Currently this causes multiple copies of the row to be created
* but that behavior is subject to change.
*/
whenMatchedUpdateAll?: boolean
/**
* If true then rows that exist only in the source table (new data)
* will be inserted into the target table.
*/
whenNotMatchedInsertAll?: boolean
/**
* If true then rows that exist only in the target table (old data)
* will be deleted.
*
* If this is a string then it will be treated as an SQL filter and
* only rows that both do not match any row in the source table and
* match the given filter will be deleted.
*
* This can be used to replace a selection of existing data with
* new data.
*/
whenNotMatchedBySourceDelete?: string | boolean
}
export interface VectorIndex { export interface VectorIndex {
columns: string[] columns: string[]
name: string name: string
@@ -812,6 +884,38 @@ export class LocalTable<T = number[]> implements Table<T> {
}) })
} }
async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
const whenMatchedUpdateAll = args.whenMatchedUpdateAll ?? false
const whenNotMatchedInsertAll = args.whenNotMatchedInsertAll ?? false
let whenNotMatchedBySourceDelete = false
let whenNotMatchedBySourceDeleteFilt = null
if (args.whenNotMatchedBySourceDelete !== undefined && args.whenNotMatchedBySourceDelete !== null) {
whenNotMatchedBySourceDelete = true
if (args.whenNotMatchedBySourceDelete !== true) {
whenNotMatchedBySourceDeleteFilt = args.whenNotMatchedBySourceDelete
}
}
const schema = await this.schema
let tbl: ArrowTable
if (data instanceof ArrowTable) {
tbl = data
} else {
tbl = makeArrowTable(data, { schema })
}
const buffer = await fromTableToBuffer(tbl, this._embeddings, schema)
this._tbl = await tableMergeInsert.call(
this._tbl,
on,
whenMatchedUpdateAll,
whenNotMatchedInsertAll,
whenNotMatchedBySourceDelete,
whenNotMatchedBySourceDeleteFilt,
buffer
)
}
/** /**
* Clean up old versions of the table, freeing disk space. * Clean up old versions of the table, freeing disk space.
* *

View File

@@ -24,7 +24,8 @@ import {
type IndexStats, type IndexStats,
type UpdateArgs, type UpdateArgs,
type UpdateSqlArgs, type UpdateSqlArgs,
makeArrowTable makeArrowTable,
type MergeInsertArgs
} from '../index' } from '../index'
import { Query } from '../query' import { Query } from '../query'
@@ -270,6 +271,56 @@ export class RemoteTable<T = number[]> implements Table<T> {
return new RemoteQuery(query, this._client, this._name) //, this._embeddings_new) return new RemoteQuery(query, this._client, this._name) //, this._embeddings_new)
} }
filter (where: string): Query<T> {
throw new Error('Not implemented')
}
async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
let tbl: ArrowTable
if (data instanceof ArrowTable) {
tbl = data
} else {
tbl = makeArrowTable(data, await this.schema)
}
const queryParams: any = {
on
}
if (args.whenMatchedUpdateAll ?? false) {
queryParams.when_matched_update_all = 'true'
} else {
queryParams.when_matched_update_all = 'false'
}
if (args.whenNotMatchedInsertAll ?? false) {
queryParams.when_not_matched_insert_all = 'true'
} else {
queryParams.when_not_matched_insert_all = 'false'
}
if (args.whenNotMatchedBySourceDelete !== false && args.whenNotMatchedBySourceDelete !== null && args.whenNotMatchedBySourceDelete !== undefined) {
queryParams.when_not_matched_by_source_delete = 'true'
if (typeof args.whenNotMatchedBySourceDelete === 'string') {
queryParams.when_not_matched_by_source_delete_filt = args.whenNotMatchedBySourceDelete
}
} else {
queryParams.when_not_matched_by_source_delete = 'false'
}
const buffer = await fromTableToStreamBuffer(tbl, this._embeddings)
const res = await this._client.post(
`/v1/table/${this._name}/merge_insert/`,
buffer,
queryParams,
'application/vnd.apache.arrow.stream'
)
if (res.status !== 200) {
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
)
}
}
async add (data: Array<Record<string, unknown>> | ArrowTable): Promise<number> { async add (data: Array<Record<string, unknown>> | ArrowTable): Promise<number> {
let tbl: ArrowTable let tbl: ArrowTable
if (data instanceof ArrowTable) { if (data instanceof ArrowTable) {

View File

@@ -391,24 +391,6 @@ describe('LanceDB client', function () {
}) })
}).timeout(120000) }).timeout(120000)
it('fails to create a new table when the vector column is missing', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const data = [
{
id: 1,
price: 10
}
]
const create = con.createTable('missing_vector', data)
await expect(create).to.be.rejectedWith(
Error,
"column 'vector' is missing"
)
})
it('use overwrite flag to overwrite existing table', async function () { it('use overwrite flag to overwrite existing table', async function () {
const dir = await track().mkdir('lancejs') const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir) const con = await lancedb.connect(dir)
@@ -549,6 +531,44 @@ describe('LanceDB client', function () {
assert.equal(await table.countRows(), 2) assert.equal(await table.countRows(), 2)
}) })
it('can merge insert records into the table', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const data = [{ id: 1, age: 1 }, { id: 2, age: 1 }]
const table = await con.createTable('my_table', data)
let newData = [{ id: 2, age: 2 }, { id: 3, age: 2 }]
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true
})
assert.equal(await table.countRows(), 3)
assert.equal((await table.filter('age = 2').execute()).length, 1)
newData = [{ id: 3, age: 3 }, { id: 4, age: 3 }]
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true,
whenMatchedUpdateAll: true
})
assert.equal(await table.countRows(), 4)
assert.equal((await table.filter('age = 3').execute()).length, 2)
newData = [{ id: 5, age: 4 }]
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true,
whenMatchedUpdateAll: true,
whenNotMatchedBySourceDelete: 'age < 3'
})
assert.equal(await table.countRows(), 3)
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true,
whenMatchedUpdateAll: true,
whenNotMatchedBySourceDelete: true
})
assert.equal(await table.countRows(), 1)
})
it('can update records in the table', async function () { it('can update records in the table', async function () {
const uri = await createTestDB() const uri = await createTestDB()
const con = await lancedb.connect(uri) const con = await lancedb.connect(uri)

View File

@@ -10,14 +10,15 @@ crate-type = ["cdylib"]
[dependencies] [dependencies]
arrow-ipc.workspace = true arrow-ipc.workspace = true
futures.workspace = true
lance-linalg.workspace = true
lance.workspace = true
vectordb = { path = "../rust/vectordb" }
napi = { version = "2.14", default-features = false, features = [ napi = { version = "2.14", default-features = false, features = [
"napi7", "napi7",
"async" "async"
] } ] }
napi-derive = "2.14" napi-derive = "2.14"
vectordb = { path = "../rust/vectordb" }
lance.workspace = true
lance-linalg.workspace = true
[build-dependencies] [build-dependencies]
napi-build = "2.1" napi-build = "2.1"

View File

@@ -53,6 +53,16 @@ describe("Test creating index", () => {
const indexDir = path.join(tmpDir, "test.lance", "_indices"); const indexDir = path.join(tmpDir, "test.lance", "_indices");
expect(fs.readdirSync(indexDir)).toHaveLength(1); expect(fs.readdirSync(indexDir)).toHaveLength(1);
// TODO: check index type. // TODO: check index type.
// Search without specifying the column
let query_vector = data.toArray()[5].vec.toJSON();
let rst = await tbl.query().nearestTo(query_vector).limit(2).toArrow();
expect(rst.numRows).toBe(2);
// Search with specifying the column
let rst2 = await tbl.search(query_vector, "vec").limit(2).toArrow();
expect(rst2.numRows).toBe(2);
expect(rst.toString()).toEqual(rst2.toString());
}); });
test("no vector column available", async () => { test("no vector column available", async () => {
@@ -71,6 +81,80 @@ describe("Test creating index", () => {
await tbl.createIndex("val").build(); await tbl.createIndex("val").build();
const indexDir = path.join(tmpDir, "no_vec.lance", "_indices"); const indexDir = path.join(tmpDir, "no_vec.lance", "_indices");
expect(fs.readdirSync(indexDir)).toHaveLength(1); expect(fs.readdirSync(indexDir)).toHaveLength(1);
for await (const r of tbl.query().filter("id > 1").select(["id"])) {
expect(r.numRows).toBe(1);
}
});
test("two columns with different dimensions", async () => {
const db = await connect(tmpDir);
const schema = new Schema([
new Field("id", new Int32(), true),
new Field("vec", new FixedSizeList(32, new Field("item", new Float32()))),
new Field(
"vec2",
new FixedSizeList(64, new Field("item", new Float32()))
),
]);
const tbl = await db.createTable(
"two_vectors",
makeArrowTable(
Array(300)
.fill(1)
.map((_, i) => ({
id: i,
vec: Array(32)
.fill(1)
.map(() => Math.random()),
vec2: Array(64) // different dimension
.fill(1)
.map(() => Math.random()),
})),
{ schema }
)
);
// Only build index over v1
await expect(tbl.createIndex().build()).rejects.toThrow(
/.*More than one vector columns found.*/
);
tbl
.createIndex("vec")
.ivf_pq({ num_partitions: 2, num_sub_vectors: 2 })
.build();
const rst = await tbl
.query()
.nearestTo(
Array(32)
.fill(1)
.map(() => Math.random())
)
.limit(2)
.toArrow();
expect(rst.numRows).toBe(2);
// Search with specifying the column
await expect(
tbl
.search(
Array(64)
.fill(1)
.map(() => Math.random()),
"vec"
)
.limit(2)
.toArrow()
).rejects.toThrow(/.*does not match the dimension.*/);
const query64 = Array(64)
.fill(1)
.map(() => Math.random());
const rst64_1 = await tbl.query().nearestTo(query64).limit(2).toArrow();
const rst64_2 = await tbl.search(query64, "vec2").limit(2).toArrow();
expect(rst64_1.toString()).toEqual(rst64_2.toString());
expect(rst64_1.numRows).toBe(2);
}); });
test("create scalar index", async () => { test("create scalar index", async () => {

View File

@@ -91,7 +91,6 @@ impl IndexBuilder {
#[napi] #[napi]
pub async fn build(&self) -> napi::Result<()> { pub async fn build(&self) -> napi::Result<()> {
println!("nodejs::index.rs : build");
self.inner self.inner
.build() .build()
.await .await

47
nodejs/src/iterator.rs Normal file
View File

@@ -0,0 +1,47 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use futures::StreamExt;
use lance::io::RecordBatchStream;
use napi::bindgen_prelude::*;
use napi_derive::napi;
use vectordb::ipc::batches_to_ipc_file;
/** Typescript-style Async Iterator over RecordBatches */
#[napi]
pub struct RecordBatchIterator {
inner: Box<dyn RecordBatchStream + Unpin>,
}
#[napi]
impl RecordBatchIterator {
pub(crate) fn new(inner: Box<dyn RecordBatchStream + Unpin>) -> Self {
Self { inner }
}
#[napi]
pub async unsafe fn next(&mut self) -> napi::Result<Option<Buffer>> {
if let Some(rst) = self.inner.next().await {
let batch = rst.map_err(|e| {
napi::Error::from_reason(format!("Failed to get next batch from stream: {}", e))
})?;
batches_to_ipc_file(&[batch])
.map_err(|e| napi::Error::from_reason(format!("Failed to write IPC file: {}", e)))
.map(|buf| Some(Buffer::from(buf)))
} else {
// We are done with the stream.
Ok(None)
}
}
}

View File

@@ -17,6 +17,7 @@ use napi_derive::*;
mod connection; mod connection;
mod index; mod index;
mod iterator;
mod query; mod query;
mod table; mod table;

View File

@@ -16,7 +16,7 @@ use napi::bindgen_prelude::*;
use napi_derive::napi; use napi_derive::napi;
use vectordb::query::Query as LanceDBQuery; use vectordb::query::Query as LanceDBQuery;
use crate::table::Table; use crate::{iterator::RecordBatchIterator, table::Table};
#[napi] #[napi]
pub struct Query { pub struct Query {
@@ -32,17 +32,50 @@ impl Query {
} }
#[napi] #[napi]
pub fn vector(&mut self, vector: Float32Array) { pub fn column(&mut self, column: String) {
let inn = self.inner.clone().nearest_to(&vector); self.inner = self.inner.clone().column(&column);
self.inner = inn;
} }
#[napi] #[napi]
pub fn to_arrow(&self) -> napi::Result<()> { pub fn filter(&mut self, filter: String) {
// let buf = self.inner.to_arrow().map_err(|e| { self.inner = self.inner.clone().filter(filter);
// napi::Error::from_reason(format!("Failed to convert query to arrow: {}", e)) }
// })?;
// Ok(buf) #[napi]
todo!() pub fn select(&mut self, columns: Vec<String>) {
self.inner = self.inner.clone().select(&columns);
}
#[napi]
pub fn limit(&mut self, limit: u32) {
self.inner = self.inner.clone().limit(limit as usize);
}
#[napi]
pub fn prefilter(&mut self, prefilter: bool) {
self.inner = self.inner.clone().prefilter(prefilter);
}
#[napi]
pub fn nearest_to(&mut self, vector: Float32Array) {
self.inner = self.inner.clone().nearest_to(&vector);
}
#[napi]
pub fn refine_factor(&mut self, refine_factor: u32) {
self.inner = self.inner.clone().refine_factor(refine_factor);
}
#[napi]
pub fn nprobes(&mut self, nprobe: u32) {
self.inner = self.inner.clone().nprobes(nprobe as usize);
}
#[napi]
pub async fn execute_stream(&self) -> napi::Result<RecordBatchIterator> {
let inner_stream = self.inner.execute_stream().await.map_err(|e| {
napi::Error::from_reason(format!("Failed to execute query stream: {}", e))
})?;
Ok(RecordBatchIterator::new(Box::new(inner_stream)))
} }
} }

View File

@@ -54,9 +54,20 @@ export class IndexBuilder {
scalar(): void scalar(): void
build(): Promise<void> build(): Promise<void>
} }
/** Typescript-style Async Iterator over RecordBatches */
export class RecordBatchIterator {
next(): Promise<Buffer | null>
}
export class Query { export class Query {
vector(vector: Float32Array): void column(column: string): void
toArrow(): void filter(filter: string): void
select(columns: Array<string>): void
limit(limit: number): void
prefilter(prefilter: boolean): void
nearestTo(vector: Float32Array): void
refineFactor(refineFactor: number): void
nprobes(nprobe: number): void
executeStream(): Promise<RecordBatchIterator>
} }
export class Table { export class Table {
/** Return Schema as empty Arrow IPC file. */ /** Return Schema as empty Arrow IPC file. */

View File

@@ -295,12 +295,13 @@ if (!nativeBinding) {
throw new Error(`Failed to load native binding`) throw new Error(`Failed to load native binding`)
} }
const { Connection, IndexType, MetricType, IndexBuilder, Query, Table, WriteMode, connect } = nativeBinding const { Connection, IndexType, MetricType, IndexBuilder, RecordBatchIterator, Query, Table, WriteMode, connect } = nativeBinding
module.exports.Connection = Connection module.exports.Connection = Connection
module.exports.IndexType = IndexType module.exports.IndexType = IndexType
module.exports.MetricType = MetricType module.exports.MetricType = MetricType
module.exports.IndexBuilder = IndexBuilder module.exports.IndexBuilder = IndexBuilder
module.exports.RecordBatchIterator = RecordBatchIterator
module.exports.Query = Query module.exports.Query = Query
module.exports.Table = Table module.exports.Table = Table
module.exports.WriteMode = WriteMode module.exports.WriteMode = WriteMode

View File

@@ -12,46 +12,73 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
import { RecordBatch } from "apache-arrow"; import { RecordBatch, tableFromIPC, Table as ArrowTable } from "apache-arrow";
import { Table } from "./table"; import {
RecordBatchIterator as NativeBatchIterator,
Query as NativeQuery,
Table as NativeTable,
} from "./native";
// TODO: re-eanble eslint once we have a real implementation
/* eslint-disable */
class RecordBatchIterator implements AsyncIterator<RecordBatch> { class RecordBatchIterator implements AsyncIterator<RecordBatch> {
next( private promised_inner?: Promise<NativeBatchIterator>;
...args: [] | [undefined] private inner?: NativeBatchIterator;
): Promise<IteratorResult<RecordBatch<any>, any>> {
throw new Error("Method not implemented."); constructor(
inner?: NativeBatchIterator,
promise?: Promise<NativeBatchIterator>
) {
// TODO: check promise reliably so we dont need to pass two arguments.
this.inner = inner;
this.promised_inner = promise;
} }
return?(value?: any): Promise<IteratorResult<RecordBatch<any>, any>> {
throw new Error("Method not implemented."); async next(): Promise<IteratorResult<RecordBatch<any>, any>> {
} if (this.inner === undefined) {
throw?(e?: any): Promise<IteratorResult<RecordBatch<any>, any>> { this.inner = await this.promised_inner;
throw new Error("Method not implemented."); }
if (this.inner === undefined) {
throw new Error("Invalid iterator state state");
}
const n = await this.inner.next();
if (n == null) {
return Promise.resolve({ done: true, value: null });
}
const tbl = tableFromIPC(n);
if (tbl.batches.length != 1) {
throw new Error("Expected only one batch");
}
return Promise.resolve({ done: false, value: tbl.batches[0] });
} }
} }
/* eslint-enable */ /* eslint-enable */
/** Query executor */ /** Query executor */
export class Query implements AsyncIterable<RecordBatch> { export class Query implements AsyncIterable<RecordBatch> {
private readonly tbl: Table; private readonly inner: NativeQuery;
private _filter?: string;
private _limit?: number;
// Vector search constructor(tbl: NativeTable) {
private _vector?: Float32Array; this.inner = tbl.query();
private _nprobes?: number; }
private _refine_factor?: number = 1;
constructor(tbl: Table) { /** Set the column to run query. */
this.tbl = tbl; column(column: string): Query {
this.inner.column(column);
return this;
} }
/** Set the filter predicate, only returns the results that satisfy the filter. /** Set the filter predicate, only returns the results that satisfy the filter.
* *
*/ */
filter(predicate: string): Query { filter(predicate: string): Query {
this._filter = predicate; this.inner.filter(predicate);
return this;
}
/**
* Select the columns to return. If not set, all columns are returned.
*/
select(columns: string[]): Query {
this.inner.select(columns);
return this; return this;
} }
@@ -59,35 +86,67 @@ export class Query implements AsyncIterable<RecordBatch> {
* Set the limit of rows to return. * Set the limit of rows to return.
*/ */
limit(limit: number): Query { limit(limit: number): Query {
this._limit = limit; this.inner.limit(limit);
return this;
}
prefilter(prefilter: boolean): Query {
this.inner.prefilter(prefilter);
return this; return this;
} }
/** /**
* Set the query vector. * Set the query vector.
*/ */
vector(vector: number[]): Query { nearestTo(vector: number[]): Query {
this._vector = Float32Array.from(vector); this.inner.nearestTo(Float32Array.from(vector));
return this; return this;
} }
/** /**
* Set the number of probes to use for the query. * Set the number of IVF partitions to use for the query.
*/ */
nprobes(nprobes: number): Query { nprobes(nprobes: number): Query {
this._nprobes = nprobes; this.inner.nprobes(nprobes);
return this; return this;
} }
/** /**
* Set the refine factor for the query. * Set the refine factor for the query.
*/ */
refine_factor(refine_factor: number): Query { refineFactor(refine_factor: number): Query {
this._refine_factor = refine_factor; this.inner.refineFactor(refine_factor);
return this; return this;
} }
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>, any, undefined> { /**
throw new RecordBatchIterator(); * Execute the query and return the results as an AsyncIterator.
*/
async executeStream(): Promise<RecordBatchIterator> {
const inner = await this.inner.executeStream();
return new RecordBatchIterator(inner);
}
/** Collect the results as an Arrow Table. */
async toArrow(): Promise<ArrowTable> {
const batches = [];
for await (const batch of this) {
batches.push(batch);
}
return new ArrowTable(batches);
}
/** Returns a JSON Array of All results.
*
*/
async toArray(): Promise<any[]> {
const tbl = await this.toArrow();
// eslint-disable-next-line @typescript-eslint/no-unsafe-return
return tbl.toArray();
}
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>> {
const promise = this.inner.executeStream();
return new RecordBatchIterator(undefined, promise);
} }
} }

View File

@@ -95,10 +95,58 @@ export class Table {
return builder; return builder;
} }
search(vector?: number[]): Query { /**
const q = new Query(this); * Create a generic {@link Query} Builder.
if (vector !== undefined) { *
q.vector(vector); * When appropriate, various indices and statistics based pruning will be used to
* accelerate the query.
*
* @example
*
* ### Run a SQL-style query
* ```typescript
* for await (const batch of table.query()
* .filter("id > 1").select(["id"]).limit(20)) {
* console.log(batch);
* }
* ```
*
* ### Run Top-10 vector similarity search
* ```typescript
* for await (const batch of table.query()
* .nearestTo([1, 2, 3])
* .refineFactor(5).nprobe(10)
* .limit(10)) {
* console.log(batch);
* }
*```
*
* ### Scan the full dataset
* ```typescript
* for await (const batch of table.query()) {
* console.log(batch);
* }
*
* ### Return the full dataset as Arrow Table
* ```typescript
* let arrowTbl = await table.query().nearestTo([1.0, 2.0, 0.5, 6.7]).toArrow();
* ```
*
* @returns {@link Query}
*/
query(): Query {
return new Query(this.inner);
}
/** Search the table with a given query vector.
*
* This is a convenience method for preparing an ANN {@link Query}.
*/
search(vector: number[], column?: string): Query {
const q = this.query();
q.nearestTo(vector);
if (column !== undefined) {
q.column(column);
} }
return q; return q;
} }

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 0.5.1 current_version = 0.5.2
commit = True commit = True
message = [python] Bump version: {current_version} → {new_version} message = [python] Bump version: {current_version} → {new_version}
tag = True tag = True

View File

@@ -16,9 +16,9 @@ from typing import Iterable, List, Union
import numpy as np import numpy as np
import pyarrow as pa import pyarrow as pa
from .util import safe_import_pandas from .util import safe_import
pd = safe_import_pandas() pd = safe_import("pandas")
DATA = Union[List[dict], dict, "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]] DATA = Union[List[dict], dict, "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]]
VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray] VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray]

View File

@@ -16,9 +16,9 @@ import deprecation
from . import __version__ from . import __version__
from .exceptions import MissingColumnError, MissingValueError from .exceptions import MissingColumnError, MissingValueError
from .util import safe_import_pandas from .util import safe_import
pd = safe_import_pandas() pd = safe_import("pandas")
def contextualize(raw_df: "pd.DataFrame") -> Contextualizer: def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:

View File

@@ -13,6 +13,7 @@
# ruff: noqa: F401 # ruff: noqa: F401
from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction
from .bedrock import BedRockText
from .cohere import CohereEmbeddingFunction from .cohere import CohereEmbeddingFunction
from .gemini_text import GeminiText from .gemini_text import GeminiText
from .instructor import InstructorEmbeddingFunction from .instructor import InstructorEmbeddingFunction

View File

@@ -0,0 +1,223 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from functools import cached_property
from typing import List, Union
import numpy as np
from lancedb.pydantic import PYDANTIC_VERSION
from .base import TextEmbeddingFunction
from .registry import register
from .utils import TEXT
@register("bedrock-text")
class BedRockText(TextEmbeddingFunction):
"""
Parameters
----------
name: str, default "amazon.titan-embed-text-v1"
The model ID of the bedrock model to use. Supported models for are:
- amazon.titan-embed-text-v1
- cohere.embed-english-v3
- cohere.embed-multilingual-v3
region: str, default "us-east-1"
Optional name of the AWS Region in which the service should be called.
profile_name: str, default None
Optional name of the AWS profile to use for calling the Bedrock service.
If not specified, the default profile will be used.
assumed_role: str, default None
Optional ARN of an AWS IAM role to assume for calling the Bedrock service.
If not specified, the current active credentials will be used.
role_session_name: str, default "lancedb-embeddings"
Optional name of the AWS IAM role session to use for calling the Bedrock
service. If not specified, "lancedb-embeddings" name will be used.
Examples
--------
import lancedb
import pandas as pd
from lancedb.pydantic import LanceModel, Vector
model = get_registry().get("bedrock-text").create()
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect("tmp_path")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
rs = tbl.search("hello").limit(1).to_pandas()
"""
name: str = "amazon.titan-embed-text-v1"
region: str = "us-east-1"
assumed_role: Union[str, None] = None
profile_name: Union[str, None] = None
role_session_name: str = "lancedb-embeddings"
if PYDANTIC_VERSION < (2, 0): # Pydantic 1.x compat
class Config:
keep_untouched = (cached_property,)
def ndims(self):
# return len(self._generate_embedding("test"))
# TODO: fix hardcoding
if self.name == "amazon.titan-embed-text-v1":
return 1536
elif self.name in {"cohere.embed-english-v3", "cohere.embed-multilingual-v3"}:
return 1024
else:
raise ValueError(f"Unknown model name: {self.name}")
def compute_query_embeddings(
self, query: str, *args, **kwargs
) -> List[List[float]]:
return self.compute_source_embeddings(query)
def compute_source_embeddings(
self, texts: TEXT, *args, **kwargs
) -> List[List[float]]:
texts = self.sanitize_input(texts)
return self.generate_embeddings(texts)
def generate_embeddings(
self, texts: Union[List[str], np.ndarray], *args, **kwargs
) -> List[List[float]]:
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
Returns
-------
list[list[float]]
The embeddings for the given texts
"""
results = []
for text in texts:
response = self._generate_embedding(text)
results.append(response)
return results
def _generate_embedding(self, text: str) -> List[float]:
"""
Get the embeddings for the given texts
Parameters
----------
texts: str
The texts to embed
Returns
-------
list[float]
The embeddings for the given texts
"""
# format input body for provider
provider = self.name.split(".")[0]
_model_kwargs = {}
input_body = {**_model_kwargs}
if provider == "cohere":
if "input_type" not in input_body.keys():
input_body["input_type"] = "search_document"
input_body["texts"] = [text]
else:
# includes common provider == "amazon"
input_body["inputText"] = text
body = json.dumps(input_body)
try:
# invoke bedrock API
response = self.client.invoke_model(
body=body,
modelId=self.name,
accept="application/json",
contentType="application/json",
)
# format output based on provider
response_body = json.loads(response.get("body").read())
if provider == "cohere":
return response_body.get("embeddings")[0]
else:
# includes common provider == "amazon"
return response_body.get("embedding")
except Exception as e:
help_txt = """
boto3 client failed to invoke the bedrock API. In case of
AWS credentials error:
- Please check your AWS credentials and ensure that you have access.
You can set up aws credentials using `aws configure` command and
verify by running `aws sts get-caller-identity` in your terminal.
"""
raise ValueError(f"Error raised by boto3 client: {e}. \n {help_txt}")
@cached_property
def client(self):
"""Create a boto3 client for Amazon Bedrock service
Returns
-------
boto3.client
The boto3 client for Amazon Bedrock service
"""
botocore = self.safe_import("botocore")
boto3 = self.safe_import("boto3")
session_kwargs = {"region_name": self.region}
client_kwargs = {**session_kwargs}
if self.profile_name:
session_kwargs["profile_name"] = self.profile_name
retry_config = botocore.config.Config(
region_name=self.region,
retries={
"max_attempts": 0, # disable this as retries retries are handled
"mode": "standard",
},
)
session = (
boto3.Session(**session_kwargs) if self.profile_name else boto3.Session()
)
if self.assumed_role: # if not using default credentials
sts = session.client("sts")
response = sts.assume_role(
RoleArn=str(self.assumed_role),
RoleSessionName=self.role_session_name,
)
client_kwargs["aws_access_key_id"] = response["Credentials"]["AccessKeyId"]
client_kwargs["aws_secret_access_key"] = response["Credentials"][
"SecretAccessKey"
]
client_kwargs["aws_session_token"] = response["Credentials"]["SessionToken"]
service_name = "bedrock-runtime"
bedrock_client = session.client(
service_name=service_name, config=retry_config, **client_kwargs
)
return bedrock_client

View File

@@ -0,0 +1,130 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
import numpy as np
from .base import TextEmbeddingFunction
from .registry import register
from .utils import weak_lru
@register("gte-text")
class GteEmbeddings(TextEmbeddingFunction):
"""
An embedding function that uses GTE-LARGE MLX format(for Apple silicon devices only)
as well as the standard cpu/gpu version from: https://huggingface.co/thenlper/gte-large.
For Apple users, you will need the mlx package insalled, which can be done with:
pip install mlx
Parameters
----------
name: str, default "thenlper/gte-large"
The name of the model to use.
device: str, default "cpu"
Sets the device type for the model.
normalize: str, default "True"
Controls normalize param in encode function for the transformer.
mlx: bool, default False
Controls which model to use. False for gte-large,True for the mlx version.
Examples
--------
import lancedb
import lancedb.embeddings.gte
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
import pandas as pd
model = get_registry().get("gte-text").create() # mlx=True for Apple silicon
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hi hello sayonara", "goodbye world"]})
db = lancedb.connect("~/.lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
rs = tbl.search("hello").limit(1).to_pandas()
"""
name: str = "thenlper/gte-large"
device: str = "cpu"
normalize: bool = True
mlx: bool = False
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._ndims = None
if kwargs:
self.mlx = kwargs.get("mlx", False)
if self.mlx is True:
self.name = "gte-mlx"
@property
def embedding_model(self):
"""
Get the embedding model specified by the flag,
name and device. This is cached so that the model is only loaded
once per process.
"""
return self.get_embedding_model()
def ndims(self):
if self.mlx is True:
self._ndims = self.embedding_model.dims
if self._ndims is None:
self._ndims = len(self.generate_embeddings("foo")[0])
return self._ndims
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
"""
Get the embeddings for the given texts.
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
if self.mlx is True:
return self.embedding_model.run(list(texts)).tolist()
return self.embedding_model.encode(
list(texts),
convert_to_numpy=True,
normalize_embeddings=self.normalize,
).tolist()
@weak_lru(maxsize=1)
def get_embedding_model(self):
"""
Get the embedding model specified by the flag,
name and device. This is cached so that the model is only loaded
once per process.
"""
if self.mlx is True:
from .gte_mlx_model import Model
return Model()
else:
sentence_transformers = self.safe_import(
"sentence_transformers", "sentence-transformers"
)
return sentence_transformers.SentenceTransformer(
self.name, device=self.device
)

View File

@@ -0,0 +1,154 @@
import json
from typing import List, Optional
import numpy as np
from huggingface_hub import snapshot_download
from pydantic import BaseModel
from transformers import BertTokenizer
try:
import mlx.core as mx
import mlx.nn as nn
except ImportError:
raise ImportError("You need to install MLX to use this model use - pip install mlx")
def average_pool(last_hidden_state: mx.array, attention_mask: mx.array) -> mx.array:
last_hidden = mx.multiply(last_hidden_state, attention_mask[..., None])
return last_hidden.sum(axis=1) / attention_mask.sum(axis=1)[..., None]
class ModelConfig(BaseModel):
dim: int = 1024
num_attention_heads: int = 16
num_hidden_layers: int = 24
vocab_size: int = 30522
attention_probs_dropout_prob: float = 0.1
hidden_dropout_prob: float = 0.1
layer_norm_eps: float = 1e-12
max_position_embeddings: int = 512
class TransformerEncoderLayer(nn.Module):
"""
A transformer encoder layer with (the original BERT) post-normalization.
"""
def __init__(
self,
dims: int,
num_heads: int,
mlp_dims: Optional[int] = None,
layer_norm_eps: float = 1e-12,
):
super().__init__()
mlp_dims = mlp_dims or dims * 4
self.attention = nn.MultiHeadAttention(dims, num_heads, bias=True)
self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps)
self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps)
self.linear1 = nn.Linear(dims, mlp_dims)
self.linear2 = nn.Linear(mlp_dims, dims)
self.gelu = nn.GELU()
def __call__(self, x, mask):
attention_out = self.attention(x, x, x, mask)
add_and_norm = self.ln1(x + attention_out)
ff = self.linear1(add_and_norm)
ff_gelu = self.gelu(ff)
ff_out = self.linear2(ff_gelu)
x = self.ln2(ff_out + add_and_norm)
return x
class TransformerEncoder(nn.Module):
def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
):
super().__init__()
self.layers = [
TransformerEncoderLayer(dims, num_heads, mlp_dims)
for i in range(num_layers)
]
def __call__(self, x, mask):
for layer in self.layers:
x = layer(x, mask)
return x
class BertEmbeddings(nn.Module):
def __init__(self, config: ModelConfig):
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.token_type_embeddings = nn.Embedding(2, config.dim)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.dim
)
self.norm = nn.LayerNorm(config.dim, eps=config.layer_norm_eps)
def __call__(self, input_ids: mx.array, token_type_ids: mx.array) -> mx.array:
words = self.word_embeddings(input_ids)
position = self.position_embeddings(
mx.broadcast_to(mx.arange(input_ids.shape[1]), input_ids.shape)
)
token_types = self.token_type_embeddings(token_type_ids)
embeddings = position + words + token_types
return self.norm(embeddings)
class Bert(nn.Module):
def __init__(self, config: ModelConfig):
self.embeddings = BertEmbeddings(config)
self.encoder = TransformerEncoder(
num_layers=config.num_hidden_layers,
dims=config.dim,
num_heads=config.num_attention_heads,
)
self.pooler = nn.Linear(config.dim, config.dim)
def __call__(
self,
input_ids: mx.array,
token_type_ids: mx.array,
attention_mask: mx.array = None,
) -> tuple[mx.array, mx.array]:
x = self.embeddings(input_ids, token_type_ids)
if attention_mask is not None:
# convert 0's to -infs, 1's to 0's, and make it broadcastable
attention_mask = mx.log(attention_mask)
attention_mask = mx.expand_dims(attention_mask, (1, 2))
y = self.encoder(x, attention_mask)
return y, mx.tanh(self.pooler(y[:, 0]))
class Model:
def __init__(self) -> None:
# get converted embedding model
model_path = snapshot_download(repo_id="vegaluisjose/mlx-rag")
with open(f"{model_path}/config.json") as f:
model_config = ModelConfig(**json.load(f))
self.dims = model_config.dim
self.model = Bert(model_config)
self.model.load_weights(f"{model_path}/model.npz")
self.tokenizer = BertTokenizer.from_pretrained("thenlper/gte-large")
self.embeddings = []
def run(self, input_text: List[str]) -> mx.array:
tokens = self.tokenizer(input_text, return_tensors="np", padding=True)
tokens = {key: mx.array(v) for key, v in tokens.items()}
last_hidden_state, _ = self.model(**tokens)
embeddings = average_pool(
last_hidden_state, tokens["attention_mask"].astype(mx.float32)
)
self.embeddings = (
embeddings / mx.linalg.norm(embeddings, ord=2, axis=1)[..., None]
)
return np.array(embeddings.astype(mx.float32))

View File

@@ -26,10 +26,10 @@ import pyarrow as pa
from lance.vector import vec_to_table from lance.vector import vec_to_table
from retry import retry from retry import retry
from ..util import safe_import_pandas from ..util import safe_import
from ..utils.general import LOGGER from ..utils.general import LOGGER
pd = safe_import_pandas() pd = safe_import("pandas")
DATA = Union[pa.Table, "pd.DataFrame"] DATA = Union[pa.Table, "pd.DataFrame"]
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray] TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]

103
python/lancedb/merge.py Normal file
View File

@@ -0,0 +1,103 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional
if TYPE_CHECKING:
from .common import DATA
class LanceMergeInsertBuilder(object):
"""Builder for a LanceDB merge insert operation
See [`merge_insert`][lancedb.table.Table.merge_insert] for
more context
"""
def __init__(self, table: "Table", on: List[str]): # noqa: F821
# Do not put a docstring here. This method should be hidden
# from API docs. Users should use merge_insert to create
# this object.
self._table = table
self._on = on
self._when_matched_update_all = False
self._when_not_matched_insert_all = False
self._when_not_matched_by_source_delete = False
self._when_not_matched_by_source_condition = None
def when_matched_update_all(self) -> LanceMergeInsertBuilder:
"""
Rows that exist in both the source table (new data) and
the target table (old data) will be updated, replacing
the old row with the corresponding matching row.
If there are multiple matches then the behavior is undefined.
Currently this causes multiple copies of the row to be created
but that behavior is subject to change.
"""
self._when_matched_update_all = True
return self
def when_not_matched_insert_all(self) -> LanceMergeInsertBuilder:
"""
Rows that exist only in the source table (new data) should
be inserted into the target table.
"""
self._when_not_matched_insert_all = True
return self
def when_not_matched_by_source_delete(
self, condition: Optional[str] = None
) -> LanceMergeInsertBuilder:
"""
Rows that exist only in the target table (old data) will be
deleted. An optional condition can be provided to limit what
data is deleted.
Parameters
----------
condition: Optional[str], default None
If None then all such rows will be deleted. Otherwise the
condition will be used as an SQL filter to limit what rows
are deleted.
"""
self._when_not_matched_by_source_delete = True
if condition is not None:
self._when_not_matched_by_source_condition = condition
return self
def execute(
self,
new_data: DATA,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
):
"""
Executes the merge insert operation
Nothing is returned but the [`Table`][lancedb.table.Table] is updated
Parameters
----------
new_data: DATA
New records which will be matched against the existing records
to potentially insert or update into the table. This parameter
can be anything you use for [`add`][lancedb.table.Table.add]
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
self._table._do_merge(self, new_data, on_bad_vectors, fill_value)

View File

@@ -14,8 +14,9 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, List, Literal, Optional, Type, Union from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Type, Union
import deprecation import deprecation
import numpy as np import numpy as np
@@ -23,8 +24,10 @@ import pyarrow as pa
import pydantic import pydantic
from . import __version__ from . import __version__
from .common import VECTOR_COLUMN_NAME from .common import VEC, VECTOR_COLUMN_NAME
from .util import safe_import_pandas from .rerankers.base import Reranker
from .rerankers.linear_combination import LinearCombinationReranker
from .util import safe_import
if TYPE_CHECKING: if TYPE_CHECKING:
import PIL import PIL
@@ -33,7 +36,7 @@ if TYPE_CHECKING:
from .pydantic import LanceModel from .pydantic import LanceModel
from .table import Table from .table import Table
pd = safe_import_pandas() pd = safe_import("pandas")
class Query(pydantic.BaseModel): class Query(pydantic.BaseModel):
@@ -99,6 +102,8 @@ class Query(pydantic.BaseModel):
# Refine factor. # Refine factor.
refine_factor: Optional[int] = None refine_factor: Optional[int] = None
with_row_id: bool = False
class LanceQueryBuilder(ABC): class LanceQueryBuilder(ABC):
"""Build LanceDB query based on specific query type: """Build LanceDB query based on specific query type:
@@ -109,19 +114,26 @@ class LanceQueryBuilder(ABC):
def create( def create(
cls, cls,
table: "Table", table: "Table",
query: Optional[Union[np.ndarray, str, "PIL.Image.Image"]], query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
query_type: str, query_type: str,
vector_column_name: str, vector_column_name: str,
) -> LanceQueryBuilder: ) -> LanceQueryBuilder:
if query is None: if query is None:
return LanceEmptyQueryBuilder(table) return LanceEmptyQueryBuilder(table)
# convert "auto" query_type to "vector" or "fts" if query_type == "hybrid":
# and convert the query to vector if needed # hybrid fts and vector query
return LanceHybridQueryBuilder(table, query, vector_column_name)
# convert "auto" query_type to "vector", "fts"
# or "hybrid" and convert the query to vector if needed
query, query_type = cls._resolve_query( query, query_type = cls._resolve_query(
table, query, query_type, vector_column_name table, query, query_type, vector_column_name
) )
if query_type == "hybrid":
return LanceHybridQueryBuilder(table, query, vector_column_name)
if isinstance(query, str): if isinstance(query, str):
# fts # fts
return LanceFtsQueryBuilder(table, query) return LanceFtsQueryBuilder(table, query)
@@ -144,17 +156,13 @@ class LanceQueryBuilder(ABC):
raise TypeError(f"'fts' queries must be a string: {type(query)}") raise TypeError(f"'fts' queries must be a string: {type(query)}")
return query, query_type return query, query_type
elif query_type == "vector": elif query_type == "vector":
if not isinstance(query, (list, np.ndarray)): query = cls._query_to_vector(table, query, vector_column_name)
conf = table.embedding_functions.get(vector_column_name)
if conf is not None:
query = conf.function.compute_query_embeddings_with_retry(query)[0]
else:
msg = f"No embedding function for {vector_column_name}"
raise ValueError(msg)
return query, query_type return query, query_type
elif query_type == "auto": elif query_type == "auto":
if isinstance(query, (list, np.ndarray)): if isinstance(query, (list, np.ndarray)):
return query, "vector" return query, "vector"
if isinstance(query, tuple):
return query, "hybrid"
else: else:
conf = table.embedding_functions.get(vector_column_name) conf = table.embedding_functions.get(vector_column_name)
if conf is not None: if conf is not None:
@@ -167,11 +175,23 @@ class LanceQueryBuilder(ABC):
f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}" f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}"
) )
@classmethod
def _query_to_vector(cls, table, query, vector_column_name):
if isinstance(query, (list, np.ndarray)):
return query
conf = table.embedding_functions.get(vector_column_name)
if conf is not None:
return conf.function.compute_query_embeddings_with_retry(query)[0]
else:
msg = f"No embedding function for {vector_column_name}"
raise ValueError(msg)
def __init__(self, table: "Table"): def __init__(self, table: "Table"):
self._table = table self._table = table
self._limit = 10 self._limit = 10
self._columns = None self._columns = None
self._where = None self._where = None
self._with_row_id = False
@deprecation.deprecated( @deprecation.deprecated(
deprecated_in="0.3.1", deprecated_in="0.3.1",
@@ -341,6 +361,22 @@ class LanceQueryBuilder(ABC):
self._prefilter = prefilter self._prefilter = prefilter
return self return self
def with_row_id(self, with_row_id: bool) -> LanceQueryBuilder:
"""Set whether to return row ids.
Parameters
----------
with_row_id: bool
If True, return _rowid column in the results.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
self._with_row_id = with_row_id
return self
class LanceVectorQueryBuilder(LanceQueryBuilder): class LanceVectorQueryBuilder(LanceQueryBuilder):
""" """
@@ -459,6 +495,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
nprobes=self._nprobes, nprobes=self._nprobes,
refine_factor=self._refine_factor, refine_factor=self._refine_factor,
vector_column=self._vector_column, vector_column=self._vector_column,
with_row_id=self._with_row_id,
) )
return self._table._execute_query(query) return self._table._execute_query(query)
@@ -568,6 +605,10 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
ds = lance.write_dataset(output_tbl, tmp) ds = lance.write_dataset(output_tbl, tmp)
output_tbl = ds.to_table(filter=self._where) output_tbl = ds.to_table(filter=self._where)
if self._with_row_id:
# Need to set this to uint explicitly as vector results are in uint64
row_ids = pa.array(row_ids, type=pa.uint64())
output_tbl = output_tbl.append_column("_rowid", row_ids)
return output_tbl return output_tbl
@@ -579,3 +620,258 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
filter=self._where, filter=self._where,
limit=self._limit, limit=self._limit,
) )
class LanceHybridQueryBuilder(LanceQueryBuilder):
def __init__(self, table: "Table", query: str, vector_column: str):
super().__init__(table)
self._validate_fts_index()
self._query = query
vector_query, fts_query = self._validate_query(query)
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
vector_query = self._query_to_vector(table, vector_query, vector_column)
self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column)
self._norm = "score"
self._reranker = LinearCombinationReranker(weight=0.7, fill=1.0)
def _validate_fts_index(self):
if self._table._get_fts_index_path() is None:
raise ValueError(
"Please create a full-text search index " "to perform hybrid search."
)
def _validate_query(self, query):
# Temp hack to support vectorized queries for hybrid search
if isinstance(query, str):
return query, query
elif isinstance(query, tuple):
if len(query) != 2:
raise ValueError(
"The query must be a tuple of (vector_query, fts_query)."
)
if not isinstance(query[0], (list, np.ndarray, pa.Array, pa.ChunkedArray)):
raise ValueError(f"The vector query must be one of {VEC}.")
if not isinstance(query[1], str):
raise ValueError("The fts query must be a string.")
return query[0], query[1]
else:
raise ValueError(
"The query must be either a string or a tuple of (vector, string)."
)
def to_arrow(self) -> pa.Table:
with ThreadPoolExecutor() as executor:
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)
vector_future = executor.submit(
self._vector_query.with_row_id(True).to_arrow
)
fts_results = fts_future.result()
vector_results = vector_future.result()
# convert to ranks first if needed
if self._norm == "rank":
vector_results = self._rank(vector_results, "_distance")
fts_results = self._rank(fts_results, "score")
# normalize the scores to be between 0 and 1, 0 being most relevant
vector_results = self._normalize_scores(vector_results, "_distance")
# In fts higher scores represent relevance. Not inverting them here as
# rerankers might need to preserve this score to support `return_score="all"`
fts_results = self._normalize_scores(fts_results, "score")
results = self._reranker.rerank_hybrid(self, vector_results, fts_results)
if not isinstance(results, pa.Table): # Enforce type
raise TypeError(
f"rerank_hybrid must return a pyarrow.Table, got {type(results)}"
)
if not self._with_row_id:
results = results.drop(["_rowid"])
return results
def _rank(self, results: pa.Table, column: str, ascending: bool = True):
if len(results) == 0:
return results
# Get the _score column from results
scores = results.column(column).to_numpy()
sort_indices = np.argsort(scores)
if not ascending:
sort_indices = sort_indices[::-1]
ranks = np.empty_like(sort_indices)
ranks[sort_indices] = np.arange(len(scores)) + 1
# replace the _score column with the ranks
_score_idx = results.column_names.index(column)
results = results.set_column(
_score_idx, column, pa.array(ranks, type=pa.float32())
)
return results
def _normalize_scores(self, results: pa.Table, column: str, invert=False):
if len(results) == 0:
return results
# Get the _score column from results
scores = results.column(column).to_numpy()
# normalize the scores by subtracting the min and dividing by the max
max, min = np.max(scores), np.min(scores)
if np.isclose(max, min):
rng = max
else:
rng = max - min
scores = (scores - min) / rng
if invert:
scores = 1 - scores
# replace the _score column with the ranks
_score_idx = results.column_names.index(column)
results = results.set_column(
_score_idx, column, pa.array(scores, type=pa.float32())
)
return results
def rerank(
self,
normalize="score",
reranker: Reranker = LinearCombinationReranker(weight=0.7, fill=1.0),
) -> LanceHybridQueryBuilder:
"""
Rerank the hybrid search results using the specified reranker. The reranker
must be an instance of Reranker class.
Parameters
----------
normalize: str, default "score"
The method to normalize the scores. Can be "rank" or "score". If "rank",
the scores are converted to ranks and then normalized. If "score", the
scores are normalized directly.
reranker: Reranker, default LinearCombinationReranker(weight=0.7, fill=1.0)
The reranker to use. Must be an instance of Reranker class.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
if normalize not in ["rank", "score"]:
raise ValueError("normalize must be 'rank' or 'score'.")
if reranker and not isinstance(reranker, Reranker):
raise ValueError("reranker must be an instance of Reranker class.")
self._norm = normalize
self._reranker = reranker
return self
def limit(self, limit: int) -> LanceHybridQueryBuilder:
"""
Set the maximum number of results to return for both vector and fts search
components.
Parameters
----------
limit: int
The maximum number of results to return.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._vector_query.limit(limit)
self._fts_query.limit(limit)
return self
def select(self, columns: list) -> LanceHybridQueryBuilder:
"""
Set the columns to return for both vector and fts search.
Parameters
----------
columns: list
The columns to return.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._vector_query.select(columns)
self._fts_query.select(columns)
return self
def where(self, where: str, prefilter: bool = False) -> LanceHybridQueryBuilder:
"""
Set the where clause for both vector and fts search.
Parameters
----------
where: str
The where clause which is a valid SQL where clause. See
`Lance filter pushdown <https://lancedb.github.io/lance/read_and_write.html#filter-push-down>`_
for valid SQL expressions.
prefilter: bool, default False
If True, apply the filter before vector search, otherwise the
filter is applied on the result of vector search.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._vector_query.where(where, prefilter=prefilter)
self._fts_query.where(where)
return self
def metric(self, metric: Literal["L2", "cosine"]) -> LanceHybridQueryBuilder:
"""
Set the distance metric to use for vector search.
Parameters
----------
metric: "L2" or "cosine"
The distance metric to use. By default "L2" is used.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._vector_query.metric(metric)
return self
def nprobes(self, nprobes: int) -> LanceHybridQueryBuilder:
"""
Set the number of probes to use for vector search.
Higher values will yield better recall (more likely to find vectors if
they exist) at the expense of latency.
Parameters
----------
nprobes: int
The number of probes to use.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._vector_query.nprobes(nprobes)
return self
def refine_factor(self, refine_factor: int) -> LanceHybridQueryBuilder:
"""
Refine the vector search results by reading extra elements and
re-ranking them in memory.
Parameters
----------
refine_factor: int
The refine factor to use.
Returns
-------
LanceHybridQueryBuilder
The LanceHybridQueryBuilder object.
"""
self._vector_query.refine_factor(refine_factor)
return self

View File

@@ -13,6 +13,8 @@
import functools import functools
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
from urllib.parse import urljoin from urllib.parse import urljoin
@@ -20,6 +22,8 @@ import attrs
import pyarrow as pa import pyarrow as pa
import requests import requests
from pydantic import BaseModel from pydantic import BaseModel
from requests.adapters import HTTPAdapter
from urllib3 import Retry
from lancedb.common import Credential from lancedb.common import Credential
from lancedb.remote import VectorQuery, VectorQueryResult from lancedb.remote import VectorQuery, VectorQueryResult
@@ -57,6 +61,10 @@ class RestfulLanceDBClient:
@functools.cached_property @functools.cached_property
def session(self) -> requests.Session: def session(self) -> requests.Session:
sess = requests.Session() sess = requests.Session()
retry_adapter_instance = retry_adapter(retry_adapter_options())
sess.mount(urljoin(self.url, "/v1/table/"), retry_adapter_instance)
adapter_class = LanceDBClientHTTPAdapterFactory() adapter_class = LanceDBClientHTTPAdapterFactory()
sess.mount("https://", adapter_class()) sess.mount("https://", adapter_class())
return sess return sess
@@ -109,7 +117,7 @@ class RestfulLanceDBClient:
urljoin(self.url, uri), urljoin(self.url, uri),
params=params, params=params,
headers=self.headers, headers=self.headers,
timeout=(10.0, 300.0), timeout=(120.0, 300.0),
) as resp: ) as resp:
self._check_status(resp) self._check_status(resp)
return resp.json() return resp.json()
@@ -151,7 +159,7 @@ class RestfulLanceDBClient:
urljoin(self.url, uri), urljoin(self.url, uri),
headers=headers, headers=headers,
params=params, params=params,
timeout=(10.0, 300.0), timeout=(120.0, 300.0),
**req_kwargs, **req_kwargs,
) as resp: ) as resp:
self._check_status(resp) self._check_status(resp)
@@ -170,3 +178,72 @@ class RestfulLanceDBClient:
"""Query a table.""" """Query a table."""
tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc) tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc)
return VectorQueryResult(tbl) return VectorQueryResult(tbl)
def mount_retry_adapter_for_table(self, table_name: str) -> None:
"""
Adds an http adapter to session that will retry retryable requests to the table.
"""
retry_options = retry_adapter_options(methods=["GET", "POST"])
retry_adapter_instance = retry_adapter(retry_options)
session = self.session
session.mount(
urljoin(self.url, f"/v1/table/{table_name}/query/"), retry_adapter_instance
)
session.mount(
urljoin(self.url, f"/v1/table/{table_name}/describe/"),
retry_adapter_instance,
)
session.mount(
urljoin(self.url, f"/v1/table/{table_name}/index/list/"),
retry_adapter_instance,
)
def retry_adapter_options(methods=["GET"]) -> Dict[str, Any]:
return {
"retries": int(os.environ.get("LANCE_CLIENT_MAX_RETRIES", "3")),
"connect_retries": int(os.environ.get("LANCE_CLIENT_CONNECT_RETRIES", "3")),
"read_retries": int(os.environ.get("LANCE_CLIENT_READ_RETRIES", "3")),
"backoff_factor": float(
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_FACTOR", "0.25")
),
"backoff_jitter": float(
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_JITTER", "0.25")
),
"statuses": [
int(i.strip())
for i in os.environ.get(
"LANCE_CLIENT_RETRY_STATUSES", "429, 500, 502, 503"
).split(",")
],
"methods": methods,
}
def retry_adapter(options: Dict[str, Any]) -> HTTPAdapter:
total_retries = options["retries"]
connect_retries = options["connect_retries"]
read_retries = options["read_retries"]
backoff_factor = options["backoff_factor"]
backoff_jitter = options["backoff_jitter"]
statuses = options["statuses"]
methods = frozenset(options["methods"])
logging.debug(
f"Setting up retry adapter with {total_retries} retries," # noqa G003
+ f"connect retries {connect_retries}, read retries {read_retries},"
+ f"backoff factor {backoff_factor}, statuses {statuses}, "
+ f"methods {methods}"
)
return HTTPAdapter(
max_retries=Retry(
total=total_retries,
connect=connect_retries,
read=read_retries,
backoff_factor=backoff_factor,
backoff_jitter=backoff_jitter,
status_forcelist=statuses,
allowed_methods=methods,
)
)

View File

@@ -95,6 +95,8 @@ class RemoteDBConnection(DBConnection):
""" """
from .table import RemoteTable from .table import RemoteTable
self._client.mount_retry_adapter_for_table(name)
# check if table exists # check if table exists
try: try:
self._client.post(f"/v1/table/{name}/describe/") self._client.post(f"/v1/table/{name}/describe/")

View File

@@ -19,6 +19,7 @@ import pyarrow as pa
from lance import json_to_schema from lance import json_to_schema
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
from lancedb.merge import LanceMergeInsertBuilder
from ..query import LanceVectorQueryBuilder from ..query import LanceVectorQueryBuilder
from ..table import Query, Table, _sanitize_data from ..table import Query, Table, _sanitize_data
@@ -244,6 +245,47 @@ class RemoteTable(Table):
result = self._conn._client.query(self._name, query) result = self._conn._client.query(self._name, query)
return result.to_arrow() return result.to_arrow()
def _do_merge(
self,
merge: LanceMergeInsertBuilder,
new_data: DATA,
on_bad_vectors: str,
fill_value: float,
):
data = _sanitize_data(
new_data,
self.schema,
metadata=None,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
payload = to_ipc_binary(data)
params = {}
if len(merge._on) != 1:
raise ValueError(
"RemoteTable only supports a single on key in merge_insert"
)
params["on"] = merge._on[0]
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
params["when_not_matched_insert_all"] = str(
merge._when_not_matched_insert_all
).lower()
params["when_not_matched_by_source_delete"] = str(
merge._when_not_matched_by_source_delete
).lower()
if merge._when_not_matched_by_source_condition is not None:
params[
"when_not_matched_by_source_delete_filt"
] = merge._when_not_matched_by_source_condition
self._conn._client.post(
f"/v1/table/{self._name}/merge_insert/",
data=payload,
params=params,
content_type=ARROW_STREAM_CONTENT_TYPE,
)
def delete(self, predicate: str): def delete(self, predicate: str):
"""Delete rows from the table. """Delete rows from the table.
@@ -355,6 +397,18 @@ class RemoteTable(Table):
payload = {"predicate": where, "updates": updates} payload = {"predicate": where, "updates": updates}
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload) self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
def cleanup_old_versions(self, *_):
"""cleanup_old_versions() is not supported on the LanceDB cloud"""
raise NotImplementedError(
"cleanup_old_versions() is not supported on the LanceDB cloud"
)
def compact_files(self, *_):
"""compact_files() is not supported on the LanceDB cloud"""
raise NotImplementedError(
"compact_files() is not supported on the LanceDB cloud"
)
def add_index(tbl: pa.Table, i: int) -> pa.Table: def add_index(tbl: pa.Table, i: int) -> pa.Table:
return tbl.add_column( return tbl.add_column(

View File

@@ -0,0 +1,11 @@
from .base import Reranker
from .cohere import CohereReranker
from .cross_encoder import CrossEncoderReranker
from .linear_combination import LinearCombinationReranker
__all__ = [
"Reranker",
"CrossEncoderReranker",
"CohereReranker",
"LinearCombinationReranker",
]

View File

@@ -0,0 +1,109 @@
import typing
from abc import ABC, abstractmethod
import numpy as np
import pyarrow as pa
if typing.TYPE_CHECKING:
import lancedb
class Reranker(ABC):
def __init__(self, return_score: str = "relevance"):
"""
Interface for a reranker. A reranker is used to rerank the results from a
vector and FTS search. This is useful for combining the results from both
search methods.
Parameters
----------
return_score : str, default "relevance"
opntions are "relevance" or "all"
The type of score to return. If "relevance", will return only the relevance
score. If "all", will return all scores from the vector and FTS search along
with the relevance score.
"""
if return_score not in ["relevance", "all"]:
raise ValueError("score must be either 'relevance' or 'all'")
self.score = return_score
@abstractmethod
def rerank_hybrid(
query_builder: "lancedb.HybridQueryBuilder",
vector_results: pa.Table,
fts_results: pa.Table,
):
"""
Rerank function receives the individual results from the vector and FTS search
results. You can choose to use any of the results to generate the final results,
allowing maximum flexibility. This is mandatory to implement
Parameters
----------
query_builder : "lancedb.HybridQueryBuilder"
The query builder object that was used to generate the results
vector_results : pa.Table
The results from the vector search
fts_results : pa.Table
The results from the FTS search
"""
pass
def rerank_vector(
query_builder: "lancedb.VectorQueryBuilder", vector_results: pa.Table
):
"""
Rerank function receives the individual results from the vector search.
This isn't mandatory to implement
Parameters
----------
query_builder : "lancedb.VectorQueryBuilder"
The query builder object that was used to generate the results
vector_results : pa.Table
The results from the vector search
"""
raise NotImplementedError("Vector Reranking is not implemented")
def rerank_fts(query_builder: "lancedb.FTSQueryBuilder", fts_results: pa.Table):
"""
Rerank function receives the individual results from the FTS search.
This isn't mandatory to implement
Parameters
----------
query_builder : "lancedb.FTSQueryBuilder"
The query builder object that was used to generate the results
fts_results : pa.Table
The results from the FTS search
"""
raise NotImplementedError("FTS Reranking is not implemented")
def merge_results(self, vector_results: pa.Table, fts_results: pa.Table):
"""
Merge the results from the vector and FTS search. This is a vanilla merging
function that just concatenates the results and removes the duplicates.
NOTE: This doesn't take score into account. It'll keep the instance that was
encountered first. This is designed for rerankers that don't use the score.
In case you want to use the score, or support `return_scores="all"` you'll
have to implement your own merging function.
Parameters
----------
vector_results : pa.Table
The results from the vector search
fts_results : pa.Table
The results from the FTS search
"""
combined = pa.concat_tables([vector_results, fts_results], promote=True)
row_id = combined.column("_rowid")
# deduplicate
mask = np.full((combined.shape[0]), False)
_, mask_indices = np.unique(np.array(row_id), return_index=True)
mask[mask_indices] = True
combined = combined.filter(mask=mask)
return combined

View File

@@ -0,0 +1,85 @@
import os
import typing
from functools import cached_property
from typing import Union
import pyarrow as pa
from ..util import safe_import
from .base import Reranker
if typing.TYPE_CHECKING:
import lancedb
class CohereReranker(Reranker):
"""
Reranks the results using the Cohere Rerank API.
https://docs.cohere.com/docs/rerank-guide
Parameters
----------
model_name : str, default "rerank-english-v2.0"
The name of the cross encoder model to use. Available cohere models are:
- rerank-english-v2.0
- rerank-multilingual-v2.0
column : str, default "text"
The name of the column to use as input to the cross encoder model.
top_n : str, default None
The number of results to return. If None, will return all results.
"""
def __init__(
self,
model_name: str = "rerank-english-v2.0",
column: str = "text",
top_n: Union[int, None] = None,
return_score="relevance",
api_key: Union[str, None] = None,
):
super().__init__(return_score)
self.model_name = model_name
self.column = column
self.top_n = top_n
self.api_key = api_key
@cached_property
def _client(self):
cohere = safe_import("cohere")
if os.environ.get("COHERE_API_KEY") is None and self.api_key is None:
raise ValueError(
"COHERE_API_KEY not set. Either set it in your environment or \
pass it as `api_key` argument to the CohereReranker."
)
return cohere.Client(os.environ.get("COHERE_API_KEY") or self.api_key)
def rerank_hybrid(
self,
query_builder: "lancedb.HybridQueryBuilder",
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
docs = combined_results[self.column].to_pylist()
results = self._client.rerank(
query=query_builder._query,
documents=docs,
top_n=self.top_n,
model=self.model_name,
) # returns list (text, idx, relevance) attributes sorted descending by score
indices, scores = list(
zip(*[(result.index, result.relevance_score) for result in results])
) # tuples
combined_results = combined_results.take(list(indices))
# add the scores
combined_results = combined_results.append_column(
"_relevance_score", pa.array(scores, type=pa.float32())
)
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
elif self.score == "all":
raise NotImplementedError(
"return_score='all' not implemented for cohere reranker"
)
return combined_results

View File

@@ -0,0 +1,78 @@
import typing
from functools import cached_property
from typing import Union
import pyarrow as pa
from ..util import safe_import
from .base import Reranker
if typing.TYPE_CHECKING:
import lancedb
class CrossEncoderReranker(Reranker):
"""
Reranks the results using a cross encoder model. The cross encoder model is
used to score the query and each result. The results are then sorted by the score.
Parameters
----------
model : str, default "cross-encoder/ms-marco-TinyBERT-L-6"
The name of the cross encoder model to use. See the sentence transformers
documentation for a list of available models.
column : str, default "text"
The name of the column to use as input to the cross encoder model.
device : str, default None
The device to use for the cross encoder model. If None, will use "cuda"
if available, otherwise "cpu".
"""
def __init__(
self,
model_name: str = "cross-encoder/ms-marco-TinyBERT-L-6",
column: str = "text",
device: Union[str, None] = None,
return_score="relevance",
):
super().__init__(return_score)
torch = safe_import("torch")
self.model_name = model_name
self.column = column
self.device = device
if self.device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
@cached_property
def model(self):
sbert = safe_import("sentence_transformers")
cross_encoder = sbert.CrossEncoder(self.model_name)
return cross_encoder
def rerank_hybrid(
self,
query_builder: "lancedb.HybridQueryBuilder",
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
passages = combined_results[self.column].to_pylist()
cross_inp = [[query_builder._query, passage] for passage in passages]
cross_scores = self.model.predict(cross_inp)
combined_results = combined_results.append_column(
"_relevance_score", pa.array(cross_scores, type=pa.float32())
)
# sort the results by _score
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
elif self.score == "all":
raise NotImplementedError(
"return_score='all' not implemented for CrossEncoderReranker"
)
combined_results = combined_results.sort_by(
[("_relevance_score", "descending")]
)
return combined_results

View File

@@ -0,0 +1,117 @@
from typing import List
import pyarrow as pa
from .base import Reranker
class LinearCombinationReranker(Reranker):
"""
Reranks the results using a linear combination of the scores from the
vector and FTS search. For missing scores, fill with `fill` value.
Parameters
----------
weight : float, default 0.7
The weight to give to the vector score. Must be between 0 and 1.
fill : float, default 1.0
The score to give to results that are only in one of the two result sets.
This is treated as penalty, so a higher value means a lower score.
TODO: We should just hardcode this--
its pretty confusing as we invert scores to calculate final score
return_score : str, default "relevance"
opntions are "relevance" or "all"
The type of score to return. If "relevance", will return only the relevance
score. If "all", will return all scores from the vector and FTS search along
with the relevance score.
"""
def __init__(
self, weight: float = 0.7, fill: float = 1.0, return_score="relevance"
):
if weight < 0 or weight > 1:
raise ValueError("weight must be between 0 and 1.")
super().__init__(return_score)
self.weight = weight
self.fill = fill
def rerank_hybrid(
self,
query_builder: "lancedb.HybridQueryBuilder", # noqa: F821
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results, self.fill)
return combined_results
def merge_results(
self, vector_results: pa.Table, fts_results: pa.Table, fill: float
):
# If both are empty then just return an empty table
if len(vector_results) == 0 and len(fts_results) == 0:
return vector_results
# If one is empty then return the other
if len(vector_results) == 0:
return fts_results
if len(fts_results) == 0:
return vector_results
# sort both input tables on _rowid
combined_list = []
vector_list = vector_results.sort_by("_rowid").to_pylist()
fts_list = fts_results.sort_by("_rowid").to_pylist()
i, j = 0, 0
while i < len(vector_list):
if j >= len(fts_list):
for vi in vector_list[i:]:
vi["_relevance_score"] = self._combine_score(vi["_distance"], fill)
combined_list.append(vi)
break
vi = vector_list[i]
fj = fts_list[j]
# invert the fts score from relevance to distance
inverted_fts_score = self._invert_score(fj["score"])
if vi["_rowid"] == fj["_rowid"]:
vi["_relevance_score"] = self._combine_score(
vi["_distance"], inverted_fts_score
)
vi["score"] = fj["score"] # keep the original score
combined_list.append(vi)
i += 1
j += 1
elif vector_list[i]["_rowid"] < fts_list[j]["_rowid"]:
vi["_relevance_score"] = self._combine_score(vi["_distance"], fill)
combined_list.append(vi)
i += 1
else:
fj["_relevance_score"] = self._combine_score(inverted_fts_score, fill)
combined_list.append(fj)
j += 1
if j < len(fts_list) - 1:
for fj in fts_list[j:]:
fj["_relevance_score"] = self._combine_score(inverted_fts_score, fill)
combined_list.append(fj)
relevance_score_schema = pa.schema(
[
pa.field("_relevance_score", pa.float32()),
]
)
combined_schema = pa.unify_schemas(
[vector_results.schema, fts_results.schema, relevance_score_schema]
)
tbl = pa.Table.from_pylist(combined_list, schema=combined_schema).sort_by(
[("_relevance_score", "descending")]
)
if self.score == "relevance":
tbl = tbl.drop_columns(["score", "_distance"])
return tbl
def _combine_score(self, score1, score2):
# these scores represent distance
return 1 - (self.weight * score1 + (1 - self.weight) * score2)
def _invert_score(self, scores: List[float]):
# Invert the scores between relevance and distance
return 1 - scores

View File

@@ -16,7 +16,7 @@ from __future__ import annotations
import inspect import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import cached_property from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
import lance import lance
import numpy as np import numpy as np
@@ -28,13 +28,13 @@ from lance.vector import vec_to_table
from .common import DATA, VEC, VECTOR_COLUMN_NAME from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from .merge import LanceMergeInsertBuilder
from .pydantic import LanceModel, model_to_dict from .pydantic import LanceModel, model_to_dict
from .query import LanceQueryBuilder, Query from .query import LanceQueryBuilder, Query
from .util import ( from .util import (
fs_from_uri, fs_from_uri,
join_uri, join_uri,
safe_import_pandas, safe_import,
safe_import_polars,
value_to_sql, value_to_sql,
) )
from .utils.events import register_event from .utils.events import register_event
@@ -48,8 +48,8 @@ if TYPE_CHECKING:
from .db import LanceDBConnection from .db import LanceDBConnection
pd = safe_import_pandas() pd = safe_import("pandas")
pl = safe_import_polars() pl = safe_import("polars")
def _sanitize_data( def _sanitize_data(
@@ -335,10 +335,70 @@ class Table(ABC):
""" """
raise NotImplementedError raise NotImplementedError
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
"""
Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder]
that can be used to create a "merge insert" operation
This operation can add rows, update rows, and remove rows all in a single
transaction. It is a very generic tool that can be used to create
behaviors like "insert if not exists", "update or insert (i.e. upsert)",
or even replace a portion of existing data with new data (e.g. replace
all data where month="january")
The merge insert operation works by combining new data from a
**source table** with existing data in a **target table** by using a
join. There are three categories of records.
"Matched" records are records that exist in both the source table and
the target table. "Not matched" records exist only in the source table
(e.g. these are new data) "Not matched by source" records exist only
in the target table (this is old data)
The builder returned by this method can be used to customize what
should happen for each category of data.
Please note that the data may appear to be reordered as part of this
operation. This is because updated rows will be deleted from the
dataset and then reinserted at the end with the new values.
Parameters
----------
on: Union[str, Iterable[str]]
A column (or columns) to join on. This is how records from the
source table and target table are matched. Typically this is some
kind of key or id column.
Examples
--------
>>> import lancedb
>>> data = pa.table({"a": [2, 1, 3], "b": ["a", "b", "c"]})
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data)
>>> new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]})
>>> # Perform a "upsert" operation
>>> table.merge_insert("a") \\
... .when_matched_update_all() \\
... .when_not_matched_insert_all() \\
... .execute(new_data)
>>> # The order of new rows is non-deterministic since we use
>>> # a hash-join as part of this operation and so we sort here
>>> table.to_arrow().sort_by("a").to_pandas()
a b
0 1 b
1 2 x
2 3 y
3 4 z
"""
on = [on] if isinstance(on, str) else list(on.iter())
return LanceMergeInsertBuilder(self, on)
@abstractmethod @abstractmethod
def search( def search(
self, self,
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None, query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: str = VECTOR_COLUMN_NAME, vector_column_name: str = VECTOR_COLUMN_NAME,
query_type: str = "auto", query_type: str = "auto",
) -> LanceQueryBuilder: ) -> LanceQueryBuilder:
@@ -380,6 +440,8 @@ class Table(ABC):
the table the table
vector_column_name: str vector_column_name: str
The name of the vector column to search. The name of the vector column to search.
The vector column needs to be a pyarrow fixed size list type
*default "vector"* *default "vector"*
query_type: str query_type: str
*default "auto"*. *default "auto"*.
@@ -415,6 +477,16 @@ class Table(ABC):
def _execute_query(self, query: Query) -> pa.Table: def _execute_query(self, query: Query) -> pa.Table:
pass pass
@abstractmethod
def _do_merge(
self,
merge: LanceMergeInsertBuilder,
new_data: DATA,
on_bad_vectors: str,
fill_value: float,
):
pass
@abstractmethod @abstractmethod
def delete(self, where: str): def delete(self, where: str):
"""Delete rows from the table. """Delete rows from the table.
@@ -522,6 +594,52 @@ class Table(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def cleanup_old_versions(
self,
older_than: Optional[timedelta] = None,
*,
delete_unverified: bool = False,
) -> CleanupStats:
"""
Clean up old versions of the table, freeing disk space.
Note: This function is not available in LanceDb Cloud (since LanceDb
Cloud manages cleanup for you automatically)
Parameters
----------
older_than: timedelta, default None
The minimum age of the version to delete. If None, then this defaults
to two weeks.
delete_unverified: bool, default False
Because they may be part of an in-progress transaction, files newer
than 7 days old are not deleted by default. If you are sure that
there are no in-progress transactions, then you can set this to True
to delete all files older than `older_than`.
Returns
-------
CleanupStats
The stats of the cleanup operation, including how many bytes were
freed.
"""
@abstractmethod
def compact_files(self, *args, **kwargs):
"""
Run the compaction process on the table.
Note: This function is not available in LanceDb Cloud (since LanceDb
Cloud manages compaction for you automatically)
This can be run after making several small appends to optimize the table
for faster reads.
Arguments are passed onto :meth:`lance.dataset.DatasetOptimizer.compact_files`.
For most cases, the default should be fine.
"""
class LanceTable(Table): class LanceTable(Table):
""" """
@@ -924,7 +1042,7 @@ class LanceTable(Table):
def search( def search(
self, self,
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None, query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: str = VECTOR_COLUMN_NAME, vector_column_name: str = VECTOR_COLUMN_NAME,
query_type: str = "auto", query_type: str = "auto",
) -> LanceQueryBuilder: ) -> LanceQueryBuilder:
@@ -1194,8 +1312,34 @@ class LanceTable(Table):
"nprobes": query.nprobes, "nprobes": query.nprobes,
"refine_factor": query.refine_factor, "refine_factor": query.refine_factor,
}, },
with_row_id=query.with_row_id,
) )
def _do_merge(
self,
merge: LanceMergeInsertBuilder,
new_data: DATA,
on_bad_vectors: str,
fill_value: float,
):
new_data = _sanitize_data(
new_data,
self.schema,
metadata=self.schema.metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
ds = self.to_lance()
builder = ds.merge_insert(merge._on)
if merge._when_matched_update_all:
builder.when_matched_update_all()
if merge._when_not_matched_insert_all:
builder.when_not_matched_insert_all()
if merge._when_not_matched_by_source_delete:
cond = merge._when_not_matched_by_source_condition
builder.when_not_matched_by_source_delete(cond)
builder.execute(new_data)
def cleanup_old_versions( def cleanup_old_versions(
self, self,
older_than: Optional[timedelta] = None, older_than: Optional[timedelta] = None,
@@ -1233,8 +1377,9 @@ class LanceTable(Table):
This can be run after making several small appends to optimize the table This can be run after making several small appends to optimize the table
for faster reads. for faster reads.
Arguments are passed onto :meth:`lance.dataset.DatasetOptimizer.compact_files`. Arguments are passed onto `lance.dataset.DatasetOptimizer.compact_files`.
For most cases, the default should be fine. (see Lance documentation for more details) For most cases, the default
should be fine.
""" """
return self.to_lance().optimize.compact_files(*args, **kwargs) return self.to_lance().optimize.compact_files(*args, **kwargs)

View File

@@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
import os import os
import pathlib import pathlib
from datetime import date, datetime from datetime import date, datetime
@@ -114,22 +115,23 @@ def join_uri(base: Union[str, pathlib.Path], *parts: str) -> str:
return "/".join([p.rstrip("/") for p in [base, *parts]]) return "/".join([p.rstrip("/") for p in [base, *parts]])
def safe_import_pandas(): def safe_import(module: str, mitigation=None):
"""
Import the specified module. If the module is not installed,
raise an ImportError with a helpful message.
Parameters
----------
module : str
The name of the module to import
mitigation : Optional[str]
The package(s) to install to mitigate the error.
If not provided then the module name will be used.
"""
try: try:
import pandas as pd return importlib.import_module(module)
return pd
except ImportError: except ImportError:
return None raise ImportError(f"Please install {mitigation or module}")
def safe_import_polars():
try:
import polars as pl
return pl
except ImportError:
return None
@singledispatch @singledispatch

View File

@@ -1,9 +1,9 @@
[project] [project]
name = "lancedb" name = "lancedb"
version = "0.5.1" version = "0.5.2"
dependencies = [ dependencies = [
"deprecation", "deprecation",
"pylance==0.9.9", "pylance==0.9.12",
"ratelimiter~=1.0", "ratelimiter~=1.0",
"retry>=0.9.2", "retry>=0.9.2",
"tqdm>=4.27.0", "tqdm>=4.27.0",
@@ -52,7 +52,8 @@ tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "d
dev = ["ruff", "pre-commit"] dev = ["ruff", "pre-commit"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"] clip = ["torch", "pillow", "open-clip"]
embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "InstructorEmbedding"] embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "huggingface_hub",
"InstructorEmbedding", "google.generativeai", "boto3>=1.28.57", "awscli>=1.29.57", "botocore>=1.31.57"]
[project.scripts] [project.scripts]
lancedb = "lancedb.cli.cli:cli" lancedb = "lancedb.cli.cli:cli"
@@ -65,7 +66,8 @@ build-backend = "setuptools.build_meta"
select = ["F", "E", "W", "I", "G", "TCH", "PERF"] select = ["F", "E", "W", "I", "G", "TCH", "PERF"]
[tool.pytest.ini_options] [tool.pytest.ini_options]
addopts = "--strict-markers" addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py"
markers = [ markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')", "slow: marks tests as slow (deselect with '-m \"not slow\"')",
"asyncio" "asyncio"

View File

@@ -10,6 +10,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
import io import io
import os import os
@@ -22,6 +23,11 @@ import lancedb
from lancedb.embeddings import get_registry from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector from lancedb.pydantic import LanceModel, Vector
try:
if importlib.util.find_spec("mlx.core") is not None:
_mlx = True
except ImportError:
_mlx = None
# These are integration tests for embedding functions. # These are integration tests for embedding functions.
# They are slow because they require downloading models # They are slow because they require downloading models
# or connection to external api # or connection to external api
@@ -202,3 +208,61 @@ def test_gemini_embedding(tmp_path):
tbl.add(df) tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims() assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world" assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
@pytest.mark.skipif(
_mlx is None,
reason="mlx tests only required for apple users.",
)
@pytest.mark.slow
def test_gte_embedding(tmp_path):
import lancedb.embeddings.gte
model = get_registry().get("gte-text").create()
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
def aws_setup():
try:
import boto3
sts = boto3.client("sts")
sts.get_caller_identity()
return True
except Exception:
return False
@pytest.mark.slow
@pytest.mark.skipif(
not aws_setup(), reason="AWS credentials not set or libraries not installed"
)
def test_bedrock_embedding(tmp_path):
for name in [
"amazon.titan-embed-text-v1",
"cohere.embed-english-v3",
"cohere.embed-multilingual-v3",
]:
model = get_registry().get("bedrock-text").create(max_retries=0, name=name)
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()

View File

@@ -29,6 +29,9 @@ class FakeLanceDBClient:
def post(self, path: str): def post(self, path: str):
pass pass
def mount_retry_adapter_for_table(self, table_name: str):
pass
def test_remote_db(): def test_remote_db():
conn = lancedb.connect("db://client-will-be-injected", api_key="fake") conn = lancedb.connect("db://client-will-be-injected", api_key="fake")

View File

@@ -0,0 +1,168 @@
import os
import numpy as np
import pytest
import lancedb
from lancedb.conftest import MockTextEmbeddingFunction # noqa
from lancedb.embeddings import EmbeddingFunctionRegistry
from lancedb.pydantic import LanceModel, Vector
from lancedb.rerankers import CohereReranker, CrossEncoderReranker
from lancedb.table import LanceTable
def get_test_table(tmp_path):
db = lancedb.connect(tmp_path)
# Create a LanceDB table schema with a vector and a text column
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
class MyTable(LanceModel):
text: str = emb.SourceField()
vector: Vector(emb.ndims()) = emb.VectorField()
# Initialize the table using the schema
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
)
# Need to test with a bunch of phrases to make sure sorting is consistent
phrases = [
"great kid don't get cocky",
"now that's a name I haven't heard in a long time",
"if you strike me down I shall become more powerful than you imagine",
"I find your lack of faith disturbing",
"I've got a bad feeling about this",
"never tell me the odds",
"I am your father",
"somebody has to save our skins",
"New strategy R2 let the wookiee win",
"Arrrrggghhhhhhh",
"I see a mansard roof through the trees",
"I see a salty message written in the eves",
"the ground beneath my feet",
"the hot garbage and concrete",
"and now the tops of buildings",
"everybody with a worried mind could never forgive the sight",
"of wicked snakes inside a place you thought was dignified",
"I don't wanna live like this",
"but I don't wanna die",
"The templars want control",
"the brotherhood of assassins want freedom",
"if only they could both see the world as it really is",
"there would be peace",
"but the war goes on",
"altair's legacy was a warning",
"Kratos had a son",
"he was a god",
"the god of war",
"but his son was mortal",
"there hasn't been a good battlefield game since 2142",
"I wish they would make another one",
"campains are not as good as they used to be",
"Multiplayer and open world games have destroyed the single player experience",
"Maybe the future is console games",
"I don't know",
]
# Add the phrases and vectors to the table
table.add([{"text": p} for p in phrases])
# Create a fts index
table.create_fts_index("text")
return table, MyTable
## These tests are pretty loose, we should also check for correctness
def test_linear_combination(tmp_path):
table, schema = get_test_table(tmp_path)
# The default reranker
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score")
.to_pydantic(schema)
)
result2 = ( # noqa
table.search("Our father who art in heaven.", query_type="hybrid")
.rerank(normalize="rank")
.to_pydantic(schema)
)
result3 = table.search(
"Our father who art in heaven..", query_type="hybrid"
).to_pydantic(schema)
assert result1 == result3 # 2 & 3 should be the same as they use score as score
result = (
table.search("Our father who art in heaven", query_type="hybrid")
.limit(50)
.rerank(normalize="score")
.to_arrow()
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
@pytest.mark.skipif(
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
)
def test_cohere_reranker(tmp_path):
pytest.importorskip("cohere")
table, schema = get_test_table(tmp_path)
# The default reranker
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=CohereReranker())
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="rank", reranker=CohereReranker())
.to_pydantic(schema)
)
assert result1 == result2
result = (
table.search("Our father who art in heaven", query_type="hybrid")
.limit(50)
.rerank(reranker=CohereReranker())
.to_arrow()
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
def test_cross_encoder_reranker(tmp_path):
pytest.importorskip("sentence_transformers")
table, schema = get_test_table(tmp_path)
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=CrossEncoderReranker())
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="rank", reranker=CrossEncoderReranker())
.to_pydantic(schema)
)
assert result1 == result2
result = (
table.search("Our father who art in heaven", query_type="hybrid")
.limit(50)
.rerank(reranker=CrossEncoderReranker())
.to_arrow()
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)

View File

@@ -493,6 +493,62 @@ def test_update_types(db):
assert actual == expected assert actual == expected
def test_merge_insert(db):
table = LanceTable.create(
db,
"my_table",
data=pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]}),
)
assert len(table) == 3
version = table.version
new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]})
# upsert
table.merge_insert(
"a"
).when_matched_update_all().when_not_matched_insert_all().execute(new_data)
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "x", "y", "z"]})
# These `sort_by` calls can be removed once lance#1892
# is merged (it fixes the ordering)
assert table.to_arrow().sort_by("a") == expected
table.restore(version)
# insert-if-not-exists
table.merge_insert("a").when_not_matched_insert_all().execute(new_data)
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "b", "c", "z"]})
assert table.to_arrow().sort_by("a") == expected
table.restore(version)
new_data = pa.table({"a": [2, 4], "b": ["x", "z"]})
# replace-range
table.merge_insert(
"a"
).when_matched_update_all().when_not_matched_insert_all().when_not_matched_by_source_delete(
"a > 2"
).execute(new_data)
expected = pa.table({"a": [1, 2, 4], "b": ["a", "x", "z"]})
assert table.to_arrow().sort_by("a") == expected
table.restore(version)
# replace-range no condition
table.merge_insert(
"a"
).when_matched_update_all().when_not_matched_insert_all().when_not_matched_by_source_delete().execute(
new_data
)
expected = pa.table({"a": [2, 4], "b": ["x", "z"]})
assert table.to_arrow().sort_by("a") == expected
def test_create_with_embedding_function(db): def test_create_with_embedding_function(db):
class MyTable(LanceModel): class MyTable(LanceModel):
text: str text: str
@@ -682,3 +738,57 @@ def test_count_rows(db):
assert len(table) == 2 assert len(table) == 2
assert table.count_rows() == 2 assert table.count_rows() == 2
assert table.count_rows(filter="text='bar'") == 1 assert table.count_rows(filter="text='bar'") == 1
def test_hybrid_search(db):
# hardcoding temporarily.. this test is failing with tmp_path mockdb.
# Probably not being parsed right by the fts
db = MockDB("~/lancedb_")
# Create a LanceDB table schema with a vector and a text column
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
class MyTable(LanceModel):
text: str = emb.SourceField()
vector: Vector(emb.ndims()) = emb.VectorField()
# Initialize the table using the schema
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
)
# Create a list of 10 unique english phrases
phrases = [
"great kid don't get cocky",
"now that's a name I haven't heard in a long time",
"if you strike me down I shall become more powerful than you imagine",
"I find your lack of faith disturbing",
"I've got a bad feeling about this",
"never tell me the odds",
"I am your father",
"somebody has to save our skins",
"New strategy R2 let the wookiee win",
"Arrrrggghhhhhhh",
]
# Add the phrases and vectors to the table
table.add([{"text": p} for p in phrases])
# Create a fts index
table.create_fts_index("text")
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score")
.to_pydantic(MyTable)
)
result2 = ( # noqa
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="rank")
.to_pydantic(MyTable)
)
result3 = table.search(
"Our father who art in heaven", query_type="hybrid"
).to_pydantic(MyTable)
assert result1 == result3

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "vectordb-node" name = "vectordb-node"
version = "0.4.4" version = "0.4.7"
description = "Serverless, low-latency vector database for AI applications" description = "Serverless, low-latency vector database for AI applications"
license = "Apache-2.0" license = "Apache-2.0"
edition = "2018" edition = "2018"

View File

@@ -1,4 +1,4 @@
// Copyright 2023 Lance Developers. // Copyright 2024 Lance Developers.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@@ -19,19 +19,8 @@ use arrow_array::RecordBatch;
use arrow_ipc::reader::FileReader; use arrow_ipc::reader::FileReader;
use arrow_ipc::writer::FileWriter; use arrow_ipc::writer::FileWriter;
use arrow_schema::SchemaRef; use arrow_schema::SchemaRef;
use vectordb::table::VECTOR_COLUMN_NAME;
use crate::error::{MissingColumnSnafu, Result}; use crate::error::Result;
use snafu::prelude::*;
fn validate_vector_column(record_batch: &RecordBatch) -> Result<()> {
record_batch
.column_by_name(VECTOR_COLUMN_NAME)
.map(|_| ())
.context(MissingColumnSnafu {
name: VECTOR_COLUMN_NAME,
})
}
pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBatch>, SchemaRef)> { pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBatch>, SchemaRef)> {
let mut batches: Vec<RecordBatch> = Vec::new(); let mut batches: Vec<RecordBatch> = Vec::new();
@@ -39,7 +28,6 @@ pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBa
let schema = file_reader.schema(); let schema = file_reader.schema();
for b in file_reader { for b in file_reader {
let record_batch = b?; let record_batch = b?;
validate_vector_column(&record_batch)?;
batches.push(record_batch); batches.push(record_batch);
} }
Ok((batches, schema)) Ok((batches, schema))

View File

@@ -19,6 +19,7 @@ use neon::{
}; };
use crate::{error::ResultExt, runtime, table::JsTable}; use crate::{error::ResultExt, runtime, table::JsTable};
use vectordb::Table;
pub(crate) fn table_create_scalar_index(mut cx: FunctionContext) -> JsResult<JsPromise> { pub(crate) fn table_create_scalar_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?; let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
@@ -35,7 +36,9 @@ pub(crate) fn table_create_scalar_index(mut cx: FunctionContext) -> JsResult<JsP
let idx_result = table let idx_result = table
.as_native() .as_native()
.unwrap() .unwrap()
.create_scalar_index(&column, replace) .create_index(&[&column])
.replace(replace)
.build()
.await; .await;
deferred.settle_with(&channel, move |mut cx| { deferred.settle_with(&channel, move |mut cx| {

View File

@@ -24,7 +24,7 @@ use tokio::runtime::Runtime;
use vectordb::connection::Database; use vectordb::connection::Database;
use vectordb::table::ReadParams; use vectordb::table::ReadParams;
use vectordb::Connection; use vectordb::{ConnectOptions, Connection};
use crate::error::ResultExt; use crate::error::ResultExt;
use crate::query::JsQuery; use crate::query::JsQuery;
@@ -82,13 +82,26 @@ fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> {
fn database_new(mut cx: FunctionContext) -> JsResult<JsPromise> { fn database_new(mut cx: FunctionContext) -> JsResult<JsPromise> {
let path = cx.argument::<JsString>(0)?.value(&mut cx); let path = cx.argument::<JsString>(0)?.value(&mut cx);
let aws_creds = get_aws_creds(&mut cx, 1)?;
let region = get_aws_region(&mut cx, 4)?;
let rt = runtime(&mut cx)?; let rt = runtime(&mut cx)?;
let channel = cx.channel(); let channel = cx.channel();
let (deferred, promise) = cx.promise(); let (deferred, promise) = cx.promise();
let mut conn_options = ConnectOptions::new(&path);
if let Some(region) = region {
conn_options = conn_options.region(&region);
}
if let Some(aws_creds) = aws_creds {
conn_options = conn_options.aws_creds(AwsCredential {
key_id: aws_creds.key_id,
secret_key: aws_creds.secret_key,
token: aws_creds.token,
});
}
rt.spawn(async move { rt.spawn(async move {
let database = Database::connect(&path).await; let database = Database::connect_with_options(&conn_options).await;
deferred.settle_with(&channel, move |mut cx| { deferred.settle_with(&channel, move |mut cx| {
let db = JsDatabase { let db = JsDatabase {
@@ -127,7 +140,7 @@ fn database_table_names(mut cx: FunctionContext) -> JsResult<JsPromise> {
fn get_aws_creds( fn get_aws_creds(
cx: &mut FunctionContext, cx: &mut FunctionContext,
arg_starting_location: i32, arg_starting_location: i32,
) -> NeonResult<Option<AwsCredentialProvider>> { ) -> NeonResult<Option<AwsCredential>> {
let secret_key_id = cx let secret_key_id = cx
.argument_opt(arg_starting_location) .argument_opt(arg_starting_location)
.filter(|arg| arg.is_a::<JsString, _>(cx)) .filter(|arg| arg.is_a::<JsString, _>(cx))
@@ -147,18 +160,26 @@ fn get_aws_creds(
.map(|v| v.value(cx)); .map(|v| v.value(cx));
match (secret_key_id, secret_key, temp_token) { match (secret_key_id, secret_key, temp_token) {
(Some(key_id), Some(key), optional_token) => Ok(Some(Arc::new( (Some(key_id), Some(key), optional_token) => Ok(Some(AwsCredential {
StaticCredentialProvider::new(AwsCredential { key_id,
key_id, secret_key: key,
secret_key: key, token: optional_token,
token: optional_token, })),
}),
))),
(None, None, None) => Ok(None), (None, None, None) => Ok(None),
_ => cx.throw_error("Invalid credentials configuration"), _ => cx.throw_error("Invalid credentials configuration"),
} }
} }
fn get_aws_credential_provider(
cx: &mut FunctionContext,
arg_starting_location: i32,
) -> NeonResult<Option<AwsCredentialProvider>> {
Ok(get_aws_creds(cx, arg_starting_location)?.map(|aws_cred| {
Arc::new(StaticCredentialProvider::new(aws_cred))
as Arc<dyn CredentialProvider<Credential = AwsCredential>>
}))
}
/// Get AWS region arguments from the context /// Get AWS region arguments from the context
fn get_aws_region(cx: &mut FunctionContext, arg_location: i32) -> NeonResult<Option<String>> { fn get_aws_region(cx: &mut FunctionContext, arg_location: i32) -> NeonResult<Option<String>> {
let region = cx let region = cx
@@ -179,7 +200,7 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?; .downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
let table_name = cx.argument::<JsString>(0)?.value(&mut cx); let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
let aws_creds = get_aws_creds(&mut cx, 1)?; let aws_creds = get_aws_credential_provider(&mut cx, 1)?;
let aws_region = get_aws_region(&mut cx, 4)?; let aws_region = get_aws_region(&mut cx, 4)?;
@@ -239,6 +260,7 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> {
cx.export_function("tableCountRows", JsTable::js_count_rows)?; cx.export_function("tableCountRows", JsTable::js_count_rows)?;
cx.export_function("tableDelete", JsTable::js_delete)?; cx.export_function("tableDelete", JsTable::js_delete)?;
cx.export_function("tableUpdate", JsTable::js_update)?; cx.export_function("tableUpdate", JsTable::js_update)?;
cx.export_function("tableMergeInsert", JsTable::js_merge_insert)?;
cx.export_function("tableCleanupOldVersions", JsTable::js_cleanup)?; cx.export_function("tableCleanupOldVersions", JsTable::js_cleanup)?;
cx.export_function("tableCompactFiles", JsTable::js_compact)?; cx.export_function("tableCompactFiles", JsTable::js_compact)?;
cx.export_function("tableListIndices", JsTable::js_list_indices)?; cx.export_function("tableListIndices", JsTable::js_list_indices)?;

View File

@@ -12,10 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::ops::Deref;
use arrow_array::{RecordBatch, RecordBatchIterator}; use arrow_array::{RecordBatch, RecordBatchIterator};
use lance::dataset::optimize::CompactionOptions; use lance::dataset::optimize::CompactionOptions;
use lance::dataset::{WriteMode, WriteParams}; use lance::dataset::{WriteMode, WriteParams};
use lance::io::ObjectStoreParams; use lance::io::ObjectStoreParams;
use vectordb::table::OptimizeAction;
use crate::arrow::{arrow_buffer_to_record_batch, record_batch_to_buffer}; use crate::arrow::{arrow_buffer_to_record_batch, record_batch_to_buffer};
use neon::prelude::*; use neon::prelude::*;
@@ -23,7 +26,7 @@ use neon::types::buffer::TypedArray;
use vectordb::TableRef; use vectordb::TableRef;
use crate::error::ResultExt; use crate::error::ResultExt;
use crate::{convert, get_aws_creds, get_aws_region, runtime, JsDatabase}; use crate::{convert, get_aws_credential_provider, get_aws_region, runtime, JsDatabase};
pub(crate) struct JsTable { pub(crate) struct JsTable {
pub table: TableRef, pub table: TableRef,
@@ -63,7 +66,7 @@ impl JsTable {
let (deferred, promise) = cx.promise(); let (deferred, promise) = cx.promise();
let database = db.database.clone(); let database = db.database.clone();
let aws_creds = get_aws_creds(&mut cx, 3)?; let aws_creds = get_aws_credential_provider(&mut cx, 3)?;
let aws_region = get_aws_region(&mut cx, 6)?; let aws_region = get_aws_region(&mut cx, 6)?;
let params = WriteParams { let params = WriteParams {
@@ -105,7 +108,7 @@ impl JsTable {
"overwrite" => WriteMode::Overwrite, "overwrite" => WriteMode::Overwrite,
s => return cx.throw_error(format!("invalid write mode {}", s)), s => return cx.throw_error(format!("invalid write mode {}", s)),
}; };
let aws_creds = get_aws_creds(&mut cx, 2)?; let aws_creds = get_aws_credential_provider(&mut cx, 2)?;
let aws_region = get_aws_region(&mut cx, 5)?; let aws_region = get_aws_region(&mut cx, 5)?;
let params = WriteParams { let params = WriteParams {
@@ -165,6 +168,53 @@ impl JsTable {
Ok(promise) Ok(promise)
} }
pub(crate) fn js_merge_insert(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise();
let channel = cx.channel();
let table = js_table.table.clone();
let key = cx.argument::<JsString>(0)?.value(&mut cx);
let mut builder = table.merge_insert(&[&key]);
if cx.argument::<JsBoolean>(1)?.value(&mut cx) {
builder.when_matched_update_all();
}
if cx.argument::<JsBoolean>(2)?.value(&mut cx) {
builder.when_not_matched_insert_all();
}
if cx.argument::<JsBoolean>(3)?.value(&mut cx) {
if let Some(filter) = cx.argument_opt(4) {
if filter.is_a::<JsNull, _>(&mut cx) {
builder.when_not_matched_by_source_delete(None);
} else {
let filter = filter
.downcast_or_throw::<JsString, _>(&mut cx)?
.deref()
.value(&mut cx);
builder.when_not_matched_by_source_delete(Some(filter));
}
} else {
builder.when_not_matched_by_source_delete(None);
}
}
let buffer = cx.argument::<JsBuffer>(5)?;
let (batches, schema) =
arrow_buffer_to_record_batch(buffer.as_slice(&cx)).or_throw(&mut cx)?;
rt.spawn(async move {
let new_data = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
let merge_insert_result = builder.execute(Box::new(new_data)).await;
deferred.settle_with(&channel, move |mut cx| {
merge_insert_result.or_throw(&mut cx)?;
Ok(cx.boxed(JsTable::from(table)))
})
});
Ok(promise)
}
pub(crate) fn js_update(mut cx: FunctionContext) -> JsResult<JsPromise> { pub(crate) fn js_update(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?; let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let table = js_table.table.clone(); let table = js_table.table.clone();
@@ -245,27 +295,30 @@ impl JsTable {
.map(|val| val.value(&mut cx) as i64) .map(|val| val.value(&mut cx) as i64)
.unwrap_or_else(|| 2 * 7 * 24 * 60); // 2 weeks .unwrap_or_else(|| 2 * 7 * 24 * 60); // 2 weeks
let older_than = chrono::Duration::minutes(older_than); let older_than = chrono::Duration::minutes(older_than);
let delete_unverified: bool = cx let delete_unverified: Option<bool> = Some(
.argument_opt(1) cx.argument_opt(1)
.and_then(|val| val.downcast::<JsBoolean, _>(&mut cx).ok()) .and_then(|val| val.downcast::<JsBoolean, _>(&mut cx).ok())
.map(|val| val.value(&mut cx)) .map(|val| val.value(&mut cx))
.unwrap_or_default(); .unwrap_or_default(),
);
rt.spawn(async move { rt.spawn(async move {
let stats = table let stats = table
.as_native() .optimize(OptimizeAction::Prune {
.unwrap() older_than,
.cleanup_old_versions(older_than, Some(delete_unverified)) delete_unverified,
})
.await; .await;
deferred.settle_with(&channel, move |mut cx| { deferred.settle_with(&channel, move |mut cx| {
let stats = stats.or_throw(&mut cx)?; let stats = stats.or_throw(&mut cx)?;
let prune_stats = stats.prune.as_ref().expect("Prune stats missing");
let output_metrics = JsObject::new(&mut cx); let output_metrics = JsObject::new(&mut cx);
let bytes_removed = cx.number(stats.bytes_removed as f64); let bytes_removed = cx.number(prune_stats.bytes_removed as f64);
output_metrics.set(&mut cx, "bytesRemoved", bytes_removed)?; output_metrics.set(&mut cx, "bytesRemoved", bytes_removed)?;
let old_versions = cx.number(stats.old_versions as f64); let old_versions = cx.number(prune_stats.old_versions as f64);
output_metrics.set(&mut cx, "oldVersions", old_versions)?; output_metrics.set(&mut cx, "oldVersions", old_versions)?;
let output_table = cx.boxed(JsTable::from(table)); let output_table = cx.boxed(JsTable::from(table));
@@ -317,13 +370,15 @@ impl JsTable {
rt.spawn(async move { rt.spawn(async move {
let stats = table let stats = table
.as_native() .optimize(OptimizeAction::Compact {
.unwrap() options,
.compact_files(options, None) remap_options: None,
})
.await; .await;
deferred.settle_with(&channel, move |mut cx| { deferred.settle_with(&channel, move |mut cx| {
let stats = stats.or_throw(&mut cx)?; let stats = stats.or_throw(&mut cx)?;
let stats = stats.compaction.as_ref().expect("Compact stats missing");
let output_metrics = JsObject::new(&mut cx); let output_metrics = JsObject::new(&mut cx);
let fragments_removed = cx.number(stats.fragments_removed as f64); let fragments_removed = cx.number(stats.fragments_removed as f64);

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "vectordb" name = "vectordb"
version = "0.4.4" version = "0.4.7"
edition = "2021" edition = "2021"
description = "LanceDB: A serverless, low-latency vector database for AI applications" description = "LanceDB: A serverless, low-latency vector database for AI applications"
license = "Apache-2.0" license = "Apache-2.0"

View File

@@ -0,0 +1,168 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use arrow_array::types::Float32Type;
use arrow_array::{FixedSizeListArray, Int32Array, RecordBatch, RecordBatchIterator};
use arrow_schema::{DataType, Field, Schema};
use futures::TryStreamExt;
use vectordb::Connection;
use vectordb::{connect, Result, Table, TableRef};
#[tokio::main]
async fn main() -> Result<()> {
if std::path::Path::new("data").exists() {
std::fs::remove_dir_all("data").unwrap();
}
// --8<-- [start:connect]
let uri = "data/sample-lancedb";
let db = connect(uri).await?;
// --8<-- [end:connect]
// --8<-- [start:list_names]
println!("{:?}", db.table_names().await?);
// --8<-- [end:list_names]
let tbl = create_table(db.clone()).await?;
create_index(tbl.as_ref()).await?;
let batches = search(tbl.as_ref()).await?;
println!("{:?}", batches);
create_empty_table(db.clone()).await.unwrap();
// --8<-- [start:delete]
tbl.delete("id > 24").await.unwrap();
// --8<-- [end:delete]
// --8<-- [start:drop_table]
db.drop_table("my_table").await.unwrap();
// --8<-- [end:drop_table]
Ok(())
}
#[allow(dead_code)]
async fn open_with_existing_tbl() -> Result<()> {
let uri = "data/sample-lancedb";
let db = connect(uri).await?;
// --8<-- [start:open_with_existing_file]
let _ = db
.open_table_with_params("my_table", Default::default())
.await
.unwrap();
// --8<-- [end:open_with_existing_file]
Ok(())
}
async fn create_table(db: Arc<dyn Connection>) -> Result<TableRef> {
// --8<-- [start:create_table]
const TOTAL: usize = 1000;
const DIM: usize = 128;
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new(
"vector",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
DIM as i32,
),
true,
),
]));
// Create a RecordBatch stream.
let batches = RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
Arc::new(
FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
(0..TOTAL).map(|_| Some(vec![Some(1.0); DIM])),
DIM as i32,
),
),
],
)
.unwrap()]
.into_iter()
.map(Ok),
schema.clone(),
);
let tbl = db
.create_table("my_table", Box::new(batches), None)
.await
.unwrap();
// --8<-- [end:create_table]
let new_batches = RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
Arc::new(
FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
(0..TOTAL).map(|_| Some(vec![Some(1.0); DIM])),
DIM as i32,
),
),
],
)
.unwrap()]
.into_iter()
.map(Ok),
schema.clone(),
);
// --8<-- [start:add]
tbl.add(Box::new(new_batches), None).await.unwrap();
// --8<-- [end:add]
Ok(tbl)
}
async fn create_empty_table(db: Arc<dyn Connection>) -> Result<TableRef> {
// --8<-- [start:create_empty_table]
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("item", DataType::Utf8, true),
]));
let batches = RecordBatchIterator::new(vec![], schema.clone());
db.create_table("empty_table", Box::new(batches), None)
.await
// --8<-- [end:create_empty_table]
}
async fn create_index(table: &dyn Table) -> Result<()> {
// --8<-- [start:create_index]
table
.create_index(&["vector"])
.ivf_pq()
.num_partitions(8)
.build()
.await
// --8<-- [end:create_index]
}
async fn search(table: &dyn Table) -> Result<Vec<RecordBatch>> {
// --8<-- [start:search]
Ok(table
.search(&[1.0; 128])
.limit(2)
.execute_stream()
.await?
.try_collect::<Vec<_>>()
.await?)
// --8<-- [end:search]
}

View File

@@ -21,8 +21,10 @@ use std::sync::Arc;
use arrow_array::RecordBatchReader; use arrow_array::RecordBatchReader;
use lance::dataset::WriteParams; use lance::dataset::WriteParams;
use lance::io::{ObjectStore, WrappingObjectStore}; use lance::io::{ObjectStore, ObjectStoreParams, WrappingObjectStore};
use object_store::local::LocalFileSystem; use object_store::{
aws::AwsCredential, local::LocalFileSystem, CredentialProvider, StaticCredentialProvider,
};
use snafu::prelude::*; use snafu::prelude::*;
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result}; use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
@@ -68,6 +70,98 @@ pub trait Connection: Send + Sync {
async fn drop_table(&self, name: &str) -> Result<()>; async fn drop_table(&self, name: &str) -> Result<()>;
} }
#[derive(Debug)]
pub struct ConnectOptions {
/// Database URI
///
/// # Accpeted URI formats
///
/// - `/path/to/database` - local database on file system.
/// - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud object store
/// - `db://dbname` - Lance Cloud
pub uri: String,
/// Lance Cloud API key
pub api_key: Option<String>,
/// Lance Cloud region
pub region: Option<String>,
/// Lance Cloud host override
pub host_override: Option<String>,
/// User provided AWS credentials
pub aws_creds: Option<AwsCredential>,
/// The maximum number of indices to cache in memory. Defaults to 256.
pub index_cache_size: u32,
}
impl ConnectOptions {
/// Create a new [`ConnectOptions`] with the given database URI.
pub fn new(uri: &str) -> Self {
Self {
uri: uri.to_string(),
api_key: None,
region: None,
host_override: None,
aws_creds: None,
index_cache_size: 256,
}
}
pub fn api_key(mut self, api_key: &str) -> Self {
self.api_key = Some(api_key.to_string());
self
}
pub fn region(mut self, region: &str) -> Self {
self.region = Some(region.to_string());
self
}
pub fn host_override(mut self, host_override: &str) -> Self {
self.host_override = Some(host_override.to_string());
self
}
/// [`AwsCredential`] to use when connecting to S3.
///
pub fn aws_creds(mut self, aws_creds: AwsCredential) -> Self {
self.aws_creds = Some(aws_creds);
self
}
pub fn index_cache_size(mut self, index_cache_size: u32) -> Self {
self.index_cache_size = index_cache_size;
self
}
}
/// Connect to a LanceDB database.
///
/// # Arguments
///
/// - `uri` - URI where the database is located, can be a local file or a supported remote cloud storage
///
/// ## Accepted URI formats
///
/// - `/path/to/database` - local database on file system.
/// - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud object store
/// - `db://dbname` - Lance Cloud
///
pub async fn connect(uri: &str) -> Result<Arc<dyn Connection>> {
let options = ConnectOptions::new(uri);
connect_with_options(&options).await
}
/// Connect with [`ConnectOptions`].
///
/// # Arguments
/// - `options` - [`ConnectOptions`] to connect to the database.
pub async fn connect_with_options(options: &ConnectOptions) -> Result<Arc<dyn Connection>> {
let db = Database::connect(&options.uri).await?;
Ok(Arc::new(db))
}
pub struct Database { pub struct Database {
object_store: ObjectStore, object_store: ObjectStore,
query_string: Option<String>, query_string: Option<String>,
@@ -95,6 +189,12 @@ impl Database {
/// ///
/// * A [Database] object. /// * A [Database] object.
pub async fn connect(uri: &str) -> Result<Database> { pub async fn connect(uri: &str) -> Result<Database> {
let options = ConnectOptions::new(uri);
Self::connect_with_options(&options).await
}
pub async fn connect_with_options(options: &ConnectOptions) -> Result<Database> {
let uri = &options.uri;
let parse_res = url::Url::parse(uri); let parse_res = url::Url::parse(uri);
match parse_res { match parse_res {
@@ -146,7 +246,23 @@ impl Database {
}; };
let plain_uri = url.to_string(); let plain_uri = url.to_string();
let (object_store, base_path) = ObjectStore::from_uri(&plain_uri).await?; let os_params: ObjectStoreParams = if let Some(aws_creds) = &options.aws_creds {
let credential_provider: Arc<
dyn CredentialProvider<Credential = AwsCredential>,
> = Arc::new(StaticCredentialProvider::new(AwsCredential {
key_id: aws_creds.key_id.clone(),
secret_key: aws_creds.secret_key.clone(),
token: aws_creds.token.clone(),
}));
ObjectStoreParams::with_aws_credentials(
Some(credential_provider),
options.region.clone(),
)
} else {
ObjectStoreParams::default()
};
let (object_store, base_path) =
ObjectStore::from_uri_and_params(&plain_uri, &os_params).await?;
if object_store.is_local() { if object_store.is_local() {
Self::try_create_dir(&plain_uri).context(CreateDirSnafu { path: plain_uri })?; Self::try_create_dir(&plain_uri).context(CreateDirSnafu { path: plain_uri })?;
} }

View File

@@ -14,6 +14,7 @@
use std::{cmp::max, sync::Arc}; use std::{cmp::max, sync::Arc};
use lance::index::scalar::ScalarIndexParams;
use lance_index::{DatasetIndexExt, IndexType}; use lance_index::{DatasetIndexExt, IndexType};
pub use lance_linalg::distance::MetricType; pub use lance_linalg::distance::MetricType;
@@ -232,10 +233,14 @@ impl IndexBuilder {
let mut dataset = tbl.clone_inner_dataset(); let mut dataset = tbl.clone_inner_dataset();
match params { match params {
IndexParams::Scalar { replace } => { IndexParams::Scalar { replace } => {
self.table dataset
.as_native() .create_index(
.unwrap() &[&column],
.create_scalar_index(column, replace) IndexType::Scalar,
None,
&ScalarIndexParams::default(),
replace,
)
.await? .await?
} }
IndexParams::IvfPq { IndexParams::IvfPq {

View File

@@ -16,10 +16,10 @@
use std::io::Cursor; use std::io::Cursor;
use arrow_array::RecordBatchReader; use arrow_array::{RecordBatch, RecordBatchReader};
use arrow_ipc::reader::StreamReader; use arrow_ipc::{reader::StreamReader, writer::FileWriter};
use crate::Result; use crate::{Error, Result};
/// Convert a Arrow IPC file to a batch reader /// Convert a Arrow IPC file to a batch reader
pub fn ipc_file_to_batches(buf: Vec<u8>) -> Result<impl RecordBatchReader> { pub fn ipc_file_to_batches(buf: Vec<u8>) -> Result<impl RecordBatchReader> {
@@ -28,6 +28,22 @@ pub fn ipc_file_to_batches(buf: Vec<u8>) -> Result<impl RecordBatchReader> {
Ok(reader) Ok(reader)
} }
/// Convert record batches to Arrow IPC file
pub fn batches_to_ipc_file(batches: &[RecordBatch]) -> Result<Vec<u8>> {
if batches.is_empty() {
return Err(Error::Store {
message: "No batches to write".to_string(),
});
}
let schema = batches[0].schema();
let mut writer = FileWriter::try_new(vec![], &schema)?;
for batch in batches {
writer.write(batch)?;
}
writer.finish()?;
Ok(writer.into_inner()?)
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {

View File

@@ -33,23 +33,37 @@
//! LanceDB runs in process, to use it in your Rust project, put the following in your `Cargo.toml`: //! LanceDB runs in process, to use it in your Rust project, put the following in your `Cargo.toml`:
//! //!
//! ```ignore //! ```ignore
//! [dependencies] //! cargo install vectordb
//! vectordb = "0.4"
//! arrow-schema = "50"
//! arrow-array = "50"
//! ``` //! ```
//! //!
//! ### Quick Start //! ### Quick Start
//! //!
//! <div class="warning">Rust API is not stable yet.</div> //! <div class="warning">Rust API is not stable yet, please expect breaking changes.</div>
//! //!
//! #### Connect to a database. //! #### Connect to a database.
//! //!
//! ```rust //! ```rust
//! use vectordb::connection::Database; //! use vectordb::connect;
//! # use arrow_schema::{Field, Schema}; //! # use arrow_schema::{Field, Schema};
//! # tokio::runtime::Runtime::new().unwrap().block_on(async { //! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! let db = Database::connect("data/sample-lancedb").await.unwrap(); //! let db = connect("data/sample-lancedb").await.unwrap();
//! # });
//! ```
//!
//! LanceDB accepts the different form of database path:
//!
//! - `/path/to/database` - local database on file system.
//! - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud object store
//! - `db://dbname` - Lance Cloud
//!
//! You can also use [`ConnectOptions`] to configure the connectoin to the database.
//!
//! ```rust
//! use vectordb::{connect_with_options, ConnectOptions};
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! let options = ConnectOptions::new("data/sample-lancedb")
//! .index_cache_size(1024);
//! let db = connect_with_options(&options).await.unwrap();
//! # }); //! # });
//! ``` //! ```
//! //!
@@ -57,6 +71,8 @@
//! It treats [`FixedSizeList<Float16/Float32>`](https://docs.rs/arrow/latest/arrow/array/struct.FixedSizeListArray.html) //! It treats [`FixedSizeList<Float16/Float32>`](https://docs.rs/arrow/latest/arrow/array/struct.FixedSizeListArray.html)
//! columns as vector columns. //! columns as vector columns.
//! //!
//! For more details, please refer to [LanceDB documentation](https://lancedb.github.io/lancedb/).
//!
//! #### Create a table //! #### Create a table
//! //!
//! To create a Table, you need to provide a [`arrow_schema::Schema`] and a [`arrow_array::RecordBatch`] stream. //! To create a Table, you need to provide a [`arrow_schema::Schema`] and a [`arrow_array::RecordBatch`] stream.
@@ -67,10 +83,11 @@
//! use arrow_array::{RecordBatch, RecordBatchIterator}; //! use arrow_array::{RecordBatch, RecordBatchIterator};
//! # use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, types::Float32Type}; //! # use arrow_array::{FixedSizeListArray, Float32Array, Int32Array, types::Float32Type};
//! # use vectordb::connection::{Database, Connection}; //! # use vectordb::connection::{Database, Connection};
//! # use vectordb::connect;
//! //!
//! # tokio::runtime::Runtime::new().unwrap().block_on(async { //! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! # let tmpdir = tempfile::tempdir().unwrap(); //! # let tmpdir = tempfile::tempdir().unwrap();
//! # let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap(); //! # let db = connect(tmpdir.path().to_str().unwrap()).await.unwrap();
//! let schema = Arc::new(Schema::new(vec![ //! let schema = Arc::new(Schema::new(vec![
//! Field::new("id", DataType::Int32, false), //! Field::new("id", DataType::Int32, false),
//! Field::new("vector", DataType::FixedSizeList( //! Field::new("vector", DataType::FixedSizeList(
@@ -80,9 +97,9 @@
//! let batches = RecordBatchIterator::new(vec![ //! let batches = RecordBatchIterator::new(vec![
//! RecordBatch::try_new(schema.clone(), //! RecordBatch::try_new(schema.clone(),
//! vec![ //! vec![
//! Arc::new(Int32Array::from_iter_values(0..10)), //! Arc::new(Int32Array::from_iter_values(0..1000)),
//! Arc::new(FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>( //! Arc::new(FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
//! (0..10).map(|_| Some(vec![Some(1.0); 128])), 128)), //! (0..1000).map(|_| Some(vec![Some(1.0); 128])), 128)),
//! ]).unwrap() //! ]).unwrap()
//! ].into_iter().map(Ok), //! ].into_iter().map(Ok),
//! schema.clone()); //! schema.clone());
@@ -94,13 +111,13 @@
//! //!
//! ```no_run //! ```no_run
//! # use std::sync::Arc; //! # use std::sync::Arc;
//! # use vectordb::connection::{Database, Connection}; //! # use vectordb::connect;
//! # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch, //! # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
//! # RecordBatchIterator, Int32Array}; //! # RecordBatchIterator, Int32Array};
//! # use arrow_schema::{Schema, Field, DataType}; //! # use arrow_schema::{Schema, Field, DataType};
//! # tokio::runtime::Runtime::new().unwrap().block_on(async { //! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! # let tmpdir = tempfile::tempdir().unwrap(); //! # let tmpdir = tempfile::tempdir().unwrap();
//! # let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap(); //! # let db = connect(tmpdir.path().to_str().unwrap()).await.unwrap();
//! # let tbl = db.open_table("idx_test").await.unwrap(); //! # let tbl = db.open_table("idx_test").await.unwrap();
//! tbl.create_index(&["vector"]) //! tbl.create_index(&["vector"])
//! .ivf_pq() //! .ivf_pq()
@@ -138,7 +155,7 @@
//! # ].into_iter().map(Ok), //! # ].into_iter().map(Ok),
//! # schema.clone()); //! # schema.clone());
//! # db.create_table("my_table", Box::new(batches), None).await.unwrap(); //! # db.create_table("my_table", Box::new(batches), None).await.unwrap();
//! let table = db.open_table("my_table").await.unwrap(); //! # let table = db.open_table("my_table").await.unwrap();
//! let results = table //! let results = table
//! .search(&[1.0; 128]) //! .search(&[1.0; 128])
//! .execute_stream() //! .execute_stream()
@@ -166,4 +183,6 @@ pub use connection::{Connection, Database};
pub use error::{Error, Result}; pub use error::{Error, Result};
pub use table::{Table, TableRef}; pub use table::{Table, TableRef};
/// Connect to a database
pub use connection::{connect, connect_with_options, ConnectOptions};
pub use lance::dataset::WriteMode; pub use lance::dataset::WriteMode;

View File

@@ -22,6 +22,7 @@ use lance_linalg::distance::MetricType;
use crate::error::Result; use crate::error::Result;
use crate::utils::default_vector_column; use crate::utils::default_vector_column;
use crate::Error;
const DEFAULT_TOP_K: usize = 10; const DEFAULT_TOP_K: usize = 10;
@@ -93,6 +94,19 @@ impl Query {
let arrow_schema = Schema::from(self.dataset.schema()); let arrow_schema = Schema::from(self.dataset.schema());
default_vector_column(&arrow_schema, Some(query.len() as i32))? default_vector_column(&arrow_schema, Some(query.len() as i32))?
}; };
let field = self.dataset.schema().field(&column).ok_or(Error::Store {
message: format!("Column {} not found in dataset schema", column),
})?;
if !matches!(field.data_type(), arrow_schema::DataType::FixedSizeList(f, dim) if f.data_type().is_floating() && dim == query.len() as i32)
{
return Err(Error::Store {
message: format!(
"Vector column '{}' does not match the dimension of the query vector: dim={}",
column,
query.len(),
),
});
}
scanner.nearest(&column, query, self.limit.unwrap_or(DEFAULT_TOP_K))?; scanner.nearest(&column, query, self.limit.unwrap_or(DEFAULT_TOP_K))?;
} else { } else {
// If there is no vector query, it's ok to not have a limit // If there is no vector query, it's ok to not have a limit

View File

@@ -19,6 +19,7 @@ use std::sync::{Arc, Mutex};
use arrow_array::RecordBatchReader; use arrow_array::RecordBatchReader;
use arrow_schema::{Schema, SchemaRef}; use arrow_schema::{Schema, SchemaRef};
use async_trait::async_trait;
use chrono::Duration; use chrono::Duration;
use lance::dataset::builder::DatasetBuilder; use lance::dataset::builder::DatasetBuilder;
use lance::dataset::cleanup::RemovalStats; use lance::dataset::cleanup::RemovalStats;
@@ -27,9 +28,10 @@ use lance::dataset::optimize::{
}; };
pub use lance::dataset::ReadParams; pub use lance::dataset::ReadParams;
use lance::dataset::{Dataset, UpdateBuilder, WriteParams}; use lance::dataset::{Dataset, UpdateBuilder, WriteParams};
use lance::index::scalar::ScalarIndexParams; use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
use lance::io::WrappingObjectStore; use lance::io::WrappingObjectStore;
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt, IndexType}; use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
use log::info;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::index::vector::{VectorIndex, VectorIndexStatistics}; use crate::index::vector::{VectorIndex, VectorIndexStatistics};
@@ -38,7 +40,50 @@ use crate::query::Query;
use crate::utils::{PatchReadParam, PatchWriteParam}; use crate::utils::{PatchReadParam, PatchWriteParam};
use crate::WriteMode; use crate::WriteMode;
pub const VECTOR_COLUMN_NAME: &str = "vector"; use self::merge::{MergeInsert, MergeInsertBuilder};
pub mod merge;
/// Optimize the dataset.
///
/// Similar to `VACUUM` in PostgreSQL, it offers different options to
/// optimize different parts of the table on disk.
///
/// By default, it optimizes everything, as [`OptimizeAction::All`].
pub enum OptimizeAction {
/// Run optimization on every, with default options.
All,
/// Compact files in the dataset
Compact {
options: CompactionOptions,
remap_options: Option<Arc<dyn IndexRemapperOptions>>,
},
/// Prune old version of datasets.
Prune {
/// The duration of time to keep versions of the dataset.
older_than: Duration,
/// Because they may be part of an in-progress transaction, files newer than 7 days old are not deleted by default.
/// If you are sure that there are no in-progress transactions, then you can set this to True to delete all files older than `older_than`.
delete_unverified: Option<bool>,
},
/// Optimize index.
Index(OptimizeOptions),
}
impl Default for OptimizeAction {
fn default() -> Self {
Self::All
}
}
/// Statistics about the optimization.
pub struct OptimizeStats {
/// Stats of the file compaction.
pub compaction: Option<CompactionMetrics>,
/// Stats of the version pruning
pub prune: Option<RemovalStats>,
}
/// A Table is a collection of strong typed Rows. /// A Table is a collection of strong typed Rows.
/// ///
@@ -131,6 +176,71 @@ pub trait Table: std::fmt::Display + Send + Sync {
/// ``` /// ```
fn create_index(&self, column: &[&str]) -> IndexBuilder; fn create_index(&self, column: &[&str]) -> IndexBuilder;
/// Create a builder for a merge insert operation
///
/// This operation can add rows, update rows, and remove rows all in a single
/// transaction. It is a very generic tool that can be used to create
/// behaviors like "insert if not exists", "update or insert (i.e. upsert)",
/// or even replace a portion of existing data with new data (e.g. replace
/// all data where month="january")
///
/// The merge insert operation works by combining new data from a
/// **source table** with existing data in a **target table** by using a
/// join. There are three categories of records.
///
/// "Matched" records are records that exist in both the source table and
/// the target table. "Not matched" records exist only in the source table
/// (e.g. these are new data) "Not matched by source" records exist only
/// in the target table (this is old data)
///
/// The builder returned by this method can be used to customize what
/// should happen for each category of data.
///
/// Please note that the data may appear to be reordered as part of this
/// operation. This is because updated rows will be deleted from the
/// dataset and then reinserted at the end with the new values.
///
/// # Arguments
///
/// * `on` One or more columns to join on. This is how records from the
/// source table and target table are matched. Typically this is some
/// kind of key or id column.
///
/// # Examples
///
/// ```no_run
/// # use std::sync::Arc;
/// # use vectordb::connection::{Database, Connection};
/// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
/// # RecordBatchIterator, Int32Array};
/// # use arrow_schema::{Schema, Field, DataType};
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let tmpdir = tempfile::tempdir().unwrap();
/// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
/// # let tbl = db.open_table("idx_test").await.unwrap();
/// # let schema = Arc::new(Schema::new(vec![
/// # Field::new("id", DataType::Int32, false),
/// # Field::new("vector", DataType::FixedSizeList(
/// # Arc::new(Field::new("item", DataType::Float32, true)), 128), true),
/// # ]));
/// let new_data = RecordBatchIterator::new(vec![
/// RecordBatch::try_new(schema.clone(),
/// vec![
/// Arc::new(Int32Array::from_iter_values(0..10)),
/// Arc::new(FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
/// (0..10).map(|_| Some(vec![Some(1.0); 128])), 128)),
/// ]).unwrap()
/// ].into_iter().map(Ok),
/// schema.clone());
/// // Perform an upsert operation
/// let mut merge_insert = tbl.merge_insert(&["id"]);
/// merge_insert.when_matched_update_all()
/// .when_not_matched_insert_all();
/// merge_insert.execute(Box::new(new_data)).await.unwrap();
/// # });
/// ```
fn merge_insert(&self, on: &[&str]) -> MergeInsertBuilder;
/// Search the table with a given query vector. /// Search the table with a given query vector.
/// ///
/// This is a convenience method for preparing an ANN query. /// This is a convenience method for preparing an ANN query.
@@ -194,6 +304,14 @@ pub trait Table: std::fmt::Display + Send + Sync {
/// # }); /// # });
/// ``` /// ```
fn query(&self) -> Query; fn query(&self) -> Query;
/// Optimize the on-disk data and indices for better performance.
///
/// <section class="warning">Experimental API</section>
///
/// Modeled after ``VACCUM`` in PostgreSQL.
/// Not all implementations support explicit optimization.
async fn optimize(&self, action: OptimizeAction) -> Result<OptimizeStats>;
} }
/// Reference to a Table pointer. /// Reference to a Table pointer.
@@ -396,17 +514,8 @@ impl NativeTable {
self.dataset.lock().expect("lock poison").version().version self.dataset.lock().expect("lock poison").version().version
} }
/// Create a scalar index on the table async fn optimize_indices(&self, options: &OptimizeOptions) -> Result<()> {
pub async fn create_scalar_index(&self, column: &str, replace: bool) -> Result<()> { info!("LanceDB: optimizing indices: {:?}", options);
let mut dataset = self.clone_inner_dataset();
let params = ScalarIndexParams::default();
dataset
.create_index(&[column], IndexType::Scalar, None, &params, replace)
.await?;
Ok(())
}
pub async fn optimize_indices(&mut self, options: &OptimizeOptions) -> Result<()> {
let mut dataset = self.clone_inner_dataset(); let mut dataset = self.clone_inner_dataset();
dataset.optimize_indices(options).await?; dataset.optimize_indices(options).await?;
@@ -463,7 +572,7 @@ impl NativeTable {
/// ///
/// This calls into [lance::dataset::Dataset::cleanup_old_versions] and /// This calls into [lance::dataset::Dataset::cleanup_old_versions] and
/// returns the result. /// returns the result.
pub async fn cleanup_old_versions( async fn cleanup_old_versions(
&self, &self,
older_than: Duration, older_than: Duration,
delete_unverified: Option<bool>, delete_unverified: Option<bool>,
@@ -480,7 +589,7 @@ impl NativeTable {
/// for faster reads. /// for faster reads.
/// ///
/// This calls into [lance::dataset::optimize::compact_files]. /// This calls into [lance::dataset::optimize::compact_files].
pub async fn compact_files( async fn compact_files(
&self, &self,
options: CompactionOptions, options: CompactionOptions,
remap_options: Option<Arc<dyn IndexRemapperOptions>>, remap_options: Option<Arc<dyn IndexRemapperOptions>>,
@@ -555,6 +664,42 @@ impl NativeTable {
} }
} }
#[async_trait]
impl MergeInsert for NativeTable {
async fn do_merge_insert(
&self,
params: MergeInsertBuilder,
new_data: Box<dyn RecordBatchReader + Send>,
) -> Result<()> {
let dataset = Arc::new(self.clone_inner_dataset());
let mut builder = LanceMergeInsertBuilder::try_new(dataset.clone(), params.on)?;
if params.when_matched_update_all {
builder.when_matched(lance::dataset::WhenMatched::UpdateAll);
} else {
builder.when_matched(lance::dataset::WhenMatched::DoNothing);
}
if params.when_not_matched_insert_all {
builder.when_not_matched(lance::dataset::WhenNotMatched::InsertAll);
} else {
builder.when_not_matched(lance::dataset::WhenNotMatched::DoNothing);
}
if params.when_not_matched_by_source_delete {
let behavior = if let Some(filter) = params.when_not_matched_by_source_delete_filt {
WhenNotMatchedBySource::delete_if(dataset.as_ref(), &filter)?
} else {
WhenNotMatchedBySource::Delete
};
builder.when_not_matched_by_source(behavior);
} else {
builder.when_not_matched_by_source(WhenNotMatchedBySource::Keep);
}
let job = builder.try_build()?;
let new_dataset = job.execute_reader(new_data).await?;
self.reset_dataset((*new_dataset).clone());
Ok(())
}
}
#[async_trait::async_trait] #[async_trait::async_trait]
impl Table for NativeTable { impl Table for NativeTable {
fn as_any(&self) -> &dyn std::any::Any { fn as_any(&self) -> &dyn std::any::Any {
@@ -599,6 +744,11 @@ impl Table for NativeTable {
Ok(()) Ok(())
} }
fn merge_insert(&self, on: &[&str]) -> MergeInsertBuilder {
let on = Vec::from_iter(on.iter().map(|key| key.to_string()));
MergeInsertBuilder::new(Arc::new(self.clone()), on)
}
fn create_index(&self, columns: &[&str]) -> IndexBuilder { fn create_index(&self, columns: &[&str]) -> IndexBuilder {
IndexBuilder::new(Arc::new(self.clone()), columns) IndexBuilder::new(Arc::new(self.clone()), columns)
} }
@@ -614,6 +764,52 @@ impl Table for NativeTable {
self.reset_dataset(dataset); self.reset_dataset(dataset);
Ok(()) Ok(())
} }
async fn optimize(&self, action: OptimizeAction) -> Result<OptimizeStats> {
let mut stats = OptimizeStats {
compaction: None,
prune: None,
};
match action {
OptimizeAction::All => {
stats.compaction = self
.optimize(OptimizeAction::Compact {
options: CompactionOptions::default(),
remap_options: None,
})
.await?
.compaction;
stats.prune = self
.optimize(OptimizeAction::Prune {
older_than: Duration::days(7),
delete_unverified: None,
})
.await?
.prune;
self.optimize(OptimizeAction::Index(OptimizeOptions::default()))
.await?;
}
OptimizeAction::Compact {
options,
remap_options,
} => {
stats.compaction = Some(self.compact_files(options, remap_options).await?);
}
OptimizeAction::Prune {
older_than,
delete_unverified,
} => {
stats.prune = Some(
self.cleanup_old_versions(older_than, delete_unverified)
.await?,
);
}
OptimizeAction::Index(options) => {
self.optimize_indices(&options).await?;
}
}
Ok(stats)
}
} }
#[cfg(test)] #[cfg(test)]
@@ -718,6 +914,38 @@ mod tests {
assert_eq!(table.name, "test"); assert_eq!(table.name, "test");
} }
#[tokio::test]
async fn test_merge_insert() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
// Create a dataset with i=0..10
let batches = make_test_batches_with_offset(0);
let table = NativeTable::create(&uri, "test", batches, None, None)
.await
.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
// Create new data with i=5..15
let new_batches = Box::new(make_test_batches_with_offset(5));
// Perform a "insert if not exists"
let mut merge_insert_builder = table.merge_insert(&["i"]);
merge_insert_builder.when_not_matched_insert_all();
merge_insert_builder.execute(new_batches).await.unwrap();
// Only 5 rows should actually be inserted
assert_eq!(table.count_rows().await.unwrap(), 15);
// Create new data with i=15..25 (no id matches)
let new_batches = Box::new(make_test_batches_with_offset(15));
// Perform a "bulk update" (should not affect anything)
let mut merge_insert_builder = table.merge_insert(&["i"]);
merge_insert_builder.when_matched_update_all();
merge_insert_builder.execute(new_batches).await.unwrap();
// No new rows should have been inserted
assert_eq!(table.count_rows().await.unwrap(), 15);
}
#[tokio::test] #[tokio::test]
async fn test_add_overwrite() { async fn test_add_overwrite() {
let tmp_dir = tempdir().unwrap(); let tmp_dir = tempdir().unwrap();
@@ -1064,17 +1292,25 @@ mod tests {
assert!(wrapper.called()); assert!(wrapper.called());
} }
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static { fn make_test_batches_with_offset(
offset: i32,
) -> impl RecordBatchReader + Send + Sync + 'static {
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)])); let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
RecordBatchIterator::new( RecordBatchIterator::new(
vec![RecordBatch::try_new( vec![RecordBatch::try_new(
schema.clone(), schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..10))], vec![Arc::new(Int32Array::from_iter_values(
offset..(offset + 10),
))],
)], )],
schema, schema,
) )
} }
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
make_test_batches_with_offset(0)
}
#[tokio::test] #[tokio::test]
async fn test_create_index() { async fn test_create_index() {
use arrow_array::RecordBatch; use arrow_array::RecordBatch;

View File

@@ -0,0 +1,95 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use arrow_array::RecordBatchReader;
use async_trait::async_trait;
use crate::Result;
#[async_trait]
pub(super) trait MergeInsert: Send + Sync {
async fn do_merge_insert(
&self,
params: MergeInsertBuilder,
new_data: Box<dyn RecordBatchReader + Send>,
) -> Result<()>;
}
/// A builder used to create and run a merge insert operation
///
/// See [`super::Table::merge_insert`] for more context
pub struct MergeInsertBuilder {
table: Arc<dyn MergeInsert>,
pub(super) on: Vec<String>,
pub(super) when_matched_update_all: bool,
pub(super) when_not_matched_insert_all: bool,
pub(super) when_not_matched_by_source_delete: bool,
pub(super) when_not_matched_by_source_delete_filt: Option<String>,
}
impl MergeInsertBuilder {
pub(super) fn new(table: Arc<dyn MergeInsert>, on: Vec<String>) -> Self {
Self {
table,
on,
when_matched_update_all: false,
when_not_matched_insert_all: false,
when_not_matched_by_source_delete: false,
when_not_matched_by_source_delete_filt: None,
}
}
/// Rows that exist in both the source table (new data) and
/// the target table (old data) will be updated, replacing
/// the old row with the corresponding matching row.
///
/// If there are multiple matches then the behavior is undefined.
/// Currently this causes multiple copies of the row to be created
/// but that behavior is subject to change.
pub fn when_matched_update_all(&mut self) -> &mut Self {
self.when_matched_update_all = true;
self
}
/// Rows that exist only in the source table (new data) should
/// be inserted into the target table.
pub fn when_not_matched_insert_all(&mut self) -> &mut Self {
self.when_not_matched_insert_all = true;
self
}
/// Rows that exist only in the target table (old data) will be
/// deleted. An optional condition can be provided to limit what
/// data is deleted.
///
/// # Arguments
///
/// * `condition` - If None then all such rows will be deleted.
/// Otherwise the condition will be used as an SQL filter to
/// limit what rows are deleted.
pub fn when_not_matched_by_source_delete(&mut self, filter: Option<String>) -> &mut Self {
self.when_not_matched_by_source_delete = true;
self.when_not_matched_by_source_delete_filt = filter;
self
}
/// Executes the merge insert operation
///
/// Nothing is returned but the [`super::Table`] is updated
pub async fn execute(self, new_data: Box<dyn RecordBatchReader + Send>) -> Result<()> {
self.table.clone().do_merge_insert(self, new_data).await
}
}