mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 13:29:57 +00:00
Compare commits
60 Commits
api-docs-f
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
82936c77ef | ||
|
|
dddcddcaf9 | ||
|
|
a9727eb318 | ||
|
|
48d55bf952 | ||
|
|
d2e71c8b08 | ||
|
|
f53aace89c | ||
|
|
d982ee934a | ||
|
|
57605a2d86 | ||
|
|
738511c5f2 | ||
|
|
0b0f42537e | ||
|
|
e412194008 | ||
|
|
a9088224c5 | ||
|
|
688c57a0d8 | ||
|
|
12a98deded | ||
|
|
e4bb042918 | ||
|
|
04e1662681 | ||
|
|
ce2242e06d | ||
|
|
778339388a | ||
|
|
7f8637a0b4 | ||
|
|
09cd08222d | ||
|
|
a248d7feec | ||
|
|
cc9473a94a | ||
|
|
d77e95a4f4 | ||
|
|
62f053ac92 | ||
|
|
34e10caad2 | ||
|
|
f5726e2d0c | ||
|
|
12b4fb42fc | ||
|
|
1328cd46f1 | ||
|
|
0c940ed9f8 | ||
|
|
5f59e51583 | ||
|
|
8d0ea29f89 | ||
|
|
b9468bb980 | ||
|
|
a42df158a3 | ||
|
|
9df6905d86 | ||
|
|
3ffed89793 | ||
|
|
f150768739 | ||
|
|
b432ecf2f6 | ||
|
|
d1a7257810 | ||
|
|
5c5e23bbb9 | ||
|
|
e5796a4836 | ||
|
|
b9c5323265 | ||
|
|
e41a52863a | ||
|
|
13acc8a480 | ||
|
|
22b9eceb12 | ||
|
|
5f62302614 | ||
|
|
d84e0d1db8 | ||
|
|
ac94b2a420 | ||
|
|
b49bc113c4 | ||
|
|
77b5b1cf0e | ||
|
|
e910809de0 | ||
|
|
90b5b55126 | ||
|
|
488e4f8452 | ||
|
|
ba6f949515 | ||
|
|
3dd8522bc9 | ||
|
|
e01ef63488 | ||
|
|
a6cf24b359 | ||
|
|
9a07c9aad8 | ||
|
|
d405798952 | ||
|
|
e8a8b92b2a | ||
|
|
66362c6506 |
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.4.4
|
current_version = 0.4.8
|
||||||
commit = True
|
commit = True
|
||||||
message = Bump version: {current_version} → {new_version}
|
message = Bump version: {current_version} → {new_version}
|
||||||
tag = True
|
tag = True
|
||||||
|
|||||||
34
.cargo/config.toml
Normal file
34
.cargo/config.toml
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
[profile.release]
|
||||||
|
lto = "fat"
|
||||||
|
codegen-units = 1
|
||||||
|
|
||||||
|
[profile.release-with-debug]
|
||||||
|
inherits = "release"
|
||||||
|
debug = true
|
||||||
|
# Prioritize compile time over runtime performance
|
||||||
|
codegen-units = 16
|
||||||
|
lto = "thin"
|
||||||
|
|
||||||
|
[target.'cfg(all())']
|
||||||
|
rustflags = [
|
||||||
|
"-Wclippy::all",
|
||||||
|
"-Wclippy::style",
|
||||||
|
"-Wclippy::fallible_impl_from",
|
||||||
|
"-Wclippy::manual_let_else",
|
||||||
|
"-Wclippy::redundant_pub_crate",
|
||||||
|
"-Wclippy::string_add_assign",
|
||||||
|
"-Wclippy::string_add",
|
||||||
|
"-Wclippy::string_lit_as_bytes",
|
||||||
|
"-Wclippy::string_to_string",
|
||||||
|
"-Wclippy::use_self",
|
||||||
|
"-Dclippy::cargo",
|
||||||
|
"-Dclippy::dbg_macro",
|
||||||
|
# not too much we can do to avoid multiple crate versions
|
||||||
|
"-Aclippy::multiple-crate-versions",
|
||||||
|
]
|
||||||
|
|
||||||
|
[target.x86_64-unknown-linux-gnu]
|
||||||
|
rustflags = ["-C", "target-cpu=haswell", "-C", "target-feature=+avx2,+fma,+f16c"]
|
||||||
|
|
||||||
|
[target.aarch64-apple-darwin]
|
||||||
|
rustflags = ["-C", "target-cpu=apple-m1", "-C", "target-feature=+neon,+fp16,+fhm,+dotprod"]
|
||||||
2
.github/workflows/cargo-publish.yml
vendored
2
.github/workflows/cargo-publish.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
|||||||
# Only runs on tags that matches the make-release action
|
# Only runs on tags that matches the make-release action
|
||||||
if: startsWith(github.ref, 'refs/tags/v')
|
if: startsWith(github.ref, 'refs/tags/v')
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
- uses: Swatinem/rust-cache@v2
|
- uses: Swatinem/rust-cache@v2
|
||||||
with:
|
with:
|
||||||
workspaces: rust
|
workspaces: rust
|
||||||
|
|||||||
9
.github/workflows/docs.yml
vendored
9
.github/workflows/docs.yml
vendored
@@ -27,9 +27,9 @@ jobs:
|
|||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
cache: "pip"
|
cache: "pip"
|
||||||
@@ -42,7 +42,7 @@ jobs:
|
|||||||
- name: Set up node
|
- name: Set up node
|
||||||
uses: actions/setup-node@v3
|
uses: actions/setup-node@v3
|
||||||
with:
|
with:
|
||||||
node-version: ${{ matrix.node-version }}
|
node-version: 20
|
||||||
cache: 'npm'
|
cache: 'npm'
|
||||||
cache-dependency-path: node/package-lock.json
|
cache-dependency-path: node/package-lock.json
|
||||||
- uses: Swatinem/rust-cache@v2
|
- uses: Swatinem/rust-cache@v2
|
||||||
@@ -62,8 +62,9 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
npx typedoc --plugin typedoc-plugin-markdown --out ../docs/src/javascript src/index.ts
|
npx typedoc --plugin typedoc-plugin-markdown --out ../docs/src/javascript src/index.ts
|
||||||
- name: Build docs
|
- name: Build docs
|
||||||
|
working-directory: docs
|
||||||
run: |
|
run: |
|
||||||
PYTHONPATH=. mkdocs build -f docs/mkdocs.yml
|
PYTHONPATH=. mkdocs build
|
||||||
- name: Setup Pages
|
- name: Setup Pages
|
||||||
uses: actions/configure-pages@v2
|
uses: actions/configure-pages@v2
|
||||||
- name: Upload artifact
|
- name: Upload artifact
|
||||||
|
|||||||
53
.github/workflows/docs_test.yml
vendored
53
.github/workflows/docs_test.yml
vendored
@@ -18,24 +18,20 @@ on:
|
|||||||
env:
|
env:
|
||||||
# Disable full debug symbol generation to speed up CI build and keep memory down
|
# Disable full debug symbol generation to speed up CI build and keep memory down
|
||||||
# "1" means line tables only, which is useful for panic tracebacks.
|
# "1" means line tables only, which is useful for panic tracebacks.
|
||||||
RUSTFLAGS: "-C debuginfo=1"
|
RUSTFLAGS: "-C debuginfo=1 -C target-cpu=native -C target-feature=+f16c,+avx2,+fma"
|
||||||
RUST_BACKTRACE: "1"
|
RUST_BACKTRACE: "1"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test-python:
|
test-python:
|
||||||
name: Test doc python code
|
name: Test doc python code
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: "ubuntu-latest"
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
python-minor-version: [ "11" ]
|
|
||||||
os: ["ubuntu-22.04"]
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: 3.${{ matrix.python-minor-version }}
|
python-version: 3.11
|
||||||
cache: "pip"
|
cache: "pip"
|
||||||
cache-dependency-path: "docs/test/requirements.txt"
|
cache-dependency-path: "docs/test/requirements.txt"
|
||||||
- name: Build Python
|
- name: Build Python
|
||||||
@@ -52,45 +48,42 @@ jobs:
|
|||||||
for d in *; do cd "$d"; echo "$d".py; python "$d".py; cd ..; done
|
for d in *; do cd "$d"; echo "$d".py; python "$d".py; cd ..; done
|
||||||
test-node:
|
test-node:
|
||||||
name: Test doc nodejs code
|
name: Test doc nodejs code
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: "ubuntu-latest"
|
||||||
|
timeout-minutes: 45
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
fail-fast: false
|
||||||
node-version: [ "18" ]
|
|
||||||
os: ["ubuntu-22.04"]
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
- name: Set up Node
|
- name: Set up Node
|
||||||
uses: actions/setup-node@v3
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: ${{ matrix.node-version }}
|
node-version: 20
|
||||||
- name: Install dependecies needed for ubuntu
|
- name: Install dependecies needed for ubuntu
|
||||||
if: ${{ matrix.os == 'ubuntu-22.04' }}
|
|
||||||
run: |
|
run: |
|
||||||
sudo apt install -y protobuf-compiler libssl-dev
|
sudo apt install -y protobuf-compiler libssl-dev
|
||||||
- name: Install node dependencies
|
|
||||||
run: |
|
|
||||||
cd docs/test
|
|
||||||
npm install
|
|
||||||
- name: Rust cache
|
- name: Rust cache
|
||||||
uses: swatinem/rust-cache@v2
|
uses: swatinem/rust-cache@v2
|
||||||
- name: Install LanceDB
|
- name: Install node dependencies
|
||||||
run: |
|
run: |
|
||||||
cd docs/test/node_modules/vectordb
|
sudo swapoff -a
|
||||||
|
sudo fallocate -l 8G /swapfile
|
||||||
|
sudo chmod 600 /swapfile
|
||||||
|
sudo mkswap /swapfile
|
||||||
|
sudo swapon /swapfile
|
||||||
|
sudo swapon --show
|
||||||
|
cd node
|
||||||
npm ci
|
npm ci
|
||||||
npm run build-release
|
npm run build-release
|
||||||
npm run tsc
|
cd ../docs
|
||||||
- name: Create test files
|
npm install
|
||||||
run: |
|
|
||||||
cd docs/test
|
|
||||||
node md_testing.js
|
|
||||||
- name: Test
|
- name: Test
|
||||||
env:
|
env:
|
||||||
LANCEDB_URI: ${{ secrets.LANCEDB_URI }}
|
LANCEDB_URI: ${{ secrets.LANCEDB_URI }}
|
||||||
LANCEDB_DEV_API_KEY: ${{ secrets.LANCEDB_DEV_API_KEY }}
|
LANCEDB_DEV_API_KEY: ${{ secrets.LANCEDB_DEV_API_KEY }}
|
||||||
run: |
|
run: |
|
||||||
cd docs/test/node
|
cd docs
|
||||||
for d in *; do cd "$d"; echo "$d".js; node "$d".js; cd ..; done
|
npm t
|
||||||
|
|||||||
8
.github/workflows/make-release-commit.yml
vendored
8
.github/workflows/make-release-commit.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Check out main
|
- name: Check out main
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
ref: main
|
ref: main
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
@@ -37,10 +37,10 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
git config user.name 'Lance Release'
|
git config user.name 'Lance Release'
|
||||||
git config user.email 'lance-dev@lancedb.com'
|
git config user.email 'lance-dev@lancedb.com'
|
||||||
- name: Set up Python 3.10
|
- name: Set up Python 3.11
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.11"
|
||||||
- name: Bump version, create tag and commit
|
- name: Bump version, create tag and commit
|
||||||
run: |
|
run: |
|
||||||
pip install bump2version
|
pip install bump2version
|
||||||
|
|||||||
8
.github/workflows/node.yml
vendored
8
.github/workflows/node.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
|||||||
shell: bash
|
shell: bash
|
||||||
working-directory: node
|
working-directory: node
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
@@ -57,7 +57,7 @@ jobs:
|
|||||||
shell: bash
|
shell: bash
|
||||||
working-directory: node
|
working-directory: node
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
@@ -89,7 +89,7 @@ jobs:
|
|||||||
shell: bash
|
shell: bash
|
||||||
working-directory: node
|
working-directory: node
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
@@ -128,7 +128,7 @@ jobs:
|
|||||||
# this one is for dynamodb
|
# this one is for dynamodb
|
||||||
DYNAMODB_ENDPOINT: http://localhost:4566
|
DYNAMODB_ENDPOINT: http://localhost:4566
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
|
|||||||
8
.github/workflows/nodejs.yml
vendored
8
.github/workflows/nodejs.yml
vendored
@@ -29,7 +29,7 @@ jobs:
|
|||||||
shell: bash
|
shell: bash
|
||||||
working-directory: nodejs
|
working-directory: nodejs
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
@@ -61,7 +61,7 @@ jobs:
|
|||||||
shell: bash
|
shell: bash
|
||||||
working-directory: nodejs
|
working-directory: nodejs
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
@@ -84,13 +84,13 @@ jobs:
|
|||||||
run: npm run test
|
run: npm run test
|
||||||
macos:
|
macos:
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
runs-on: "macos-13"
|
runs-on: "macos-14"
|
||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
shell: bash
|
shell: bash
|
||||||
working-directory: nodejs
|
working-directory: nodejs
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
|
|||||||
12
.github/workflows/npm-publish.yml
vendored
12
.github/workflows/npm-publish.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
|||||||
working-directory: node
|
working-directory: node
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
- uses: actions/setup-node@v3
|
- uses: actions/setup-node@v3
|
||||||
with:
|
with:
|
||||||
node-version: 20
|
node-version: 20
|
||||||
@@ -45,13 +45,13 @@ jobs:
|
|||||||
runner: macos-13
|
runner: macos-13
|
||||||
- arch: aarch64-apple-darwin
|
- arch: aarch64-apple-darwin
|
||||||
# xlarge is implicitly arm64.
|
# xlarge is implicitly arm64.
|
||||||
runner: macos-13-xlarge
|
runner: macos-14
|
||||||
runs-on: ${{ matrix.config.runner }}
|
runs-on: ${{ matrix.config.runner }}
|
||||||
# Only runs on tags that matches the make-release action
|
# Only runs on tags that matches the make-release action
|
||||||
if: startsWith(github.ref, 'refs/tags/v')
|
if: startsWith(github.ref, 'refs/tags/v')
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
- name: Install system dependencies
|
- name: Install system dependencies
|
||||||
run: brew install protobuf
|
run: brew install protobuf
|
||||||
- name: Install npm dependencies
|
- name: Install npm dependencies
|
||||||
@@ -83,7 +83,7 @@ jobs:
|
|||||||
runner: buildjet-4vcpu-ubuntu-2204-arm
|
runner: buildjet-4vcpu-ubuntu-2204-arm
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
- name: Build Linux Artifacts
|
- name: Build Linux Artifacts
|
||||||
run: |
|
run: |
|
||||||
bash ci/build_linux_artifacts.sh ${{ matrix.config.arch }}
|
bash ci/build_linux_artifacts.sh ${{ matrix.config.arch }}
|
||||||
@@ -104,7 +104,7 @@ jobs:
|
|||||||
target: [x86_64-pc-windows-msvc]
|
target: [x86_64-pc-windows-msvc]
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
- name: Install Protoc v21.12
|
- name: Install Protoc v21.12
|
||||||
working-directory: C:\
|
working-directory: C:\
|
||||||
run: |
|
run: |
|
||||||
@@ -154,7 +154,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
ref: main
|
ref: main
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|||||||
4
.github/workflows/pypi-publish.yml
vendored
4
.github/workflows/pypi-publish.yml
vendored
@@ -14,9 +14,9 @@ jobs:
|
|||||||
shell: bash
|
shell: bash
|
||||||
working-directory: python
|
working-directory: python
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.8"
|
python-version: "3.8"
|
||||||
- name: Build distribution
|
- name: Build distribution
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Check out main
|
- name: Check out main
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
ref: main
|
ref: main
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
@@ -37,10 +37,10 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
git config user.name 'Lance Release'
|
git config user.name 'Lance Release'
|
||||||
git config user.email 'lance-dev@lancedb.com'
|
git config user.email 'lance-dev@lancedb.com'
|
||||||
- name: Set up Python 3.10
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.11"
|
||||||
- name: Bump version, create tag and commit
|
- name: Bump version, create tag and commit
|
||||||
working-directory: python
|
working-directory: python
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
16
.github/workflows/python.yml
vendored
16
.github/workflows/python.yml
vendored
@@ -18,19 +18,19 @@ jobs:
|
|||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-minor-version: [ "8", "9", "10", "11" ]
|
python-minor-version: [ "8", "11" ]
|
||||||
runs-on: "ubuntu-22.04"
|
runs-on: "ubuntu-22.04"
|
||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
shell: bash
|
shell: bash
|
||||||
working-directory: python
|
working-directory: python
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: 3.${{ matrix.python-minor-version }}
|
python-version: 3.${{ matrix.python-minor-version }}
|
||||||
- name: Install lancedb
|
- name: Install lancedb
|
||||||
@@ -55,7 +55,7 @@ jobs:
|
|||||||
- name: x86 Mac
|
- name: x86 Mac
|
||||||
runner: macos-13
|
runner: macos-13
|
||||||
- name: Arm Mac
|
- name: Arm Mac
|
||||||
runner: macos-13-xlarge
|
runner: macos-14
|
||||||
- name: x86 Windows
|
- name: x86 Windows
|
||||||
runner: windows-latest
|
runner: windows-latest
|
||||||
runs-on: "${{ matrix.config.runner }}"
|
runs-on: "${{ matrix.config.runner }}"
|
||||||
@@ -64,12 +64,12 @@ jobs:
|
|||||||
shell: bash
|
shell: bash
|
||||||
working-directory: python
|
working-directory: python
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
- name: Install lancedb
|
- name: Install lancedb
|
||||||
@@ -87,12 +87,12 @@ jobs:
|
|||||||
shell: bash
|
shell: bash
|
||||||
working-directory: python
|
working-directory: python
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: 3.9
|
python-version: 3.9
|
||||||
- name: Install lancedb
|
- name: Install lancedb
|
||||||
|
|||||||
12
.github/workflows/rust.yml
vendored
12
.github/workflows/rust.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
|||||||
shell: bash
|
shell: bash
|
||||||
working-directory: rust
|
working-directory: rust
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
@@ -55,7 +55,7 @@ jobs:
|
|||||||
shell: bash
|
shell: bash
|
||||||
working-directory: rust
|
working-directory: rust
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
@@ -70,18 +70,20 @@ jobs:
|
|||||||
run: cargo build --all-features
|
run: cargo build --all-features
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: cargo test --all-features
|
run: cargo test --all-features
|
||||||
|
- name: Run examples
|
||||||
|
run: cargo run --example simple
|
||||||
macos:
|
macos:
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
mac-runner: [ "macos-13", "macos-13-xlarge" ]
|
mac-runner: [ "macos-13", "macos-14" ]
|
||||||
runs-on: "${{ matrix.mac-runner }}"
|
runs-on: "${{ matrix.mac-runner }}"
|
||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
shell: bash
|
shell: bash
|
||||||
working-directory: rust
|
working-directory: rust
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
@@ -99,7 +101,7 @@ jobs:
|
|||||||
windows:
|
windows:
|
||||||
runs-on: windows-2022
|
runs-on: windows-2022
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
- uses: Swatinem/rust-cache@v2
|
- uses: Swatinem/rust-cache@v2
|
||||||
with:
|
with:
|
||||||
workspaces: rust
|
workspaces: rust
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
ref: main
|
ref: main
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|||||||
29
Cargo.toml
29
Cargo.toml
@@ -6,24 +6,27 @@ resolver = "2"
|
|||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Lance Devs <dev@lancedb.com>"]
|
authors = ["LanceDB Devs <dev@lancedb.com>"]
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
repository = "https://github.com/lancedb/lancedb"
|
repository = "https://github.com/lancedb/lancedb"
|
||||||
|
description = "Serverless, low-latency vector database for AI applications"
|
||||||
|
keywords = ["lancedb", "lance", "database", "vector", "search"]
|
||||||
|
categories = ["database-implementations"]
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
lance = { "version" = "=0.9.9", "features" = ["dynamodb"] }
|
lance = { "version" = "=0.9.15", "features" = ["dynamodb"] }
|
||||||
lance-index = { "version" = "=0.9.9" }
|
lance-index = { "version" = "=0.9.15" }
|
||||||
lance-linalg = { "version" = "=0.9.9" }
|
lance-linalg = { "version" = "=0.9.15" }
|
||||||
lance-testing = { "version" = "=0.9.9" }
|
lance-testing = { "version" = "=0.9.15" }
|
||||||
# Note that this one does not include pyarrow
|
# Note that this one does not include pyarrow
|
||||||
arrow = { version = "49.0.0", optional = false }
|
arrow = { version = "50.0", optional = false }
|
||||||
arrow-array = "49.0"
|
arrow-array = "50.0"
|
||||||
arrow-data = "49.0"
|
arrow-data = "50.0"
|
||||||
arrow-ipc = "49.0"
|
arrow-ipc = "50.0"
|
||||||
arrow-ord = "49.0"
|
arrow-ord = "50.0"
|
||||||
arrow-schema = "49.0"
|
arrow-schema = "50.0"
|
||||||
arrow-arith = "49.0"
|
arrow-arith = "50.0"
|
||||||
arrow-cast = "49.0"
|
arrow-cast = "50.0"
|
||||||
async-trait = "0"
|
async-trait = "0"
|
||||||
chrono = "0.4.23"
|
chrono = "0.4.23"
|
||||||
half = { "version" = "=2.3.1", default-features = false, features = [
|
half = { "version" = "=2.3.1", default-features = false, features = [
|
||||||
|
|||||||
13
README.md
13
README.md
@@ -51,12 +51,19 @@ npm install vectordb
|
|||||||
const lancedb = require('vectordb');
|
const lancedb = require('vectordb');
|
||||||
const db = await lancedb.connect('data/sample-lancedb');
|
const db = await lancedb.connect('data/sample-lancedb');
|
||||||
|
|
||||||
const table = await db.createTable('vectors',
|
const table = await db.createTable({
|
||||||
[{ id: 1, vector: [0.1, 0.2], item: "foo", price: 10 },
|
name: 'vectors',
|
||||||
{ id: 2, vector: [1.1, 1.2], item: "bar", price: 50 }])
|
data: [
|
||||||
|
{ id: 1, vector: [0.1, 0.2], item: "foo", price: 10 },
|
||||||
|
{ id: 2, vector: [1.1, 1.2], item: "bar", price: 50 }
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
const query = table.search([0.1, 0.3]).limit(2);
|
const query = table.search([0.1, 0.3]).limit(2);
|
||||||
const results = await query.execute();
|
const results = await query.execute();
|
||||||
|
|
||||||
|
// You can also search for rows by specific criteria without involving a vector search.
|
||||||
|
const rowsByCriteria = await table.search(undefined).where("price >= 10").execute();
|
||||||
```
|
```
|
||||||
|
|
||||||
**Python**
|
**Python**
|
||||||
|
|||||||
@@ -33,3 +33,12 @@ You can run a local server to test the docs prior to deployment by navigating to
|
|||||||
cd docs
|
cd docs
|
||||||
mkdocs serve
|
mkdocs serve
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Run doctest for typescript example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd lancedb/docs
|
||||||
|
npm i
|
||||||
|
npm run build
|
||||||
|
npm run all
|
||||||
|
```
|
||||||
|
|||||||
@@ -67,7 +67,9 @@ markdown_extensions:
|
|||||||
line_spans: __span
|
line_spans: __span
|
||||||
pygments_lang_class: true
|
pygments_lang_class: true
|
||||||
- pymdownx.inlinehilite
|
- pymdownx.inlinehilite
|
||||||
- pymdownx.snippets
|
- pymdownx.snippets:
|
||||||
|
base_path: ..
|
||||||
|
dedent_subsections: true
|
||||||
- pymdownx.superfences
|
- pymdownx.superfences
|
||||||
- pymdownx.tabbed:
|
- pymdownx.tabbed:
|
||||||
alternate_style: true
|
alternate_style: true
|
||||||
@@ -88,6 +90,7 @@ nav:
|
|||||||
- Building an ANN index: ann_indexes.md
|
- Building an ANN index: ann_indexes.md
|
||||||
- Vector Search: search.md
|
- Vector Search: search.md
|
||||||
- Full-text search: fts.md
|
- Full-text search: fts.md
|
||||||
|
- Hybrid search: hybrid_search.md
|
||||||
- Filtering: sql.md
|
- Filtering: sql.md
|
||||||
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||||
- Configuring Storage: guides/storage.md
|
- Configuring Storage: guides/storage.md
|
||||||
@@ -130,6 +133,7 @@ nav:
|
|||||||
- ⚙️ API reference:
|
- ⚙️ API reference:
|
||||||
- 🐍 Python: python/python.md
|
- 🐍 Python: python/python.md
|
||||||
- 👾 JavaScript: javascript/modules.md
|
- 👾 JavaScript: javascript/modules.md
|
||||||
|
- 🦀 Rust: https://docs.rs/vectordb/latest/vectordb/
|
||||||
- ☁️ LanceDB Cloud:
|
- ☁️ LanceDB Cloud:
|
||||||
- Overview: cloud/index.md
|
- Overview: cloud/index.md
|
||||||
- API reference:
|
- API reference:
|
||||||
@@ -148,6 +152,7 @@ nav:
|
|||||||
- Building an ANN index: ann_indexes.md
|
- Building an ANN index: ann_indexes.md
|
||||||
- Vector Search: search.md
|
- Vector Search: search.md
|
||||||
- Full-text search: fts.md
|
- Full-text search: fts.md
|
||||||
|
- Hybrid search: hybrid_search.md
|
||||||
- Filtering: sql.md
|
- Filtering: sql.md
|
||||||
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||||
- Configuring Storage: guides/storage.md
|
- Configuring Storage: guides/storage.md
|
||||||
@@ -195,6 +200,9 @@ extra_css:
|
|||||||
- styles/global.css
|
- styles/global.css
|
||||||
- styles/extra.css
|
- styles/extra.css
|
||||||
|
|
||||||
|
extra_javascript:
|
||||||
|
- "extra_js/init_ask_ai_widget.js"
|
||||||
|
|
||||||
extra:
|
extra:
|
||||||
analytics:
|
analytics:
|
||||||
provider: google
|
provider: google
|
||||||
|
|||||||
132
docs/package-lock.json
generated
Normal file
132
docs/package-lock.json
generated
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
{
|
||||||
|
"name": "lancedb-docs-test",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"lockfileVersion": 3,
|
||||||
|
"requires": true,
|
||||||
|
"packages": {
|
||||||
|
"": {
|
||||||
|
"name": "lancedb-docs-test",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"license": "Apache 2",
|
||||||
|
"dependencies": {
|
||||||
|
"apache-arrow": "file:../node/node_modules/apache-arrow",
|
||||||
|
"vectordb": "file:../node"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"@types/node": "^20.11.8",
|
||||||
|
"typescript": "^5.3.3"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"../node": {
|
||||||
|
"name": "vectordb",
|
||||||
|
"version": "0.4.6",
|
||||||
|
"cpu": [
|
||||||
|
"x64",
|
||||||
|
"arm64"
|
||||||
|
],
|
||||||
|
"license": "Apache-2.0",
|
||||||
|
"os": [
|
||||||
|
"darwin",
|
||||||
|
"linux",
|
||||||
|
"win32"
|
||||||
|
],
|
||||||
|
"dependencies": {
|
||||||
|
"@apache-arrow/ts": "^14.0.2",
|
||||||
|
"@neon-rs/load": "^0.0.74",
|
||||||
|
"apache-arrow": "^14.0.2",
|
||||||
|
"axios": "^1.4.0"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"@neon-rs/cli": "^0.0.160",
|
||||||
|
"@types/chai": "^4.3.4",
|
||||||
|
"@types/chai-as-promised": "^7.1.5",
|
||||||
|
"@types/mocha": "^10.0.1",
|
||||||
|
"@types/node": "^18.16.2",
|
||||||
|
"@types/sinon": "^10.0.15",
|
||||||
|
"@types/temp": "^0.9.1",
|
||||||
|
"@types/uuid": "^9.0.3",
|
||||||
|
"@typescript-eslint/eslint-plugin": "^5.59.1",
|
||||||
|
"cargo-cp-artifact": "^0.1",
|
||||||
|
"chai": "^4.3.7",
|
||||||
|
"chai-as-promised": "^7.1.1",
|
||||||
|
"eslint": "^8.39.0",
|
||||||
|
"eslint-config-standard-with-typescript": "^34.0.1",
|
||||||
|
"eslint-plugin-import": "^2.26.0",
|
||||||
|
"eslint-plugin-n": "^15.7.0",
|
||||||
|
"eslint-plugin-promise": "^6.1.1",
|
||||||
|
"mocha": "^10.2.0",
|
||||||
|
"openai": "^4.24.1",
|
||||||
|
"sinon": "^15.1.0",
|
||||||
|
"temp": "^0.9.4",
|
||||||
|
"ts-node": "^10.9.1",
|
||||||
|
"ts-node-dev": "^2.0.0",
|
||||||
|
"typedoc": "^0.24.7",
|
||||||
|
"typedoc-plugin-markdown": "^3.15.3",
|
||||||
|
"typescript": "*",
|
||||||
|
"uuid": "^9.0.0"
|
||||||
|
},
|
||||||
|
"optionalDependencies": {
|
||||||
|
"@lancedb/vectordb-darwin-arm64": "0.4.6",
|
||||||
|
"@lancedb/vectordb-darwin-x64": "0.4.6",
|
||||||
|
"@lancedb/vectordb-linux-arm64-gnu": "0.4.6",
|
||||||
|
"@lancedb/vectordb-linux-x64-gnu": "0.4.6",
|
||||||
|
"@lancedb/vectordb-win32-x64-msvc": "0.4.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"../node/node_modules/apache-arrow": {
|
||||||
|
"version": "14.0.2",
|
||||||
|
"license": "Apache-2.0",
|
||||||
|
"dependencies": {
|
||||||
|
"@types/command-line-args": "5.2.0",
|
||||||
|
"@types/command-line-usage": "5.0.2",
|
||||||
|
"@types/node": "20.3.0",
|
||||||
|
"@types/pad-left": "2.1.1",
|
||||||
|
"command-line-args": "5.2.1",
|
||||||
|
"command-line-usage": "7.0.1",
|
||||||
|
"flatbuffers": "23.5.26",
|
||||||
|
"json-bignum": "^0.0.3",
|
||||||
|
"pad-left": "^2.1.0",
|
||||||
|
"tslib": "^2.5.3"
|
||||||
|
},
|
||||||
|
"bin": {
|
||||||
|
"arrow2csv": "bin/arrow2csv.js"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/@types/node": {
|
||||||
|
"version": "20.11.8",
|
||||||
|
"resolved": "https://registry.npmjs.org/@types/node/-/node-20.11.8.tgz",
|
||||||
|
"integrity": "sha512-i7omyekpPTNdv4Jb/Rgqg0RU8YqLcNsI12quKSDkRXNfx7Wxdm6HhK1awT3xTgEkgxPn3bvnSpiEAc7a7Lpyow==",
|
||||||
|
"dev": true,
|
||||||
|
"dependencies": {
|
||||||
|
"undici-types": "~5.26.4"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/apache-arrow": {
|
||||||
|
"resolved": "../node/node_modules/apache-arrow",
|
||||||
|
"link": true
|
||||||
|
},
|
||||||
|
"node_modules/typescript": {
|
||||||
|
"version": "5.3.3",
|
||||||
|
"resolved": "https://registry.npmjs.org/typescript/-/typescript-5.3.3.tgz",
|
||||||
|
"integrity": "sha512-pXWcraxM0uxAS+tN0AG/BF2TyqmHO014Z070UsJ+pFvYuRSq8KH8DmWpnbXe0pEPDHXZV3FcAbJkijJ5oNEnWw==",
|
||||||
|
"dev": true,
|
||||||
|
"bin": {
|
||||||
|
"tsc": "bin/tsc",
|
||||||
|
"tsserver": "bin/tsserver"
|
||||||
|
},
|
||||||
|
"engines": {
|
||||||
|
"node": ">=14.17"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/undici-types": {
|
||||||
|
"version": "5.26.5",
|
||||||
|
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz",
|
||||||
|
"integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==",
|
||||||
|
"dev": true
|
||||||
|
},
|
||||||
|
"node_modules/vectordb": {
|
||||||
|
"resolved": "../node",
|
||||||
|
"link": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
20
docs/package.json
Normal file
20
docs/package.json
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
{
|
||||||
|
"name": "lancedb-docs-test",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"description": "auto-generated tests from doc",
|
||||||
|
"author": "dev@lancedb.com",
|
||||||
|
"license": "Apache 2",
|
||||||
|
"dependencies": {
|
||||||
|
"apache-arrow": "file:../node/node_modules/apache-arrow",
|
||||||
|
"vectordb": "file:../node"
|
||||||
|
},
|
||||||
|
"scripts": {
|
||||||
|
"build": "tsc -b && cd ../node && npm run build-release",
|
||||||
|
"example": "npm run build && node",
|
||||||
|
"test": "npm run build && ls dist/*.js | xargs -n 1 node"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"@types/node": "^20.11.8",
|
||||||
|
"typescript": "^5.3.3"
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,7 +7,7 @@ for brute-force scanning of the entire vector space.
|
|||||||
A vector index is faster but less accurate than exhaustive search (kNN or flat search).
|
A vector index is faster but less accurate than exhaustive search (kNN or flat search).
|
||||||
LanceDB provides many parameters to fine-tune the index's size, the speed of queries, and the accuracy of results.
|
LanceDB provides many parameters to fine-tune the index's size, the speed of queries, and the accuracy of results.
|
||||||
|
|
||||||
Currently, LanceDB does *not* automatically create the ANN index.
|
Currently, LanceDB does _not_ automatically create the ANN index.
|
||||||
LanceDB has optimized code for kNN as well. For many use-cases, datasets under 100K vectors won't require index creation at all.
|
LanceDB has optimized code for kNN as well. For many use-cases, datasets under 100K vectors won't require index creation at all.
|
||||||
If you can live with <100ms latency, skipping index creation is a simpler workflow while guaranteeing 100% recall.
|
If you can live with <100ms latency, skipping index creation is a simpler workflow while guaranteeing 100% recall.
|
||||||
|
|
||||||
@@ -17,16 +17,17 @@ In the future we will look to automatically create and configure the ANN index a
|
|||||||
|
|
||||||
Lance can support multiple index types, the most widely used one is `IVF_PQ`.
|
Lance can support multiple index types, the most widely used one is `IVF_PQ`.
|
||||||
|
|
||||||
* `IVF_PQ`: use **Inverted File Index (IVF)** to first divide the dataset into `N` partitions,
|
- `IVF_PQ`: use **Inverted File Index (IVF)** to first divide the dataset into `N` partitions,
|
||||||
and then use **Product Quantization** to compress vectors in each partition.
|
and then use **Product Quantization** to compress vectors in each partition.
|
||||||
* `DiskANN` (**Experimental**): organize the vector as a on-disk graph, where the vertices approximately
|
- `DiskANN` (**Experimental**): organize the vector as a on-disk graph, where the vertices approximately
|
||||||
represent the nearest neighbors of each vector.
|
represent the nearest neighbors of each vector.
|
||||||
|
|
||||||
## Creating an IVF_PQ Index
|
## Creating an IVF_PQ Index
|
||||||
|
|
||||||
Lance supports `IVF_PQ` index type by default.
|
Lance supports `IVF_PQ` index type by default.
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
|
|
||||||
Creating indexes is done via the [create_index](https://lancedb.github.io/lancedb/python/#lancedb.table.LanceTable.create_index) method.
|
Creating indexes is done via the [create_index](https://lancedb.github.io/lancedb/python/#lancedb.table.LanceTable.create_index) method.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@@ -46,25 +47,20 @@ Lance supports `IVF_PQ` index type by default.
|
|||||||
tbl.create_index(num_partitions=256, num_sub_vectors=96)
|
tbl.create_index(num_partitions=256, num_sub_vectors=96)
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "Javascript"
|
=== "Typescript"
|
||||||
```javascript
|
|
||||||
const vectordb = require('vectordb')
|
|
||||||
const db = await vectordb.connect('data/sample-lancedb')
|
|
||||||
|
|
||||||
let data = []
|
```typescript
|
||||||
for (let i = 0; i < 10_000; i++) {
|
--8<--- "docs/src/ann_indexes.ts:import"
|
||||||
data.push({vector: Array(1536).fill(i), id: `${i}`, content: "", longId: `${i}`},)
|
|
||||||
}
|
--8<-- "docs/src/ann_indexes.ts:ingest"
|
||||||
const table = await db.createTable('my_vectors', data)
|
|
||||||
await table.createIndex({ type: 'ivf_pq', column: 'vector', num_partitions: 256, num_sub_vectors: 96 })
|
|
||||||
```
|
```
|
||||||
|
|
||||||
- **metric** (default: "L2"): The distance metric to use. By default it uses euclidean distance "`L2`".
|
- **metric** (default: "L2"): The distance metric to use. By default it uses euclidean distance "`L2`".
|
||||||
We also support "cosine" and "dot" distance as well.
|
We also support "cosine" and "dot" distance as well.
|
||||||
- **num_partitions** (default: 256): The number of partitions of the index.
|
- **num_partitions** (default: 256): The number of partitions of the index.
|
||||||
- **num_sub_vectors** (default: 96): The number of sub-vectors (M) that will be created during Product Quantization (PQ).
|
- **num_sub_vectors** (default: 96): The number of sub-vectors (M) that will be created during Product Quantization (PQ).
|
||||||
For D dimensional vector, it will be divided into `M` of `D/M` sub-vectors, each of which is presented by
|
For D dimensional vector, it will be divided into `M` of `D/M` sub-vectors, each of which is presented by
|
||||||
a single PQ code.
|
a single PQ code.
|
||||||
|
|
||||||
<figure markdown>
|
<figure markdown>
|
||||||

|

|
||||||
@@ -78,7 +74,7 @@ Using GPU for index creation requires [PyTorch>2.0](https://pytorch.org/) being
|
|||||||
|
|
||||||
You can specify the GPU device to train IVF partitions via
|
You can specify the GPU device to train IVF partitions via
|
||||||
|
|
||||||
- **accelerator**: Specify to ``cuda`` or ``mps`` (on Apple Silicon) to enable GPU training.
|
- **accelerator**: Specify to `cuda` or `mps` (on Apple Silicon) to enable GPU training.
|
||||||
|
|
||||||
=== "Linux"
|
=== "Linux"
|
||||||
|
|
||||||
@@ -106,10 +102,9 @@ You can specify the GPU device to train IVF partitions via
|
|||||||
|
|
||||||
Trouble shootings:
|
Trouble shootings:
|
||||||
|
|
||||||
If you see ``AssertionError: Torch not compiled with CUDA enabled``, you need to [install
|
If you see `AssertionError: Torch not compiled with CUDA enabled`, you need to [install
|
||||||
PyTorch with CUDA support](https://pytorch.org/get-started/locally/).
|
PyTorch with CUDA support](https://pytorch.org/get-started/locally/).
|
||||||
|
|
||||||
|
|
||||||
## Querying an ANN Index
|
## Querying an ANN Index
|
||||||
|
|
||||||
Querying vector indexes is done via the [search](https://lancedb.github.io/lancedb/python/#lancedb.table.LanceTable.search) function.
|
Querying vector indexes is done via the [search](https://lancedb.github.io/lancedb/python/#lancedb.table.LanceTable.search) function.
|
||||||
@@ -127,6 +122,7 @@ There are a couple of parameters that can be used to fine-tune the search:
|
|||||||
Note: refine_factor is only applicable if an ANN index is present. If specified on a table without an ANN index, it is ignored.
|
Note: refine_factor is only applicable if an ANN index is present. If specified on a table without an ANN index, it is ignored.
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
|
|
||||||
```python
|
```python
|
||||||
tbl.search(np.random.random((1536))) \
|
tbl.search(np.random.random((1536))) \
|
||||||
.limit(2) \
|
.limit(2) \
|
||||||
@@ -134,41 +130,35 @@ There are a couple of parameters that can be used to fine-tune the search:
|
|||||||
.refine_factor(10) \
|
.refine_factor(10) \
|
||||||
.to_pandas()
|
.to_pandas()
|
||||||
```
|
```
|
||||||
```
|
|
||||||
|
```text
|
||||||
vector item _distance
|
vector item _distance
|
||||||
0 [0.44949695, 0.8444449, 0.06281311, 0.23338133... item 1141 103.575333
|
0 [0.44949695, 0.8444449, 0.06281311, 0.23338133... item 1141 103.575333
|
||||||
1 [0.48587373, 0.269207, 0.15095535, 0.65531915,... item 3953 108.393867
|
1 [0.48587373, 0.269207, 0.15095535, 0.65531915,... item 3953 108.393867
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "Javascript"
|
=== "Typescript"
|
||||||
```javascript
|
|
||||||
const results_1 = await table
|
```typescript
|
||||||
.search(Array(1536).fill(1.2))
|
--8<-- "docs/src/ann_indexes.ts:search1"
|
||||||
.limit(2)
|
|
||||||
.nprobes(20)
|
|
||||||
.refineFactor(10)
|
|
||||||
.execute()
|
|
||||||
```
|
```
|
||||||
|
|
||||||
The search will return the data requested in addition to the distance of each item.
|
The search will return the data requested in addition to the distance of each item.
|
||||||
|
|
||||||
|
|
||||||
### Filtering (where clause)
|
### Filtering (where clause)
|
||||||
|
|
||||||
You can further filter the elements returned by a search using a where clause.
|
You can further filter the elements returned by a search using a where clause.
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
|
|
||||||
```python
|
```python
|
||||||
tbl.search(np.random.random((1536))).where("item != 'item 1141'").to_pandas()
|
tbl.search(np.random.random((1536))).where("item != 'item 1141'").to_pandas()
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "Javascript"
|
=== "Typescript"
|
||||||
|
|
||||||
```javascript
|
```javascript
|
||||||
const results_2 = await table
|
--8<-- "docs/src/ann_indexes.ts:search2"
|
||||||
.search(Array(1536).fill(1.2))
|
|
||||||
.where("id != '1141'")
|
|
||||||
.limit(2)
|
|
||||||
.execute()
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Projections (select clause)
|
### Projections (select clause)
|
||||||
@@ -176,23 +166,23 @@ You can further filter the elements returned by a search using a where clause.
|
|||||||
You can select the columns returned by the query using a select clause.
|
You can select the columns returned by the query using a select clause.
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
|
|
||||||
```python
|
```python
|
||||||
tbl.search(np.random.random((1536))).select(["vector"]).to_pandas()
|
tbl.search(np.random.random((1536))).select(["vector"]).to_pandas()
|
||||||
```
|
```
|
||||||
```
|
|
||||||
vector _distance
|
|
||||||
|
```text
|
||||||
|
vector _distance
|
||||||
0 [0.30928212, 0.022668175, 0.1756372, 0.4911822... 93.971092
|
0 [0.30928212, 0.022668175, 0.1756372, 0.4911822... 93.971092
|
||||||
1 [0.2525465, 0.01723831, 0.261568, 0.002007689,... 95.173485
|
1 [0.2525465, 0.01723831, 0.261568, 0.002007689,... 95.173485
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "Javascript"
|
=== "Typescript"
|
||||||
```javascript
|
|
||||||
const results_3 = await table
|
```typescript
|
||||||
.search(Array(1536).fill(1.2))
|
--8<-- "docs/src/ann_indexes.ts:search3"
|
||||||
.select(["id"])
|
|
||||||
.limit(2)
|
|
||||||
.execute()
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## FAQ
|
## FAQ
|
||||||
|
|||||||
53
docs/src/ann_indexes.ts
Normal file
53
docs/src/ann_indexes.ts
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
// --8<-- [start:import]
|
||||||
|
import * as vectordb from "vectordb";
|
||||||
|
// --8<-- [end:import]
|
||||||
|
|
||||||
|
(async () => {
|
||||||
|
// --8<-- [start:ingest]
|
||||||
|
const db = await vectordb.connect("data/sample-lancedb");
|
||||||
|
|
||||||
|
let data = [];
|
||||||
|
for (let i = 0; i < 10_000; i++) {
|
||||||
|
data.push({
|
||||||
|
vector: Array(1536).fill(i),
|
||||||
|
id: `${i}`,
|
||||||
|
content: "",
|
||||||
|
longId: `${i}`,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
const table = await db.createTable("my_vectors", data);
|
||||||
|
await table.createIndex({
|
||||||
|
type: "ivf_pq",
|
||||||
|
column: "vector",
|
||||||
|
num_partitions: 16,
|
||||||
|
num_sub_vectors: 48,
|
||||||
|
});
|
||||||
|
// --8<-- [end:ingest]
|
||||||
|
|
||||||
|
// --8<-- [start:search1]
|
||||||
|
const results_1 = await table
|
||||||
|
.search(Array(1536).fill(1.2))
|
||||||
|
.limit(2)
|
||||||
|
.nprobes(20)
|
||||||
|
.refineFactor(10)
|
||||||
|
.execute();
|
||||||
|
// --8<-- [end:search1]
|
||||||
|
|
||||||
|
// --8<-- [start:search2]
|
||||||
|
const results_2 = await table
|
||||||
|
.search(Array(1536).fill(1.2))
|
||||||
|
.where("id != '1141'")
|
||||||
|
.limit(2)
|
||||||
|
.execute();
|
||||||
|
// --8<-- [end:search2]
|
||||||
|
|
||||||
|
// --8<-- [start:search3]
|
||||||
|
const results_3 = await table
|
||||||
|
.search(Array(1536).fill(1.2))
|
||||||
|
.select(["id"])
|
||||||
|
.limit(2)
|
||||||
|
.execute();
|
||||||
|
// --8<-- [end:search3]
|
||||||
|
|
||||||
|
console.log("Ann indexes: done");
|
||||||
|
})();
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 266 KiB After Width: | Height: | Size: 107 KiB |
@@ -11,43 +11,78 @@
|
|||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
pip install lancedb
|
pip install lancedb
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "Javascript"
|
=== "Typescript"
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
npm install vectordb
|
npm install vectordb
|
||||||
```
|
```
|
||||||
|
|
||||||
|
=== "Rust"
|
||||||
|
|
||||||
|
!!! warning "Rust SDK is experimental, might introduce breaking changes in the near future"
|
||||||
|
|
||||||
|
```shell
|
||||||
|
cargo add vectordb
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! info "To use the vectordb create, you first need to install protobuf."
|
||||||
|
|
||||||
|
=== "macOS"
|
||||||
|
|
||||||
|
```shell
|
||||||
|
brew install protobuf
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Ubuntu/Debian"
|
||||||
|
|
||||||
|
```shell
|
||||||
|
sudo apt install -y protobuf-compiler libssl-dev
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! info "Please also make sure you're using the same version of Arrow as in the [vectordb crate](https://github.com/lancedb/lancedb/blob/main/Cargo.toml)"
|
||||||
|
|
||||||
## How to connect to a database
|
## How to connect to a database
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import lancedb
|
import lancedb
|
||||||
uri = "data/sample-lancedb"
|
uri = "data/sample-lancedb"
|
||||||
db = lancedb.connect(uri)
|
db = lancedb.connect(uri)
|
||||||
```
|
```
|
||||||
|
|
||||||
LanceDB will create the directory if it doesn't exist (including parent directories).
|
=== "Typescript"
|
||||||
|
|
||||||
If you need a reminder of the uri, use the `db.uri` property.
|
```typescript
|
||||||
|
--8<-- "docs/src/basic_legacy.ts:import"
|
||||||
|
|
||||||
=== "Javascript"
|
--8<-- "docs/src/basic_legacy.ts:open_db"
|
||||||
```javascript
|
```
|
||||||
const lancedb = require("vectordb");
|
|
||||||
|
|
||||||
const uri = "data/sample-lancedb";
|
=== "Rust"
|
||||||
const db = await lancedb.connect(uri);
|
|
||||||
```
|
|
||||||
|
|
||||||
LanceDB will create the directory if it doesn't exist (including parent directories).
|
```rust
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<()> {
|
||||||
|
--8<-- "rust/vectordb/examples/simple.rs:connect"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
If you need a reminder of the uri, you can call `db.uri()`.
|
!!! info "See [examples/simple.rs](https://github.com/lancedb/lancedb/tree/main/rust/vectordb/examples/simple.rs) for a full working example."
|
||||||
|
|
||||||
|
LanceDB will create the directory if it doesn't exist (including parent directories).
|
||||||
|
|
||||||
|
If you need a reminder of the uri, you can call `db.uri()`.
|
||||||
|
|
||||||
## How to create a table
|
## How to create a table
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
|
|
||||||
```python
|
```python
|
||||||
tbl = db.create_table("my_table",
|
tbl = db.create_table("my_table",
|
||||||
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||||
@@ -59,6 +94,7 @@
|
|||||||
to the `create_table` method.
|
to the `create_table` method.
|
||||||
|
|
||||||
You can also pass in a pandas DataFrame directly:
|
You can also pass in a pandas DataFrame directly:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
df = pd.DataFrame([{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
df = pd.DataFrame([{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||||
@@ -66,19 +102,26 @@
|
|||||||
tbl = db.create_table("table_from_df", data=df)
|
tbl = db.create_table("table_from_df", data=df)
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "Javascript"
|
=== "Typescript"
|
||||||
```javascript
|
|
||||||
const tb = await db.createTable(
|
```typescript
|
||||||
"myTable",
|
--8<-- "docs/src/basic_legacy.ts:create_table"
|
||||||
[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
|
||||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}]
|
|
||||||
)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
If the table already exists, LanceDB will raise an error by default.
|
If the table already exists, LanceDB will raise an error by default.
|
||||||
If you want to overwrite the table, you can pass in `mode="overwrite"`
|
If you want to overwrite the table, you can pass in `mode="overwrite"`
|
||||||
to the `createTable` function.
|
to the `createTable` function.
|
||||||
|
|
||||||
|
=== "Rust"
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use arrow_schema::{DataType, Schema, Field};
|
||||||
|
use arrow_array::{RecordBatch, RecordBatchIterator};
|
||||||
|
|
||||||
|
--8<-- "rust/vectordb/examples/simple.rs:create_table"
|
||||||
|
```
|
||||||
|
|
||||||
|
If the table already exists, LanceDB will raise an error by default.
|
||||||
|
|
||||||
!!! info "Under the hood, LanceDB is converting the input data into an Apache Arrow table and persisting it to disk in [Lance format](https://www.github.com/lancedb/lance)."
|
!!! info "Under the hood, LanceDB is converting the input data into an Apache Arrow table and persisting it to disk in [Lance format](https://www.github.com/lancedb/lance)."
|
||||||
|
|
||||||
@@ -88,76 +131,145 @@ Sometimes you may not have the data to insert into the table at creation time.
|
|||||||
In this case, you can create an empty table and specify the schema.
|
In this case, you can create an empty table and specify the schema.
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), list_size=2))])
|
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), list_size=2))])
|
||||||
tbl = db.create_table("empty_table", schema=schema)
|
tbl = db.create_table("empty_table", schema=schema)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
=== "Typescript"
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
--8<-- "docs/src/basic_legacy.ts:create_empty_table"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Rust"
|
||||||
|
|
||||||
|
```rust
|
||||||
|
--8<-- "rust/vectordb/examples/simple.rs:create_empty_table"
|
||||||
|
```
|
||||||
|
|
||||||
## How to open an existing table
|
## How to open an existing table
|
||||||
|
|
||||||
Once created, you can open a table using the following code:
|
Once created, you can open a table using the following code:
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
```python
|
|
||||||
tbl = db.open_table("my_table")
|
|
||||||
```
|
|
||||||
|
|
||||||
If you forget the name of your table, you can always get a listing of all table names:
|
```python
|
||||||
|
tbl = db.open_table("my_table")
|
||||||
|
```
|
||||||
|
|
||||||
```python
|
=== "Typescript"
|
||||||
print(db.table_names())
|
|
||||||
```
|
```typescript
|
||||||
|
const tbl = await db.openTable("myTable");
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Rust"
|
||||||
|
|
||||||
|
```rust
|
||||||
|
--8<-- "rust/vectordb/examples/simple.rs:open_with_existing_file"
|
||||||
|
```
|
||||||
|
|
||||||
|
If you forget the name of your table, you can always get a listing of all table names:
|
||||||
|
|
||||||
|
=== "Python"
|
||||||
|
|
||||||
|
```python
|
||||||
|
print(db.table_names())
|
||||||
|
```
|
||||||
|
|
||||||
=== "Javascript"
|
=== "Javascript"
|
||||||
```javascript
|
|
||||||
const tbl = await db.openTable("myTable");
|
|
||||||
```
|
|
||||||
|
|
||||||
If you forget the name of your table, you can always get a listing of all table names:
|
```javascript
|
||||||
|
console.log(await db.tableNames());
|
||||||
|
```
|
||||||
|
|
||||||
```javascript
|
=== "Rust"
|
||||||
console.log(await db.tableNames());
|
|
||||||
```
|
```rust
|
||||||
|
--8<-- "rust/vectordb/examples/simple.rs:list_names"
|
||||||
|
```
|
||||||
|
|
||||||
## How to add data to a table
|
## How to add data to a table
|
||||||
|
|
||||||
After a table has been created, you can always add more data to it using
|
After a table has been created, you can always add more data to it using
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
```python
|
|
||||||
|
|
||||||
# Option 1: Add a list of dicts to a table
|
```python
|
||||||
data = [{"vector": [1.3, 1.4], "item": "fizz", "price": 100.0},
|
|
||||||
{"vector": [9.5, 56.2], "item": "buzz", "price": 200.0}]
|
|
||||||
tbl.add(data)
|
|
||||||
|
|
||||||
# Option 2: Add a pandas DataFrame to a table
|
# Option 1: Add a list of dicts to a table
|
||||||
df = pd.DataFrame(data)
|
data = [{"vector": [1.3, 1.4], "item": "fizz", "price": 100.0},
|
||||||
tbl.add(data)
|
{"vector": [9.5, 56.2], "item": "buzz", "price": 200.0}]
|
||||||
```
|
tbl.add(data)
|
||||||
|
|
||||||
=== "Javascript"
|
# Option 2: Add a pandas DataFrame to a table
|
||||||
```javascript
|
df = pd.DataFrame(data)
|
||||||
await tbl.add([{vector: [1.3, 1.4], item: "fizz", price: 100.0},
|
tbl.add(data)
|
||||||
{vector: [9.5, 56.2], item: "buzz", price: 200.0}])
|
```
|
||||||
```
|
|
||||||
|
=== "Typescript"
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
--8<-- "docs/src/basic_legacy.ts:add"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Rust"
|
||||||
|
|
||||||
|
```rust
|
||||||
|
--8<-- "rust/vectordb/examples/simple.rs:add"
|
||||||
|
```
|
||||||
|
|
||||||
## How to search for (approximate) nearest neighbors
|
## How to search for (approximate) nearest neighbors
|
||||||
|
|
||||||
Once you've embedded the query, you can find its nearest neighbors using the following code:
|
Once you've embedded the query, you can find its nearest neighbors using the following code:
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
```python
|
|
||||||
tbl.search([100, 100]).limit(2).to_pandas()
|
|
||||||
```
|
|
||||||
|
|
||||||
This returns a pandas DataFrame with the results.
|
```python
|
||||||
|
tbl.search([100, 100]).limit(2).to_pandas()
|
||||||
|
```
|
||||||
|
|
||||||
=== "Javascript"
|
This returns a pandas DataFrame with the results.
|
||||||
```javascript
|
|
||||||
const query = await tbl.search([100, 100]).limit(2).execute();
|
=== "Typescript"
|
||||||
```
|
|
||||||
|
```typescript
|
||||||
|
--8<-- "docs/src/basic_legacy.ts:search"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Rust"
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use futures::TryStreamExt;
|
||||||
|
|
||||||
|
--8<-- "rust/vectordb/examples/simple.rs:search"
|
||||||
|
```
|
||||||
|
|
||||||
|
By default, LanceDB runs a brute-force scan over dataset to find the K nearest neighbours (KNN).
|
||||||
|
For tables with more than 50K vectors, creating an ANN index is recommended to speed up search performance.
|
||||||
|
|
||||||
|
=== "Python"
|
||||||
|
|
||||||
|
```py
|
||||||
|
tbl.create_index()
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Typescript"
|
||||||
|
|
||||||
|
```{.typescript .ignore}
|
||||||
|
--8<-- "docs/src/basic_legacy.ts:create_index"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Rust"
|
||||||
|
|
||||||
|
```rust
|
||||||
|
--8<-- "rust/vectordb/examples/simple.rs:create_index"
|
||||||
|
```
|
||||||
|
|
||||||
|
Check [Approximate Nearest Neighbor (ANN) Indexes](/ann_indices.md) section for more details.
|
||||||
|
|
||||||
## How to delete rows from a table
|
## How to delete rows from a table
|
||||||
|
|
||||||
@@ -166,20 +278,27 @@ which rows to delete, provide a filter that matches on the metadata columns.
|
|||||||
This can delete any number of rows that match the filter.
|
This can delete any number of rows that match the filter.
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
```python
|
|
||||||
tbl.delete('item = "fizz"')
|
|
||||||
```
|
|
||||||
|
|
||||||
=== "Javascript"
|
```python
|
||||||
```javascript
|
tbl.delete('item = "fizz"')
|
||||||
await tbl.delete('item = "fizz"')
|
```
|
||||||
```
|
|
||||||
|
=== "Typescript"
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
--8<-- "docs/src/basic_legacy.ts:delete"
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Rust"
|
||||||
|
|
||||||
|
```rust
|
||||||
|
--8<-- "rust/vectordb/examples/simple.rs:delete"
|
||||||
|
```
|
||||||
|
|
||||||
The deletion predicate is a SQL expression that supports the same expressions
|
The deletion predicate is a SQL expression that supports the same expressions
|
||||||
as the `where()` clause on a search. They can be as simple or complex as needed.
|
as the `where()` clause on a search. They can be as simple or complex as needed.
|
||||||
To see what expressions are supported, see the [SQL filters](sql.md) section.
|
To see what expressions are supported, see the [SQL filters](sql.md) section.
|
||||||
|
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
|
|
||||||
Read more: [lancedb.table.Table.delete][]
|
Read more: [lancedb.table.Table.delete][]
|
||||||
@@ -193,6 +312,7 @@ To see what expressions are supported, see the [SQL filters](sql.md) section.
|
|||||||
Use the `drop_table()` method on the database to remove a table.
|
Use the `drop_table()` method on the database to remove a table.
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
|
|
||||||
```python
|
```python
|
||||||
db.drop_table("my_table")
|
db.drop_table("my_table")
|
||||||
```
|
```
|
||||||
@@ -201,14 +321,21 @@ Use the `drop_table()` method on the database to remove a table.
|
|||||||
By default, if the table does not exist an exception is raised. To suppress this,
|
By default, if the table does not exist an exception is raised. To suppress this,
|
||||||
you can pass in `ignore_missing=True`.
|
you can pass in `ignore_missing=True`.
|
||||||
|
|
||||||
=== "JavaScript"
|
=== "Typescript"
|
||||||
```javascript
|
|
||||||
await db.dropTable('myTable')
|
```typescript
|
||||||
|
--8<-- "docs/src/basic_legacy.ts:drop_table"
|
||||||
```
|
```
|
||||||
|
|
||||||
This permanently removes the table and is not recoverable, unlike deleting rows.
|
This permanently removes the table and is not recoverable, unlike deleting rows.
|
||||||
If the table does not exist an exception is raised.
|
If the table does not exist an exception is raised.
|
||||||
|
|
||||||
|
=== "Rust"
|
||||||
|
|
||||||
|
```rust
|
||||||
|
--8<-- "rust/vectordb/examples/simple.rs:drop_table"
|
||||||
|
```
|
||||||
|
|
||||||
!!! note "Bundling `vectordb` apps with Webpack"
|
!!! note "Bundling `vectordb` apps with Webpack"
|
||||||
|
|
||||||
If you're using the `vectordb` module in JavaScript, since LanceDB contains a prebuilt Node binary, you must configure `next.config.js` to exclude it from webpack. This is required for both using Next.js and deploying a LanceDB app on Vercel.
|
If you're using the `vectordb` module in JavaScript, since LanceDB contains a prebuilt Node binary, you must configure `next.config.js` to exclude it from webpack. This is required for both using Next.js and deploying a LanceDB app on Vercel.
|
||||||
|
|||||||
92
docs/src/basic_legacy.ts
Normal file
92
docs/src/basic_legacy.ts
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
// --8<-- [start:import]
|
||||||
|
import * as lancedb from "vectordb";
|
||||||
|
import { Schema, Field, Float32, FixedSizeList, Int32, Float16 } from "apache-arrow";
|
||||||
|
// --8<-- [end:import]
|
||||||
|
import * as fs from "fs";
|
||||||
|
import { Table as ArrowTable, Utf8 } from "apache-arrow";
|
||||||
|
|
||||||
|
const example = async () => {
|
||||||
|
fs.rmSync("data/sample-lancedb", { recursive: true, force: true });
|
||||||
|
// --8<-- [start:open_db]
|
||||||
|
const lancedb = require("vectordb");
|
||||||
|
const uri = "data/sample-lancedb";
|
||||||
|
const db = await lancedb.connect(uri);
|
||||||
|
// --8<-- [end:open_db]
|
||||||
|
|
||||||
|
// --8<-- [start:create_table]
|
||||||
|
const tbl = await db.createTable(
|
||||||
|
"myTable",
|
||||||
|
[
|
||||||
|
{ vector: [3.1, 4.1], item: "foo", price: 10.0 },
|
||||||
|
{ vector: [5.9, 26.5], item: "bar", price: 20.0 },
|
||||||
|
],
|
||||||
|
{ writeMode: lancedb.WriteMode.Overwrite }
|
||||||
|
);
|
||||||
|
// --8<-- [end:create_table]
|
||||||
|
|
||||||
|
// --8<-- [start:add]
|
||||||
|
const newData = Array.from({ length: 500 }, (_, i) => ({
|
||||||
|
vector: [i, i + 1],
|
||||||
|
item: "fizz",
|
||||||
|
price: i * 0.1,
|
||||||
|
}));
|
||||||
|
await tbl.add(newData);
|
||||||
|
// --8<-- [end:add]
|
||||||
|
|
||||||
|
// --8<-- [start:create_index]
|
||||||
|
await tbl.createIndex({
|
||||||
|
type: "ivf_pq",
|
||||||
|
num_partitions: 2,
|
||||||
|
num_sub_vectors: 2,
|
||||||
|
});
|
||||||
|
// --8<-- [end:create_index]
|
||||||
|
|
||||||
|
// --8<-- [start:create_empty_table]
|
||||||
|
const schema = new Schema([
|
||||||
|
new Field("id", new Int32()),
|
||||||
|
new Field("name", new Utf8()),
|
||||||
|
]);
|
||||||
|
const empty_tbl = await db.createTable({ name: "empty_table", schema });
|
||||||
|
// --8<-- [end:create_empty_table]
|
||||||
|
|
||||||
|
// --8<-- [start:create_f16_table]
|
||||||
|
const dim = 16
|
||||||
|
const total = 10
|
||||||
|
const f16_schema = new Schema([
|
||||||
|
new Field('id', new Int32()),
|
||||||
|
new Field(
|
||||||
|
'vector',
|
||||||
|
new FixedSizeList(dim, new Field('item', new Float16(), true)),
|
||||||
|
false
|
||||||
|
)
|
||||||
|
])
|
||||||
|
const data = lancedb.makeArrowTable(
|
||||||
|
Array.from(Array(total), (_, i) => ({
|
||||||
|
id: i,
|
||||||
|
vector: Array.from(Array(dim), Math.random)
|
||||||
|
})),
|
||||||
|
{ f16_schema }
|
||||||
|
)
|
||||||
|
const table = await db.createTable('f16_tbl', data)
|
||||||
|
// --8<-- [end:create_f16_table]
|
||||||
|
|
||||||
|
// --8<-- [start:search]
|
||||||
|
const query = await tbl.search([100, 100]).limit(2).execute();
|
||||||
|
// --8<-- [end:search]
|
||||||
|
console.log(query);
|
||||||
|
|
||||||
|
// --8<-- [start:delete]
|
||||||
|
await tbl.delete('item = "fizz"');
|
||||||
|
// --8<-- [end:delete]
|
||||||
|
|
||||||
|
// --8<-- [start:drop_table]
|
||||||
|
await db.dropTable("myTable");
|
||||||
|
// --8<-- [end:drop_table]
|
||||||
|
};
|
||||||
|
|
||||||
|
async function main() {
|
||||||
|
await example();
|
||||||
|
console.log("Basic example: done");
|
||||||
|
}
|
||||||
|
|
||||||
|
main();
|
||||||
@@ -119,7 +119,7 @@ texts = [{"text": "Capitalism has been dominant in the Western world since the e
|
|||||||
tbl.add(texts)
|
tbl.add(texts)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Gemini Embedding Function
|
### Gemini Embeddings
|
||||||
With Google's Gemini, you can represent text (words, sentences, and blocks of text) in a vectorized form, making it easier to compare and contrast embeddings. For example, two texts that share a similar subject matter or sentiment should have similar embeddings, which can be identified through mathematical comparison techniques such as cosine similarity. For more on how and why you should use embeddings, refer to the Embeddings guide.
|
With Google's Gemini, you can represent text (words, sentences, and blocks of text) in a vectorized form, making it easier to compare and contrast embeddings. For example, two texts that share a similar subject matter or sentiment should have similar embeddings, which can be identified through mathematical comparison techniques such as cosine similarity. For more on how and why you should use embeddings, refer to the Embeddings guide.
|
||||||
The Gemini Embedding Model API supports various task types:
|
The Gemini Embedding Model API supports various task types:
|
||||||
|
|
||||||
@@ -155,6 +155,51 @@ tbl.add(df)
|
|||||||
rs = tbl.search("hello").limit(1).to_pandas()
|
rs = tbl.search("hello").limit(1).to_pandas()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### AWS Bedrock Text Embedding Functions
|
||||||
|
AWS Bedrock supports multiple base models for generating text embeddings. You need to setup the AWS credentials to use this embedding function.
|
||||||
|
You can do so by using `awscli` and also add your session_token:
|
||||||
|
```shell
|
||||||
|
aws configure
|
||||||
|
aws configure set aws_session_token "<your_session_token>"
|
||||||
|
```
|
||||||
|
to ensure that the credentials are set up correctly, you can run the following command:
|
||||||
|
```shell
|
||||||
|
aws sts get-caller-identity
|
||||||
|
```
|
||||||
|
|
||||||
|
Supported Embedding modelIDs are:
|
||||||
|
* `amazon.titan-embed-text-v1`
|
||||||
|
* `cohere.embed-english-v3`
|
||||||
|
* `cohere.embed-multilingual-v3`
|
||||||
|
|
||||||
|
Supported paramters (to be passed in `create` method) are:
|
||||||
|
| Parameter | Type | Default Value | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| **name** | str | "amazon.titan-embed-text-v1" | The model ID of the bedrock model to use. Supported base models for Text Embeddings: amazon.titan-embed-text-v1, cohere.embed-english-v3, cohere.embed-multilingual-v3 |
|
||||||
|
| **region** | str | "us-east-1" | Optional name of the AWS Region in which the service should be called (e.g., "us-east-1"). |
|
||||||
|
| **profile_name** | str | None | Optional name of the AWS profile to use for calling the Bedrock service. If not specified, the default profile will be used. |
|
||||||
|
| **assumed_role** | str | None | Optional ARN of an AWS IAM role to assume for calling the Bedrock service. If not specified, the current active credentials will be used. |
|
||||||
|
| **role_session_name** | str | "lancedb-embeddings" | Optional name of the AWS IAM role session to use for calling the Bedrock service. If not specified, a "lancedb-embeddings" name will be used. |
|
||||||
|
| **runtime** | bool | True | Optional choice of getting different client to perform operations with the Amazon Bedrock service. |
|
||||||
|
| **max_retries** | int | 7 | Optional number of retries to perform when a request fails. |
|
||||||
|
|
||||||
|
Usage Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
model = get_registry().get("bedrock-text").create()
|
||||||
|
|
||||||
|
class TextModel(LanceModel):
|
||||||
|
text: str = model.SourceField()
|
||||||
|
vector: Vector(model.ndims()) = model.VectorField()
|
||||||
|
|
||||||
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||||
|
db = lancedb.connect("tmp_path")
|
||||||
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||||
|
|
||||||
|
tbl.add(df)
|
||||||
|
rs = tbl.search("hello").limit(1).to_pandas()
|
||||||
|
```
|
||||||
|
|
||||||
## Multi-modal embedding functions
|
## Multi-modal embedding functions
|
||||||
Multi-modal embedding functions allow you to query your table using both images and text.
|
Multi-modal embedding functions allow you to query your table using both images and text.
|
||||||
|
|
||||||
|
|||||||
@@ -79,7 +79,10 @@ def qanda_langchain(query):
|
|||||||
download_docs()
|
download_docs()
|
||||||
docs = store_docs()
|
docs = store_docs()
|
||||||
|
|
||||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200,)
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=1000,
|
||||||
|
chunk_overlap=200,
|
||||||
|
)
|
||||||
documents = text_splitter.split_documents(docs)
|
documents = text_splitter.split_documents(docs)
|
||||||
embeddings = OpenAIEmbeddings()
|
embeddings = OpenAIEmbeddings()
|
||||||
|
|
||||||
|
|||||||
11
docs/src/extra_js/init_ask_ai_widget.js
Normal file
11
docs/src/extra_js/init_ask_ai_widget.js
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
document.addEventListener("DOMContentLoaded", function () {
|
||||||
|
var script = document.createElement("script");
|
||||||
|
script.src = "https://widget.kapa.ai/kapa-widget.bundle.js";
|
||||||
|
script.setAttribute("data-website-id", "c5881fae-cec0-490b-b45e-d83d131d4f25");
|
||||||
|
script.setAttribute("data-project-name", "LanceDB");
|
||||||
|
script.setAttribute("data-project-color", "#000000");
|
||||||
|
script.setAttribute("data-project-logo", "https://avatars.githubusercontent.com/u/108903835?s=200&v=4");
|
||||||
|
script.setAttribute("data-modal-example-questions","Help me create an IVF_PQ index,How do I do an exhaustive search?,How do I create a LanceDB table?,Can I use my own embedding function?");
|
||||||
|
script.async = true;
|
||||||
|
document.head.appendChild(script);
|
||||||
|
});
|
||||||
@@ -69,3 +69,19 @@ MinIO supports an S3 compatible API. In order to connect to a MinIO instance, yo
|
|||||||
- Set the envvar `AWS_ENDPOINT` to the URL of your MinIO API
|
- Set the envvar `AWS_ENDPOINT` to the URL of your MinIO API
|
||||||
- Set the envvars `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` with your MinIO credential
|
- Set the envvars `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` with your MinIO credential
|
||||||
- Call `lancedb.connect("s3://minio_bucket_name")`
|
- Call `lancedb.connect("s3://minio_bucket_name")`
|
||||||
|
|
||||||
|
### Where can I find benchmarks for LanceDB?
|
||||||
|
|
||||||
|
Refer to this [post](https://blog.lancedb.com/benchmarking-lancedb-92b01032874a) for recent benchmarks.
|
||||||
|
|
||||||
|
### How much data can LanceDB practically manage without effecting performance?
|
||||||
|
|
||||||
|
We target good performance on ~10-50 billion rows and ~10-30 TB of data.
|
||||||
|
|
||||||
|
### Does LanceDB support concurrent operations?
|
||||||
|
|
||||||
|
LanceDB can handle concurrent reads very well, and can scale horizontally. The main constraint is how well the [storage layer](https://lancedb.github.io/lancedb/concepts/storage/) you've chosen scales. For writes, we support concurrent writing, though too many concurrent writers can lead to failing writes as there is a limited number of times a writer retries a commit
|
||||||
|
|
||||||
|
!!! info "Multiprocessing with LanceDB"
|
||||||
|
|
||||||
|
For multiprocessing you should probably not use ```fork``` as lance is multi-threaded internally and ```fork``` and multi-thread do not work well.[Refer to this discussion](https://discuss.python.org/t/concerns-regarding-deprecation-of-fork-with-alive-threads/33555)
|
||||||
|
|||||||
@@ -68,6 +68,82 @@ Alternatively, if you are using AWS SSO, you can use the `AWS_PROFILE` and `AWS_
|
|||||||
|
|
||||||
You can see a full list of environment variables [here](https://docs.rs/object_store/latest/object_store/aws/struct.AmazonS3Builder.html#method.from_env).
|
You can see a full list of environment variables [here](https://docs.rs/object_store/latest/object_store/aws/struct.AmazonS3Builder.html#method.from_env).
|
||||||
|
|
||||||
|
!!! tip "Automatic cleanup for failed writes"
|
||||||
|
|
||||||
|
LanceDB uses [multi-part uploads](https://docs.aws.amazon.com/AmazonS3/latest/userguide/mpuoverview.html) when writing data to S3 in order to maximize write speed. LanceDB will abort these uploads when it shuts down gracefully, such as when cancelled by keyboard interrupt. However, in the rare case that LanceDB crashes, it is possible that some data will be left lingering in your account. To cleanup this data, we recommend (as AWS themselves do) that you setup a lifecycle rule to delete in-progress uploads after 7 days. See the AWS guide:
|
||||||
|
|
||||||
|
**[Configuring a bucket lifecycle configuration to delete incomplete multipart uploads](https://docs.aws.amazon.com/AmazonS3/latest/userguide/mpu-abort-incomplete-mpu-lifecycle-config.html)**
|
||||||
|
|
||||||
|
#### AWS IAM Permissions
|
||||||
|
|
||||||
|
If a bucket is private, then an IAM policy must be specified to allow access to it. For many development scenarios, using broad permissions such as a PowerUser account is more than sufficient for working with LanceDB. However, in many production scenarios, you may wish to have as narrow as possible permissions.
|
||||||
|
|
||||||
|
For **read and write access**, LanceDB will need a policy such as:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"Version": "2012-10-17",
|
||||||
|
"Statement": [
|
||||||
|
{
|
||||||
|
"Effect": "Allow",
|
||||||
|
"Action": [
|
||||||
|
"s3:PutObject",
|
||||||
|
"s3:GetObject",
|
||||||
|
"s3:DeleteObject",
|
||||||
|
],
|
||||||
|
"Resource": "arn:aws:s3:::<bucket>/<prefix>/*"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Effect": "Allow",
|
||||||
|
"Action": [
|
||||||
|
"s3:ListBucket",
|
||||||
|
"s3:GetBucketLocation"
|
||||||
|
],
|
||||||
|
"Resource": "arn:aws:s3:::<bucket>",
|
||||||
|
"Condition": {
|
||||||
|
"StringLike": {
|
||||||
|
"s3:prefix": [
|
||||||
|
"<prefix>/*"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
For **read-only access**, LanceDB will need a policy such as:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"Version": "2012-10-17",
|
||||||
|
"Statement": [
|
||||||
|
{
|
||||||
|
"Effect": "Allow",
|
||||||
|
"Action": [
|
||||||
|
"s3:GetObject",
|
||||||
|
],
|
||||||
|
"Resource": "arn:aws:s3:::<bucket>/<prefix>/*"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Effect": "Allow",
|
||||||
|
"Action": [
|
||||||
|
"s3:ListBucket",
|
||||||
|
"s3:GetBucketLocation"
|
||||||
|
],
|
||||||
|
"Resource": "arn:aws:s3:::<bucket>",
|
||||||
|
"Condition": {
|
||||||
|
"StringLike": {
|
||||||
|
"s3:prefix": [
|
||||||
|
"<prefix>/*"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
#### S3-compatible stores
|
#### S3-compatible stores
|
||||||
|
|
||||||
LanceDB can also connect to S3-compatible stores, such as MinIO. To do so, you must specify two environment variables: `AWS_ENDPOINT` and `AWS_DEFAULT_REGION`. `AWS_ENDPOINT` should be the URL of the S3-compatible store, and `AWS_DEFAULT_REGION` should be the region to use.
|
LanceDB can also connect to S3-compatible stores, such as MinIO. To do so, you must specify two environment variables: `AWS_ENDPOINT` and `AWS_DEFAULT_REGION`. `AWS_ENDPOINT` should be the URL of the S3-compatible store, and `AWS_DEFAULT_REGION` should be the region to use.
|
||||||
|
|||||||
@@ -16,9 +16,22 @@ This guide will show how to create tables, insert data into them, and update the
|
|||||||
db = lancedb.connect("./.lancedb")
|
db = lancedb.connect("./.lancedb")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
=== "Javascript"
|
||||||
|
|
||||||
|
Initialize a VectorDB connection and create a table using one of the many methods listed below.
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
const lancedb = require("vectordb");
|
||||||
|
|
||||||
|
const uri = "data/sample-lancedb";
|
||||||
|
const db = await lancedb.connect(uri);
|
||||||
|
```
|
||||||
|
|
||||||
LanceDB allows ingesting data from various sources - `dict`, `list[dict]`, `pd.DataFrame`, `pa.Table` or a `Iterator[pa.RecordBatch]`. Let's take a look at some of the these.
|
LanceDB allows ingesting data from various sources - `dict`, `list[dict]`, `pd.DataFrame`, `pa.Table` or a `Iterator[pa.RecordBatch]`. Let's take a look at some of the these.
|
||||||
|
|
||||||
### From list of tuples or dictionaries
|
### From list of tuples or dictionaries
|
||||||
|
|
||||||
|
=== "Python"
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import lancedb
|
import lancedb
|
||||||
@@ -32,7 +45,6 @@ This guide will show how to create tables, insert data into them, and update the
|
|||||||
|
|
||||||
db["my_table"].head()
|
db["my_table"].head()
|
||||||
```
|
```
|
||||||
|
|
||||||
!!! info "Note"
|
!!! info "Note"
|
||||||
If the table already exists, LanceDB will raise an error by default.
|
If the table already exists, LanceDB will raise an error by default.
|
||||||
|
|
||||||
@@ -51,6 +63,27 @@ This guide will show how to create tables, insert data into them, and update the
|
|||||||
db.create_table("name", data, mode="overwrite")
|
db.create_table("name", data, mode="overwrite")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
=== "Javascript"
|
||||||
|
You can create a LanceDB table in JavaScript using an array of JSON records as follows.
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
const tb = await db.createTable("my_table", [{
|
||||||
|
"vector": [3.1, 4.1],
|
||||||
|
"item": "foo",
|
||||||
|
"price": 10.0
|
||||||
|
}, {
|
||||||
|
"vector": [5.9, 26.5],
|
||||||
|
"item": "bar",
|
||||||
|
"price": 20.0
|
||||||
|
}]);
|
||||||
|
```
|
||||||
|
!!! info "Note"
|
||||||
|
If the table already exists, LanceDB will raise an error by default. If you want to overwrite the table, you need to specify the `WriteMode` in the createTable function.
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
const table = await con.createTable(tableName, data, { writeMode: WriteMode.Overwrite })
|
||||||
|
```
|
||||||
|
|
||||||
### From a Pandas DataFrame
|
### From a Pandas DataFrame
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@@ -67,7 +100,9 @@ This guide will show how to create tables, insert data into them, and update the
|
|||||||
db["my_table"].head()
|
db["my_table"].head()
|
||||||
```
|
```
|
||||||
!!! info "Note"
|
!!! info "Note"
|
||||||
Data is converted to Arrow before being written to disk. For maximum control over how data is saved, either provide the PyArrow schema to convert to or else provide a PyArrow Table directly.
|
Data is converted to Arrow before being written to disk. For maximum control over how data is saved, either provide the PyArrow schema to convert to or else provide a PyArrow Table directly.
|
||||||
|
|
||||||
|
The **`vector`** column needs to be a [Vector](../python/pydantic.md#vector-field) (defined as [pyarrow.FixedSizeList](https://arrow.apache.org/docs/python/generated/pyarrow.list_.html)) type.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
custom_schema = pa.schema([
|
custom_schema = pa.schema([
|
||||||
@@ -79,7 +114,7 @@ This guide will show how to create tables, insert data into them, and update the
|
|||||||
table = db.create_table("my_table", data, schema=custom_schema)
|
table = db.create_table("my_table", data, schema=custom_schema)
|
||||||
```
|
```
|
||||||
|
|
||||||
### From a Polars DataFrame
|
### From a Polars DataFrame
|
||||||
|
|
||||||
LanceDB supports [Polars](https://pola.rs/), a modern, fast DataFrame library
|
LanceDB supports [Polars](https://pola.rs/), a modern, fast DataFrame library
|
||||||
written in Rust. Just like in Pandas, the Polars integration is enabled by PyArrow
|
written in Rust. Just like in Pandas, the Polars integration is enabled by PyArrow
|
||||||
@@ -97,26 +132,44 @@ This guide will show how to create tables, insert data into them, and update the
|
|||||||
table = db.create_table("pl_table", data=data)
|
table = db.create_table("pl_table", data=data)
|
||||||
```
|
```
|
||||||
|
|
||||||
### From PyArrow Tables
|
### From an Arrow Table
|
||||||
You can also create LanceDB tables directly from PyArrow tables
|
=== "Python"
|
||||||
|
You can also create LanceDB tables directly from Arrow tables.
|
||||||
|
LanceDB supports float16 data type!
|
||||||
|
|
||||||
```python
|
```python
|
||||||
table = pa.Table.from_arrays(
|
import pyarrows as pa
|
||||||
[
|
import numpy as np
|
||||||
pa.array([[3.1, 4.1, 5.1, 6.1], [5.9, 26.5, 4.7, 32.8]],
|
|
||||||
pa.list_(pa.float32(), 4)),
|
|
||||||
pa.array(["foo", "bar"]),
|
|
||||||
pa.array([10.0, 20.0]),
|
|
||||||
],
|
|
||||||
["vector", "item", "price"],
|
|
||||||
)
|
|
||||||
|
|
||||||
db = lancedb.connect("db")
|
dim = 16
|
||||||
|
total = 2
|
||||||
|
schema = pa.schema(
|
||||||
|
[
|
||||||
|
pa.field("vector", pa.list_(pa.float16(), dim)),
|
||||||
|
pa.field("text", pa.string())
|
||||||
|
]
|
||||||
|
)
|
||||||
|
data = pa.Table.from_arrays(
|
||||||
|
[
|
||||||
|
pa.array([np.random.randn(dim).astype(np.float16) for _ in range(total)],
|
||||||
|
pa.list_(pa.float16(), dim)),
|
||||||
|
pa.array(["foo", "bar"])
|
||||||
|
],
|
||||||
|
["vector", "text"],
|
||||||
|
)
|
||||||
|
tbl = db.create_table("f16_tbl", data, schema=schema)
|
||||||
|
```
|
||||||
|
|
||||||
tbl = db.create_table("my_table", table)
|
=== "Javascript"
|
||||||
|
You can also create LanceDB tables directly from Arrow tables.
|
||||||
|
LanceDB supports Float16 data type!
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
--8<-- "docs/src/basic_legacy.ts:create_f16_table"
|
||||||
```
|
```
|
||||||
|
|
||||||
### From Pydantic Models
|
### From Pydantic Models
|
||||||
|
|
||||||
When you create an empty table without data, you must specify the table schema.
|
When you create an empty table without data, you must specify the table schema.
|
||||||
LanceDB supports creating tables by specifying a PyArrow schema or a specialized
|
LanceDB supports creating tables by specifying a PyArrow schema or a specialized
|
||||||
Pydantic model called `LanceModel`.
|
Pydantic model called `LanceModel`.
|
||||||
@@ -261,37 +314,6 @@ This guide will show how to create tables, insert data into them, and update the
|
|||||||
|
|
||||||
You can also use iterators of other types like Pandas DataFrame or Pylists directly in the above example.
|
You can also use iterators of other types like Pandas DataFrame or Pylists directly in the above example.
|
||||||
|
|
||||||
=== "JavaScript"
|
|
||||||
Initialize a VectorDB connection and create a table using one of the many methods listed below.
|
|
||||||
|
|
||||||
```javascript
|
|
||||||
const lancedb = require("vectordb");
|
|
||||||
|
|
||||||
const uri = "data/sample-lancedb";
|
|
||||||
const db = await lancedb.connect(uri);
|
|
||||||
```
|
|
||||||
|
|
||||||
You can create a LanceDB table in JavaScript using an array of JSON records as follows.
|
|
||||||
|
|
||||||
```javascript
|
|
||||||
const tb = await db.createTable("my_table", [{
|
|
||||||
"vector": [3.1, 4.1],
|
|
||||||
"item": "foo",
|
|
||||||
"price": 10.0
|
|
||||||
}, {
|
|
||||||
"vector": [5.9, 26.5],
|
|
||||||
"item": "bar",
|
|
||||||
"price": 20.0
|
|
||||||
}]);
|
|
||||||
```
|
|
||||||
|
|
||||||
!!! info "Note"
|
|
||||||
If the table already exists, LanceDB will raise an error by default. If you want to overwrite the table, you need to specify the `WriteMode` in the createTable function.
|
|
||||||
|
|
||||||
```javascript
|
|
||||||
const table = await con.createTable(tableName, data, { writeMode: WriteMode.Overwrite })
|
|
||||||
```
|
|
||||||
|
|
||||||
## Open existing tables
|
## Open existing tables
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
|
|||||||
235
docs/src/hybrid_search.md
Normal file
235
docs/src/hybrid_search.md
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
# Hybrid Search
|
||||||
|
|
||||||
|
LanceDB supports both semantic and keyword-based search. In real world applications, it is often useful to combine these two approaches to get the best best results. For example, you may want to search for a document that is semantically similar to a query document, but also contains a specific keyword. This is an example of *hybrid search*, a search algorithm that combines multiple search techniques.
|
||||||
|
|
||||||
|
## Hybrid search in LanceDB
|
||||||
|
You can perform hybrid search in LanceDB by combining the results of semantic and full-text search via a reranking algorithm of your choice. LanceDB provides multiple rerankers out of the box. However, you can always write a custom reranker if your use case need more sophisticated logic .
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
|
||||||
|
import lancedb
|
||||||
|
import openai
|
||||||
|
from lancedb.embeddings import get_registry
|
||||||
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
|
||||||
|
db = lancedb.connect("~/.lancedb")
|
||||||
|
|
||||||
|
# Ingest embedding function in LanceDB table
|
||||||
|
# Configuring the environment variable OPENAI_API_KEY
|
||||||
|
if "OPENAI_API_KEY" not in os.environ:
|
||||||
|
# OR set the key here as a variable
|
||||||
|
openai.api_key = "sk-..."
|
||||||
|
embeddings = get_registry().get("openai").create()
|
||||||
|
|
||||||
|
class Documents(LanceModel):
|
||||||
|
vector: Vector(embeddings.ndims()) = embeddings.VectorField()
|
||||||
|
text: str = embeddings.SourceField()
|
||||||
|
|
||||||
|
table = db.create_table("documents", schema=Documents)
|
||||||
|
|
||||||
|
data = [
|
||||||
|
{ "text": "rebel spaceships striking from a hidden base"},
|
||||||
|
{ "text": "have won their first victory against the evil Galactic Empire"},
|
||||||
|
{ "text": "during the battle rebel spies managed to steal secret plans"},
|
||||||
|
{ "text": "to the Empire's ultimate weapon the Death Star"}
|
||||||
|
]
|
||||||
|
|
||||||
|
# ingest docs with auto-vectorization
|
||||||
|
table.add(data)
|
||||||
|
|
||||||
|
# Create a fts index before the hybrid search
|
||||||
|
table.create_fts_index("text")
|
||||||
|
# hybrid search with default re-ranker
|
||||||
|
results = table.search("flower moon", query_type="hybrid").to_pandas()
|
||||||
|
```
|
||||||
|
|
||||||
|
By default, LanceDB uses `LinearCombinationReranker(weight=0.7)` to combine and rerank the results of semantic and full-text search. You can customize the hyperparameters as needed or write your own custom reranker. Here's how you can use any of the available rerankers:
|
||||||
|
|
||||||
|
|
||||||
|
### `rerank()` arguments
|
||||||
|
* `normalize`: `str`, default `"score"`:
|
||||||
|
The method to normalize the scores. Can be "rank" or "score". If "rank", the scores are converted to ranks and then normalized. If "score", the scores are normalized directly.
|
||||||
|
* `reranker`: `Reranker`, default `LinearCombinationReranker(weight=0.7)`.
|
||||||
|
The reranker to use. If not specified, the default reranker is used.
|
||||||
|
|
||||||
|
|
||||||
|
## Available Rerankers
|
||||||
|
LanceDB provides a number of re-rankers out of the box. You can use any of these re-rankers by passing them to the `rerank()` method. Here's a list of available re-rankers:
|
||||||
|
|
||||||
|
### Linear Combination Reranker
|
||||||
|
This is the default re-ranker used by LanceDB. It combines the results of semantic and full-text search using a linear combination of the scores. The weights for the linear combination can be specified. It defaults to 0.7, i.e, 70% weight for semantic search and 30% weight for full-text search.
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lancedb.rerankers import LinearCombinationReranker
|
||||||
|
|
||||||
|
reranker = LinearCombinationReranker(weight=0.3) # Use 0.3 as the weight for vector search
|
||||||
|
|
||||||
|
results = table.search("rebel", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||||
|
```
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
----------------
|
||||||
|
* `weight`: `float`, default `0.7`:
|
||||||
|
The weight to use for the semantic search score. The weight for the full-text search score is `1 - weights`.
|
||||||
|
* `fill`: `float`, default `1.0`:
|
||||||
|
The score to give to results that are only in one of the two result sets.This is treated as penalty, so a higher value means a lower score.
|
||||||
|
TODO: We should just hardcode this-- its pretty confusing as we invert scores to calculate final score
|
||||||
|
* `return_score` : str, default `"relevance"`
|
||||||
|
options are "relevance" or "all"
|
||||||
|
The type of score to return. If "relevance", will return only the `_relevance_score. If "all", will return all scores from the vector and FTS search along with the relevance score.
|
||||||
|
|
||||||
|
### Cohere Reranker
|
||||||
|
This re-ranker uses the [Cohere](https://cohere.ai/) API to combine the results of semantic and full-text search. You can use this re-ranker by passing `CohereReranker()` to the `rerank()` method. Note that you'll need to set the `COHERE_API_KEY` environment variable to use this re-ranker.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lancedb.rerankers import CohereReranker
|
||||||
|
|
||||||
|
reranker = CohereReranker()
|
||||||
|
|
||||||
|
results = table.search("vampire weekend", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||||
|
```
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
----------------
|
||||||
|
* `model_name`` : str, default `"rerank-english-v2.0"``
|
||||||
|
The name of the cross encoder model to use. Available cohere models are:
|
||||||
|
- rerank-english-v2.0
|
||||||
|
- rerank-multilingual-v2.0
|
||||||
|
* `column` : str, default `"text"`
|
||||||
|
The name of the column to use as input to the cross encoder model.
|
||||||
|
* `top_n` : str, default `None`
|
||||||
|
The number of results to return. If None, will return all results.
|
||||||
|
|
||||||
|
!!! Note
|
||||||
|
Only returns `_relevance_score`. Does not support `return_score = "all"`.
|
||||||
|
|
||||||
|
### Cross Encoder Reranker
|
||||||
|
This reranker uses the [Sentence Transformers](https://www.sbert.net/) library to combine the results of semantic and full-text search. You can use it by passing `CrossEncoderReranker()` to the `rerank()` method.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lancedb.rerankers import CrossEncoderReranker
|
||||||
|
|
||||||
|
reranker = CrossEncoderReranker()
|
||||||
|
|
||||||
|
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
----------------
|
||||||
|
* `model` : str, default `"cross-encoder/ms-marco-TinyBERT-L-6"`
|
||||||
|
The name of the cross encoder model to use. Available cross encoder models can be found [here](https://www.sbert.net/docs/pretrained_cross-encoders.html)
|
||||||
|
* `column` : str, default `"text"`
|
||||||
|
The name of the column to use as input to the cross encoder model.
|
||||||
|
* `device` : str, default `None`
|
||||||
|
The device to use for the cross encoder model. If None, will use "cuda" if available, otherwise "cpu".
|
||||||
|
|
||||||
|
!!! Note
|
||||||
|
Only returns `_relevance_score`. Does not support `return_score = "all"`.
|
||||||
|
|
||||||
|
|
||||||
|
### ColBERT Reranker
|
||||||
|
This reranker uses the ColBERT model to combine the results of semantic and full-text search. You can use it by passing `ColbertrReranker()` to the `rerank()` method.
|
||||||
|
|
||||||
|
ColBERT reranker model calculates relevance of given docs against the query and don't take existing fts and vector search scores into account, so it currently only supports `return_score="relevance"`. By default, it looks for `text` column to rerank the results. But you can specify the column name to use as input to the cross encoder model as described below.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lancedb.rerankers import ColbertReranker
|
||||||
|
|
||||||
|
reranker = ColbertReranker()
|
||||||
|
|
||||||
|
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||||
|
```
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
----------------
|
||||||
|
* `model_name` : `str`, default `"colbert-ir/colbertv2.0"`
|
||||||
|
The name of the cross encoder model to use.
|
||||||
|
* `column` : `str`, default `"text"`
|
||||||
|
The name of the column to use as input to the cross encoder model.
|
||||||
|
* `return_score` : `str`, default `"relevance"`
|
||||||
|
options are `"relevance"` or `"all"`. Only `"relevance"` is supported for now.
|
||||||
|
|
||||||
|
!!! Note
|
||||||
|
Only returns `_relevance_score`. Does not support `return_score = "all"`.
|
||||||
|
|
||||||
|
### OpenAI Reranker
|
||||||
|
This reranker uses the OpenAI API to combine the results of semantic and full-text search. You can use it by passing `OpenaiReranker()` to the `rerank()` method.
|
||||||
|
|
||||||
|
!!! Note
|
||||||
|
This prompts chat model to rerank results which is not a dedicated reranker model. This should be treated as experimental.
|
||||||
|
|
||||||
|
!!! Tip
|
||||||
|
You might run out of token limit so set the search `limits` based on your token limit.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lancedb.rerankers import OpenaiReranker
|
||||||
|
|
||||||
|
reranker = OpenaiReranker()
|
||||||
|
|
||||||
|
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||||
|
```
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
----------------
|
||||||
|
`model_name` : `str`, default `"gpt-3.5-turbo-1106"`
|
||||||
|
The name of the cross encoder model to use.
|
||||||
|
`column` : `str`, default `"text"`
|
||||||
|
The name of the column to use as input to the cross encoder model.
|
||||||
|
`return_score` : `str`, default `"relevance"`
|
||||||
|
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||||
|
`api_key` : `str`, default `None`
|
||||||
|
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
|
||||||
|
|
||||||
|
|
||||||
|
## Building Custom Rerankers
|
||||||
|
You can build your own custom reranker by subclassing the `Reranker` class and implementing the `rerank_hybrid()` method. Here's an example of a custom reranker that combines the results of semantic and full-text search using a linear combination of the scores.
|
||||||
|
|
||||||
|
The `Reranker` base interface comes with a `merge_results()` method that can be used to combine the results of semantic and full-text search. This is a vanilla merging algorithm that simply concatenates the results and removes the duplicates without taking the scores into consideration. It only keeps the first copy of the row encountered. This works well in cases that don't require the scores of semantic and full-text search to combine the results. If you want to use the scores or want to support `return_score="all"`, you'll need to implement your own merging algorithm.
|
||||||
|
|
||||||
|
```python
|
||||||
|
|
||||||
|
from lancedb.rerankers import Reranker
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
class MyReranker(Reranker):
|
||||||
|
def __init__(self, param1, param2, ..., return_score="relevance"):
|
||||||
|
super().__init__(return_score)
|
||||||
|
self.param1 = param1
|
||||||
|
self.param2 = param2
|
||||||
|
|
||||||
|
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table):
|
||||||
|
# Use the built-in merging function
|
||||||
|
combined_result = self.merge_results(vector_results, fts_results)
|
||||||
|
|
||||||
|
# Do something with the combined results
|
||||||
|
# ...
|
||||||
|
|
||||||
|
# Return the combined results
|
||||||
|
return combined_result
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also accept additional arguments like a filter along with fts and vector search results
|
||||||
|
|
||||||
|
```python
|
||||||
|
|
||||||
|
from lancedb.rerankers import Reranker
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
class MyReranker(Reranker):
|
||||||
|
...
|
||||||
|
|
||||||
|
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table, filter: str):
|
||||||
|
# Use the built-in merging function
|
||||||
|
combined_result = self.merge_results(vector_results, fts_results)
|
||||||
|
|
||||||
|
# Do something with the combined results & filter
|
||||||
|
# ...
|
||||||
|
|
||||||
|
# Return the combined results
|
||||||
|
return combined_result
|
||||||
|
|
||||||
|
```
|
||||||
@@ -13,7 +13,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 50,
|
"execution_count": 2,
|
||||||
"id": "c1b4e34b-a49c-471d-a343-a5940bb5138a",
|
"id": "c1b4e34b-a49c-471d-a343-a5940bb5138a",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -23,7 +23,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 3,
|
||||||
"id": "4e5a8d07-d9a1-48c1-913a-8e0629289579",
|
"id": "4e5a8d07-d9a1-48c1-913a-8e0629289579",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -44,7 +44,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 4,
|
||||||
"id": "5df12f66-8d99-43ad-8d0b-22189ec0a6b9",
|
"id": "5df12f66-8d99-43ad-8d0b-22189ec0a6b9",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -62,7 +62,7 @@
|
|||||||
"long: [[-122.7,-74.1]]"
|
"long: [[-122.7,-74.1]]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 2,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -90,7 +90,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 5,
|
||||||
"id": "f4d87ae9-0ccb-48eb-b31d-bb8f2370e47e",
|
"id": "f4d87ae9-0ccb-48eb-b31d-bb8f2370e47e",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -108,7 +108,7 @@
|
|||||||
"long: [[-122.7,-74.1]]"
|
"long: [[-122.7,-74.1]]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 3,
|
"execution_count": 5,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -135,10 +135,17 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 6,
|
||||||
"id": "25f34bcf-fca0-4431-8601-eac95d1bd347",
|
"id": "25f34bcf-fca0-4431-8601-eac95d1bd347",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[2024-01-31T18:59:33Z WARN lance::dataset] No existing dataset at /Users/qian/Work/LanceDB/lancedb/docs/src/notebooks/.lancedb/table3.lance, it will be created\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
@@ -148,7 +155,7 @@
|
|||||||
"long: float"
|
"long: float"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 8,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -171,45 +178,51 @@
|
|||||||
"id": "4df51925-7ca2-4005-9c72-38b3d26240c6",
|
"id": "4df51925-7ca2-4005-9c72-38b3d26240c6",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"### From PyArrow Tables\n",
|
"### From an Arrow Table\n",
|
||||||
"\n",
|
"\n",
|
||||||
"You can also create LanceDB tables directly from pyarrow tables"
|
"You can also create LanceDB tables directly from pyarrow tables. LanceDB supports float16 type."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 12,
|
"execution_count": 7,
|
||||||
"id": "90a880f6-be43-4c9d-ba65-0b05197c0f6f",
|
"id": "90a880f6-be43-4c9d-ba65-0b05197c0f6f",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"vector: fixed_size_list<item: float>[2]\n",
|
"vector: fixed_size_list<item: halffloat>[16]\n",
|
||||||
" child 0, item: float\n",
|
" child 0, item: halffloat\n",
|
||||||
"item: string\n",
|
"text: string"
|
||||||
"price: double"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 12,
|
"execution_count": 7,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"table = pa.Table.from_arrays(\n",
|
"import numpy as np\n",
|
||||||
" [\n",
|
|
||||||
" pa.array([[3.1, 4.1], [5.9, 26.5]],\n",
|
|
||||||
" pa.list_(pa.float32(), 2)),\n",
|
|
||||||
" pa.array([\"foo\", \"bar\"]),\n",
|
|
||||||
" pa.array([10.0, 20.0]),\n",
|
|
||||||
" ],\n",
|
|
||||||
" [\"vector\", \"item\", \"price\"],\n",
|
|
||||||
" )\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"db = lancedb.connect(\"db\")\n",
|
"dim = 16\n",
|
||||||
|
"total = 2\n",
|
||||||
|
"schema = pa.schema(\n",
|
||||||
|
" [\n",
|
||||||
|
" pa.field(\"vector\", pa.list_(pa.float16(), dim)),\n",
|
||||||
|
" pa.field(\"text\", pa.string())\n",
|
||||||
|
" ]\n",
|
||||||
|
")\n",
|
||||||
|
"data = pa.Table.from_arrays(\n",
|
||||||
|
" [\n",
|
||||||
|
" pa.array([np.random.randn(dim).astype(np.float16) for _ in range(total)],\n",
|
||||||
|
" pa.list_(pa.float16(), dim)),\n",
|
||||||
|
" pa.array([\"foo\", \"bar\"])\n",
|
||||||
|
" ],\n",
|
||||||
|
" [\"vector\", \"text\"],\n",
|
||||||
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"tbl = db.create_table(\"test1\", table, mode=\"overwrite\")\n",
|
"tbl = db.create_table(\"f16_tbl\", data, schema=schema)\n",
|
||||||
"tbl.schema"
|
"tbl.schema"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -225,7 +238,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 13,
|
"execution_count": 8,
|
||||||
"id": "d81121d7-e4b7-447c-a48c-974b6ebb464a",
|
"id": "d81121d7-e4b7-447c-a48c-974b6ebb464a",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -240,7 +253,7 @@
|
|||||||
"imdb_id: int64 not null"
|
"imdb_id: int64 not null"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 13,
|
"execution_count": 8,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -282,7 +295,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 14,
|
"execution_count": 9,
|
||||||
"id": "bc247142-4e3c-41a2-b94c-8e00d2c2a508",
|
"id": "bc247142-4e3c-41a2-b94c-8e00d2c2a508",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -292,7 +305,7 @@
|
|||||||
"LanceTable(table4)"
|
"LanceTable(table4)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 14,
|
"execution_count": 9,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -333,7 +346,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 16,
|
"execution_count": 10,
|
||||||
"id": "25ad3523-e0c9-4c28-b3df-38189c4e0e5f",
|
"id": "25ad3523-e0c9-4c28-b3df-38189c4e0e5f",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -346,7 +359,7 @@
|
|||||||
"price: double not null"
|
"price: double not null"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 16,
|
"execution_count": 10,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -385,7 +398,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 17,
|
"execution_count": 11,
|
||||||
"id": "2814173a-eacc-4dd8-a64d-6312b44582cc",
|
"id": "2814173a-eacc-4dd8-a64d-6312b44582cc",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -411,7 +424,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 18,
|
"execution_count": 12,
|
||||||
"id": "df9e13c0-41f6-437f-9dfa-2fd71d3d9c45",
|
"id": "df9e13c0-41f6-437f-9dfa-2fd71d3d9c45",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -421,7 +434,7 @@
|
|||||||
"['table6', 'table4', 'table5', 'movielens_small']"
|
"['table6', 'table4', 'table5', 'movielens_small']"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 18,
|
"execution_count": 12,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -432,7 +445,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 20,
|
"execution_count": 13,
|
||||||
"id": "9343f5ad-6024-42ee-ac2f-6c1471df8679",
|
"id": "9343f5ad-6024-42ee-ac2f-6c1471df8679",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -541,7 +554,7 @@
|
|||||||
"9 [5.9, 26.5] bar 20.0"
|
"9 [5.9, 26.5] bar 20.0"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 20,
|
"execution_count": 13,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -564,7 +577,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 21,
|
"execution_count": 14,
|
||||||
"id": "8a56250f-73a1-4c26-a6ad-5c7a0ce3a9ab",
|
"id": "8a56250f-73a1-4c26-a6ad-5c7a0ce3a9ab",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -590,7 +603,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 22,
|
"execution_count": 15,
|
||||||
"id": "030c7057-b98e-4e2f-be14-b8c1f927f83c",
|
"id": "030c7057-b98e-4e2f-be14-b8c1f927f83c",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -621,7 +634,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 24,
|
"execution_count": 16,
|
||||||
"id": "e7a17de2-08d2-41b7-bd05-f63d1045ab1f",
|
"id": "e7a17de2-08d2-41b7-bd05-f63d1045ab1f",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -629,16 +642,16 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"32\n"
|
"22\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"17"
|
"12"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 24,
|
"execution_count": 16,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -661,7 +674,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 30,
|
"execution_count": 17,
|
||||||
"id": "fe3310bd-08f4-4a22-a63b-b3127d22f9f7",
|
"id": "fe3310bd-08f4-4a22-a63b-b3127d22f9f7",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -681,25 +694,20 @@
|
|||||||
"8 [3.1, 4.1] foo 10.0\n",
|
"8 [3.1, 4.1] foo 10.0\n",
|
||||||
"9 [3.1, 4.1] foo 10.0\n",
|
"9 [3.1, 4.1] foo 10.0\n",
|
||||||
"10 [3.1, 4.1] foo 10.0\n",
|
"10 [3.1, 4.1] foo 10.0\n",
|
||||||
"11 [3.1, 4.1] foo 10.0\n",
|
"11 [3.1, 4.1] foo 10.0\n"
|
||||||
"12 [3.1, 4.1] foo 10.0\n",
|
|
||||||
"13 [3.1, 4.1] foo 10.0\n",
|
|
||||||
"14 [3.1, 4.1] foo 10.0\n",
|
|
||||||
"15 [3.1, 4.1] foo 10.0\n",
|
|
||||||
"16 [3.1, 4.1] foo 10.0\n"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"ename": "OSError",
|
"ename": "OSError",
|
||||||
"evalue": "LanceError(IO): Error during planning: column foo does not exist",
|
"evalue": "LanceError(IO): Error during planning: column foo does not exist, /Users/runner/work/lance/lance/rust/lance-core/src/error.rs:212:23",
|
||||||
"output_type": "error",
|
"output_type": "error",
|
||||||
"traceback": [
|
"traceback": [
|
||||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||||
"\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)",
|
"\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)",
|
||||||
"Cell \u001b[0;32mIn[30], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m to_remove \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;28mstr\u001b[39m(v) \u001b[38;5;28;01mfor\u001b[39;00m v \u001b[38;5;129;01min\u001b[39;00m to_remove)\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(tbl\u001b[38;5;241m.\u001b[39mto_pandas())\n\u001b[0;32m----> 4\u001b[0m \u001b[43mtbl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdelete\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mitem IN (\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mto_remove\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m)\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m tbl\u001b[38;5;241m.\u001b[39mto_pandas()\n",
|
"Cell \u001b[0;32mIn[17], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m to_remove \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;28mstr\u001b[39m(v) \u001b[38;5;28;01mfor\u001b[39;00m v \u001b[38;5;129;01min\u001b[39;00m to_remove)\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(tbl\u001b[38;5;241m.\u001b[39mto_pandas())\n\u001b[0;32m----> 4\u001b[0m \u001b[43mtbl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdelete\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mitem IN (\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mto_remove\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m)\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
||||||
"File \u001b[0;32m~/Documents/lancedb/lancedb/python/lancedb/table.py:610\u001b[0m, in \u001b[0;36mLanceTable.delete\u001b[0;34m(self, where)\u001b[0m\n\u001b[1;32m 609\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdelete\u001b[39m(\u001b[38;5;28mself\u001b[39m, where: \u001b[38;5;28mstr\u001b[39m):\n\u001b[0;32m--> 610\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdelete\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwhere\u001b[49m\u001b[43m)\u001b[49m\n",
|
"File \u001b[0;32m~/Work/LanceDB/lancedb/docs/doc-venv/lib/python3.11/site-packages/lancedb/table.py:872\u001b[0m, in \u001b[0;36mLanceTable.delete\u001b[0;34m(self, where)\u001b[0m\n\u001b[1;32m 871\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdelete\u001b[39m(\u001b[38;5;28mself\u001b[39m, where: \u001b[38;5;28mstr\u001b[39m):\n\u001b[0;32m--> 872\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdelete\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwhere\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
"File \u001b[0;32m~/Documents/lancedb/lancedb/env/lib/python3.11/site-packages/lance/dataset.py:489\u001b[0m, in \u001b[0;36mLanceDataset.delete\u001b[0;34m(self, predicate)\u001b[0m\n\u001b[1;32m 487\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(predicate, pa\u001b[38;5;241m.\u001b[39mcompute\u001b[38;5;241m.\u001b[39mExpression):\n\u001b[1;32m 488\u001b[0m predicate \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mstr\u001b[39m(predicate)\n\u001b[0;32m--> 489\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_ds\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdelete\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpredicate\u001b[49m\u001b[43m)\u001b[49m\n",
|
"File \u001b[0;32m~/Work/LanceDB/lancedb/docs/doc-venv/lib/python3.11/site-packages/lance/dataset.py:596\u001b[0m, in \u001b[0;36mLanceDataset.delete\u001b[0;34m(self, predicate)\u001b[0m\n\u001b[1;32m 594\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(predicate, pa\u001b[38;5;241m.\u001b[39mcompute\u001b[38;5;241m.\u001b[39mExpression):\n\u001b[1;32m 595\u001b[0m predicate \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mstr\u001b[39m(predicate)\n\u001b[0;32m--> 596\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_ds\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdelete\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpredicate\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
"\u001b[0;31mOSError\u001b[0m: LanceError(IO): Error during planning: column foo does not exist"
|
"\u001b[0;31mOSError\u001b[0m: LanceError(IO): Error during planning: column foo does not exist, /Users/runner/work/lance/lance/rust/lance-core/src/error.rs:212:23"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -712,7 +720,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 43,
|
"execution_count": null,
|
||||||
"id": "87d5bc21-847f-4c81-b56e-f6dbe5d05aac",
|
"id": "87d5bc21-847f-4c81-b56e-f6dbe5d05aac",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -729,7 +737,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 44,
|
"execution_count": null,
|
||||||
"id": "9cba4519-eb3a-4941-ab7e-873d762e750f",
|
"id": "9cba4519-eb3a-4941-ab7e-873d762e750f",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -742,7 +750,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 46,
|
"execution_count": null,
|
||||||
"id": "5bdc9801-d5ed-4871-92d0-88b27108e788",
|
"id": "5bdc9801-d5ed-4871-92d0-88b27108e788",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -817,7 +825,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.4"
|
"version": "3.11.7"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
@@ -58,6 +58,8 @@ pip install lancedb
|
|||||||
|
|
||||||
::: lancedb.schema.vector
|
::: lancedb.schema.vector
|
||||||
|
|
||||||
|
::: lancedb.merge.LanceMergeInsertBuilder
|
||||||
|
|
||||||
## Integrations
|
## Integrations
|
||||||
|
|
||||||
### Pydantic
|
### Pydantic
|
||||||
|
|||||||
@@ -2,27 +2,26 @@
|
|||||||
|
|
||||||
A vector search finds the approximate or exact nearest neighbors to a given query vector.
|
A vector search finds the approximate or exact nearest neighbors to a given query vector.
|
||||||
|
|
||||||
* In a recommendation system or search engine, you can find similar records to
|
- In a recommendation system or search engine, you can find similar records to
|
||||||
the one you searched.
|
the one you searched.
|
||||||
* In LLM and other AI applications,
|
- In LLM and other AI applications,
|
||||||
each data point can be represented by [embeddings generated from existing models](embeddings/index.md),
|
each data point can be represented by [embeddings generated from existing models](embeddings/index.md),
|
||||||
following which the search returns the most relevant features.
|
following which the search returns the most relevant features.
|
||||||
|
|
||||||
## Distance metrics
|
## Distance metrics
|
||||||
|
|
||||||
Distance metrics are a measure of the similarity between a pair of vectors.
|
Distance metrics are a measure of the similarity between a pair of vectors.
|
||||||
Currently, LanceDB supports the following metrics:
|
Currently, LanceDB supports the following metrics:
|
||||||
|
|
||||||
| Metric | Description |
|
| Metric | Description |
|
||||||
| ----------- | ------------------------------------ |
|
| -------- | --------------------------------------------------------------------------- |
|
||||||
| `l2` | [Euclidean / L2 distance](https://en.wikipedia.org/wiki/Euclidean_distance) |
|
| `l2` | [Euclidean / L2 distance](https://en.wikipedia.org/wiki/Euclidean_distance) |
|
||||||
| `cosine` | [Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity)|
|
| `cosine` | [Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity) |
|
||||||
| `dot` | [Dot Production](https://en.wikipedia.org/wiki/Dot_product) |
|
| `dot` | [Dot Production](https://en.wikipedia.org/wiki/Dot_product) |
|
||||||
|
|
||||||
|
|
||||||
## Exhaustive search (kNN)
|
## Exhaustive search (kNN)
|
||||||
|
|
||||||
If you do not create a vector index, LanceDB exhaustively scans the *entire* vector space
|
If you do not create a vector index, LanceDB exhaustively scans the _entire_ vector space
|
||||||
and compute the distance to every vector in order to find the exact nearest neighbors. This is effectively a kNN search.
|
and compute the distance to every vector in order to find the exact nearest neighbors. This is effectively a kNN search.
|
||||||
|
|
||||||
<!-- Setup Code
|
<!-- Setup Code
|
||||||
@@ -38,22 +37,9 @@ data = [{"vector": row, "item": f"item {i}"}
|
|||||||
db.create_table("my_vectors", data=data)
|
db.create_table("my_vectors", data=data)
|
||||||
```
|
```
|
||||||
-->
|
-->
|
||||||
<!-- Setup Code
|
|
||||||
```javascript
|
|
||||||
const vectordb_setup = require('vectordb')
|
|
||||||
const db_setup = await vectordb_setup.connect('data/sample-lancedb')
|
|
||||||
|
|
||||||
let data = []
|
|
||||||
for (let i = 0; i < 10_000; i++) {
|
|
||||||
data.push({vector: Array(1536).fill(i), id: `${i}`, content: "", longId: `${i}`},)
|
|
||||||
}
|
|
||||||
await db_setup.createTable('my_vectors', data)
|
|
||||||
```
|
|
||||||
-->
|
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import lancedb
|
import lancedb
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -70,14 +56,9 @@ await db_setup.createTable('my_vectors', data)
|
|||||||
=== "JavaScript"
|
=== "JavaScript"
|
||||||
|
|
||||||
```javascript
|
```javascript
|
||||||
const vectordb = require('vectordb')
|
--8<-- "docs/src/search_legacy.ts:import"
|
||||||
const db = await vectordb.connect('data/sample-lancedb')
|
|
||||||
|
|
||||||
const tbl = await db.openTable("my_vectors")
|
--8<-- "docs/src/search_legacy.ts:search1"
|
||||||
|
|
||||||
const results_1 = await tbl.search(Array(1536).fill(1.2))
|
|
||||||
.limit(10)
|
|
||||||
.execute()
|
|
||||||
```
|
```
|
||||||
|
|
||||||
By default, `l2` will be used as metric type. You can specify the metric type as
|
By default, `l2` will be used as metric type. You can specify the metric type as
|
||||||
@@ -92,14 +73,10 @@ By default, `l2` will be used as metric type. You can specify the metric type as
|
|||||||
.to_list()
|
.to_list()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
=== "JavaScript"
|
=== "JavaScript"
|
||||||
|
|
||||||
```javascript
|
```javascript
|
||||||
const results_2 = await tbl.search(Array(1536).fill(1.2))
|
--8<-- "docs/src/search_legacy.ts:search2"
|
||||||
.metricType("cosine")
|
|
||||||
.limit(10)
|
|
||||||
.execute()
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Approximate nearest neighbor (ANN) search
|
## Approximate nearest neighbor (ANN) search
|
||||||
@@ -117,7 +94,9 @@ LanceDB returns vector search results via different formats commonly used in pyt
|
|||||||
Let's create a LanceDB table with a nested schema:
|
Let's create a LanceDB table with a nested schema:
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import lancedb
|
import lancedb
|
||||||
from lancedb.pydantic import LanceModel, Vector
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
|||||||
41
docs/src/search_legacy.ts
Normal file
41
docs/src/search_legacy.ts
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
// --8<-- [start:import]
|
||||||
|
import * as lancedb from "vectordb";
|
||||||
|
// --8<-- [end:import]
|
||||||
|
import * as fs from "fs";
|
||||||
|
|
||||||
|
async function setup() {
|
||||||
|
fs.rmSync("data/sample-lancedb", { recursive: true, force: true });
|
||||||
|
const db = await lancedb.connect("data/sample-lancedb");
|
||||||
|
|
||||||
|
let data = [];
|
||||||
|
for (let i = 0; i < 10_000; i++) {
|
||||||
|
data.push({
|
||||||
|
vector: Array(1536).fill(i),
|
||||||
|
id: `${i}`,
|
||||||
|
content: "",
|
||||||
|
longId: `${i}`,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
await db.createTable("my_vectors", data);
|
||||||
|
}
|
||||||
|
|
||||||
|
async () => {
|
||||||
|
await setup();
|
||||||
|
|
||||||
|
// --8<-- [start:search1]
|
||||||
|
const db = await lancedb.connect("data/sample-lancedb");
|
||||||
|
const tbl = await db.openTable("my_vectors");
|
||||||
|
|
||||||
|
const results_1 = await tbl.search(Array(1536).fill(1.2)).limit(10).execute();
|
||||||
|
// --8<-- [end:search1]
|
||||||
|
|
||||||
|
// --8<-- [start:search2]
|
||||||
|
const results_2 = await tbl
|
||||||
|
.search(Array(1536).fill(1.2))
|
||||||
|
.metricType(lancedb.MetricType.Cosine)
|
||||||
|
.limit(10)
|
||||||
|
.execute();
|
||||||
|
// --8<-- [end:search2]
|
||||||
|
|
||||||
|
console.log("search: done");
|
||||||
|
};
|
||||||
@@ -34,6 +34,7 @@ const tbl = await db.createTable('myVectors', data)
|
|||||||
-->
|
-->
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
|
|
||||||
```py
|
```py
|
||||||
result = (
|
result = (
|
||||||
tbl.search([0.5, 0.2])
|
tbl.search([0.5, 0.2])
|
||||||
@@ -44,12 +45,9 @@ const tbl = await db.createTable('myVectors', data)
|
|||||||
```
|
```
|
||||||
|
|
||||||
=== "JavaScript"
|
=== "JavaScript"
|
||||||
|
|
||||||
```javascript
|
```javascript
|
||||||
let result = await tbl.search(Array(1536).fill(0.5))
|
--8<-- "docs/src/sql_legacy.ts:search"
|
||||||
.limit(1)
|
|
||||||
.filter("id = 10")
|
|
||||||
.prefilter(true)
|
|
||||||
.execute()
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## SQL filters
|
## SQL filters
|
||||||
@@ -60,14 +58,14 @@ It can be used during vector search, update, and deletion operations.
|
|||||||
|
|
||||||
Currently, Lance supports a growing list of SQL expressions.
|
Currently, Lance supports a growing list of SQL expressions.
|
||||||
|
|
||||||
* ``>``, ``>=``, ``<``, ``<=``, ``=``
|
- `>`, `>=`, `<`, `<=`, `=`
|
||||||
* ``AND``, ``OR``, ``NOT``
|
- `AND`, `OR`, `NOT`
|
||||||
* ``IS NULL``, ``IS NOT NULL``
|
- `IS NULL`, `IS NOT NULL`
|
||||||
* ``IS TRUE``, ``IS NOT TRUE``, ``IS FALSE``, ``IS NOT FALSE``
|
- `IS TRUE`, `IS NOT TRUE`, `IS FALSE`, `IS NOT FALSE`
|
||||||
* ``IN``
|
- `IN`
|
||||||
* ``LIKE``, ``NOT LIKE``
|
- `LIKE`, `NOT LIKE`
|
||||||
* ``CAST``
|
- `CAST`
|
||||||
* ``regexp_match(column, pattern)``
|
- `regexp_match(column, pattern)`
|
||||||
|
|
||||||
For example, the following filter string is acceptable:
|
For example, the following filter string is acceptable:
|
||||||
|
|
||||||
@@ -82,29 +80,27 @@ For example, the following filter string is acceptable:
|
|||||||
=== "Javascript"
|
=== "Javascript"
|
||||||
|
|
||||||
```javascript
|
```javascript
|
||||||
await tbl.search(Array(1536).fill(0))
|
--8<-- "docs/src/sql_legacy.ts:vec_search"
|
||||||
.where("(item IN ('item 0', 'item 2')) AND (id > 10)")
|
|
||||||
.execute()
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
If your column name contains special characters or is a [SQL Keyword](https://docs.rs/sqlparser/latest/sqlparser/keywords/index.html),
|
If your column name contains special characters or is a [SQL Keyword](https://docs.rs/sqlparser/latest/sqlparser/keywords/index.html),
|
||||||
you can use backtick (`` ` ``) to escape it. For nested fields, each segment of the
|
you can use backtick (`` ` ``) to escape it. For nested fields, each segment of the
|
||||||
path must be wrapped in backticks.
|
path must be wrapped in backticks.
|
||||||
|
|
||||||
=== "SQL"
|
=== "SQL"
|
||||||
|
|
||||||
```sql
|
```sql
|
||||||
`CUBE` = 10 AND `column name with space` IS NOT NULL
|
`CUBE` = 10 AND `column name with space` IS NOT NULL
|
||||||
AND `nested with space`.`inner with space` < 2
|
AND `nested with space`.`inner with space` < 2
|
||||||
```
|
```
|
||||||
|
|
||||||
!!! warning
|
!!!warning "Field names containing periods (`.`) are not supported."
|
||||||
Field names containing periods (``.``) are not supported.
|
|
||||||
|
|
||||||
Literals for dates, timestamps, and decimals can be written by writing the string
|
Literals for dates, timestamps, and decimals can be written by writing the string
|
||||||
value after the type name. For example
|
value after the type name. For example
|
||||||
|
|
||||||
=== "SQL"
|
=== "SQL"
|
||||||
|
|
||||||
```sql
|
```sql
|
||||||
date_col = date '2021-01-01'
|
date_col = date '2021-01-01'
|
||||||
and timestamp_col = timestamp '2021-01-01 00:00:00'
|
and timestamp_col = timestamp '2021-01-01 00:00:00'
|
||||||
@@ -114,49 +110,47 @@ value after the type name. For example
|
|||||||
For timestamp columns, the precision can be specified as a number in the type
|
For timestamp columns, the precision can be specified as a number in the type
|
||||||
parameter. Microsecond precision (6) is the default.
|
parameter. Microsecond precision (6) is the default.
|
||||||
|
|
||||||
| SQL | Time unit |
|
| SQL | Time unit |
|
||||||
|------------------|--------------|
|
| -------------- | ------------ |
|
||||||
| ``timestamp(0)`` | Seconds |
|
| `timestamp(0)` | Seconds |
|
||||||
| ``timestamp(3)`` | Milliseconds |
|
| `timestamp(3)` | Milliseconds |
|
||||||
| ``timestamp(6)`` | Microseconds |
|
| `timestamp(6)` | Microseconds |
|
||||||
| ``timestamp(9)`` | Nanoseconds |
|
| `timestamp(9)` | Nanoseconds |
|
||||||
|
|
||||||
LanceDB internally stores data in [Apache Arrow](https://arrow.apache.org/) format.
|
LanceDB internally stores data in [Apache Arrow](https://arrow.apache.org/) format.
|
||||||
The mapping from SQL types to Arrow types is:
|
The mapping from SQL types to Arrow types is:
|
||||||
|
|
||||||
| SQL type | Arrow type |
|
| SQL type | Arrow type |
|
||||||
|----------|------------|
|
| --------------------------------------------------------- | ------------------ |
|
||||||
| ``boolean`` | ``Boolean`` |
|
| `boolean` | `Boolean` |
|
||||||
| ``tinyint`` / ``tinyint unsigned`` | ``Int8`` / ``UInt8`` |
|
| `tinyint` / `tinyint unsigned` | `Int8` / `UInt8` |
|
||||||
| ``smallint`` / ``smallint unsigned`` | ``Int16`` / ``UInt16`` |
|
| `smallint` / `smallint unsigned` | `Int16` / `UInt16` |
|
||||||
| ``int`` or ``integer`` / ``int unsigned`` or ``integer unsigned`` | ``Int32`` / ``UInt32`` |
|
| `int` or `integer` / `int unsigned` or `integer unsigned` | `Int32` / `UInt32` |
|
||||||
| ``bigint`` / ``bigint unsigned`` | ``Int64`` / ``UInt64`` |
|
| `bigint` / `bigint unsigned` | `Int64` / `UInt64` |
|
||||||
| ``float`` | ``Float32`` |
|
| `float` | `Float32` |
|
||||||
| ``double`` | ``Float64`` |
|
| `double` | `Float64` |
|
||||||
| ``decimal(precision, scale)`` | ``Decimal128`` |
|
| `decimal(precision, scale)` | `Decimal128` |
|
||||||
| ``date`` | ``Date32`` |
|
| `date` | `Date32` |
|
||||||
| ``timestamp`` | ``Timestamp`` [^1] |
|
| `timestamp` | `Timestamp` [^1] |
|
||||||
| ``string`` | ``Utf8`` |
|
| `string` | `Utf8` |
|
||||||
| ``binary`` | ``Binary`` |
|
| `binary` | `Binary` |
|
||||||
|
|
||||||
[^1]: See precision mapping in previous table.
|
[^1]: See precision mapping in previous table.
|
||||||
|
|
||||||
|
|
||||||
## Filtering without Vector Search
|
## Filtering without Vector Search
|
||||||
|
|
||||||
You can also filter your data without search.
|
You can also filter your data without search.
|
||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
```python
|
|
||||||
tbl.search().where("id = 10").limit(10).to_arrow()
|
```python
|
||||||
```
|
tbl.search().where("id = 10").limit(10).to_arrow()
|
||||||
|
```
|
||||||
|
|
||||||
=== "JavaScript"
|
=== "JavaScript"
|
||||||
```javascript
|
|
||||||
await tbl.where('id = 10').limit(10).execute()
|
|
||||||
```
|
|
||||||
|
|
||||||
!!! warning
|
```javascript
|
||||||
If your table is large, this could potentially return a very large
|
--8<---- "docs/src/sql_legacy.ts:sql_search"
|
||||||
amount of data. Please be sure to use a `limit` clause unless
|
```
|
||||||
you're sure you want to return the whole result set.
|
|
||||||
|
!!!warning "If your table is large, this could potentially return a very large amount of data. Please be sure to use a `limit` clause unless you're sure you want to return the whole result set."
|
||||||
|
|||||||
38
docs/src/sql_legacy.ts
Normal file
38
docs/src/sql_legacy.ts
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import * as vectordb from "vectordb";
|
||||||
|
|
||||||
|
(async () => {
|
||||||
|
const db = await vectordb.connect("data/sample-lancedb");
|
||||||
|
|
||||||
|
let data = [];
|
||||||
|
for (let i = 0; i < 10_000; i++) {
|
||||||
|
data.push({
|
||||||
|
vector: Array(1536).fill(i),
|
||||||
|
id: i,
|
||||||
|
item: `item ${i}`,
|
||||||
|
strId: `${i}`,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
const tbl = await db.createTable("myVectors", data);
|
||||||
|
|
||||||
|
// --8<-- [start:search]
|
||||||
|
let result = await tbl
|
||||||
|
.search(Array(1536).fill(0.5))
|
||||||
|
.limit(1)
|
||||||
|
.filter("id = 10")
|
||||||
|
.prefilter(true)
|
||||||
|
.execute();
|
||||||
|
// --8<-- [end:search]
|
||||||
|
|
||||||
|
// --8<-- [start:vec_search]
|
||||||
|
await tbl
|
||||||
|
.search(Array(1536).fill(0))
|
||||||
|
.where("(item IN ('item 0', 'item 2')) AND (id > 10)")
|
||||||
|
.execute();
|
||||||
|
// --8<-- [end:vec_search]
|
||||||
|
|
||||||
|
// --8<-- [start:sql_search]
|
||||||
|
await tbl.filter("id = 10").limit(10).execute();
|
||||||
|
// --8<-- [end:sql_search]
|
||||||
|
|
||||||
|
console.log("SQL search: done");
|
||||||
|
})();
|
||||||
@@ -1,54 +0,0 @@
|
|||||||
const glob = require("glob");
|
|
||||||
const fs = require("fs");
|
|
||||||
const path = require("path");
|
|
||||||
|
|
||||||
const globString = "../src/**/*.md";
|
|
||||||
|
|
||||||
const excludedGlobs = [
|
|
||||||
"../src/fts.md",
|
|
||||||
"../src/embedding.md",
|
|
||||||
"../src/examples/*.md",
|
|
||||||
"../src/guides/tables.md",
|
|
||||||
"../src/embeddings/*.md",
|
|
||||||
];
|
|
||||||
|
|
||||||
const nodePrefix = "javascript";
|
|
||||||
const nodeFile = ".js";
|
|
||||||
const nodeFolder = "node";
|
|
||||||
const asyncPrefix = "(async () => {\n";
|
|
||||||
const asyncSuffix = "})();";
|
|
||||||
|
|
||||||
function* yieldLines(lines, prefix, suffix) {
|
|
||||||
let inCodeBlock = false;
|
|
||||||
for (const line of lines) {
|
|
||||||
if (line.trim().startsWith(prefix + nodePrefix)) {
|
|
||||||
inCodeBlock = true;
|
|
||||||
} else if (inCodeBlock && line.trim().startsWith(suffix)) {
|
|
||||||
inCodeBlock = false;
|
|
||||||
yield "\n";
|
|
||||||
} else if (inCodeBlock) {
|
|
||||||
yield line;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const files = glob.sync(globString, { recursive: true });
|
|
||||||
const excludedFiles = glob.sync(excludedGlobs, { recursive: true });
|
|
||||||
|
|
||||||
for (const file of files.filter((file) => !excludedFiles.includes(file))) {
|
|
||||||
const lines = [];
|
|
||||||
const data = fs.readFileSync(file, "utf-8");
|
|
||||||
const fileLines = data.split("\n");
|
|
||||||
|
|
||||||
for (const line of yieldLines(fileLines, "```", "```")) {
|
|
||||||
lines.push(line);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (lines.length > 0) {
|
|
||||||
const fileName = path.basename(file, ".md");
|
|
||||||
const outPath = path.join(nodeFolder, fileName, `${fileName}${nodeFile}`);
|
|
||||||
console.log(outPath)
|
|
||||||
fs.mkdirSync(path.dirname(outPath), { recursive: true });
|
|
||||||
fs.writeFileSync(outPath, asyncPrefix + "\n" + lines.join("\n") + asyncSuffix);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -14,6 +14,7 @@ excluded_globs = [
|
|||||||
"../src/concepts/*.md",
|
"../src/concepts/*.md",
|
||||||
"../src/ann_indexes.md",
|
"../src/ann_indexes.md",
|
||||||
"../src/basic.md",
|
"../src/basic.md",
|
||||||
|
"../src/hybrid_search.md",
|
||||||
]
|
]
|
||||||
|
|
||||||
python_prefix = "py"
|
python_prefix = "py"
|
||||||
@@ -48,6 +49,7 @@ def yield_lines(lines: Iterator[str], prefix: str, suffix: str):
|
|||||||
if not skip_test:
|
if not skip_test:
|
||||||
yield line[strip_length:]
|
yield line[strip_length:]
|
||||||
|
|
||||||
|
|
||||||
for file in filter(lambda file: file not in excluded_files, files):
|
for file in filter(lambda file: file not in excluded_files, files):
|
||||||
with open(file, "r") as f:
|
with open(file, "r") as f:
|
||||||
lines = list(yield_lines(iter(f), "```", "```"))
|
lines = list(yield_lines(iter(f), "```", "```"))
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "lancedb-docs-test",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": "",
|
|
||||||
"author": "",
|
|
||||||
"license": "ISC",
|
|
||||||
"dependencies": {
|
|
||||||
"fs": "^0.0.1-security",
|
|
||||||
"glob": "^10.2.7",
|
|
||||||
"path": "^0.12.7",
|
|
||||||
"vectordb": "https://gitpkg.now.sh/lancedb/lancedb/node?main"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
17
docs/tsconfig.json
Normal file
17
docs/tsconfig.json
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
{
|
||||||
|
"include": [
|
||||||
|
"src/*.ts",
|
||||||
|
],
|
||||||
|
"compilerOptions": {
|
||||||
|
"target": "es2022",
|
||||||
|
"module": "nodenext",
|
||||||
|
"declaration": true,
|
||||||
|
"outDir": "./dist",
|
||||||
|
"strict": true,
|
||||||
|
"allowJs": true,
|
||||||
|
"resolveJsonModule": true,
|
||||||
|
},
|
||||||
|
"exclude": [
|
||||||
|
"./dist/*",
|
||||||
|
]
|
||||||
|
}
|
||||||
74
node/package-lock.json
generated
74
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.4.4",
|
"version": "0.4.8",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.4.4",
|
"version": "0.4.8",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64",
|
"x64",
|
||||||
"arm64"
|
"arm64"
|
||||||
@@ -53,11 +53,11 @@
|
|||||||
"uuid": "^9.0.0"
|
"uuid": "^9.0.0"
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
"optionalDependencies": {
|
||||||
"@lancedb/vectordb-darwin-arm64": "0.4.4",
|
"@lancedb/vectordb-darwin-arm64": "0.4.8",
|
||||||
"@lancedb/vectordb-darwin-x64": "0.4.4",
|
"@lancedb/vectordb-darwin-x64": "0.4.8",
|
||||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.4",
|
"@lancedb/vectordb-linux-arm64-gnu": "0.4.8",
|
||||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.4",
|
"@lancedb/vectordb-linux-x64-gnu": "0.4.8",
|
||||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.4"
|
"@lancedb/vectordb-win32-x64-msvc": "0.4.8"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/@75lb/deep-merge": {
|
"node_modules/@75lb/deep-merge": {
|
||||||
@@ -328,6 +328,66 @@
|
|||||||
"@jridgewell/sourcemap-codec": "^1.4.10"
|
"@jridgewell/sourcemap-codec": "^1.4.10"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||||
|
"version": "0.4.8",
|
||||||
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.4.8.tgz",
|
||||||
|
"integrity": "sha512-FpnJaw7KmNdD/FtOw9AcmPL5P+L04AcnfPj9ZyEjN8iCwB/qaOGYgdfBv+EbEtfHIsqA12q/1BRduu9KdB6BIA==",
|
||||||
|
"cpu": [
|
||||||
|
"arm64"
|
||||||
|
],
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"darwin"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||||
|
"version": "0.4.8",
|
||||||
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.4.8.tgz",
|
||||||
|
"integrity": "sha512-RafOEYyZIgphp8wPGuVLFaTc8aAqo0NCO1LQMx0mB0xV96vrdo0Mooivs+dYN3RFfSHtTKPw9O1Jc957Vp1TLg==",
|
||||||
|
"cpu": [
|
||||||
|
"x64"
|
||||||
|
],
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"darwin"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||||
|
"version": "0.4.8",
|
||||||
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.4.8.tgz",
|
||||||
|
"integrity": "sha512-WlbYNfj4+v1hBHUluF+hnlG/A0ZaQFdXBTGDfHQniL11o+n3emWm4ujP5nSAoQHXjSH9DaOTGr/N4Mc9Xe+luw==",
|
||||||
|
"cpu": [
|
||||||
|
"arm64"
|
||||||
|
],
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"linux"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||||
|
"version": "0.4.8",
|
||||||
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.4.8.tgz",
|
||||||
|
"integrity": "sha512-z+qFJrDqnNEv4JcwYDyt51PHmWjuM/XaOlSjpBnyyuUImeY+QcwctMuyXt8+Q4zhuqQR1AhLKrMwCU+YmMfk5g==",
|
||||||
|
"cpu": [
|
||||||
|
"x64"
|
||||||
|
],
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"linux"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||||
|
"version": "0.4.8",
|
||||||
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.4.8.tgz",
|
||||||
|
"integrity": "sha512-VjUryVvEA04r0j4lU9pJy84cmjuQm1GhBzbPc8kwbn5voT4A6BPglrlNsU0Zc+j8Fbjyvauzw2lMEcMsF4F0rw==",
|
||||||
|
"cpu": [
|
||||||
|
"x64"
|
||||||
|
],
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"win32"
|
||||||
|
]
|
||||||
|
},
|
||||||
"node_modules/@neon-rs/cli": {
|
"node_modules/@neon-rs/cli": {
|
||||||
"version": "0.0.160",
|
"version": "0.0.160",
|
||||||
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",
|
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.4.4",
|
"version": "0.4.8",
|
||||||
"description": " Serverless, low-latency vector database for AI applications",
|
"description": " Serverless, low-latency vector database for AI applications",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"types": "dist/index.d.ts",
|
"types": "dist/index.d.ts",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"tsc": "tsc -b",
|
"tsc": "tsc -b",
|
||||||
"build": "cargo-cp-artifact --artifact cdylib vectordb-node index.node -- cargo build --message-format=json && tsc -b",
|
"build": "npm run tsc && cargo-cp-artifact --artifact cdylib vectordb-node index.node -- cargo build --message-format=json",
|
||||||
"build-release": "npm run build -- --release",
|
"build-release": "npm run build -- --release",
|
||||||
"test": "npm run tsc && mocha -recursive dist/test",
|
"test": "npm run tsc && mocha -recursive dist/test",
|
||||||
"integration-test": "npm run tsc && mocha -recursive dist/integration_test",
|
"integration-test": "npm run tsc && mocha -recursive dist/integration_test",
|
||||||
@@ -17,7 +17,11 @@
|
|||||||
},
|
},
|
||||||
"repository": {
|
"repository": {
|
||||||
"type": "git",
|
"type": "git",
|
||||||
"url": "https://github.com/lancedb/lancedb/node"
|
"url": "https://github.com/lancedb/lancedb.git"
|
||||||
|
},
|
||||||
|
"homepage": "https://lancedb.github.io/lancedb/",
|
||||||
|
"bugs": {
|
||||||
|
"url": "https://github.com/lancedb/lancedb/issues"
|
||||||
},
|
},
|
||||||
"keywords": [
|
"keywords": [
|
||||||
"data-format",
|
"data-format",
|
||||||
@@ -81,10 +85,10 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
"optionalDependencies": {
|
||||||
"@lancedb/vectordb-darwin-arm64": "0.4.4",
|
"@lancedb/vectordb-darwin-arm64": "0.4.8",
|
||||||
"@lancedb/vectordb-darwin-x64": "0.4.4",
|
"@lancedb/vectordb-darwin-x64": "0.4.8",
|
||||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.4",
|
"@lancedb/vectordb-linux-arm64-gnu": "0.4.8",
|
||||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.4",
|
"@lancedb/vectordb-linux-x64-gnu": "0.4.8",
|
||||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.4"
|
"@lancedb/vectordb-win32-x64-msvc": "0.4.8"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ const {
|
|||||||
tableCountRows,
|
tableCountRows,
|
||||||
tableDelete,
|
tableDelete,
|
||||||
tableUpdate,
|
tableUpdate,
|
||||||
|
tableMergeInsert,
|
||||||
tableCleanupOldVersions,
|
tableCleanupOldVersions,
|
||||||
tableCompactFiles,
|
tableCompactFiles,
|
||||||
tableListIndices,
|
tableListIndices,
|
||||||
@@ -163,6 +164,7 @@ export async function connect (
|
|||||||
{
|
{
|
||||||
uri: '',
|
uri: '',
|
||||||
awsCredentials: undefined,
|
awsCredentials: undefined,
|
||||||
|
awsRegion: defaultAwsRegion,
|
||||||
apiKey: undefined,
|
apiKey: undefined,
|
||||||
region: defaultAwsRegion
|
region: defaultAwsRegion
|
||||||
},
|
},
|
||||||
@@ -174,7 +176,13 @@ export async function connect (
|
|||||||
// Remote connection
|
// Remote connection
|
||||||
return new RemoteConnection(opts)
|
return new RemoteConnection(opts)
|
||||||
}
|
}
|
||||||
const db = await databaseNew(opts.uri)
|
const db = await databaseNew(
|
||||||
|
opts.uri,
|
||||||
|
opts.awsCredentials?.accessKeyId,
|
||||||
|
opts.awsCredentials?.secretKey,
|
||||||
|
opts.awsCredentials?.sessionToken,
|
||||||
|
opts.awsRegion
|
||||||
|
)
|
||||||
return new LocalConnection(db, opts)
|
return new LocalConnection(db, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -364,7 +372,7 @@ export interface Table<T = number[]> {
|
|||||||
/**
|
/**
|
||||||
* Returns the number of rows in this table.
|
* Returns the number of rows in this table.
|
||||||
*/
|
*/
|
||||||
countRows: () => Promise<number>
|
countRows: (filter?: string) => Promise<number>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Delete rows from this table.
|
* Delete rows from this table.
|
||||||
@@ -433,6 +441,38 @@ export interface Table<T = number[]> {
|
|||||||
*/
|
*/
|
||||||
update: (args: UpdateArgs | UpdateSqlArgs) => Promise<void>
|
update: (args: UpdateArgs | UpdateSqlArgs) => Promise<void>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Runs a "merge insert" operation on the table
|
||||||
|
*
|
||||||
|
* This operation can add rows, update rows, and remove rows all in a single
|
||||||
|
* transaction. It is a very generic tool that can be used to create
|
||||||
|
* behaviors like "insert if not exists", "update or insert (i.e. upsert)",
|
||||||
|
* or even replace a portion of existing data with new data (e.g. replace
|
||||||
|
* all data where month="january")
|
||||||
|
*
|
||||||
|
* The merge insert operation works by combining new data from a
|
||||||
|
* **source table** with existing data in a **target table** by using a
|
||||||
|
* join. There are three categories of records.
|
||||||
|
*
|
||||||
|
* "Matched" records are records that exist in both the source table and
|
||||||
|
* the target table. "Not matched" records exist only in the source table
|
||||||
|
* (e.g. these are new data) "Not matched by source" records exist only
|
||||||
|
* in the target table (this is old data)
|
||||||
|
*
|
||||||
|
* The MergeInsertArgs can be used to customize what should happen for
|
||||||
|
* each category of data.
|
||||||
|
*
|
||||||
|
* Please note that the data may appear to be reordered as part of this
|
||||||
|
* operation. This is because updated rows will be deleted from the
|
||||||
|
* dataset and then reinserted at the end with the new values.
|
||||||
|
*
|
||||||
|
* @param on a column to join on. This is how records from the source
|
||||||
|
* table and target table are matched.
|
||||||
|
* @param data the new data to insert
|
||||||
|
* @param args parameters controlling how the operation should behave
|
||||||
|
*/
|
||||||
|
mergeInsert: (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs) => Promise<void>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* List the indicies on this table.
|
* List the indicies on this table.
|
||||||
*/
|
*/
|
||||||
@@ -443,6 +483,8 @@ export interface Table<T = number[]> {
|
|||||||
*/
|
*/
|
||||||
indexStats: (indexUuid: string) => Promise<IndexStats>
|
indexStats: (indexUuid: string) => Promise<IndexStats>
|
||||||
|
|
||||||
|
filter(value: string): Query<T>
|
||||||
|
|
||||||
schema: Promise<Schema>
|
schema: Promise<Schema>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -474,6 +516,47 @@ export interface UpdateSqlArgs {
|
|||||||
valuesSql: Record<string, string>
|
valuesSql: Record<string, string>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface MergeInsertArgs {
|
||||||
|
/**
|
||||||
|
* If true then rows that exist in both the source table (new data) and
|
||||||
|
* the target table (old data) will be updated, replacing the old row
|
||||||
|
* with the corresponding matching row.
|
||||||
|
*
|
||||||
|
* If there are multiple matches then the behavior is undefined.
|
||||||
|
* Currently this causes multiple copies of the row to be created
|
||||||
|
* but that behavior is subject to change.
|
||||||
|
*
|
||||||
|
* Optionally, a filter can be specified. This should be an SQL
|
||||||
|
* filter where fields with the prefix "target." refer to fields
|
||||||
|
* in the target table (old data) and fields with the prefix
|
||||||
|
* "source." refer to fields in the source table (new data). For
|
||||||
|
* example, the filter "target.lastUpdated < source.lastUpdated" will
|
||||||
|
* only update matched rows when the incoming `lastUpdated` value is
|
||||||
|
* newer.
|
||||||
|
*
|
||||||
|
* Rows that do not match the filter will not be updated. Rows that
|
||||||
|
* do not match the filter do become "not matched" rows.
|
||||||
|
*/
|
||||||
|
whenMatchedUpdateAll?: string | boolean
|
||||||
|
/**
|
||||||
|
* If true then rows that exist only in the source table (new data)
|
||||||
|
* will be inserted into the target table.
|
||||||
|
*/
|
||||||
|
whenNotMatchedInsertAll?: boolean
|
||||||
|
/**
|
||||||
|
* If true then rows that exist only in the target table (old data)
|
||||||
|
* will be deleted.
|
||||||
|
*
|
||||||
|
* If this is a string then it will be treated as an SQL filter and
|
||||||
|
* only rows that both do not match any row in the source table and
|
||||||
|
* match the given filter will be deleted.
|
||||||
|
*
|
||||||
|
* This can be used to replace a selection of existing data with
|
||||||
|
* new data.
|
||||||
|
*/
|
||||||
|
whenNotMatchedBySourceDelete?: string | boolean
|
||||||
|
}
|
||||||
|
|
||||||
export interface VectorIndex {
|
export interface VectorIndex {
|
||||||
columns: string[]
|
columns: string[]
|
||||||
name: string
|
name: string
|
||||||
@@ -768,8 +851,8 @@ export class LocalTable<T = number[]> implements Table<T> {
|
|||||||
/**
|
/**
|
||||||
* Returns the number of rows in this table.
|
* Returns the number of rows in this table.
|
||||||
*/
|
*/
|
||||||
async countRows (): Promise<number> {
|
async countRows (filter?: string): Promise<number> {
|
||||||
return tableCountRows.call(this._tbl)
|
return tableCountRows.call(this._tbl, filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -812,6 +895,46 @@ export class LocalTable<T = number[]> implements Table<T> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
|
||||||
|
let whenMatchedUpdateAll = false
|
||||||
|
let whenMatchedUpdateAllFilt = null
|
||||||
|
if (args.whenMatchedUpdateAll !== undefined && args.whenMatchedUpdateAll !== null) {
|
||||||
|
whenMatchedUpdateAll = true
|
||||||
|
if (args.whenMatchedUpdateAll !== true) {
|
||||||
|
whenMatchedUpdateAllFilt = args.whenMatchedUpdateAll
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const whenNotMatchedInsertAll = args.whenNotMatchedInsertAll ?? false
|
||||||
|
let whenNotMatchedBySourceDelete = false
|
||||||
|
let whenNotMatchedBySourceDeleteFilt = null
|
||||||
|
if (args.whenNotMatchedBySourceDelete !== undefined && args.whenNotMatchedBySourceDelete !== null) {
|
||||||
|
whenNotMatchedBySourceDelete = true
|
||||||
|
if (args.whenNotMatchedBySourceDelete !== true) {
|
||||||
|
whenNotMatchedBySourceDeleteFilt = args.whenNotMatchedBySourceDelete
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const schema = await this.schema
|
||||||
|
let tbl: ArrowTable
|
||||||
|
if (data instanceof ArrowTable) {
|
||||||
|
tbl = data
|
||||||
|
} else {
|
||||||
|
tbl = makeArrowTable(data, { schema })
|
||||||
|
}
|
||||||
|
const buffer = await fromTableToBuffer(tbl, this._embeddings, schema)
|
||||||
|
|
||||||
|
this._tbl = await tableMergeInsert.call(
|
||||||
|
this._tbl,
|
||||||
|
on,
|
||||||
|
whenMatchedUpdateAll,
|
||||||
|
whenMatchedUpdateAllFilt,
|
||||||
|
whenNotMatchedInsertAll,
|
||||||
|
whenNotMatchedBySourceDelete,
|
||||||
|
whenNotMatchedBySourceDeleteFilt,
|
||||||
|
buffer
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Clean up old versions of the table, freeing disk space.
|
* Clean up old versions of the table, freeing disk space.
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -24,7 +24,8 @@ import {
|
|||||||
type IndexStats,
|
type IndexStats,
|
||||||
type UpdateArgs,
|
type UpdateArgs,
|
||||||
type UpdateSqlArgs,
|
type UpdateSqlArgs,
|
||||||
makeArrowTable
|
makeArrowTable,
|
||||||
|
type MergeInsertArgs
|
||||||
} from '../index'
|
} from '../index'
|
||||||
import { Query } from '../query'
|
import { Query } from '../query'
|
||||||
|
|
||||||
@@ -270,6 +271,59 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
|||||||
return new RemoteQuery(query, this._client, this._name) //, this._embeddings_new)
|
return new RemoteQuery(query, this._client, this._name) //, this._embeddings_new)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
filter (where: string): Query<T> {
|
||||||
|
throw new Error('Not implemented')
|
||||||
|
}
|
||||||
|
|
||||||
|
async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
|
||||||
|
let tbl: ArrowTable
|
||||||
|
if (data instanceof ArrowTable) {
|
||||||
|
tbl = data
|
||||||
|
} else {
|
||||||
|
tbl = makeArrowTable(data, await this.schema)
|
||||||
|
}
|
||||||
|
|
||||||
|
const queryParams: any = {
|
||||||
|
on
|
||||||
|
}
|
||||||
|
if (args.whenMatchedUpdateAll !== false && args.whenMatchedUpdateAll !== null && args.whenMatchedUpdateAll !== undefined) {
|
||||||
|
queryParams.when_matched_update_all = 'true'
|
||||||
|
if (typeof args.whenMatchedUpdateAll === 'string') {
|
||||||
|
queryParams.when_matched_update_all_filt = args.whenMatchedUpdateAll
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
queryParams.when_matched_update_all = 'false'
|
||||||
|
}
|
||||||
|
if (args.whenNotMatchedInsertAll ?? false) {
|
||||||
|
queryParams.when_not_matched_insert_all = 'true'
|
||||||
|
} else {
|
||||||
|
queryParams.when_not_matched_insert_all = 'false'
|
||||||
|
}
|
||||||
|
if (args.whenNotMatchedBySourceDelete !== false && args.whenNotMatchedBySourceDelete !== null && args.whenNotMatchedBySourceDelete !== undefined) {
|
||||||
|
queryParams.when_not_matched_by_source_delete = 'true'
|
||||||
|
if (typeof args.whenNotMatchedBySourceDelete === 'string') {
|
||||||
|
queryParams.when_not_matched_by_source_delete_filt = args.whenNotMatchedBySourceDelete
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
queryParams.when_not_matched_by_source_delete = 'false'
|
||||||
|
}
|
||||||
|
|
||||||
|
const buffer = await fromTableToStreamBuffer(tbl, this._embeddings)
|
||||||
|
const res = await this._client.post(
|
||||||
|
`/v1/table/${this._name}/merge_insert/`,
|
||||||
|
buffer,
|
||||||
|
queryParams,
|
||||||
|
'application/vnd.apache.arrow.stream'
|
||||||
|
)
|
||||||
|
if (res.status !== 200) {
|
||||||
|
throw new Error(
|
||||||
|
`Server Error, status: ${res.status}, ` +
|
||||||
|
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||||
|
`message: ${res.statusText}: ${res.data}`
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async add (data: Array<Record<string, unknown>> | ArrowTable): Promise<number> {
|
async add (data: Array<Record<string, unknown>> | ArrowTable): Promise<number> {
|
||||||
let tbl: ArrowTable
|
let tbl: ArrowTable
|
||||||
if (data instanceof ArrowTable) {
|
if (data instanceof ArrowTable) {
|
||||||
|
|||||||
@@ -294,6 +294,7 @@ describe('LanceDB client', function () {
|
|||||||
})
|
})
|
||||||
assert.equal(table.name, 'vectors')
|
assert.equal(table.name, 'vectors')
|
||||||
assert.equal(await table.countRows(), 10)
|
assert.equal(await table.countRows(), 10)
|
||||||
|
assert.equal(await table.countRows('vector IS NULL'), 0)
|
||||||
assert.deepEqual(await con.tableNames(), ['vectors'])
|
assert.deepEqual(await con.tableNames(), ['vectors'])
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -369,6 +370,7 @@ describe('LanceDB client', function () {
|
|||||||
const table = await con.createTable('f16', data)
|
const table = await con.createTable('f16', data)
|
||||||
assert.equal(table.name, 'f16')
|
assert.equal(table.name, 'f16')
|
||||||
assert.equal(await table.countRows(), total)
|
assert.equal(await table.countRows(), total)
|
||||||
|
assert.equal(await table.countRows('id < 5'), 5)
|
||||||
assert.deepEqual(await con.tableNames(), ['f16'])
|
assert.deepEqual(await con.tableNames(), ['f16'])
|
||||||
assert.deepEqual(await table.schema, schema)
|
assert.deepEqual(await table.schema, schema)
|
||||||
|
|
||||||
@@ -391,24 +393,6 @@ describe('LanceDB client', function () {
|
|||||||
})
|
})
|
||||||
}).timeout(120000)
|
}).timeout(120000)
|
||||||
|
|
||||||
it('fails to create a new table when the vector column is missing', async function () {
|
|
||||||
const dir = await track().mkdir('lancejs')
|
|
||||||
const con = await lancedb.connect(dir)
|
|
||||||
|
|
||||||
const data = [
|
|
||||||
{
|
|
||||||
id: 1,
|
|
||||||
price: 10
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
const create = con.createTable('missing_vector', data)
|
|
||||||
await expect(create).to.be.rejectedWith(
|
|
||||||
Error,
|
|
||||||
"column 'vector' is missing"
|
|
||||||
)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('use overwrite flag to overwrite existing table', async function () {
|
it('use overwrite flag to overwrite existing table', async function () {
|
||||||
const dir = await track().mkdir('lancejs')
|
const dir = await track().mkdir('lancejs')
|
||||||
const con = await lancedb.connect(dir)
|
const con = await lancedb.connect(dir)
|
||||||
@@ -549,6 +533,54 @@ describe('LanceDB client', function () {
|
|||||||
assert.equal(await table.countRows(), 2)
|
assert.equal(await table.countRows(), 2)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('can merge insert records into the table', async function () {
|
||||||
|
const dir = await track().mkdir('lancejs')
|
||||||
|
const con = await lancedb.connect(dir)
|
||||||
|
|
||||||
|
const data = [{ id: 1, age: 1 }, { id: 2, age: 1 }]
|
||||||
|
const table = await con.createTable('my_table', data)
|
||||||
|
|
||||||
|
// insert if not exists
|
||||||
|
let newData = [{ id: 2, age: 2 }, { id: 3, age: 2 }]
|
||||||
|
await table.mergeInsert('id', newData, {
|
||||||
|
whenNotMatchedInsertAll: true
|
||||||
|
})
|
||||||
|
assert.equal(await table.countRows(), 3)
|
||||||
|
assert.equal(await table.countRows('age = 2'), 1)
|
||||||
|
|
||||||
|
// conditional update
|
||||||
|
newData = [{ id: 2, age: 3 }, { id: 3, age: 3 }]
|
||||||
|
await table.mergeInsert('id', newData, {
|
||||||
|
whenMatchedUpdateAll: 'target.age = 1'
|
||||||
|
})
|
||||||
|
assert.equal(await table.countRows(), 3)
|
||||||
|
assert.equal(await table.countRows('age = 1'), 1)
|
||||||
|
assert.equal(await table.countRows('age = 3'), 1)
|
||||||
|
|
||||||
|
newData = [{ id: 3, age: 4 }, { id: 4, age: 4 }]
|
||||||
|
await table.mergeInsert('id', newData, {
|
||||||
|
whenNotMatchedInsertAll: true,
|
||||||
|
whenMatchedUpdateAll: true
|
||||||
|
})
|
||||||
|
assert.equal(await table.countRows(), 4)
|
||||||
|
assert.equal((await table.filter('age = 4').execute()).length, 2)
|
||||||
|
|
||||||
|
newData = [{ id: 5, age: 5 }]
|
||||||
|
await table.mergeInsert('id', newData, {
|
||||||
|
whenNotMatchedInsertAll: true,
|
||||||
|
whenMatchedUpdateAll: true,
|
||||||
|
whenNotMatchedBySourceDelete: 'age < 4'
|
||||||
|
})
|
||||||
|
assert.equal(await table.countRows(), 3)
|
||||||
|
|
||||||
|
await table.mergeInsert('id', newData, {
|
||||||
|
whenNotMatchedInsertAll: true,
|
||||||
|
whenMatchedUpdateAll: true,
|
||||||
|
whenNotMatchedBySourceDelete: true
|
||||||
|
})
|
||||||
|
assert.equal(await table.countRows(), 1)
|
||||||
|
})
|
||||||
|
|
||||||
it('can update records in the table', async function () {
|
it('can update records in the table', async function () {
|
||||||
const uri = await createTestDB()
|
const uri = await createTestDB()
|
||||||
const con = await lancedb.connect(uri)
|
const con = await lancedb.connect(uri)
|
||||||
|
|||||||
@@ -1,27 +1,27 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "vectordb-nodejs"
|
name = "vectordb-nodejs"
|
||||||
edition = "2021"
|
edition.workspace = true
|
||||||
version = "0.0.0"
|
version = "0.0.0"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
description.workspace = true
|
||||||
repository.workspace = true
|
repository.workspace = true
|
||||||
|
keywords.workspace = true
|
||||||
|
categories.workspace = true
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
crate-type = ["cdylib"]
|
crate-type = ["cdylib"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
arrow-ipc.workspace = true
|
arrow-ipc.workspace = true
|
||||||
napi = { version = "2.14", default-features = false, features = [
|
futures.workspace = true
|
||||||
|
lance-linalg.workspace = true
|
||||||
|
lance.workspace = true
|
||||||
|
vectordb = { path = "../rust/vectordb" }
|
||||||
|
napi = { version = "2.15", default-features = false, features = [
|
||||||
"napi7",
|
"napi7",
|
||||||
"async"
|
"async"
|
||||||
] }
|
] }
|
||||||
napi-derive = "2.14"
|
napi-derive = "2"
|
||||||
vectordb = { path = "../rust/vectordb" }
|
|
||||||
lance.workspace = true
|
|
||||||
lance-linalg.workspace = true
|
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
napi-build = "2.1"
|
napi-build = "2.1"
|
||||||
|
|
||||||
[profile.release]
|
|
||||||
lto = true
|
|
||||||
strip = "symbols"
|
|
||||||
|
|||||||
@@ -53,6 +53,16 @@ describe("Test creating index", () => {
|
|||||||
const indexDir = path.join(tmpDir, "test.lance", "_indices");
|
const indexDir = path.join(tmpDir, "test.lance", "_indices");
|
||||||
expect(fs.readdirSync(indexDir)).toHaveLength(1);
|
expect(fs.readdirSync(indexDir)).toHaveLength(1);
|
||||||
// TODO: check index type.
|
// TODO: check index type.
|
||||||
|
|
||||||
|
// Search without specifying the column
|
||||||
|
let query_vector = data.toArray()[5].vec.toJSON();
|
||||||
|
let rst = await tbl.query().nearestTo(query_vector).limit(2).toArrow();
|
||||||
|
expect(rst.numRows).toBe(2);
|
||||||
|
|
||||||
|
// Search with specifying the column
|
||||||
|
let rst2 = await tbl.search(query_vector, "vec").limit(2).toArrow();
|
||||||
|
expect(rst2.numRows).toBe(2);
|
||||||
|
expect(rst.toString()).toEqual(rst2.toString());
|
||||||
});
|
});
|
||||||
|
|
||||||
test("no vector column available", async () => {
|
test("no vector column available", async () => {
|
||||||
@@ -71,6 +81,80 @@ describe("Test creating index", () => {
|
|||||||
await tbl.createIndex("val").build();
|
await tbl.createIndex("val").build();
|
||||||
const indexDir = path.join(tmpDir, "no_vec.lance", "_indices");
|
const indexDir = path.join(tmpDir, "no_vec.lance", "_indices");
|
||||||
expect(fs.readdirSync(indexDir)).toHaveLength(1);
|
expect(fs.readdirSync(indexDir)).toHaveLength(1);
|
||||||
|
|
||||||
|
for await (const r of tbl.query().filter("id > 1").select(["id"])) {
|
||||||
|
expect(r.numRows).toBe(1);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
test("two columns with different dimensions", async () => {
|
||||||
|
const db = await connect(tmpDir);
|
||||||
|
const schema = new Schema([
|
||||||
|
new Field("id", new Int32(), true),
|
||||||
|
new Field("vec", new FixedSizeList(32, new Field("item", new Float32()))),
|
||||||
|
new Field(
|
||||||
|
"vec2",
|
||||||
|
new FixedSizeList(64, new Field("item", new Float32()))
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
const tbl = await db.createTable(
|
||||||
|
"two_vectors",
|
||||||
|
makeArrowTable(
|
||||||
|
Array(300)
|
||||||
|
.fill(1)
|
||||||
|
.map((_, i) => ({
|
||||||
|
id: i,
|
||||||
|
vec: Array(32)
|
||||||
|
.fill(1)
|
||||||
|
.map(() => Math.random()),
|
||||||
|
vec2: Array(64) // different dimension
|
||||||
|
.fill(1)
|
||||||
|
.map(() => Math.random()),
|
||||||
|
})),
|
||||||
|
{ schema }
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
// Only build index over v1
|
||||||
|
await expect(tbl.createIndex().build()).rejects.toThrow(
|
||||||
|
/.*More than one vector columns found.*/
|
||||||
|
);
|
||||||
|
tbl
|
||||||
|
.createIndex("vec")
|
||||||
|
.ivf_pq({ num_partitions: 2, num_sub_vectors: 2 })
|
||||||
|
.build();
|
||||||
|
|
||||||
|
const rst = await tbl
|
||||||
|
.query()
|
||||||
|
.nearestTo(
|
||||||
|
Array(32)
|
||||||
|
.fill(1)
|
||||||
|
.map(() => Math.random())
|
||||||
|
)
|
||||||
|
.limit(2)
|
||||||
|
.toArrow();
|
||||||
|
expect(rst.numRows).toBe(2);
|
||||||
|
|
||||||
|
// Search with specifying the column
|
||||||
|
await expect(
|
||||||
|
tbl
|
||||||
|
.search(
|
||||||
|
Array(64)
|
||||||
|
.fill(1)
|
||||||
|
.map(() => Math.random()),
|
||||||
|
"vec"
|
||||||
|
)
|
||||||
|
.limit(2)
|
||||||
|
.toArrow()
|
||||||
|
).rejects.toThrow(/.*does not match the dimension.*/);
|
||||||
|
|
||||||
|
const query64 = Array(64)
|
||||||
|
.fill(1)
|
||||||
|
.map(() => Math.random());
|
||||||
|
const rst64_1 = await tbl.query().nearestTo(query64).limit(2).toArrow();
|
||||||
|
const rst64_2 = await tbl.search(query64, "vec2").limit(2).toArrow();
|
||||||
|
expect(rst64_1.toString()).toEqual(rst64_2.toString());
|
||||||
|
expect(rst64_1.numRows).toBe(2);
|
||||||
});
|
});
|
||||||
|
|
||||||
test("create scalar index", async () => {
|
test("create scalar index", async () => {
|
||||||
|
|||||||
@@ -2,4 +2,6 @@
|
|||||||
module.exports = {
|
module.exports = {
|
||||||
preset: 'ts-jest',
|
preset: 'ts-jest',
|
||||||
testEnvironment: 'node',
|
testEnvironment: 'node',
|
||||||
|
moduleDirectories: ["node_modules", "./dist"],
|
||||||
|
moduleFileExtensions: ["js", "ts"],
|
||||||
};
|
};
|
||||||
@@ -91,7 +91,6 @@ impl IndexBuilder {
|
|||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub async fn build(&self) -> napi::Result<()> {
|
pub async fn build(&self) -> napi::Result<()> {
|
||||||
println!("nodejs::index.rs : build");
|
|
||||||
self.inner
|
self.inner
|
||||||
.build()
|
.build()
|
||||||
.await
|
.await
|
||||||
|
|||||||
47
nodejs/src/iterator.rs
Normal file
47
nodejs/src/iterator.rs
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
// Copyright 2024 Lance Developers.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use futures::StreamExt;
|
||||||
|
use lance::io::RecordBatchStream;
|
||||||
|
use napi::bindgen_prelude::*;
|
||||||
|
use napi_derive::napi;
|
||||||
|
use vectordb::ipc::batches_to_ipc_file;
|
||||||
|
|
||||||
|
/** Typescript-style Async Iterator over RecordBatches */
|
||||||
|
#[napi]
|
||||||
|
pub struct RecordBatchIterator {
|
||||||
|
inner: Box<dyn RecordBatchStream + Unpin>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
impl RecordBatchIterator {
|
||||||
|
pub(crate) fn new(inner: Box<dyn RecordBatchStream + Unpin>) -> Self {
|
||||||
|
Self { inner }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub async unsafe fn next(&mut self) -> napi::Result<Option<Buffer>> {
|
||||||
|
if let Some(rst) = self.inner.next().await {
|
||||||
|
let batch = rst.map_err(|e| {
|
||||||
|
napi::Error::from_reason(format!("Failed to get next batch from stream: {}", e))
|
||||||
|
})?;
|
||||||
|
batches_to_ipc_file(&[batch])
|
||||||
|
.map_err(|e| napi::Error::from_reason(format!("Failed to write IPC file: {}", e)))
|
||||||
|
.map(|buf| Some(Buffer::from(buf)))
|
||||||
|
} else {
|
||||||
|
// We are done with the stream.
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -17,6 +17,7 @@ use napi_derive::*;
|
|||||||
|
|
||||||
mod connection;
|
mod connection;
|
||||||
mod index;
|
mod index;
|
||||||
|
mod iterator;
|
||||||
mod query;
|
mod query;
|
||||||
mod table;
|
mod table;
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ use napi::bindgen_prelude::*;
|
|||||||
use napi_derive::napi;
|
use napi_derive::napi;
|
||||||
use vectordb::query::Query as LanceDBQuery;
|
use vectordb::query::Query as LanceDBQuery;
|
||||||
|
|
||||||
use crate::table::Table;
|
use crate::{iterator::RecordBatchIterator, table::Table};
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub struct Query {
|
pub struct Query {
|
||||||
@@ -32,17 +32,50 @@ impl Query {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub fn vector(&mut self, vector: Float32Array) {
|
pub fn column(&mut self, column: String) {
|
||||||
let inn = self.inner.clone().nearest_to(&vector);
|
self.inner = self.inner.clone().column(&column);
|
||||||
self.inner = inn;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub fn to_arrow(&self) -> napi::Result<()> {
|
pub fn filter(&mut self, filter: String) {
|
||||||
// let buf = self.inner.to_arrow().map_err(|e| {
|
self.inner = self.inner.clone().filter(filter);
|
||||||
// napi::Error::from_reason(format!("Failed to convert query to arrow: {}", e))
|
}
|
||||||
// })?;
|
|
||||||
// Ok(buf)
|
#[napi]
|
||||||
todo!()
|
pub fn select(&mut self, columns: Vec<String>) {
|
||||||
|
self.inner = self.inner.clone().select(&columns);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub fn limit(&mut self, limit: u32) {
|
||||||
|
self.inner = self.inner.clone().limit(limit as usize);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub fn prefilter(&mut self, prefilter: bool) {
|
||||||
|
self.inner = self.inner.clone().prefilter(prefilter);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub fn nearest_to(&mut self, vector: Float32Array) {
|
||||||
|
self.inner = self.inner.clone().nearest_to(&vector);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub fn refine_factor(&mut self, refine_factor: u32) {
|
||||||
|
self.inner = self.inner.clone().refine_factor(refine_factor);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub fn nprobes(&mut self, nprobe: u32) {
|
||||||
|
self.inner = self.inner.clone().nprobes(nprobe as usize);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub async fn execute_stream(&self) -> napi::Result<RecordBatchIterator> {
|
||||||
|
let inner_stream = self.inner.execute_stream().await.map_err(|e| {
|
||||||
|
napi::Error::from_reason(format!("Failed to execute query stream: {}", e))
|
||||||
|
})?;
|
||||||
|
Ok(RecordBatchIterator::new(Box::new(inner_stream)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -57,8 +57,8 @@ impl Table {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub async fn count_rows(&self) -> napi::Result<usize> {
|
pub async fn count_rows(&self, filter: Option<String>) -> napi::Result<usize> {
|
||||||
self.table.count_rows().await.map_err(|e| {
|
self.table.count_rows(filter).await.map_err(|e| {
|
||||||
napi::Error::from_reason(format!(
|
napi::Error::from_reason(format!(
|
||||||
"Failed to count rows in table {}: {}",
|
"Failed to count rows in table {}: {}",
|
||||||
self.table, e
|
self.table, e
|
||||||
|
|||||||
17
nodejs/vectordb/native.d.ts
vendored
17
nodejs/vectordb/native.d.ts
vendored
@@ -54,15 +54,26 @@ export class IndexBuilder {
|
|||||||
scalar(): void
|
scalar(): void
|
||||||
build(): Promise<void>
|
build(): Promise<void>
|
||||||
}
|
}
|
||||||
|
/** Typescript-style Async Iterator over RecordBatches */
|
||||||
|
export class RecordBatchIterator {
|
||||||
|
next(): Promise<Buffer | null>
|
||||||
|
}
|
||||||
export class Query {
|
export class Query {
|
||||||
vector(vector: Float32Array): void
|
column(column: string): void
|
||||||
toArrow(): void
|
filter(filter: string): void
|
||||||
|
select(columns: Array<string>): void
|
||||||
|
limit(limit: number): void
|
||||||
|
prefilter(prefilter: boolean): void
|
||||||
|
nearestTo(vector: Float32Array): void
|
||||||
|
refineFactor(refineFactor: number): void
|
||||||
|
nprobes(nprobe: number): void
|
||||||
|
executeStream(): Promise<RecordBatchIterator>
|
||||||
}
|
}
|
||||||
export class Table {
|
export class Table {
|
||||||
/** Return Schema as empty Arrow IPC file. */
|
/** Return Schema as empty Arrow IPC file. */
|
||||||
schema(): Buffer
|
schema(): Buffer
|
||||||
add(buf: Buffer): Promise<void>
|
add(buf: Buffer): Promise<void>
|
||||||
countRows(): Promise<bigint>
|
countRows(filter?: string): Promise<bigint>
|
||||||
delete(predicate: string): Promise<void>
|
delete(predicate: string): Promise<void>
|
||||||
createIndex(): IndexBuilder
|
createIndex(): IndexBuilder
|
||||||
query(): Query
|
query(): Query
|
||||||
|
|||||||
@@ -295,12 +295,13 @@ if (!nativeBinding) {
|
|||||||
throw new Error(`Failed to load native binding`)
|
throw new Error(`Failed to load native binding`)
|
||||||
}
|
}
|
||||||
|
|
||||||
const { Connection, IndexType, MetricType, IndexBuilder, Query, Table, WriteMode, connect } = nativeBinding
|
const { Connection, IndexType, MetricType, IndexBuilder, RecordBatchIterator, Query, Table, WriteMode, connect } = nativeBinding
|
||||||
|
|
||||||
module.exports.Connection = Connection
|
module.exports.Connection = Connection
|
||||||
module.exports.IndexType = IndexType
|
module.exports.IndexType = IndexType
|
||||||
module.exports.MetricType = MetricType
|
module.exports.MetricType = MetricType
|
||||||
module.exports.IndexBuilder = IndexBuilder
|
module.exports.IndexBuilder = IndexBuilder
|
||||||
|
module.exports.RecordBatchIterator = RecordBatchIterator
|
||||||
module.exports.Query = Query
|
module.exports.Query = Query
|
||||||
module.exports.Table = Table
|
module.exports.Table = Table
|
||||||
module.exports.WriteMode = WriteMode
|
module.exports.WriteMode = WriteMode
|
||||||
|
|||||||
@@ -12,46 +12,73 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
import { RecordBatch } from "apache-arrow";
|
import { RecordBatch, tableFromIPC, Table as ArrowTable } from "apache-arrow";
|
||||||
import { Table } from "./table";
|
import {
|
||||||
|
RecordBatchIterator as NativeBatchIterator,
|
||||||
|
Query as NativeQuery,
|
||||||
|
Table as NativeTable,
|
||||||
|
} from "./native";
|
||||||
|
|
||||||
// TODO: re-eanble eslint once we have a real implementation
|
|
||||||
/* eslint-disable */
|
|
||||||
class RecordBatchIterator implements AsyncIterator<RecordBatch> {
|
class RecordBatchIterator implements AsyncIterator<RecordBatch> {
|
||||||
next(
|
private promised_inner?: Promise<NativeBatchIterator>;
|
||||||
...args: [] | [undefined]
|
private inner?: NativeBatchIterator;
|
||||||
): Promise<IteratorResult<RecordBatch<any>, any>> {
|
|
||||||
throw new Error("Method not implemented.");
|
constructor(
|
||||||
|
inner?: NativeBatchIterator,
|
||||||
|
promise?: Promise<NativeBatchIterator>
|
||||||
|
) {
|
||||||
|
// TODO: check promise reliably so we dont need to pass two arguments.
|
||||||
|
this.inner = inner;
|
||||||
|
this.promised_inner = promise;
|
||||||
}
|
}
|
||||||
return?(value?: any): Promise<IteratorResult<RecordBatch<any>, any>> {
|
|
||||||
throw new Error("Method not implemented.");
|
async next(): Promise<IteratorResult<RecordBatch<any>, any>> {
|
||||||
}
|
if (this.inner === undefined) {
|
||||||
throw?(e?: any): Promise<IteratorResult<RecordBatch<any>, any>> {
|
this.inner = await this.promised_inner;
|
||||||
throw new Error("Method not implemented.");
|
}
|
||||||
|
if (this.inner === undefined) {
|
||||||
|
throw new Error("Invalid iterator state state");
|
||||||
|
}
|
||||||
|
const n = await this.inner.next();
|
||||||
|
if (n == null) {
|
||||||
|
return Promise.resolve({ done: true, value: null });
|
||||||
|
}
|
||||||
|
const tbl = tableFromIPC(n);
|
||||||
|
if (tbl.batches.length != 1) {
|
||||||
|
throw new Error("Expected only one batch");
|
||||||
|
}
|
||||||
|
return Promise.resolve({ done: false, value: tbl.batches[0] });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/* eslint-enable */
|
/* eslint-enable */
|
||||||
|
|
||||||
/** Query executor */
|
/** Query executor */
|
||||||
export class Query implements AsyncIterable<RecordBatch> {
|
export class Query implements AsyncIterable<RecordBatch> {
|
||||||
private readonly tbl: Table;
|
private readonly inner: NativeQuery;
|
||||||
private _filter?: string;
|
|
||||||
private _limit?: number;
|
|
||||||
|
|
||||||
// Vector search
|
constructor(tbl: NativeTable) {
|
||||||
private _vector?: Float32Array;
|
this.inner = tbl.query();
|
||||||
private _nprobes?: number;
|
}
|
||||||
private _refine_factor?: number = 1;
|
|
||||||
|
|
||||||
constructor(tbl: Table) {
|
/** Set the column to run query. */
|
||||||
this.tbl = tbl;
|
column(column: string): Query {
|
||||||
|
this.inner.column(column);
|
||||||
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Set the filter predicate, only returns the results that satisfy the filter.
|
/** Set the filter predicate, only returns the results that satisfy the filter.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
filter(predicate: string): Query {
|
filter(predicate: string): Query {
|
||||||
this._filter = predicate;
|
this.inner.filter(predicate);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Select the columns to return. If not set, all columns are returned.
|
||||||
|
*/
|
||||||
|
select(columns: string[]): Query {
|
||||||
|
this.inner.select(columns);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -59,35 +86,67 @@ export class Query implements AsyncIterable<RecordBatch> {
|
|||||||
* Set the limit of rows to return.
|
* Set the limit of rows to return.
|
||||||
*/
|
*/
|
||||||
limit(limit: number): Query {
|
limit(limit: number): Query {
|
||||||
this._limit = limit;
|
this.inner.limit(limit);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
prefilter(prefilter: boolean): Query {
|
||||||
|
this.inner.prefilter(prefilter);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set the query vector.
|
* Set the query vector.
|
||||||
*/
|
*/
|
||||||
vector(vector: number[]): Query {
|
nearestTo(vector: number[]): Query {
|
||||||
this._vector = Float32Array.from(vector);
|
this.inner.nearestTo(Float32Array.from(vector));
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set the number of probes to use for the query.
|
* Set the number of IVF partitions to use for the query.
|
||||||
*/
|
*/
|
||||||
nprobes(nprobes: number): Query {
|
nprobes(nprobes: number): Query {
|
||||||
this._nprobes = nprobes;
|
this.inner.nprobes(nprobes);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set the refine factor for the query.
|
* Set the refine factor for the query.
|
||||||
*/
|
*/
|
||||||
refine_factor(refine_factor: number): Query {
|
refineFactor(refine_factor: number): Query {
|
||||||
this._refine_factor = refine_factor;
|
this.inner.refineFactor(refine_factor);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>, any, undefined> {
|
/**
|
||||||
throw new RecordBatchIterator();
|
* Execute the query and return the results as an AsyncIterator.
|
||||||
|
*/
|
||||||
|
async executeStream(): Promise<RecordBatchIterator> {
|
||||||
|
const inner = await this.inner.executeStream();
|
||||||
|
return new RecordBatchIterator(inner);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Collect the results as an Arrow Table. */
|
||||||
|
async toArrow(): Promise<ArrowTable> {
|
||||||
|
const batches = [];
|
||||||
|
for await (const batch of this) {
|
||||||
|
batches.push(batch);
|
||||||
|
}
|
||||||
|
return new ArrowTable(batches);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns a JSON Array of All results.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
async toArray(): Promise<any[]> {
|
||||||
|
const tbl = await this.toArrow();
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-unsafe-return
|
||||||
|
return tbl.toArray();
|
||||||
|
}
|
||||||
|
|
||||||
|
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>> {
|
||||||
|
const promise = this.inner.executeStream();
|
||||||
|
return new RecordBatchIterator(undefined, promise);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,8 +50,8 @@ export class Table {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** Count the total number of rows in the dataset. */
|
/** Count the total number of rows in the dataset. */
|
||||||
async countRows(): Promise<bigint> {
|
async countRows(filter?: string): Promise<bigint> {
|
||||||
return await this.inner.countRows();
|
return await this.inner.countRows(filter);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Delete the rows that satisfy the predicate. */
|
/** Delete the rows that satisfy the predicate. */
|
||||||
@@ -95,10 +95,58 @@ export class Table {
|
|||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
|
||||||
search(vector?: number[]): Query {
|
/**
|
||||||
const q = new Query(this);
|
* Create a generic {@link Query} Builder.
|
||||||
if (vector !== undefined) {
|
*
|
||||||
q.vector(vector);
|
* When appropriate, various indices and statistics based pruning will be used to
|
||||||
|
* accelerate the query.
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
*
|
||||||
|
* ### Run a SQL-style query
|
||||||
|
* ```typescript
|
||||||
|
* for await (const batch of table.query()
|
||||||
|
* .filter("id > 1").select(["id"]).limit(20)) {
|
||||||
|
* console.log(batch);
|
||||||
|
* }
|
||||||
|
* ```
|
||||||
|
*
|
||||||
|
* ### Run Top-10 vector similarity search
|
||||||
|
* ```typescript
|
||||||
|
* for await (const batch of table.query()
|
||||||
|
* .nearestTo([1, 2, 3])
|
||||||
|
* .refineFactor(5).nprobe(10)
|
||||||
|
* .limit(10)) {
|
||||||
|
* console.log(batch);
|
||||||
|
* }
|
||||||
|
*```
|
||||||
|
*
|
||||||
|
* ### Scan the full dataset
|
||||||
|
* ```typescript
|
||||||
|
* for await (const batch of table.query()) {
|
||||||
|
* console.log(batch);
|
||||||
|
* }
|
||||||
|
*
|
||||||
|
* ### Return the full dataset as Arrow Table
|
||||||
|
* ```typescript
|
||||||
|
* let arrowTbl = await table.query().nearestTo([1.0, 2.0, 0.5, 6.7]).toArrow();
|
||||||
|
* ```
|
||||||
|
*
|
||||||
|
* @returns {@link Query}
|
||||||
|
*/
|
||||||
|
query(): Query {
|
||||||
|
return new Query(this.inner);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Search the table with a given query vector.
|
||||||
|
*
|
||||||
|
* This is a convenience method for preparing an ANN {@link Query}.
|
||||||
|
*/
|
||||||
|
search(vector: number[], column?: string): Query {
|
||||||
|
const q = this.query();
|
||||||
|
q.nearestTo(vector);
|
||||||
|
if (column !== undefined) {
|
||||||
|
q.column(column);
|
||||||
}
|
}
|
||||||
return q;
|
return q;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.5.1
|
current_version = 0.5.4
|
||||||
commit = True
|
commit = True
|
||||||
message = [python] Bump version: {current_version} → {new_version}
|
message = [python] Bump version: {current_version} → {new_version}
|
||||||
tag = True
|
tag = True
|
||||||
|
|||||||
@@ -42,6 +42,12 @@ To run the unit tests:
|
|||||||
pytest
|
pytest
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To run the doc tests:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pytest --doctest-modules lancedb
|
||||||
|
```
|
||||||
|
|
||||||
To run linter and automatically fix all errors:
|
To run linter and automatically fix all errors:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
|
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
import os
|
import os
|
||||||
|
from datetime import timedelta
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
__version__ = importlib.metadata.version("lancedb")
|
__version__ = importlib.metadata.version("lancedb")
|
||||||
@@ -30,6 +31,7 @@ def connect(
|
|||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
region: str = "us-east-1",
|
region: str = "us-east-1",
|
||||||
host_override: Optional[str] = None,
|
host_override: Optional[str] = None,
|
||||||
|
read_consistency_interval: Optional[timedelta] = None,
|
||||||
) -> DBConnection:
|
) -> DBConnection:
|
||||||
"""Connect to a LanceDB database.
|
"""Connect to a LanceDB database.
|
||||||
|
|
||||||
@@ -45,6 +47,18 @@ def connect(
|
|||||||
The region to use for LanceDB Cloud.
|
The region to use for LanceDB Cloud.
|
||||||
host_override: str, optional
|
host_override: str, optional
|
||||||
The override url for LanceDB Cloud.
|
The override url for LanceDB Cloud.
|
||||||
|
read_consistency_interval: timedelta, default None
|
||||||
|
(For LanceDB OSS only)
|
||||||
|
The interval at which to check for updates to the table from other
|
||||||
|
processes. If None, then consistency is not checked. For performance
|
||||||
|
reasons, this is the default. For strong consistency, set this to
|
||||||
|
zero seconds. Then every read will check for updates from other
|
||||||
|
processes. As a compromise, you can set this to a non-zero timedelta
|
||||||
|
for eventual consistency. If more than that interval has passed since
|
||||||
|
the last check, then the table will be checked for updates. Note: this
|
||||||
|
consistency only applies to read operations. Write operations are
|
||||||
|
always consistent.
|
||||||
|
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
@@ -73,4 +87,4 @@ def connect(
|
|||||||
if api_key is None:
|
if api_key is None:
|
||||||
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
|
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
|
||||||
return RemoteDBConnection(uri, api_key, region, host_override)
|
return RemoteDBConnection(uri, api_key, region, host_override)
|
||||||
return LanceDBConnection(uri)
|
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ from .table import LanceTable, Table
|
|||||||
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
|
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
from .common import DATA, URI
|
from .common import DATA, URI
|
||||||
from .embeddings import EmbeddingFunctionConfig
|
from .embeddings import EmbeddingFunctionConfig
|
||||||
from .pydantic import LanceModel
|
from .pydantic import LanceModel
|
||||||
@@ -118,7 +120,7 @@ class DBConnection(EnforceOverrides):
|
|||||||
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
||||||
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
||||||
>>> db.create_table("my_table", data)
|
>>> db.create_table("my_table", data)
|
||||||
LanceTable(my_table)
|
LanceTable(connection=..., name="my_table")
|
||||||
>>> db["my_table"].head()
|
>>> db["my_table"].head()
|
||||||
pyarrow.Table
|
pyarrow.Table
|
||||||
vector: fixed_size_list<item: float>[2]
|
vector: fixed_size_list<item: float>[2]
|
||||||
@@ -139,7 +141,7 @@ class DBConnection(EnforceOverrides):
|
|||||||
... "long": [-122.7, -74.1]
|
... "long": [-122.7, -74.1]
|
||||||
... })
|
... })
|
||||||
>>> db.create_table("table2", data)
|
>>> db.create_table("table2", data)
|
||||||
LanceTable(table2)
|
LanceTable(connection=..., name="table2")
|
||||||
>>> db["table2"].head()
|
>>> db["table2"].head()
|
||||||
pyarrow.Table
|
pyarrow.Table
|
||||||
vector: fixed_size_list<item: float>[2]
|
vector: fixed_size_list<item: float>[2]
|
||||||
@@ -161,7 +163,7 @@ class DBConnection(EnforceOverrides):
|
|||||||
... pa.field("long", pa.float32())
|
... pa.field("long", pa.float32())
|
||||||
... ])
|
... ])
|
||||||
>>> db.create_table("table3", data, schema = custom_schema)
|
>>> db.create_table("table3", data, schema = custom_schema)
|
||||||
LanceTable(table3)
|
LanceTable(connection=..., name="table3")
|
||||||
>>> db["table3"].head()
|
>>> db["table3"].head()
|
||||||
pyarrow.Table
|
pyarrow.Table
|
||||||
vector: fixed_size_list<item: float>[2]
|
vector: fixed_size_list<item: float>[2]
|
||||||
@@ -195,7 +197,7 @@ class DBConnection(EnforceOverrides):
|
|||||||
... pa.field("price", pa.float32()),
|
... pa.field("price", pa.float32()),
|
||||||
... ])
|
... ])
|
||||||
>>> db.create_table("table4", make_batches(), schema=schema)
|
>>> db.create_table("table4", make_batches(), schema=schema)
|
||||||
LanceTable(table4)
|
LanceTable(connection=..., name="table4")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -243,6 +245,16 @@ class LanceDBConnection(DBConnection):
|
|||||||
----------
|
----------
|
||||||
uri: str or Path
|
uri: str or Path
|
||||||
The root uri of the database.
|
The root uri of the database.
|
||||||
|
read_consistency_interval: timedelta, default None
|
||||||
|
The interval at which to check for updates to the table from other
|
||||||
|
processes. If None, then consistency is not checked. For performance
|
||||||
|
reasons, this is the default. For strong consistency, set this to
|
||||||
|
zero seconds. Then every read will check for updates from other
|
||||||
|
processes. As a compromise, you can set this to a non-zero timedelta
|
||||||
|
for eventual consistency. If more than that interval has passed since
|
||||||
|
the last check, then the table will be checked for updates. Note: this
|
||||||
|
consistency only applies to read operations. Write operations are
|
||||||
|
always consistent.
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
@@ -250,22 +262,24 @@ class LanceDBConnection(DBConnection):
|
|||||||
>>> db = lancedb.connect("./.lancedb")
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
>>> db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2},
|
>>> db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2},
|
||||||
... {"vector": [0.5, 1.3], "b": 4}])
|
... {"vector": [0.5, 1.3], "b": 4}])
|
||||||
LanceTable(my_table)
|
LanceTable(connection=..., name="my_table")
|
||||||
>>> db.create_table("another_table", data=[{"vector": [0.4, 0.4], "b": 6}])
|
>>> db.create_table("another_table", data=[{"vector": [0.4, 0.4], "b": 6}])
|
||||||
LanceTable(another_table)
|
LanceTable(connection=..., name="another_table")
|
||||||
>>> sorted(db.table_names())
|
>>> sorted(db.table_names())
|
||||||
['another_table', 'my_table']
|
['another_table', 'my_table']
|
||||||
>>> len(db)
|
>>> len(db)
|
||||||
2
|
2
|
||||||
>>> db["my_table"]
|
>>> db["my_table"]
|
||||||
LanceTable(my_table)
|
LanceTable(connection=..., name="my_table")
|
||||||
>>> "my_table" in db
|
>>> "my_table" in db
|
||||||
True
|
True
|
||||||
>>> db.drop_table("my_table")
|
>>> db.drop_table("my_table")
|
||||||
>>> db.drop_table("another_table")
|
>>> db.drop_table("another_table")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, uri: URI):
|
def __init__(
|
||||||
|
self, uri: URI, *, read_consistency_interval: Optional[timedelta] = None
|
||||||
|
):
|
||||||
if not isinstance(uri, Path):
|
if not isinstance(uri, Path):
|
||||||
scheme = get_uri_scheme(uri)
|
scheme = get_uri_scheme(uri)
|
||||||
is_local = isinstance(uri, Path) or scheme == "file"
|
is_local = isinstance(uri, Path) or scheme == "file"
|
||||||
@@ -277,6 +291,14 @@ class LanceDBConnection(DBConnection):
|
|||||||
self._uri = str(uri)
|
self._uri = str(uri)
|
||||||
|
|
||||||
self._entered = False
|
self._entered = False
|
||||||
|
self.read_consistency_interval = read_consistency_interval
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
val = f"{self.__class__.__name__}({self._uri}"
|
||||||
|
if self.read_consistency_interval is not None:
|
||||||
|
val += f", read_consistency_interval={repr(self.read_consistency_interval)}"
|
||||||
|
val += ")"
|
||||||
|
return val
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def uri(self) -> str:
|
def uri(self) -> str:
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
|
|
||||||
# ruff: noqa: F401
|
# ruff: noqa: F401
|
||||||
from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction
|
from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction
|
||||||
|
from .bedrock import BedRockText
|
||||||
from .cohere import CohereEmbeddingFunction
|
from .cohere import CohereEmbeddingFunction
|
||||||
from .gemini_text import GeminiText
|
from .gemini_text import GeminiText
|
||||||
from .instructor import InstructorEmbeddingFunction
|
from .instructor import InstructorEmbeddingFunction
|
||||||
|
|||||||
223
python/lancedb/embeddings/bedrock.py
Normal file
223
python/lancedb/embeddings/bedrock.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
# Copyright (c) 2023. LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lancedb.pydantic import PYDANTIC_VERSION
|
||||||
|
|
||||||
|
from .base import TextEmbeddingFunction
|
||||||
|
from .registry import register
|
||||||
|
from .utils import TEXT
|
||||||
|
|
||||||
|
|
||||||
|
@register("bedrock-text")
|
||||||
|
class BedRockText(TextEmbeddingFunction):
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name: str, default "amazon.titan-embed-text-v1"
|
||||||
|
The model ID of the bedrock model to use. Supported models for are:
|
||||||
|
- amazon.titan-embed-text-v1
|
||||||
|
- cohere.embed-english-v3
|
||||||
|
- cohere.embed-multilingual-v3
|
||||||
|
region: str, default "us-east-1"
|
||||||
|
Optional name of the AWS Region in which the service should be called.
|
||||||
|
profile_name: str, default None
|
||||||
|
Optional name of the AWS profile to use for calling the Bedrock service.
|
||||||
|
If not specified, the default profile will be used.
|
||||||
|
assumed_role: str, default None
|
||||||
|
Optional ARN of an AWS IAM role to assume for calling the Bedrock service.
|
||||||
|
If not specified, the current active credentials will be used.
|
||||||
|
role_session_name: str, default "lancedb-embeddings"
|
||||||
|
Optional name of the AWS IAM role session to use for calling the Bedrock
|
||||||
|
service. If not specified, "lancedb-embeddings" name will be used.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
import lancedb
|
||||||
|
import pandas as pd
|
||||||
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
|
||||||
|
model = get_registry().get("bedrock-text").create()
|
||||||
|
|
||||||
|
class TextModel(LanceModel):
|
||||||
|
text: str = model.SourceField()
|
||||||
|
vector: Vector(model.ndims()) = model.VectorField()
|
||||||
|
|
||||||
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||||
|
db = lancedb.connect("tmp_path")
|
||||||
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||||
|
|
||||||
|
tbl.add(df)
|
||||||
|
|
||||||
|
rs = tbl.search("hello").limit(1).to_pandas()
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "amazon.titan-embed-text-v1"
|
||||||
|
region: str = "us-east-1"
|
||||||
|
assumed_role: Union[str, None] = None
|
||||||
|
profile_name: Union[str, None] = None
|
||||||
|
role_session_name: str = "lancedb-embeddings"
|
||||||
|
|
||||||
|
if PYDANTIC_VERSION < (2, 0): # Pydantic 1.x compat
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
keep_untouched = (cached_property,)
|
||||||
|
|
||||||
|
def ndims(self):
|
||||||
|
# return len(self._generate_embedding("test"))
|
||||||
|
# TODO: fix hardcoding
|
||||||
|
if self.name == "amazon.titan-embed-text-v1":
|
||||||
|
return 1536
|
||||||
|
elif self.name in {"cohere.embed-english-v3", "cohere.embed-multilingual-v3"}:
|
||||||
|
return 1024
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model name: {self.name}")
|
||||||
|
|
||||||
|
def compute_query_embeddings(
|
||||||
|
self, query: str, *args, **kwargs
|
||||||
|
) -> List[List[float]]:
|
||||||
|
return self.compute_source_embeddings(query)
|
||||||
|
|
||||||
|
def compute_source_embeddings(
|
||||||
|
self, texts: TEXT, *args, **kwargs
|
||||||
|
) -> List[List[float]]:
|
||||||
|
texts = self.sanitize_input(texts)
|
||||||
|
return self.generate_embeddings(texts)
|
||||||
|
|
||||||
|
def generate_embeddings(
|
||||||
|
self, texts: Union[List[str], np.ndarray], *args, **kwargs
|
||||||
|
) -> List[List[float]]:
|
||||||
|
"""
|
||||||
|
Get the embeddings for the given texts
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
texts: list[str] or np.ndarray (of str)
|
||||||
|
The texts to embed
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[list[float]]
|
||||||
|
The embeddings for the given texts
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for text in texts:
|
||||||
|
response = self._generate_embedding(text)
|
||||||
|
results.append(response)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _generate_embedding(self, text: str) -> List[float]:
|
||||||
|
"""
|
||||||
|
Get the embeddings for the given texts
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
texts: str
|
||||||
|
The texts to embed
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[float]
|
||||||
|
The embeddings for the given texts
|
||||||
|
"""
|
||||||
|
# format input body for provider
|
||||||
|
provider = self.name.split(".")[0]
|
||||||
|
_model_kwargs = {}
|
||||||
|
input_body = {**_model_kwargs}
|
||||||
|
if provider == "cohere":
|
||||||
|
if "input_type" not in input_body.keys():
|
||||||
|
input_body["input_type"] = "search_document"
|
||||||
|
input_body["texts"] = [text]
|
||||||
|
else:
|
||||||
|
# includes common provider == "amazon"
|
||||||
|
input_body["inputText"] = text
|
||||||
|
body = json.dumps(input_body)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# invoke bedrock API
|
||||||
|
response = self.client.invoke_model(
|
||||||
|
body=body,
|
||||||
|
modelId=self.name,
|
||||||
|
accept="application/json",
|
||||||
|
contentType="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
# format output based on provider
|
||||||
|
response_body = json.loads(response.get("body").read())
|
||||||
|
if provider == "cohere":
|
||||||
|
return response_body.get("embeddings")[0]
|
||||||
|
else:
|
||||||
|
# includes common provider == "amazon"
|
||||||
|
return response_body.get("embedding")
|
||||||
|
except Exception as e:
|
||||||
|
help_txt = """
|
||||||
|
boto3 client failed to invoke the bedrock API. In case of
|
||||||
|
AWS credentials error:
|
||||||
|
- Please check your AWS credentials and ensure that you have access.
|
||||||
|
You can set up aws credentials using `aws configure` command and
|
||||||
|
verify by running `aws sts get-caller-identity` in your terminal.
|
||||||
|
"""
|
||||||
|
raise ValueError(f"Error raised by boto3 client: {e}. \n {help_txt}")
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def client(self):
|
||||||
|
"""Create a boto3 client for Amazon Bedrock service
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
boto3.client
|
||||||
|
The boto3 client for Amazon Bedrock service
|
||||||
|
"""
|
||||||
|
botocore = self.safe_import("botocore")
|
||||||
|
boto3 = self.safe_import("boto3")
|
||||||
|
|
||||||
|
session_kwargs = {"region_name": self.region}
|
||||||
|
client_kwargs = {**session_kwargs}
|
||||||
|
|
||||||
|
if self.profile_name:
|
||||||
|
session_kwargs["profile_name"] = self.profile_name
|
||||||
|
|
||||||
|
retry_config = botocore.config.Config(
|
||||||
|
region_name=self.region,
|
||||||
|
retries={
|
||||||
|
"max_attempts": 0, # disable this as retries retries are handled
|
||||||
|
"mode": "standard",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
session = (
|
||||||
|
boto3.Session(**session_kwargs) if self.profile_name else boto3.Session()
|
||||||
|
)
|
||||||
|
if self.assumed_role: # if not using default credentials
|
||||||
|
sts = session.client("sts")
|
||||||
|
response = sts.assume_role(
|
||||||
|
RoleArn=str(self.assumed_role),
|
||||||
|
RoleSessionName=self.role_session_name,
|
||||||
|
)
|
||||||
|
client_kwargs["aws_access_key_id"] = response["Credentials"]["AccessKeyId"]
|
||||||
|
client_kwargs["aws_secret_access_key"] = response["Credentials"][
|
||||||
|
"SecretAccessKey"
|
||||||
|
]
|
||||||
|
client_kwargs["aws_session_token"] = response["Credentials"]["SessionToken"]
|
||||||
|
|
||||||
|
service_name = "bedrock-runtime"
|
||||||
|
|
||||||
|
bedrock_client = session.client(
|
||||||
|
service_name=service_name, config=retry_config, **client_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return bedrock_client
|
||||||
130
python/lancedb/embeddings/gte.py
Normal file
130
python/lancedb/embeddings/gte.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
# Copyright (c) 2023. LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .base import TextEmbeddingFunction
|
||||||
|
from .registry import register
|
||||||
|
from .utils import weak_lru
|
||||||
|
|
||||||
|
|
||||||
|
@register("gte-text")
|
||||||
|
class GteEmbeddings(TextEmbeddingFunction):
|
||||||
|
"""
|
||||||
|
An embedding function that uses GTE-LARGE MLX format(for Apple silicon devices only)
|
||||||
|
as well as the standard cpu/gpu version from: https://huggingface.co/thenlper/gte-large.
|
||||||
|
|
||||||
|
For Apple users, you will need the mlx package insalled, which can be done with:
|
||||||
|
pip install mlx
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name: str, default "thenlper/gte-large"
|
||||||
|
The name of the model to use.
|
||||||
|
device: str, default "cpu"
|
||||||
|
Sets the device type for the model.
|
||||||
|
normalize: str, default "True"
|
||||||
|
Controls normalize param in encode function for the transformer.
|
||||||
|
mlx: bool, default False
|
||||||
|
Controls which model to use. False for gte-large,True for the mlx version.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
import lancedb
|
||||||
|
import lancedb.embeddings.gte
|
||||||
|
from lancedb.embeddings import get_registry
|
||||||
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
model = get_registry().get("gte-text").create() # mlx=True for Apple silicon
|
||||||
|
class TextModel(LanceModel):
|
||||||
|
text: str = model.SourceField()
|
||||||
|
vector: Vector(model.ndims()) = model.VectorField()
|
||||||
|
|
||||||
|
df = pd.DataFrame({"text": ["hi hello sayonara", "goodbye world"]})
|
||||||
|
db = lancedb.connect("~/.lancedb")
|
||||||
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||||
|
|
||||||
|
tbl.add(df)
|
||||||
|
rs = tbl.search("hello").limit(1).to_pandas()
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "thenlper/gte-large"
|
||||||
|
device: str = "cpu"
|
||||||
|
normalize: bool = True
|
||||||
|
mlx: bool = False
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._ndims = None
|
||||||
|
if kwargs:
|
||||||
|
self.mlx = kwargs.get("mlx", False)
|
||||||
|
if self.mlx is True:
|
||||||
|
self.name = "gte-mlx"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding_model(self):
|
||||||
|
"""
|
||||||
|
Get the embedding model specified by the flag,
|
||||||
|
name and device. This is cached so that the model is only loaded
|
||||||
|
once per process.
|
||||||
|
"""
|
||||||
|
return self.get_embedding_model()
|
||||||
|
|
||||||
|
def ndims(self):
|
||||||
|
if self.mlx is True:
|
||||||
|
self._ndims = self.embedding_model.dims
|
||||||
|
if self._ndims is None:
|
||||||
|
self._ndims = len(self.generate_embeddings("foo")[0])
|
||||||
|
return self._ndims
|
||||||
|
|
||||||
|
def generate_embeddings(
|
||||||
|
self, texts: Union[List[str], np.ndarray]
|
||||||
|
) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Get the embeddings for the given texts.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
texts: list[str] or np.ndarray (of str)
|
||||||
|
The texts to embed
|
||||||
|
"""
|
||||||
|
if self.mlx is True:
|
||||||
|
return self.embedding_model.run(list(texts)).tolist()
|
||||||
|
|
||||||
|
return self.embedding_model.encode(
|
||||||
|
list(texts),
|
||||||
|
convert_to_numpy=True,
|
||||||
|
normalize_embeddings=self.normalize,
|
||||||
|
).tolist()
|
||||||
|
|
||||||
|
@weak_lru(maxsize=1)
|
||||||
|
def get_embedding_model(self):
|
||||||
|
"""
|
||||||
|
Get the embedding model specified by the flag,
|
||||||
|
name and device. This is cached so that the model is only loaded
|
||||||
|
once per process.
|
||||||
|
"""
|
||||||
|
if self.mlx is True:
|
||||||
|
from .gte_mlx_model import Model
|
||||||
|
|
||||||
|
return Model()
|
||||||
|
else:
|
||||||
|
sentence_transformers = self.safe_import(
|
||||||
|
"sentence_transformers", "sentence-transformers"
|
||||||
|
)
|
||||||
|
return sentence_transformers.SentenceTransformer(
|
||||||
|
self.name, device=self.device
|
||||||
|
)
|
||||||
154
python/lancedb/embeddings/gte_mlx_model.py
Normal file
154
python/lancedb/embeddings/gte_mlx_model.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
import json
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from transformers import BertTokenizer
|
||||||
|
|
||||||
|
try:
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("You need to install MLX to use this model use - pip install mlx")
|
||||||
|
|
||||||
|
|
||||||
|
def average_pool(last_hidden_state: mx.array, attention_mask: mx.array) -> mx.array:
|
||||||
|
last_hidden = mx.multiply(last_hidden_state, attention_mask[..., None])
|
||||||
|
return last_hidden.sum(axis=1) / attention_mask.sum(axis=1)[..., None]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConfig(BaseModel):
|
||||||
|
dim: int = 1024
|
||||||
|
num_attention_heads: int = 16
|
||||||
|
num_hidden_layers: int = 24
|
||||||
|
vocab_size: int = 30522
|
||||||
|
attention_probs_dropout_prob: float = 0.1
|
||||||
|
hidden_dropout_prob: float = 0.1
|
||||||
|
layer_norm_eps: float = 1e-12
|
||||||
|
max_position_embeddings: int = 512
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoderLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
A transformer encoder layer with (the original BERT) post-normalization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dims: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_dims: Optional[int] = None,
|
||||||
|
layer_norm_eps: float = 1e-12,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
mlp_dims = mlp_dims or dims * 4
|
||||||
|
self.attention = nn.MultiHeadAttention(dims, num_heads, bias=True)
|
||||||
|
self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps)
|
||||||
|
self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps)
|
||||||
|
self.linear1 = nn.Linear(dims, mlp_dims)
|
||||||
|
self.linear2 = nn.Linear(mlp_dims, dims)
|
||||||
|
self.gelu = nn.GELU()
|
||||||
|
|
||||||
|
def __call__(self, x, mask):
|
||||||
|
attention_out = self.attention(x, x, x, mask)
|
||||||
|
add_and_norm = self.ln1(x + attention_out)
|
||||||
|
|
||||||
|
ff = self.linear1(add_and_norm)
|
||||||
|
ff_gelu = self.gelu(ff)
|
||||||
|
ff_out = self.linear2(ff_gelu)
|
||||||
|
x = self.ln2(ff_out + add_and_norm)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = [
|
||||||
|
TransformerEncoderLayer(dims, num_heads, mlp_dims)
|
||||||
|
for i in range(num_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, x, mask):
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x, mask)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BertEmbeddings(nn.Module):
|
||||||
|
def __init__(self, config: ModelConfig):
|
||||||
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim)
|
||||||
|
self.token_type_embeddings = nn.Embedding(2, config.dim)
|
||||||
|
self.position_embeddings = nn.Embedding(
|
||||||
|
config.max_position_embeddings, config.dim
|
||||||
|
)
|
||||||
|
self.norm = nn.LayerNorm(config.dim, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
|
def __call__(self, input_ids: mx.array, token_type_ids: mx.array) -> mx.array:
|
||||||
|
words = self.word_embeddings(input_ids)
|
||||||
|
position = self.position_embeddings(
|
||||||
|
mx.broadcast_to(mx.arange(input_ids.shape[1]), input_ids.shape)
|
||||||
|
)
|
||||||
|
token_types = self.token_type_embeddings(token_type_ids)
|
||||||
|
|
||||||
|
embeddings = position + words + token_types
|
||||||
|
return self.norm(embeddings)
|
||||||
|
|
||||||
|
|
||||||
|
class Bert(nn.Module):
|
||||||
|
def __init__(self, config: ModelConfig):
|
||||||
|
self.embeddings = BertEmbeddings(config)
|
||||||
|
self.encoder = TransformerEncoder(
|
||||||
|
num_layers=config.num_hidden_layers,
|
||||||
|
dims=config.dim,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
)
|
||||||
|
self.pooler = nn.Linear(config.dim, config.dim)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_ids: mx.array,
|
||||||
|
token_type_ids: mx.array,
|
||||||
|
attention_mask: mx.array = None,
|
||||||
|
) -> tuple[mx.array, mx.array]:
|
||||||
|
x = self.embeddings(input_ids, token_type_ids)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# convert 0's to -infs, 1's to 0's, and make it broadcastable
|
||||||
|
attention_mask = mx.log(attention_mask)
|
||||||
|
attention_mask = mx.expand_dims(attention_mask, (1, 2))
|
||||||
|
|
||||||
|
y = self.encoder(x, attention_mask)
|
||||||
|
return y, mx.tanh(self.pooler(y[:, 0]))
|
||||||
|
|
||||||
|
|
||||||
|
class Model:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# get converted embedding model
|
||||||
|
model_path = snapshot_download(repo_id="vegaluisjose/mlx-rag")
|
||||||
|
with open(f"{model_path}/config.json") as f:
|
||||||
|
model_config = ModelConfig(**json.load(f))
|
||||||
|
self.dims = model_config.dim
|
||||||
|
self.model = Bert(model_config)
|
||||||
|
self.model.load_weights(f"{model_path}/model.npz")
|
||||||
|
self.tokenizer = BertTokenizer.from_pretrained("thenlper/gte-large")
|
||||||
|
self.embeddings = []
|
||||||
|
|
||||||
|
def run(self, input_text: List[str]) -> mx.array:
|
||||||
|
tokens = self.tokenizer(input_text, return_tensors="np", padding=True)
|
||||||
|
tokens = {key: mx.array(v) for key, v in tokens.items()}
|
||||||
|
|
||||||
|
last_hidden_state, _ = self.model(**tokens)
|
||||||
|
|
||||||
|
embeddings = average_pool(
|
||||||
|
last_hidden_state, tokens["attention_mask"].astype(mx.float32)
|
||||||
|
)
|
||||||
|
self.embeddings = (
|
||||||
|
embeddings / mx.linalg.norm(embeddings, ord=2, axis=1)[..., None]
|
||||||
|
)
|
||||||
|
|
||||||
|
return np.array(embeddings.astype(mx.float32))
|
||||||
@@ -12,7 +12,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import List, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -30,10 +30,21 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
name: str = "text-embedding-ada-002"
|
name: str = "text-embedding-ada-002"
|
||||||
|
dim: Optional[int] = None
|
||||||
|
|
||||||
def ndims(self):
|
def ndims(self):
|
||||||
# TODO don't hardcode this
|
return self._ndims
|
||||||
return 1536
|
|
||||||
|
@cached_property
|
||||||
|
def _ndims(self):
|
||||||
|
if self.name == "text-embedding-ada-002":
|
||||||
|
return 1536
|
||||||
|
elif self.name == "text-embedding-3-large":
|
||||||
|
return self.dim or 3072
|
||||||
|
elif self.name == "text-embedding-3-small":
|
||||||
|
return self.dim or 1536
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model name {self.name}")
|
||||||
|
|
||||||
def generate_embeddings(
|
def generate_embeddings(
|
||||||
self, texts: Union[List[str], np.ndarray]
|
self, texts: Union[List[str], np.ndarray]
|
||||||
@@ -47,7 +58,12 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
|||||||
The texts to embed
|
The texts to embed
|
||||||
"""
|
"""
|
||||||
# TODO retry, rate limit, token limit
|
# TODO retry, rate limit, token limit
|
||||||
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
|
if self.name == "text-embedding-ada-002":
|
||||||
|
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
|
||||||
|
else:
|
||||||
|
rs = self._openai_client.embeddings.create(
|
||||||
|
input=texts, model=self.name, dimensions=self.ndims()
|
||||||
|
)
|
||||||
return [v.embedding for v in rs.data]
|
return [v.embedding for v in rs.data]
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
|
|||||||
107
python/lancedb/merge.py
Normal file
107
python/lancedb/merge.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
# Copyright 2023 LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .common import DATA
|
||||||
|
|
||||||
|
|
||||||
|
class LanceMergeInsertBuilder(object):
|
||||||
|
"""Builder for a LanceDB merge insert operation
|
||||||
|
|
||||||
|
See [`merge_insert`][lancedb.table.Table.merge_insert] for
|
||||||
|
more context
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, table: "Table", on: List[str]): # noqa: F821
|
||||||
|
# Do not put a docstring here. This method should be hidden
|
||||||
|
# from API docs. Users should use merge_insert to create
|
||||||
|
# this object.
|
||||||
|
self._table = table
|
||||||
|
self._on = on
|
||||||
|
self._when_matched_update_all = False
|
||||||
|
self._when_matched_update_all_condition = None
|
||||||
|
self._when_not_matched_insert_all = False
|
||||||
|
self._when_not_matched_by_source_delete = False
|
||||||
|
self._when_not_matched_by_source_condition = None
|
||||||
|
|
||||||
|
def when_matched_update_all(
|
||||||
|
self, *, where: Optional[str] = None
|
||||||
|
) -> LanceMergeInsertBuilder:
|
||||||
|
"""
|
||||||
|
Rows that exist in both the source table (new data) and
|
||||||
|
the target table (old data) will be updated, replacing
|
||||||
|
the old row with the corresponding matching row.
|
||||||
|
|
||||||
|
If there are multiple matches then the behavior is undefined.
|
||||||
|
Currently this causes multiple copies of the row to be created
|
||||||
|
but that behavior is subject to change.
|
||||||
|
"""
|
||||||
|
self._when_matched_update_all = True
|
||||||
|
self._when_matched_update_all_condition = where
|
||||||
|
return self
|
||||||
|
|
||||||
|
def when_not_matched_insert_all(self) -> LanceMergeInsertBuilder:
|
||||||
|
"""
|
||||||
|
Rows that exist only in the source table (new data) should
|
||||||
|
be inserted into the target table.
|
||||||
|
"""
|
||||||
|
self._when_not_matched_insert_all = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def when_not_matched_by_source_delete(
|
||||||
|
self, condition: Optional[str] = None
|
||||||
|
) -> LanceMergeInsertBuilder:
|
||||||
|
"""
|
||||||
|
Rows that exist only in the target table (old data) will be
|
||||||
|
deleted. An optional condition can be provided to limit what
|
||||||
|
data is deleted.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
condition: Optional[str], default None
|
||||||
|
If None then all such rows will be deleted. Otherwise the
|
||||||
|
condition will be used as an SQL filter to limit what rows
|
||||||
|
are deleted.
|
||||||
|
"""
|
||||||
|
self._when_not_matched_by_source_delete = True
|
||||||
|
if condition is not None:
|
||||||
|
self._when_not_matched_by_source_condition = condition
|
||||||
|
return self
|
||||||
|
|
||||||
|
def execute(
|
||||||
|
self,
|
||||||
|
new_data: DATA,
|
||||||
|
on_bad_vectors: str = "error",
|
||||||
|
fill_value: float = 0.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Executes the merge insert operation
|
||||||
|
|
||||||
|
Nothing is returned but the [`Table`][lancedb.table.Table] is updated
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
new_data: DATA
|
||||||
|
New records which will be matched against the existing records
|
||||||
|
to potentially insert or update into the table. This parameter
|
||||||
|
can be anything you use for [`add`][lancedb.table.Table.add]
|
||||||
|
on_bad_vectors: str, default "error"
|
||||||
|
What to do if any of the vectors are not the same size or contains NaNs.
|
||||||
|
One of "error", "drop", "fill".
|
||||||
|
fill_value: float, default 0.
|
||||||
|
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||||
|
"""
|
||||||
|
self._table._do_merge(self, new_data, on_bad_vectors, fill_value)
|
||||||
@@ -304,7 +304,7 @@ class LanceModel(pydantic.BaseModel):
|
|||||||
... name: str
|
... name: str
|
||||||
... vector: Vector(2)
|
... vector: Vector(2)
|
||||||
...
|
...
|
||||||
>>> db = lancedb.connect("/tmp")
|
>>> db = lancedb.connect("./example")
|
||||||
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema())
|
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema())
|
||||||
>>> table.add([
|
>>> table.add([
|
||||||
... TestModel(name="test", vector=[1.0, 2.0])
|
... TestModel(name="test", vector=[1.0, 2.0])
|
||||||
|
|||||||
@@ -14,8 +14,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, List, Literal, Optional, Type, Union
|
from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import deprecation
|
import deprecation
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -23,7 +24,9 @@ import pyarrow as pa
|
|||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
from .common import VECTOR_COLUMN_NAME
|
from .common import VEC, VECTOR_COLUMN_NAME
|
||||||
|
from .rerankers.base import Reranker
|
||||||
|
from .rerankers.linear_combination import LinearCombinationReranker
|
||||||
from .util import safe_import_pandas
|
from .util import safe_import_pandas
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -99,6 +102,8 @@ class Query(pydantic.BaseModel):
|
|||||||
# Refine factor.
|
# Refine factor.
|
||||||
refine_factor: Optional[int] = None
|
refine_factor: Optional[int] = None
|
||||||
|
|
||||||
|
with_row_id: bool = False
|
||||||
|
|
||||||
|
|
||||||
class LanceQueryBuilder(ABC):
|
class LanceQueryBuilder(ABC):
|
||||||
"""Build LanceDB query based on specific query type:
|
"""Build LanceDB query based on specific query type:
|
||||||
@@ -109,19 +114,26 @@ class LanceQueryBuilder(ABC):
|
|||||||
def create(
|
def create(
|
||||||
cls,
|
cls,
|
||||||
table: "Table",
|
table: "Table",
|
||||||
query: Optional[Union[np.ndarray, str, "PIL.Image.Image"]],
|
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
|
||||||
query_type: str,
|
query_type: str,
|
||||||
vector_column_name: str,
|
vector_column_name: str,
|
||||||
) -> LanceQueryBuilder:
|
) -> LanceQueryBuilder:
|
||||||
if query is None:
|
if query is None:
|
||||||
return LanceEmptyQueryBuilder(table)
|
return LanceEmptyQueryBuilder(table)
|
||||||
|
|
||||||
# convert "auto" query_type to "vector" or "fts"
|
if query_type == "hybrid":
|
||||||
# and convert the query to vector if needed
|
# hybrid fts and vector query
|
||||||
|
return LanceHybridQueryBuilder(table, query, vector_column_name)
|
||||||
|
|
||||||
|
# convert "auto" query_type to "vector", "fts"
|
||||||
|
# or "hybrid" and convert the query to vector if needed
|
||||||
query, query_type = cls._resolve_query(
|
query, query_type = cls._resolve_query(
|
||||||
table, query, query_type, vector_column_name
|
table, query, query_type, vector_column_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if query_type == "hybrid":
|
||||||
|
return LanceHybridQueryBuilder(table, query, vector_column_name)
|
||||||
|
|
||||||
if isinstance(query, str):
|
if isinstance(query, str):
|
||||||
# fts
|
# fts
|
||||||
return LanceFtsQueryBuilder(table, query)
|
return LanceFtsQueryBuilder(table, query)
|
||||||
@@ -144,17 +156,13 @@ class LanceQueryBuilder(ABC):
|
|||||||
raise TypeError(f"'fts' queries must be a string: {type(query)}")
|
raise TypeError(f"'fts' queries must be a string: {type(query)}")
|
||||||
return query, query_type
|
return query, query_type
|
||||||
elif query_type == "vector":
|
elif query_type == "vector":
|
||||||
if not isinstance(query, (list, np.ndarray)):
|
query = cls._query_to_vector(table, query, vector_column_name)
|
||||||
conf = table.embedding_functions.get(vector_column_name)
|
|
||||||
if conf is not None:
|
|
||||||
query = conf.function.compute_query_embeddings_with_retry(query)[0]
|
|
||||||
else:
|
|
||||||
msg = f"No embedding function for {vector_column_name}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
return query, query_type
|
return query, query_type
|
||||||
elif query_type == "auto":
|
elif query_type == "auto":
|
||||||
if isinstance(query, (list, np.ndarray)):
|
if isinstance(query, (list, np.ndarray)):
|
||||||
return query, "vector"
|
return query, "vector"
|
||||||
|
if isinstance(query, tuple):
|
||||||
|
return query, "hybrid"
|
||||||
else:
|
else:
|
||||||
conf = table.embedding_functions.get(vector_column_name)
|
conf = table.embedding_functions.get(vector_column_name)
|
||||||
if conf is not None:
|
if conf is not None:
|
||||||
@@ -167,11 +175,23 @@ class LanceQueryBuilder(ABC):
|
|||||||
f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}"
|
f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _query_to_vector(cls, table, query, vector_column_name):
|
||||||
|
if isinstance(query, (list, np.ndarray)):
|
||||||
|
return query
|
||||||
|
conf = table.embedding_functions.get(vector_column_name)
|
||||||
|
if conf is not None:
|
||||||
|
return conf.function.compute_query_embeddings_with_retry(query)[0]
|
||||||
|
else:
|
||||||
|
msg = f"No embedding function for {vector_column_name}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
def __init__(self, table: "Table"):
|
def __init__(self, table: "Table"):
|
||||||
self._table = table
|
self._table = table
|
||||||
self._limit = 10
|
self._limit = 10
|
||||||
self._columns = None
|
self._columns = None
|
||||||
self._where = None
|
self._where = None
|
||||||
|
self._with_row_id = False
|
||||||
|
|
||||||
@deprecation.deprecated(
|
@deprecation.deprecated(
|
||||||
deprecated_in="0.3.1",
|
deprecated_in="0.3.1",
|
||||||
@@ -341,6 +361,22 @@ class LanceQueryBuilder(ABC):
|
|||||||
self._prefilter = prefilter
|
self._prefilter = prefilter
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def with_row_id(self, with_row_id: bool) -> LanceQueryBuilder:
|
||||||
|
"""Set whether to return row ids.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
with_row_id: bool
|
||||||
|
If True, return _rowid column in the results.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
LanceQueryBuilder
|
||||||
|
The LanceQueryBuilder object.
|
||||||
|
"""
|
||||||
|
self._with_row_id = with_row_id
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class LanceVectorQueryBuilder(LanceQueryBuilder):
|
class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||||
"""
|
"""
|
||||||
@@ -459,6 +495,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
nprobes=self._nprobes,
|
nprobes=self._nprobes,
|
||||||
refine_factor=self._refine_factor,
|
refine_factor=self._refine_factor,
|
||||||
vector_column=self._vector_column,
|
vector_column=self._vector_column,
|
||||||
|
with_row_id=self._with_row_id,
|
||||||
)
|
)
|
||||||
return self._table._execute_query(query)
|
return self._table._execute_query(query)
|
||||||
|
|
||||||
@@ -568,6 +605,10 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
|||||||
ds = lance.write_dataset(output_tbl, tmp)
|
ds = lance.write_dataset(output_tbl, tmp)
|
||||||
output_tbl = ds.to_table(filter=self._where)
|
output_tbl = ds.to_table(filter=self._where)
|
||||||
|
|
||||||
|
if self._with_row_id:
|
||||||
|
# Need to set this to uint explicitly as vector results are in uint64
|
||||||
|
row_ids = pa.array(row_ids, type=pa.uint64())
|
||||||
|
output_tbl = output_tbl.append_column("_rowid", row_ids)
|
||||||
return output_tbl
|
return output_tbl
|
||||||
|
|
||||||
|
|
||||||
@@ -579,3 +620,265 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
|||||||
filter=self._where,
|
filter=self._where,
|
||||||
limit=self._limit,
|
limit=self._limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||||
|
def __init__(self, table: "Table", query: str, vector_column: str):
|
||||||
|
super().__init__(table)
|
||||||
|
self._validate_fts_index()
|
||||||
|
vector_query, fts_query = self._validate_query(query)
|
||||||
|
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
|
||||||
|
vector_query = self._query_to_vector(table, vector_query, vector_column)
|
||||||
|
self._vector_query = LanceVectorQueryBuilder(table, vector_query, vector_column)
|
||||||
|
self._norm = "score"
|
||||||
|
self._reranker = LinearCombinationReranker(weight=0.7, fill=1.0)
|
||||||
|
|
||||||
|
def _validate_fts_index(self):
|
||||||
|
if self._table._get_fts_index_path() is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Please create a full-text search index " "to perform hybrid search."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _validate_query(self, query):
|
||||||
|
# Temp hack to support vectorized queries for hybrid search
|
||||||
|
if isinstance(query, str):
|
||||||
|
return query, query
|
||||||
|
elif isinstance(query, tuple):
|
||||||
|
if len(query) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"The query must be a tuple of (vector_query, fts_query)."
|
||||||
|
)
|
||||||
|
if not isinstance(query[0], (list, np.ndarray, pa.Array, pa.ChunkedArray)):
|
||||||
|
raise ValueError(f"The vector query must be one of {VEC}.")
|
||||||
|
if not isinstance(query[1], str):
|
||||||
|
raise ValueError("The fts query must be a string.")
|
||||||
|
return query[0], query[1]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"The query must be either a string or a tuple of (vector, string)."
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_arrow(self) -> pa.Table:
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)
|
||||||
|
vector_future = executor.submit(
|
||||||
|
self._vector_query.with_row_id(True).to_arrow
|
||||||
|
)
|
||||||
|
fts_results = fts_future.result()
|
||||||
|
vector_results = vector_future.result()
|
||||||
|
|
||||||
|
# convert to ranks first if needed
|
||||||
|
if self._norm == "rank":
|
||||||
|
vector_results = self._rank(vector_results, "_distance")
|
||||||
|
fts_results = self._rank(fts_results, "score")
|
||||||
|
# normalize the scores to be between 0 and 1, 0 being most relevant
|
||||||
|
vector_results = self._normalize_scores(vector_results, "_distance")
|
||||||
|
|
||||||
|
# In fts higher scores represent relevance. Not inverting them here as
|
||||||
|
# rerankers might need to preserve this score to support `return_score="all"`
|
||||||
|
fts_results = self._normalize_scores(fts_results, "score")
|
||||||
|
|
||||||
|
results = self._reranker.rerank_hybrid(
|
||||||
|
self._fts_query._query, vector_results, fts_results
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(results, pa.Table): # Enforce type
|
||||||
|
raise TypeError(
|
||||||
|
f"rerank_hybrid must return a pyarrow.Table, got {type(results)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# apply limit after reranking
|
||||||
|
results = results.slice(length=self._limit)
|
||||||
|
|
||||||
|
if not self._with_row_id:
|
||||||
|
results = results.drop(["_rowid"])
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _rank(self, results: pa.Table, column: str, ascending: bool = True):
|
||||||
|
if len(results) == 0:
|
||||||
|
return results
|
||||||
|
# Get the _score column from results
|
||||||
|
scores = results.column(column).to_numpy()
|
||||||
|
sort_indices = np.argsort(scores)
|
||||||
|
if not ascending:
|
||||||
|
sort_indices = sort_indices[::-1]
|
||||||
|
ranks = np.empty_like(sort_indices)
|
||||||
|
ranks[sort_indices] = np.arange(len(scores)) + 1
|
||||||
|
# replace the _score column with the ranks
|
||||||
|
_score_idx = results.column_names.index(column)
|
||||||
|
results = results.set_column(
|
||||||
|
_score_idx, column, pa.array(ranks, type=pa.float32())
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _normalize_scores(self, results: pa.Table, column: str, invert=False):
|
||||||
|
if len(results) == 0:
|
||||||
|
return results
|
||||||
|
# Get the _score column from results
|
||||||
|
scores = results.column(column).to_numpy()
|
||||||
|
# normalize the scores by subtracting the min and dividing by the max
|
||||||
|
max, min = np.max(scores), np.min(scores)
|
||||||
|
if np.isclose(max, min):
|
||||||
|
rng = max
|
||||||
|
else:
|
||||||
|
rng = max - min
|
||||||
|
scores = (scores - min) / rng
|
||||||
|
if invert:
|
||||||
|
scores = 1 - scores
|
||||||
|
# replace the _score column with the ranks
|
||||||
|
_score_idx = results.column_names.index(column)
|
||||||
|
results = results.set_column(
|
||||||
|
_score_idx, column, pa.array(scores, type=pa.float32())
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def rerank(
|
||||||
|
self,
|
||||||
|
normalize="score",
|
||||||
|
reranker: Reranker = LinearCombinationReranker(weight=0.7, fill=1.0),
|
||||||
|
) -> LanceHybridQueryBuilder:
|
||||||
|
"""
|
||||||
|
Rerank the hybrid search results using the specified reranker. The reranker
|
||||||
|
must be an instance of Reranker class.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
normalize: str, default "score"
|
||||||
|
The method to normalize the scores. Can be "rank" or "score". If "rank",
|
||||||
|
the scores are converted to ranks and then normalized. If "score", the
|
||||||
|
scores are normalized directly.
|
||||||
|
reranker: Reranker, default LinearCombinationReranker(weight=0.7, fill=1.0)
|
||||||
|
The reranker to use. Must be an instance of Reranker class.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
LanceHybridQueryBuilder
|
||||||
|
The LanceHybridQueryBuilder object.
|
||||||
|
"""
|
||||||
|
if normalize not in ["rank", "score"]:
|
||||||
|
raise ValueError("normalize must be 'rank' or 'score'.")
|
||||||
|
if reranker and not isinstance(reranker, Reranker):
|
||||||
|
raise ValueError("reranker must be an instance of Reranker class.")
|
||||||
|
|
||||||
|
self._norm = normalize
|
||||||
|
self._reranker = reranker
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def limit(self, limit: int) -> LanceHybridQueryBuilder:
|
||||||
|
"""
|
||||||
|
Set the maximum number of results to return for both vector and fts search
|
||||||
|
components.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
limit: int
|
||||||
|
The maximum number of results to return.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
LanceHybridQueryBuilder
|
||||||
|
The LanceHybridQueryBuilder object.
|
||||||
|
"""
|
||||||
|
self._vector_query.limit(limit)
|
||||||
|
self._fts_query.limit(limit)
|
||||||
|
self._limit = limit
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def select(self, columns: list) -> LanceHybridQueryBuilder:
|
||||||
|
"""
|
||||||
|
Set the columns to return for both vector and fts search.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
columns: list
|
||||||
|
The columns to return.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
LanceHybridQueryBuilder
|
||||||
|
The LanceHybridQueryBuilder object.
|
||||||
|
"""
|
||||||
|
self._vector_query.select(columns)
|
||||||
|
self._fts_query.select(columns)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def where(self, where: str, prefilter: bool = False) -> LanceHybridQueryBuilder:
|
||||||
|
"""
|
||||||
|
Set the where clause for both vector and fts search.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
where: str
|
||||||
|
The where clause which is a valid SQL where clause. See
|
||||||
|
`Lance filter pushdown <https://lancedb.github.io/lance/read_and_write.html#filter-push-down>`_
|
||||||
|
for valid SQL expressions.
|
||||||
|
|
||||||
|
prefilter: bool, default False
|
||||||
|
If True, apply the filter before vector search, otherwise the
|
||||||
|
filter is applied on the result of vector search.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
LanceHybridQueryBuilder
|
||||||
|
The LanceHybridQueryBuilder object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._vector_query.where(where, prefilter=prefilter)
|
||||||
|
self._fts_query.where(where)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def metric(self, metric: Literal["L2", "cosine"]) -> LanceHybridQueryBuilder:
|
||||||
|
"""
|
||||||
|
Set the distance metric to use for vector search.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
metric: "L2" or "cosine"
|
||||||
|
The distance metric to use. By default "L2" is used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
LanceHybridQueryBuilder
|
||||||
|
The LanceHybridQueryBuilder object.
|
||||||
|
"""
|
||||||
|
self._vector_query.metric(metric)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def nprobes(self, nprobes: int) -> LanceHybridQueryBuilder:
|
||||||
|
"""
|
||||||
|
Set the number of probes to use for vector search.
|
||||||
|
|
||||||
|
Higher values will yield better recall (more likely to find vectors if
|
||||||
|
they exist) at the expense of latency.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
nprobes: int
|
||||||
|
The number of probes to use.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
LanceHybridQueryBuilder
|
||||||
|
The LanceHybridQueryBuilder object.
|
||||||
|
"""
|
||||||
|
self._vector_query.nprobes(nprobes)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def refine_factor(self, refine_factor: int) -> LanceHybridQueryBuilder:
|
||||||
|
"""
|
||||||
|
Refine the vector search results by reading extra elements and
|
||||||
|
re-ranking them in memory.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
refine_factor: int
|
||||||
|
The refine factor to use.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
LanceHybridQueryBuilder
|
||||||
|
The LanceHybridQueryBuilder object.
|
||||||
|
"""
|
||||||
|
self._vector_query.refine_factor(refine_factor)
|
||||||
|
return self
|
||||||
|
|||||||
@@ -13,6 +13,8 @@
|
|||||||
|
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
@@ -20,6 +22,8 @@ import attrs
|
|||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import requests
|
import requests
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from requests.adapters import HTTPAdapter
|
||||||
|
from urllib3 import Retry
|
||||||
|
|
||||||
from lancedb.common import Credential
|
from lancedb.common import Credential
|
||||||
from lancedb.remote import VectorQuery, VectorQueryResult
|
from lancedb.remote import VectorQuery, VectorQueryResult
|
||||||
@@ -57,6 +61,10 @@ class RestfulLanceDBClient:
|
|||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def session(self) -> requests.Session:
|
def session(self) -> requests.Session:
|
||||||
sess = requests.Session()
|
sess = requests.Session()
|
||||||
|
|
||||||
|
retry_adapter_instance = retry_adapter(retry_adapter_options())
|
||||||
|
sess.mount(urljoin(self.url, "/v1/table/"), retry_adapter_instance)
|
||||||
|
|
||||||
adapter_class = LanceDBClientHTTPAdapterFactory()
|
adapter_class = LanceDBClientHTTPAdapterFactory()
|
||||||
sess.mount("https://", adapter_class())
|
sess.mount("https://", adapter_class())
|
||||||
return sess
|
return sess
|
||||||
@@ -109,7 +117,7 @@ class RestfulLanceDBClient:
|
|||||||
urljoin(self.url, uri),
|
urljoin(self.url, uri),
|
||||||
params=params,
|
params=params,
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
timeout=(10.0, 300.0),
|
timeout=(120.0, 300.0),
|
||||||
) as resp:
|
) as resp:
|
||||||
self._check_status(resp)
|
self._check_status(resp)
|
||||||
return resp.json()
|
return resp.json()
|
||||||
@@ -151,7 +159,7 @@ class RestfulLanceDBClient:
|
|||||||
urljoin(self.url, uri),
|
urljoin(self.url, uri),
|
||||||
headers=headers,
|
headers=headers,
|
||||||
params=params,
|
params=params,
|
||||||
timeout=(10.0, 300.0),
|
timeout=(120.0, 300.0),
|
||||||
**req_kwargs,
|
**req_kwargs,
|
||||||
) as resp:
|
) as resp:
|
||||||
self._check_status(resp)
|
self._check_status(resp)
|
||||||
@@ -170,3 +178,72 @@ class RestfulLanceDBClient:
|
|||||||
"""Query a table."""
|
"""Query a table."""
|
||||||
tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc)
|
tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc)
|
||||||
return VectorQueryResult(tbl)
|
return VectorQueryResult(tbl)
|
||||||
|
|
||||||
|
def mount_retry_adapter_for_table(self, table_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Adds an http adapter to session that will retry retryable requests to the table.
|
||||||
|
"""
|
||||||
|
retry_options = retry_adapter_options(methods=["GET", "POST"])
|
||||||
|
retry_adapter_instance = retry_adapter(retry_options)
|
||||||
|
session = self.session
|
||||||
|
|
||||||
|
session.mount(
|
||||||
|
urljoin(self.url, f"/v1/table/{table_name}/query/"), retry_adapter_instance
|
||||||
|
)
|
||||||
|
session.mount(
|
||||||
|
urljoin(self.url, f"/v1/table/{table_name}/describe/"),
|
||||||
|
retry_adapter_instance,
|
||||||
|
)
|
||||||
|
session.mount(
|
||||||
|
urljoin(self.url, f"/v1/table/{table_name}/index/list/"),
|
||||||
|
retry_adapter_instance,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def retry_adapter_options(methods=["GET"]) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"retries": int(os.environ.get("LANCE_CLIENT_MAX_RETRIES", "3")),
|
||||||
|
"connect_retries": int(os.environ.get("LANCE_CLIENT_CONNECT_RETRIES", "3")),
|
||||||
|
"read_retries": int(os.environ.get("LANCE_CLIENT_READ_RETRIES", "3")),
|
||||||
|
"backoff_factor": float(
|
||||||
|
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_FACTOR", "0.25")
|
||||||
|
),
|
||||||
|
"backoff_jitter": float(
|
||||||
|
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_JITTER", "0.25")
|
||||||
|
),
|
||||||
|
"statuses": [
|
||||||
|
int(i.strip())
|
||||||
|
for i in os.environ.get(
|
||||||
|
"LANCE_CLIENT_RETRY_STATUSES", "429, 500, 502, 503"
|
||||||
|
).split(",")
|
||||||
|
],
|
||||||
|
"methods": methods,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def retry_adapter(options: Dict[str, Any]) -> HTTPAdapter:
|
||||||
|
total_retries = options["retries"]
|
||||||
|
connect_retries = options["connect_retries"]
|
||||||
|
read_retries = options["read_retries"]
|
||||||
|
backoff_factor = options["backoff_factor"]
|
||||||
|
backoff_jitter = options["backoff_jitter"]
|
||||||
|
statuses = options["statuses"]
|
||||||
|
methods = frozenset(options["methods"])
|
||||||
|
logging.debug(
|
||||||
|
f"Setting up retry adapter with {total_retries} retries," # noqa G003
|
||||||
|
+ f"connect retries {connect_retries}, read retries {read_retries},"
|
||||||
|
+ f"backoff factor {backoff_factor}, statuses {statuses}, "
|
||||||
|
+ f"methods {methods}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return HTTPAdapter(
|
||||||
|
max_retries=Retry(
|
||||||
|
total=total_retries,
|
||||||
|
connect=connect_retries,
|
||||||
|
read=read_retries,
|
||||||
|
backoff_factor=backoff_factor,
|
||||||
|
backoff_jitter=backoff_jitter,
|
||||||
|
status_forcelist=statuses,
|
||||||
|
allowed_methods=methods,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
@@ -95,6 +95,8 @@ class RemoteDBConnection(DBConnection):
|
|||||||
"""
|
"""
|
||||||
from .table import RemoteTable
|
from .table import RemoteTable
|
||||||
|
|
||||||
|
self._client.mount_retry_adapter_for_table(name)
|
||||||
|
|
||||||
# check if table exists
|
# check if table exists
|
||||||
try:
|
try:
|
||||||
self._client.post(f"/v1/table/{name}/describe/")
|
self._client.post(f"/v1/table/{name}/describe/")
|
||||||
@@ -116,6 +118,7 @@ class RemoteDBConnection(DBConnection):
|
|||||||
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||||
on_bad_vectors: str = "error",
|
on_bad_vectors: str = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
|
mode: Optional[str] = None,
|
||||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||||
) -> Table:
|
) -> Table:
|
||||||
"""Create a [Table][lancedb.table.Table] in the database.
|
"""Create a [Table][lancedb.table.Table] in the database.
|
||||||
@@ -213,11 +216,13 @@ class RemoteDBConnection(DBConnection):
|
|||||||
if data is None and schema is None:
|
if data is None and schema is None:
|
||||||
raise ValueError("Either data or schema must be provided.")
|
raise ValueError("Either data or schema must be provided.")
|
||||||
if embedding_functions is not None:
|
if embedding_functions is not None:
|
||||||
raise NotImplementedError(
|
logging.warning(
|
||||||
"embedding_functions is not supported for remote databases."
|
"embedding_functions is not yet supported on LanceDB Cloud."
|
||||||
"Please vote https://github.com/lancedb/lancedb/issues/626 "
|
"Please vote https://github.com/lancedb/lancedb/issues/626 "
|
||||||
"for this feature."
|
"for this feature."
|
||||||
)
|
)
|
||||||
|
if mode is not None:
|
||||||
|
logging.warning("mode is not yet supported on LanceDB Cloud.")
|
||||||
|
|
||||||
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
||||||
# convert LanceModel to pyarrow schema
|
# convert LanceModel to pyarrow schema
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
@@ -19,6 +20,7 @@ import pyarrow as pa
|
|||||||
from lance import json_to_schema
|
from lance import json_to_schema
|
||||||
|
|
||||||
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||||
|
from lancedb.merge import LanceMergeInsertBuilder
|
||||||
|
|
||||||
from ..query import LanceVectorQueryBuilder
|
from ..query import LanceVectorQueryBuilder
|
||||||
from ..table import Query, Table, _sanitize_data
|
from ..table import Query, Table, _sanitize_data
|
||||||
@@ -36,6 +38,9 @@ class RemoteTable(Table):
|
|||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"RemoteTable({self._conn.db_name}.{self._name})"
|
return f"RemoteTable({self._conn.db_name}.{self._name})"
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
self.count_rows(None)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def schema(self) -> pa.Schema:
|
def schema(self) -> pa.Schema:
|
||||||
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
|
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
|
||||||
@@ -53,17 +58,17 @@ class RemoteTable(Table):
|
|||||||
return resp["version"]
|
return resp["version"]
|
||||||
|
|
||||||
def to_arrow(self) -> pa.Table:
|
def to_arrow(self) -> pa.Table:
|
||||||
"""to_arrow() is not supported on the LanceDB cloud"""
|
"""to_arrow() is not yet supported on LanceDB cloud."""
|
||||||
raise NotImplementedError("to_arrow() is not supported on the LanceDB cloud")
|
raise NotImplementedError("to_arrow() is not yet supported on LanceDB cloud.")
|
||||||
|
|
||||||
def to_pandas(self):
|
def to_pandas(self):
|
||||||
"""to_pandas() is not supported on the LanceDB cloud"""
|
"""to_pandas() is not yet supported on LanceDB cloud."""
|
||||||
return NotImplementedError("to_pandas() is not supported on the LanceDB cloud")
|
return NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.")
|
||||||
|
|
||||||
def create_scalar_index(self, *args, **kwargs):
|
def create_scalar_index(self, *args, **kwargs):
|
||||||
"""Creates a scalar index"""
|
"""Creates a scalar index"""
|
||||||
return NotImplementedError(
|
return NotImplementedError(
|
||||||
"create_scalar_index() is not supported on the LanceDB cloud"
|
"create_scalar_index() is not yet supported on LanceDB cloud."
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_index(
|
def create_index(
|
||||||
@@ -71,6 +76,10 @@ class RemoteTable(Table):
|
|||||||
metric="L2",
|
metric="L2",
|
||||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||||
index_cache_size: Optional[int] = None,
|
index_cache_size: Optional[int] = None,
|
||||||
|
num_partitions: Optional[int] = None,
|
||||||
|
num_sub_vectors: Optional[int] = None,
|
||||||
|
replace: Optional[bool] = None,
|
||||||
|
accelerator: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Create an index on the table.
|
"""Create an index on the table.
|
||||||
Currently, the only parameters that matter are
|
Currently, the only parameters that matter are
|
||||||
@@ -104,6 +113,28 @@ class RemoteTable(Table):
|
|||||||
... )
|
... )
|
||||||
>>> table.create_index("L2", "vector") # doctest: +SKIP
|
>>> table.create_index("L2", "vector") # doctest: +SKIP
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if num_partitions is not None:
|
||||||
|
logging.warning(
|
||||||
|
"num_partitions is not supported on LanceDB cloud."
|
||||||
|
"This parameter will be tuned automatically."
|
||||||
|
)
|
||||||
|
if num_sub_vectors is not None:
|
||||||
|
logging.warning(
|
||||||
|
"num_sub_vectors is not supported on LanceDB cloud."
|
||||||
|
"This parameter will be tuned automatically."
|
||||||
|
)
|
||||||
|
if accelerator is not None:
|
||||||
|
logging.warning(
|
||||||
|
"GPU accelerator is not yet supported on LanceDB cloud."
|
||||||
|
"If you have 100M+ vectors to index,"
|
||||||
|
"please contact us at contact@lancedb.com"
|
||||||
|
)
|
||||||
|
if replace is not None:
|
||||||
|
logging.warning(
|
||||||
|
"replace is not supported on LanceDB cloud."
|
||||||
|
"Existing indexes will always be replaced."
|
||||||
|
)
|
||||||
index_type = "vector"
|
index_type = "vector"
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
@@ -244,6 +275,51 @@ class RemoteTable(Table):
|
|||||||
result = self._conn._client.query(self._name, query)
|
result = self._conn._client.query(self._name, query)
|
||||||
return result.to_arrow()
|
return result.to_arrow()
|
||||||
|
|
||||||
|
def _do_merge(
|
||||||
|
self,
|
||||||
|
merge: LanceMergeInsertBuilder,
|
||||||
|
new_data: DATA,
|
||||||
|
on_bad_vectors: str,
|
||||||
|
fill_value: float,
|
||||||
|
):
|
||||||
|
data = _sanitize_data(
|
||||||
|
new_data,
|
||||||
|
self.schema,
|
||||||
|
metadata=None,
|
||||||
|
on_bad_vectors=on_bad_vectors,
|
||||||
|
fill_value=fill_value,
|
||||||
|
)
|
||||||
|
payload = to_ipc_binary(data)
|
||||||
|
|
||||||
|
params = {}
|
||||||
|
if len(merge._on) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"RemoteTable only supports a single on key in merge_insert"
|
||||||
|
)
|
||||||
|
params["on"] = merge._on[0]
|
||||||
|
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
|
||||||
|
if merge._when_matched_update_all_condition is not None:
|
||||||
|
params[
|
||||||
|
"when_matched_update_all_filt"
|
||||||
|
] = merge._when_matched_update_all_condition
|
||||||
|
params["when_not_matched_insert_all"] = str(
|
||||||
|
merge._when_not_matched_insert_all
|
||||||
|
).lower()
|
||||||
|
params["when_not_matched_by_source_delete"] = str(
|
||||||
|
merge._when_not_matched_by_source_delete
|
||||||
|
).lower()
|
||||||
|
if merge._when_not_matched_by_source_condition is not None:
|
||||||
|
params[
|
||||||
|
"when_not_matched_by_source_delete_filt"
|
||||||
|
] = merge._when_not_matched_by_source_condition
|
||||||
|
|
||||||
|
self._conn._client.post(
|
||||||
|
f"/v1/table/{self._name}/merge_insert/",
|
||||||
|
data=payload,
|
||||||
|
params=params,
|
||||||
|
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||||
|
)
|
||||||
|
|
||||||
def delete(self, predicate: str):
|
def delete(self, predicate: str):
|
||||||
"""Delete rows from the table.
|
"""Delete rows from the table.
|
||||||
|
|
||||||
@@ -355,6 +431,25 @@ class RemoteTable(Table):
|
|||||||
payload = {"predicate": where, "updates": updates}
|
payload = {"predicate": where, "updates": updates}
|
||||||
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
|
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
|
||||||
|
|
||||||
|
def cleanup_old_versions(self, *_):
|
||||||
|
"""cleanup_old_versions() is not supported on the LanceDB cloud"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"cleanup_old_versions() is not supported on the LanceDB cloud"
|
||||||
|
)
|
||||||
|
|
||||||
|
def compact_files(self, *_):
|
||||||
|
"""compact_files() is not supported on the LanceDB cloud"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"compact_files() is not supported on the LanceDB cloud"
|
||||||
|
)
|
||||||
|
|
||||||
|
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||||
|
# payload = {"filter": filter}
|
||||||
|
# self._conn._client.post(f"/v1/table/{self._name}/count_rows/", data=payload)
|
||||||
|
return NotImplementedError(
|
||||||
|
"count_rows() is not yet supported on the LanceDB cloud"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_index(tbl: pa.Table, i: int) -> pa.Table:
|
def add_index(tbl: pa.Table, i: int) -> pa.Table:
|
||||||
return tbl.add_column(
|
return tbl.add_column(
|
||||||
|
|||||||
15
python/lancedb/rerankers/__init__.py
Normal file
15
python/lancedb/rerankers/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from .base import Reranker
|
||||||
|
from .cohere import CohereReranker
|
||||||
|
from .colbert import ColbertReranker
|
||||||
|
from .cross_encoder import CrossEncoderReranker
|
||||||
|
from .linear_combination import LinearCombinationReranker
|
||||||
|
from .openai import OpenaiReranker
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Reranker",
|
||||||
|
"CrossEncoderReranker",
|
||||||
|
"CohereReranker",
|
||||||
|
"LinearCombinationReranker",
|
||||||
|
"OpenaiReranker",
|
||||||
|
"ColbertReranker",
|
||||||
|
]
|
||||||
75
python/lancedb/rerankers/base.py
Normal file
75
python/lancedb/rerankers/base.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
|
||||||
|
class Reranker(ABC):
|
||||||
|
def __init__(self, return_score: str = "relevance"):
|
||||||
|
"""
|
||||||
|
Interface for a reranker. A reranker is used to rerank the results from a
|
||||||
|
vector and FTS search. This is useful for combining the results from both
|
||||||
|
search methods.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
return_score : str, default "relevance"
|
||||||
|
opntions are "relevance" or "all"
|
||||||
|
The type of score to return. If "relevance", will return only the relevance
|
||||||
|
score. If "all", will return all scores from the vector and FTS search along
|
||||||
|
with the relevance score.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if return_score not in ["relevance", "all"]:
|
||||||
|
raise ValueError("score must be either 'relevance' or 'all'")
|
||||||
|
self.score = return_score
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def rerank_hybrid(
|
||||||
|
query: str,
|
||||||
|
vector_results: pa.Table,
|
||||||
|
fts_results: pa.Table,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Rerank function receives the individual results from the vector and FTS search
|
||||||
|
results. You can choose to use any of the results to generate the final results,
|
||||||
|
allowing maximum flexibility. This is mandatory to implement
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
query : str
|
||||||
|
The input query
|
||||||
|
vector_results : pa.Table
|
||||||
|
The results from the vector search
|
||||||
|
fts_results : pa.Table
|
||||||
|
The results from the FTS search
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def merge_results(self, vector_results: pa.Table, fts_results: pa.Table):
|
||||||
|
"""
|
||||||
|
Merge the results from the vector and FTS search. This is a vanilla merging
|
||||||
|
function that just concatenates the results and removes the duplicates.
|
||||||
|
|
||||||
|
NOTE: This doesn't take score into account. It'll keep the instance that was
|
||||||
|
encountered first. This is designed for rerankers that don't use the score.
|
||||||
|
In case you want to use the score, or support `return_scores="all"` you'll
|
||||||
|
have to implement your own merging function.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
vector_results : pa.Table
|
||||||
|
The results from the vector search
|
||||||
|
fts_results : pa.Table
|
||||||
|
The results from the FTS search
|
||||||
|
"""
|
||||||
|
combined = pa.concat_tables([vector_results, fts_results], promote=True)
|
||||||
|
row_id = combined.column("_rowid")
|
||||||
|
|
||||||
|
# deduplicate
|
||||||
|
mask = np.full((combined.shape[0]), False)
|
||||||
|
_, mask_indices = np.unique(np.array(row_id), return_index=True)
|
||||||
|
mask[mask_indices] = True
|
||||||
|
combined = combined.filter(mask=mask)
|
||||||
|
|
||||||
|
return combined
|
||||||
81
python/lancedb/rerankers/cohere.py
Normal file
81
python/lancedb/rerankers/cohere.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import os
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
from ..util import safe_import
|
||||||
|
from .base import Reranker
|
||||||
|
|
||||||
|
|
||||||
|
class CohereReranker(Reranker):
|
||||||
|
"""
|
||||||
|
Reranks the results using the Cohere Rerank API.
|
||||||
|
https://docs.cohere.com/docs/rerank-guide
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_name : str, default "rerank-english-v2.0"
|
||||||
|
The name of the cross encoder model to use. Available cohere models are:
|
||||||
|
- rerank-english-v2.0
|
||||||
|
- rerank-multilingual-v2.0
|
||||||
|
column : str, default "text"
|
||||||
|
The name of the column to use as input to the cross encoder model.
|
||||||
|
top_n : str, default None
|
||||||
|
The number of results to return. If None, will return all results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "rerank-english-v2.0",
|
||||||
|
column: str = "text",
|
||||||
|
top_n: Union[int, None] = None,
|
||||||
|
return_score="relevance",
|
||||||
|
api_key: Union[str, None] = None,
|
||||||
|
):
|
||||||
|
super().__init__(return_score)
|
||||||
|
self.model_name = model_name
|
||||||
|
self.column = column
|
||||||
|
self.top_n = top_n
|
||||||
|
self.api_key = api_key
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _client(self):
|
||||||
|
cohere = safe_import("cohere")
|
||||||
|
if os.environ.get("COHERE_API_KEY") is None and self.api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"COHERE_API_KEY not set. Either set it in your environment or \
|
||||||
|
pass it as `api_key` argument to the CohereReranker."
|
||||||
|
)
|
||||||
|
return cohere.Client(os.environ.get("COHERE_API_KEY") or self.api_key)
|
||||||
|
|
||||||
|
def rerank_hybrid(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
vector_results: pa.Table,
|
||||||
|
fts_results: pa.Table,
|
||||||
|
):
|
||||||
|
combined_results = self.merge_results(vector_results, fts_results)
|
||||||
|
docs = combined_results[self.column].to_pylist()
|
||||||
|
results = self._client.rerank(
|
||||||
|
query=query,
|
||||||
|
documents=docs,
|
||||||
|
top_n=self.top_n,
|
||||||
|
model=self.model_name,
|
||||||
|
) # returns list (text, idx, relevance) attributes sorted descending by score
|
||||||
|
indices, scores = list(
|
||||||
|
zip(*[(result.index, result.relevance_score) for result in results])
|
||||||
|
) # tuples
|
||||||
|
combined_results = combined_results.take(list(indices))
|
||||||
|
# add the scores
|
||||||
|
combined_results = combined_results.append_column(
|
||||||
|
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.score == "relevance":
|
||||||
|
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||||
|
elif self.score == "all":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"return_score='all' not implemented for cohere reranker"
|
||||||
|
)
|
||||||
|
return combined_results
|
||||||
107
python/lancedb/rerankers/colbert.py
Normal file
107
python/lancedb/rerankers/colbert.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
from functools import cached_property
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
from ..util import safe_import
|
||||||
|
from .base import Reranker
|
||||||
|
|
||||||
|
|
||||||
|
class ColbertReranker(Reranker):
|
||||||
|
"""
|
||||||
|
Reranks the results using the ColBERT model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_name : str, default "colbert-ir/colbertv2.0"
|
||||||
|
The name of the cross encoder model to use.
|
||||||
|
column : str, default "text"
|
||||||
|
The name of the column to use as input to the cross encoder model.
|
||||||
|
return_score : str, default "relevance"
|
||||||
|
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "colbert-ir/colbertv2.0",
|
||||||
|
column: str = "text",
|
||||||
|
return_score="relevance",
|
||||||
|
):
|
||||||
|
super().__init__(return_score)
|
||||||
|
self.model_name = model_name
|
||||||
|
self.column = column
|
||||||
|
self.torch = safe_import("torch") # import here for faster ops later
|
||||||
|
|
||||||
|
def rerank_hybrid(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
vector_results: pa.Table,
|
||||||
|
fts_results: pa.Table,
|
||||||
|
):
|
||||||
|
combined_results = self.merge_results(vector_results, fts_results)
|
||||||
|
docs = combined_results[self.column].to_pylist()
|
||||||
|
|
||||||
|
tokenizer, model = self._model
|
||||||
|
|
||||||
|
# Encode the query
|
||||||
|
query_encoding = tokenizer(query, return_tensors="pt")
|
||||||
|
query_embedding = model(**query_encoding).last_hidden_state.mean(dim=1)
|
||||||
|
scores = []
|
||||||
|
# Get score for each document
|
||||||
|
for document in docs:
|
||||||
|
document_encoding = tokenizer(
|
||||||
|
document, return_tensors="pt", truncation=True, max_length=512
|
||||||
|
)
|
||||||
|
document_embedding = model(**document_encoding).last_hidden_state
|
||||||
|
# Calculate MaxSim score
|
||||||
|
score = self.maxsim(query_embedding.unsqueeze(0), document_embedding)
|
||||||
|
scores.append(score.item())
|
||||||
|
|
||||||
|
# replace the self.column column with the docs
|
||||||
|
combined_results = combined_results.drop(self.column)
|
||||||
|
combined_results = combined_results.append_column(
|
||||||
|
self.column, pa.array(docs, type=pa.string())
|
||||||
|
)
|
||||||
|
# add the scores
|
||||||
|
combined_results = combined_results.append_column(
|
||||||
|
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||||
|
)
|
||||||
|
if self.score == "relevance":
|
||||||
|
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||||
|
elif self.score == "all":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"OpenAI Reranker does not support score='all' yet"
|
||||||
|
)
|
||||||
|
|
||||||
|
combined_results = combined_results.sort_by(
|
||||||
|
[("_relevance_score", "descending")]
|
||||||
|
)
|
||||||
|
|
||||||
|
return combined_results
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _model(self):
|
||||||
|
transformers = safe_import("transformers")
|
||||||
|
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
|
||||||
|
model = transformers.AutoModel.from_pretrained(self.model_name)
|
||||||
|
|
||||||
|
return tokenizer, model
|
||||||
|
|
||||||
|
def maxsim(self, query_embedding, document_embedding):
|
||||||
|
# Expand dimensions for broadcasting
|
||||||
|
# Query: [batch, length, size] -> [batch, query, 1, size]
|
||||||
|
# Document: [batch, length, size] -> [batch, 1, length, size]
|
||||||
|
expanded_query = query_embedding.unsqueeze(2)
|
||||||
|
expanded_doc = document_embedding.unsqueeze(1)
|
||||||
|
|
||||||
|
# Compute cosine similarity across the embedding dimension
|
||||||
|
sim_matrix = self.torch.nn.functional.cosine_similarity(
|
||||||
|
expanded_query, expanded_doc, dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Take the maximum similarity for each query token (across all document tokens)
|
||||||
|
# sim_matrix shape: [batch_size, query_length, doc_length]
|
||||||
|
max_sim_scores, _ = self.torch.max(sim_matrix, dim=2)
|
||||||
|
|
||||||
|
# Average these maximum scores across all query tokens
|
||||||
|
avg_max_sim = self.torch.mean(max_sim_scores, dim=1)
|
||||||
|
return avg_max_sim
|
||||||
74
python/lancedb/rerankers/cross_encoder.py
Normal file
74
python/lancedb/rerankers/cross_encoder.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
from functools import cached_property
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
from ..util import safe_import
|
||||||
|
from .base import Reranker
|
||||||
|
|
||||||
|
|
||||||
|
class CrossEncoderReranker(Reranker):
|
||||||
|
"""
|
||||||
|
Reranks the results using a cross encoder model. The cross encoder model is
|
||||||
|
used to score the query and each result. The results are then sorted by the score.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model : str, default "cross-encoder/ms-marco-TinyBERT-L-6"
|
||||||
|
The name of the cross encoder model to use. See the sentence transformers
|
||||||
|
documentation for a list of available models.
|
||||||
|
column : str, default "text"
|
||||||
|
The name of the column to use as input to the cross encoder model.
|
||||||
|
device : str, default None
|
||||||
|
The device to use for the cross encoder model. If None, will use "cuda"
|
||||||
|
if available, otherwise "cpu".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "cross-encoder/ms-marco-TinyBERT-L-6",
|
||||||
|
column: str = "text",
|
||||||
|
device: Union[str, None] = None,
|
||||||
|
return_score="relevance",
|
||||||
|
):
|
||||||
|
super().__init__(return_score)
|
||||||
|
torch = safe_import("torch")
|
||||||
|
self.model_name = model_name
|
||||||
|
self.column = column
|
||||||
|
self.device = device
|
||||||
|
if self.device is None:
|
||||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def model(self):
|
||||||
|
sbert = safe_import("sentence_transformers")
|
||||||
|
cross_encoder = sbert.CrossEncoder(self.model_name)
|
||||||
|
|
||||||
|
return cross_encoder
|
||||||
|
|
||||||
|
def rerank_hybrid(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
vector_results: pa.Table,
|
||||||
|
fts_results: pa.Table,
|
||||||
|
):
|
||||||
|
combined_results = self.merge_results(vector_results, fts_results)
|
||||||
|
passages = combined_results[self.column].to_pylist()
|
||||||
|
cross_inp = [[query, passage] for passage in passages]
|
||||||
|
cross_scores = self.model.predict(cross_inp)
|
||||||
|
combined_results = combined_results.append_column(
|
||||||
|
"_relevance_score", pa.array(cross_scores, type=pa.float32())
|
||||||
|
)
|
||||||
|
|
||||||
|
# sort the results by _score
|
||||||
|
if self.score == "relevance":
|
||||||
|
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||||
|
elif self.score == "all":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"return_score='all' not implemented for CrossEncoderReranker"
|
||||||
|
)
|
||||||
|
combined_results = combined_results.sort_by(
|
||||||
|
[("_relevance_score", "descending")]
|
||||||
|
)
|
||||||
|
|
||||||
|
return combined_results
|
||||||
117
python/lancedb/rerankers/linear_combination.py
Normal file
117
python/lancedb/rerankers/linear_combination.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
from .base import Reranker
|
||||||
|
|
||||||
|
|
||||||
|
class LinearCombinationReranker(Reranker):
|
||||||
|
"""
|
||||||
|
Reranks the results using a linear combination of the scores from the
|
||||||
|
vector and FTS search. For missing scores, fill with `fill` value.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
weight : float, default 0.7
|
||||||
|
The weight to give to the vector score. Must be between 0 and 1.
|
||||||
|
fill : float, default 1.0
|
||||||
|
The score to give to results that are only in one of the two result sets.
|
||||||
|
This is treated as penalty, so a higher value means a lower score.
|
||||||
|
TODO: We should just hardcode this--
|
||||||
|
its pretty confusing as we invert scores to calculate final score
|
||||||
|
return_score : str, default "relevance"
|
||||||
|
opntions are "relevance" or "all"
|
||||||
|
The type of score to return. If "relevance", will return only the relevance
|
||||||
|
score. If "all", will return all scores from the vector and FTS search along
|
||||||
|
with the relevance score.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, weight: float = 0.7, fill: float = 1.0, return_score="relevance"
|
||||||
|
):
|
||||||
|
if weight < 0 or weight > 1:
|
||||||
|
raise ValueError("weight must be between 0 and 1.")
|
||||||
|
super().__init__(return_score)
|
||||||
|
self.weight = weight
|
||||||
|
self.fill = fill
|
||||||
|
|
||||||
|
def rerank_hybrid(
|
||||||
|
self,
|
||||||
|
query: str, # noqa: F821
|
||||||
|
vector_results: pa.Table,
|
||||||
|
fts_results: pa.Table,
|
||||||
|
):
|
||||||
|
combined_results = self.merge_results(vector_results, fts_results, self.fill)
|
||||||
|
|
||||||
|
return combined_results
|
||||||
|
|
||||||
|
def merge_results(
|
||||||
|
self, vector_results: pa.Table, fts_results: pa.Table, fill: float
|
||||||
|
):
|
||||||
|
# If both are empty then just return an empty table
|
||||||
|
if len(vector_results) == 0 and len(fts_results) == 0:
|
||||||
|
return vector_results
|
||||||
|
# If one is empty then return the other
|
||||||
|
if len(vector_results) == 0:
|
||||||
|
return fts_results
|
||||||
|
if len(fts_results) == 0:
|
||||||
|
return vector_results
|
||||||
|
|
||||||
|
# sort both input tables on _rowid
|
||||||
|
combined_list = []
|
||||||
|
vector_list = vector_results.sort_by("_rowid").to_pylist()
|
||||||
|
fts_list = fts_results.sort_by("_rowid").to_pylist()
|
||||||
|
i, j = 0, 0
|
||||||
|
while i < len(vector_list):
|
||||||
|
if j >= len(fts_list):
|
||||||
|
for vi in vector_list[i:]:
|
||||||
|
vi["_relevance_score"] = self._combine_score(vi["_distance"], fill)
|
||||||
|
combined_list.append(vi)
|
||||||
|
break
|
||||||
|
|
||||||
|
vi = vector_list[i]
|
||||||
|
fj = fts_list[j]
|
||||||
|
# invert the fts score from relevance to distance
|
||||||
|
inverted_fts_score = self._invert_score(fj["score"])
|
||||||
|
if vi["_rowid"] == fj["_rowid"]:
|
||||||
|
vi["_relevance_score"] = self._combine_score(
|
||||||
|
vi["_distance"], inverted_fts_score
|
||||||
|
)
|
||||||
|
vi["score"] = fj["score"] # keep the original score
|
||||||
|
combined_list.append(vi)
|
||||||
|
i += 1
|
||||||
|
j += 1
|
||||||
|
elif vector_list[i]["_rowid"] < fts_list[j]["_rowid"]:
|
||||||
|
vi["_relevance_score"] = self._combine_score(vi["_distance"], fill)
|
||||||
|
combined_list.append(vi)
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
fj["_relevance_score"] = self._combine_score(inverted_fts_score, fill)
|
||||||
|
combined_list.append(fj)
|
||||||
|
j += 1
|
||||||
|
if j < len(fts_list) - 1:
|
||||||
|
for fj in fts_list[j:]:
|
||||||
|
fj["_relevance_score"] = self._combine_score(inverted_fts_score, fill)
|
||||||
|
combined_list.append(fj)
|
||||||
|
|
||||||
|
relevance_score_schema = pa.schema(
|
||||||
|
[
|
||||||
|
pa.field("_relevance_score", pa.float32()),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
combined_schema = pa.unify_schemas(
|
||||||
|
[vector_results.schema, fts_results.schema, relevance_score_schema]
|
||||||
|
)
|
||||||
|
tbl = pa.Table.from_pylist(combined_list, schema=combined_schema).sort_by(
|
||||||
|
[("_relevance_score", "descending")]
|
||||||
|
)
|
||||||
|
if self.score == "relevance":
|
||||||
|
tbl = tbl.drop_columns(["score", "_distance"])
|
||||||
|
return tbl
|
||||||
|
|
||||||
|
def _combine_score(self, score1, score2):
|
||||||
|
# these scores represent distance
|
||||||
|
return 1 - (self.weight * score1 + (1 - self.weight) * score2)
|
||||||
|
|
||||||
|
def _invert_score(self, scores: List[float]):
|
||||||
|
# Invert the scores between relevance and distance
|
||||||
|
return 1 - scores
|
||||||
102
python/lancedb/rerankers/openai.py
Normal file
102
python/lancedb/rerankers/openai.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
from ..util import safe_import
|
||||||
|
from .base import Reranker
|
||||||
|
|
||||||
|
|
||||||
|
class OpenaiReranker(Reranker):
|
||||||
|
"""
|
||||||
|
Reranks the results using the OpenAI API.
|
||||||
|
WARNING: This is a prompt based reranker that uses chat model that is
|
||||||
|
not a dedicated reranker API. This should be treated as experimental.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_name : str, default "gpt-3.5-turbo-1106 "
|
||||||
|
The name of the cross encoder model to use.
|
||||||
|
column : str, default "text"
|
||||||
|
The name of the column to use as input to the cross encoder model.
|
||||||
|
return_score : str, default "relevance"
|
||||||
|
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||||
|
api_key : str, default None
|
||||||
|
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "gpt-3.5-turbo-1106",
|
||||||
|
column: str = "text",
|
||||||
|
return_score="relevance",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(return_score)
|
||||||
|
self.model_name = model_name
|
||||||
|
self.column = column
|
||||||
|
self.api_key = api_key
|
||||||
|
|
||||||
|
def rerank_hybrid(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
vector_results: pa.Table,
|
||||||
|
fts_results: pa.Table,
|
||||||
|
):
|
||||||
|
combined_results = self.merge_results(vector_results, fts_results)
|
||||||
|
docs = combined_results[self.column].to_pylist()
|
||||||
|
response = self._client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
temperature=0,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are an expert relevance ranker. Given a list of\
|
||||||
|
documents and a query, your job is to determine the relevance\
|
||||||
|
each document is for answering the query. Your output is JSON,\
|
||||||
|
which is a list of documents. Each document has two fields,\
|
||||||
|
content and relevance_score. relevance_score is from 0.0 to\
|
||||||
|
1.0 indicating the relevance of the text to the given query.\
|
||||||
|
Make sure to include all documents in the response.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": f"Query: {query} Docs: {docs}"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
results = json.loads(response.choices[0].message.content)["documents"]
|
||||||
|
docs, scores = list(
|
||||||
|
zip(*[(result["content"], result["relevance_score"]) for result in results])
|
||||||
|
) # tuples
|
||||||
|
# replace the self.column column with the docs
|
||||||
|
combined_results = combined_results.drop(self.column)
|
||||||
|
combined_results = combined_results.append_column(
|
||||||
|
self.column, pa.array(docs, type=pa.string())
|
||||||
|
)
|
||||||
|
# add the scores
|
||||||
|
combined_results = combined_results.append_column(
|
||||||
|
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||||
|
)
|
||||||
|
if self.score == "relevance":
|
||||||
|
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||||
|
elif self.score == "all":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"OpenAI Reranker does not support score='all' yet"
|
||||||
|
)
|
||||||
|
|
||||||
|
combined_results = combined_results.sort_by(
|
||||||
|
[("_relevance_score", "descending")]
|
||||||
|
)
|
||||||
|
|
||||||
|
return combined_results
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _client(self):
|
||||||
|
openai = safe_import("openai") # TODO: force version or handle versions < 1.0
|
||||||
|
if os.environ.get("OPENAI_API_KEY") is None and self.api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"OPENAI_API_KEY not set. Either set it in your environment or \
|
||||||
|
pass it as `api_key` argument to the CohereReranker."
|
||||||
|
)
|
||||||
|
return openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY") or self.api_key)
|
||||||
@@ -14,9 +14,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import timedelta
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import lance
|
import lance
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -28,6 +31,7 @@ from lance.vector import vec_to_table
|
|||||||
|
|
||||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||||
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||||
|
from .merge import LanceMergeInsertBuilder
|
||||||
from .pydantic import LanceModel, model_to_dict
|
from .pydantic import LanceModel, model_to_dict
|
||||||
from .query import LanceQueryBuilder, Query
|
from .query import LanceQueryBuilder, Query
|
||||||
from .util import (
|
from .util import (
|
||||||
@@ -40,8 +44,6 @@ from .util import (
|
|||||||
from .utils.events import register_event
|
from .utils.events import register_event
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datetime import timedelta
|
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
from lance.dataset import CleanupStats, ReaderLike
|
from lance.dataset import CleanupStats, ReaderLike
|
||||||
|
|
||||||
@@ -175,6 +177,18 @@ class Table(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||||
|
"""
|
||||||
|
Count the number of rows in the table.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filter: str, optional
|
||||||
|
A SQL where clause to filter the rows to count.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def to_pandas(self) -> "pd.DataFrame":
|
def to_pandas(self) -> "pd.DataFrame":
|
||||||
"""Return the table as a pandas DataFrame.
|
"""Return the table as a pandas DataFrame.
|
||||||
|
|
||||||
@@ -298,7 +312,7 @@ class Table(ABC):
|
|||||||
|
|
||||||
import lance
|
import lance
|
||||||
|
|
||||||
dataset = lance.dataset("/tmp/images.lance")
|
dataset = lance.dataset("./images.lance")
|
||||||
dataset.create_scalar_index("category")
|
dataset.create_scalar_index("category")
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -335,10 +349,70 @@ class Table(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||||
|
"""
|
||||||
|
Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder]
|
||||||
|
that can be used to create a "merge insert" operation
|
||||||
|
|
||||||
|
This operation can add rows, update rows, and remove rows all in a single
|
||||||
|
transaction. It is a very generic tool that can be used to create
|
||||||
|
behaviors like "insert if not exists", "update or insert (i.e. upsert)",
|
||||||
|
or even replace a portion of existing data with new data (e.g. replace
|
||||||
|
all data where month="january")
|
||||||
|
|
||||||
|
The merge insert operation works by combining new data from a
|
||||||
|
**source table** with existing data in a **target table** by using a
|
||||||
|
join. There are three categories of records.
|
||||||
|
|
||||||
|
"Matched" records are records that exist in both the source table and
|
||||||
|
the target table. "Not matched" records exist only in the source table
|
||||||
|
(e.g. these are new data) "Not matched by source" records exist only
|
||||||
|
in the target table (this is old data)
|
||||||
|
|
||||||
|
The builder returned by this method can be used to customize what
|
||||||
|
should happen for each category of data.
|
||||||
|
|
||||||
|
Please note that the data may appear to be reordered as part of this
|
||||||
|
operation. This is because updated rows will be deleted from the
|
||||||
|
dataset and then reinserted at the end with the new values.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
|
||||||
|
on: Union[str, Iterable[str]]
|
||||||
|
A column (or columns) to join on. This is how records from the
|
||||||
|
source table and target table are matched. Typically this is some
|
||||||
|
kind of key or id column.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> import lancedb
|
||||||
|
>>> data = pa.table({"a": [2, 1, 3], "b": ["a", "b", "c"]})
|
||||||
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
|
>>> table = db.create_table("my_table", data)
|
||||||
|
>>> new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]})
|
||||||
|
>>> # Perform a "upsert" operation
|
||||||
|
>>> table.merge_insert("a") \\
|
||||||
|
... .when_matched_update_all() \\
|
||||||
|
... .when_not_matched_insert_all() \\
|
||||||
|
... .execute(new_data)
|
||||||
|
>>> # The order of new rows is non-deterministic since we use
|
||||||
|
>>> # a hash-join as part of this operation and so we sort here
|
||||||
|
>>> table.to_arrow().sort_by("a").to_pandas()
|
||||||
|
a b
|
||||||
|
0 1 b
|
||||||
|
1 2 x
|
||||||
|
2 3 y
|
||||||
|
3 4 z
|
||||||
|
"""
|
||||||
|
on = [on] if isinstance(on, str) else list(on.iter())
|
||||||
|
|
||||||
|
return LanceMergeInsertBuilder(self, on)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None,
|
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||||
query_type: str = "auto",
|
query_type: str = "auto",
|
||||||
) -> LanceQueryBuilder:
|
) -> LanceQueryBuilder:
|
||||||
@@ -380,10 +454,12 @@ class Table(ABC):
|
|||||||
the table
|
the table
|
||||||
vector_column_name: str
|
vector_column_name: str
|
||||||
The name of the vector column to search.
|
The name of the vector column to search.
|
||||||
|
|
||||||
|
The vector column needs to be a pyarrow fixed size list type
|
||||||
*default "vector"*
|
*default "vector"*
|
||||||
query_type: str
|
query_type: str
|
||||||
*default "auto"*.
|
*default "auto"*.
|
||||||
Acceptable types are: "vector", "fts", or "auto"
|
Acceptable types are: "vector", "fts", "hybrid", or "auto"
|
||||||
|
|
||||||
- If "auto" then the query type is inferred from the query;
|
- If "auto" then the query type is inferred from the query;
|
||||||
|
|
||||||
@@ -415,6 +491,16 @@ class Table(ABC):
|
|||||||
def _execute_query(self, query: Query) -> pa.Table:
|
def _execute_query(self, query: Query) -> pa.Table:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _do_merge(
|
||||||
|
self,
|
||||||
|
merge: LanceMergeInsertBuilder,
|
||||||
|
new_data: DATA,
|
||||||
|
on_bad_vectors: str,
|
||||||
|
fill_value: float,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete(self, where: str):
|
def delete(self, where: str):
|
||||||
"""Delete rows from the table.
|
"""Delete rows from the table.
|
||||||
@@ -522,24 +608,192 @@ class Table(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def cleanup_old_versions(
|
||||||
|
self,
|
||||||
|
older_than: Optional[timedelta] = None,
|
||||||
|
*,
|
||||||
|
delete_unverified: bool = False,
|
||||||
|
) -> CleanupStats:
|
||||||
|
"""
|
||||||
|
Clean up old versions of the table, freeing disk space.
|
||||||
|
|
||||||
|
Note: This function is not available in LanceDb Cloud (since LanceDb
|
||||||
|
Cloud manages cleanup for you automatically)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
older_than: timedelta, default None
|
||||||
|
The minimum age of the version to delete. If None, then this defaults
|
||||||
|
to two weeks.
|
||||||
|
delete_unverified: bool, default False
|
||||||
|
Because they may be part of an in-progress transaction, files newer
|
||||||
|
than 7 days old are not deleted by default. If you are sure that
|
||||||
|
there are no in-progress transactions, then you can set this to True
|
||||||
|
to delete all files older than `older_than`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
CleanupStats
|
||||||
|
The stats of the cleanup operation, including how many bytes were
|
||||||
|
freed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compact_files(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Run the compaction process on the table.
|
||||||
|
|
||||||
|
Note: This function is not available in LanceDb Cloud (since LanceDb
|
||||||
|
Cloud manages compaction for you automatically)
|
||||||
|
|
||||||
|
This can be run after making several small appends to optimize the table
|
||||||
|
for faster reads.
|
||||||
|
|
||||||
|
Arguments are passed onto :meth:`lance.dataset.DatasetOptimizer.compact_files`.
|
||||||
|
For most cases, the default should be fine.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class _LanceDatasetRef(ABC):
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def dataset(self) -> LanceDataset:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def dataset_mut(self) -> LanceDataset:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _LanceLatestDatasetRef(_LanceDatasetRef):
|
||||||
|
"""Reference to the latest version of a LanceDataset."""
|
||||||
|
|
||||||
|
uri: str
|
||||||
|
read_consistency_interval: Optional[timedelta] = None
|
||||||
|
last_consistency_check: Optional[float] = None
|
||||||
|
_dataset: Optional[LanceDataset] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dataset(self) -> LanceDataset:
|
||||||
|
if not self._dataset:
|
||||||
|
self._dataset = lance.dataset(self.uri)
|
||||||
|
self.last_consistency_check = time.monotonic()
|
||||||
|
elif self.read_consistency_interval is not None:
|
||||||
|
now = time.monotonic()
|
||||||
|
diff = timedelta(seconds=now - self.last_consistency_check)
|
||||||
|
if (
|
||||||
|
self.last_consistency_check is None
|
||||||
|
or diff > self.read_consistency_interval
|
||||||
|
):
|
||||||
|
self._dataset = self._dataset.checkout_version(
|
||||||
|
self._dataset.latest_version
|
||||||
|
)
|
||||||
|
self.last_consistency_check = time.monotonic()
|
||||||
|
return self._dataset
|
||||||
|
|
||||||
|
@dataset.setter
|
||||||
|
def dataset(self, value: LanceDataset):
|
||||||
|
self._dataset = value
|
||||||
|
self.last_consistency_check = time.monotonic()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dataset_mut(self) -> LanceDataset:
|
||||||
|
return self.dataset
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _LanceTimeTravelRef(_LanceDatasetRef):
|
||||||
|
uri: str
|
||||||
|
version: int
|
||||||
|
_dataset: Optional[LanceDataset] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dataset(self) -> LanceDataset:
|
||||||
|
if not self._dataset:
|
||||||
|
self._dataset = lance.dataset(self.uri, version=self.version)
|
||||||
|
return self._dataset
|
||||||
|
|
||||||
|
@dataset.setter
|
||||||
|
def dataset(self, value: LanceDataset):
|
||||||
|
self._dataset = value
|
||||||
|
self.version = value.version
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dataset_mut(self) -> LanceDataset:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot mutate table reference fixed at version "
|
||||||
|
f"{self.version}. Call checkout_latest() to get a mutable "
|
||||||
|
"table reference."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LanceTable(Table):
|
class LanceTable(Table):
|
||||||
"""
|
"""
|
||||||
A table in a LanceDB database.
|
A table in a LanceDB database.
|
||||||
|
|
||||||
|
This can be opened in two modes: standard and time-travel.
|
||||||
|
|
||||||
|
Standard mode is the default. In this mode, the table is mutable and tracks
|
||||||
|
the latest version of the table. The level of read consistency is controlled
|
||||||
|
by the `read_consistency_interval` parameter on the connection.
|
||||||
|
|
||||||
|
Time-travel mode is activated by specifying a version number. In this mode,
|
||||||
|
the table is immutable and fixed to a specific version. This is useful for
|
||||||
|
querying historical versions of the table.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, connection: "LanceDBConnection", name: str, version: int = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
connection: "LanceDBConnection",
|
||||||
|
name: str,
|
||||||
|
version: Optional[int] = None,
|
||||||
|
):
|
||||||
self._conn = connection
|
self._conn = connection
|
||||||
self.name = name
|
self.name = name
|
||||||
self._version = version
|
|
||||||
|
|
||||||
def _reset_dataset(self, version=None):
|
if version is not None:
|
||||||
try:
|
self._ref = _LanceTimeTravelRef(
|
||||||
if "_dataset" in self.__dict__:
|
uri=self._dataset_uri,
|
||||||
del self.__dict__["_dataset"]
|
version=version,
|
||||||
self._version = version
|
)
|
||||||
except AttributeError:
|
else:
|
||||||
pass
|
self._ref = _LanceLatestDatasetRef(
|
||||||
|
uri=self._dataset_uri,
|
||||||
|
read_consistency_interval=connection.read_consistency_interval,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def open(cls, db, name, **kwargs):
|
||||||
|
tbl = cls(db, name, **kwargs)
|
||||||
|
fs, path = fs_from_uri(tbl._dataset_uri)
|
||||||
|
file_info = fs.get_file_info(path)
|
||||||
|
if file_info.type != pa.fs.FileType.Directory:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Table {name} does not exist."
|
||||||
|
f"Please first call db.create_table({name}, data)"
|
||||||
|
)
|
||||||
|
register_event("open_table")
|
||||||
|
|
||||||
|
return tbl
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _dataset_uri(self) -> str:
|
||||||
|
return join_uri(self._conn.uri, f"{self.name}.lance")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _dataset(self) -> LanceDataset:
|
||||||
|
return self._ref.dataset
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _dataset_mut(self) -> LanceDataset:
|
||||||
|
return self._ref.dataset_mut
|
||||||
|
|
||||||
|
def to_lance(self) -> LanceDataset:
|
||||||
|
"""Return the LanceDataset backing this table."""
|
||||||
|
return self._dataset
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def schema(self) -> pa.Schema:
|
def schema(self) -> pa.Schema:
|
||||||
@@ -567,6 +821,9 @@ class LanceTable(Table):
|
|||||||
keep writing to the dataset starting from an old version, then use
|
keep writing to the dataset starting from an old version, then use
|
||||||
the `restore` function.
|
the `restore` function.
|
||||||
|
|
||||||
|
Calling this method will set the table into time-travel mode. If you
|
||||||
|
wish to return to standard mode, call `checkout_latest`.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
version : int
|
version : int
|
||||||
@@ -591,15 +848,13 @@ class LanceTable(Table):
|
|||||||
vector type
|
vector type
|
||||||
0 [1.1, 0.9] vector
|
0 [1.1, 0.9] vector
|
||||||
"""
|
"""
|
||||||
max_ver = max([v["version"] for v in self._dataset.versions()])
|
max_ver = self._dataset.latest_version
|
||||||
if version < 1 or version > max_ver:
|
if version < 1 or version > max_ver:
|
||||||
raise ValueError(f"Invalid version {version}")
|
raise ValueError(f"Invalid version {version}")
|
||||||
self._reset_dataset(version=version)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Accessing the property updates the cached value
|
ds = self._dataset.checkout_version(version)
|
||||||
_ = self._dataset
|
except IOError as e:
|
||||||
except Exception as e:
|
|
||||||
if "not found" in str(e):
|
if "not found" in str(e):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Version {version} no longer exists. Was it cleaned up?"
|
f"Version {version} no longer exists. Was it cleaned up?"
|
||||||
@@ -607,6 +862,27 @@ class LanceTable(Table):
|
|||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
self._ref = _LanceTimeTravelRef(
|
||||||
|
uri=self._dataset_uri,
|
||||||
|
version=version,
|
||||||
|
)
|
||||||
|
# We've already loaded the version so we can populate it directly.
|
||||||
|
self._ref.dataset = ds
|
||||||
|
|
||||||
|
def checkout_latest(self):
|
||||||
|
"""Checkout the latest version of the table. This is an in-place operation.
|
||||||
|
|
||||||
|
The table will be set back into standard mode, and will track the latest
|
||||||
|
version of the table.
|
||||||
|
"""
|
||||||
|
self.checkout(self._dataset.latest_version)
|
||||||
|
ds = self._ref.dataset
|
||||||
|
self._ref = _LanceLatestDatasetRef(
|
||||||
|
uri=self._dataset_uri,
|
||||||
|
read_consistency_interval=self._conn.read_consistency_interval,
|
||||||
|
)
|
||||||
|
self._ref.dataset = ds
|
||||||
|
|
||||||
def restore(self, version: int = None):
|
def restore(self, version: int = None):
|
||||||
"""Restore a version of the table. This is an in-place operation.
|
"""Restore a version of the table. This is an in-place operation.
|
||||||
|
|
||||||
@@ -641,7 +917,7 @@ class LanceTable(Table):
|
|||||||
>>> len(table.list_versions())
|
>>> len(table.list_versions())
|
||||||
4
|
4
|
||||||
"""
|
"""
|
||||||
max_ver = max([v["version"] for v in self._dataset.versions()])
|
max_ver = self._dataset.latest_version
|
||||||
if version is None:
|
if version is None:
|
||||||
version = self.version
|
version = self.version
|
||||||
elif version < 1 or version > max_ver:
|
elif version < 1 or version > max_ver:
|
||||||
@@ -649,29 +925,30 @@ class LanceTable(Table):
|
|||||||
else:
|
else:
|
||||||
self.checkout(version)
|
self.checkout(version)
|
||||||
|
|
||||||
if version == max_ver:
|
ds = self._dataset
|
||||||
# no-op if restoring the latest version
|
|
||||||
return
|
|
||||||
|
|
||||||
self._dataset.restore()
|
# no-op if restoring the latest version
|
||||||
self._reset_dataset()
|
if version != max_ver:
|
||||||
|
ds.restore()
|
||||||
|
|
||||||
|
self._ref = _LanceLatestDatasetRef(
|
||||||
|
uri=self._dataset_uri,
|
||||||
|
read_consistency_interval=self._conn.read_consistency_interval,
|
||||||
|
)
|
||||||
|
self._ref.dataset = ds
|
||||||
|
|
||||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||||
"""
|
|
||||||
Count the number of rows in the table.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
filter: str, optional
|
|
||||||
A SQL where clause to filter the rows to count.
|
|
||||||
"""
|
|
||||||
return self._dataset.count_rows(filter)
|
return self._dataset.count_rows(filter)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.count_rows()
|
return self.count_rows()
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"LanceTable({self.name})"
|
val = f'{self.__class__.__name__}(connection={self._conn!r}, name="{self.name}"'
|
||||||
|
if isinstance(self._ref, _LanceTimeTravelRef):
|
||||||
|
val += f", version={self._ref.version}"
|
||||||
|
val += ")"
|
||||||
|
return val
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.__repr__()
|
return self.__repr__()
|
||||||
@@ -721,10 +998,6 @@ class LanceTable(Table):
|
|||||||
self.to_lance(), allow_pyarrow_filter=False, batch_size=batch_size
|
self.to_lance(), allow_pyarrow_filter=False, batch_size=batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def _dataset_uri(self) -> str:
|
|
||||||
return join_uri(self._conn.uri, f"{self.name}.lance")
|
|
||||||
|
|
||||||
def create_index(
|
def create_index(
|
||||||
self,
|
self,
|
||||||
metric="L2",
|
metric="L2",
|
||||||
@@ -736,7 +1009,7 @@ class LanceTable(Table):
|
|||||||
index_cache_size: Optional[int] = None,
|
index_cache_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
"""Create an index on the table."""
|
"""Create an index on the table."""
|
||||||
self._dataset.create_index(
|
self._dataset_mut.create_index(
|
||||||
column=vector_column_name,
|
column=vector_column_name,
|
||||||
index_type="IVF_PQ",
|
index_type="IVF_PQ",
|
||||||
metric=metric,
|
metric=metric,
|
||||||
@@ -746,11 +1019,12 @@ class LanceTable(Table):
|
|||||||
accelerator=accelerator,
|
accelerator=accelerator,
|
||||||
index_cache_size=index_cache_size,
|
index_cache_size=index_cache_size,
|
||||||
)
|
)
|
||||||
self._reset_dataset()
|
|
||||||
register_event("create_index")
|
register_event("create_index")
|
||||||
|
|
||||||
def create_scalar_index(self, column: str, *, replace: bool = True):
|
def create_scalar_index(self, column: str, *, replace: bool = True):
|
||||||
self._dataset.create_scalar_index(column, index_type="BTREE", replace=replace)
|
self._dataset_mut.create_scalar_index(
|
||||||
|
column, index_type="BTREE", replace=replace
|
||||||
|
)
|
||||||
|
|
||||||
def create_fts_index(
|
def create_fts_index(
|
||||||
self,
|
self,
|
||||||
@@ -793,14 +1067,6 @@ class LanceTable(Table):
|
|||||||
def _get_fts_index_path(self):
|
def _get_fts_index_path(self):
|
||||||
return join_uri(self._dataset_uri, "_indices", "tantivy")
|
return join_uri(self._dataset_uri, "_indices", "tantivy")
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def _dataset(self) -> LanceDataset:
|
|
||||||
return lance.dataset(self._dataset_uri, version=self._version)
|
|
||||||
|
|
||||||
def to_lance(self) -> LanceDataset:
|
|
||||||
"""Return the LanceDataset backing this table."""
|
|
||||||
return self._dataset
|
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
data: DATA,
|
data: DATA,
|
||||||
@@ -839,8 +1105,11 @@ class LanceTable(Table):
|
|||||||
on_bad_vectors=on_bad_vectors,
|
on_bad_vectors=on_bad_vectors,
|
||||||
fill_value=fill_value,
|
fill_value=fill_value,
|
||||||
)
|
)
|
||||||
lance.write_dataset(data, self._dataset_uri, schema=self.schema, mode=mode)
|
# Access the dataset_mut property to ensure that the dataset is mutable.
|
||||||
self._reset_dataset()
|
self._ref.dataset_mut
|
||||||
|
self._ref.dataset = lance.write_dataset(
|
||||||
|
data, self._dataset_uri, schema=self.schema, mode=mode
|
||||||
|
)
|
||||||
register_event("add")
|
register_event("add")
|
||||||
|
|
||||||
def merge(
|
def merge(
|
||||||
@@ -901,10 +1170,9 @@ class LanceTable(Table):
|
|||||||
other_table = other_table.to_lance()
|
other_table = other_table.to_lance()
|
||||||
if isinstance(other_table, LanceDataset):
|
if isinstance(other_table, LanceDataset):
|
||||||
other_table = other_table.to_table()
|
other_table = other_table.to_table()
|
||||||
self._dataset.merge(
|
self._ref.dataset = self._dataset_mut.merge(
|
||||||
other_table, left_on=left_on, right_on=right_on, schema=schema
|
other_table, left_on=left_on, right_on=right_on, schema=schema
|
||||||
)
|
)
|
||||||
self._reset_dataset()
|
|
||||||
register_event("merge")
|
register_event("merge")
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
@@ -924,7 +1192,7 @@ class LanceTable(Table):
|
|||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None,
|
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||||
query_type: str = "auto",
|
query_type: str = "auto",
|
||||||
) -> LanceQueryBuilder:
|
) -> LanceQueryBuilder:
|
||||||
@@ -1107,22 +1375,8 @@ class LanceTable(Table):
|
|||||||
register_event("create_table")
|
register_event("create_table")
|
||||||
return new_table
|
return new_table
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def open(cls, db, name):
|
|
||||||
tbl = cls(db, name)
|
|
||||||
fs, path = fs_from_uri(tbl._dataset_uri)
|
|
||||||
file_info = fs.get_file_info(path)
|
|
||||||
if file_info.type != pa.fs.FileType.Directory:
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"Table {name} does not exist."
|
|
||||||
f"Please first call db.create_table({name}, data)"
|
|
||||||
)
|
|
||||||
register_event("open_table")
|
|
||||||
|
|
||||||
return tbl
|
|
||||||
|
|
||||||
def delete(self, where: str):
|
def delete(self, where: str):
|
||||||
self._dataset.delete(where)
|
self._dataset_mut.delete(where)
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
@@ -1176,8 +1430,7 @@ class LanceTable(Table):
|
|||||||
if values is not None:
|
if values is not None:
|
||||||
values_sql = {k: value_to_sql(v) for k, v in values.items()}
|
values_sql = {k: value_to_sql(v) for k, v in values.items()}
|
||||||
|
|
||||||
self.to_lance().update(values_sql, where)
|
self._dataset_mut.update(values_sql, where)
|
||||||
self._reset_dataset()
|
|
||||||
register_event("update")
|
register_event("update")
|
||||||
|
|
||||||
def _execute_query(self, query: Query) -> pa.Table:
|
def _execute_query(self, query: Query) -> pa.Table:
|
||||||
@@ -1194,8 +1447,34 @@ class LanceTable(Table):
|
|||||||
"nprobes": query.nprobes,
|
"nprobes": query.nprobes,
|
||||||
"refine_factor": query.refine_factor,
|
"refine_factor": query.refine_factor,
|
||||||
},
|
},
|
||||||
|
with_row_id=query.with_row_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _do_merge(
|
||||||
|
self,
|
||||||
|
merge: LanceMergeInsertBuilder,
|
||||||
|
new_data: DATA,
|
||||||
|
on_bad_vectors: str,
|
||||||
|
fill_value: float,
|
||||||
|
):
|
||||||
|
new_data = _sanitize_data(
|
||||||
|
new_data,
|
||||||
|
self.schema,
|
||||||
|
metadata=self.schema.metadata,
|
||||||
|
on_bad_vectors=on_bad_vectors,
|
||||||
|
fill_value=fill_value,
|
||||||
|
)
|
||||||
|
ds = self.to_lance()
|
||||||
|
builder = ds.merge_insert(merge._on)
|
||||||
|
if merge._when_matched_update_all:
|
||||||
|
builder.when_matched_update_all(merge._when_matched_update_all_condition)
|
||||||
|
if merge._when_not_matched_insert_all:
|
||||||
|
builder.when_not_matched_insert_all()
|
||||||
|
if merge._when_not_matched_by_source_delete:
|
||||||
|
cond = merge._when_not_matched_by_source_condition
|
||||||
|
builder.when_not_matched_by_source_delete(cond)
|
||||||
|
builder.execute(new_data)
|
||||||
|
|
||||||
def cleanup_old_versions(
|
def cleanup_old_versions(
|
||||||
self,
|
self,
|
||||||
older_than: Optional[timedelta] = None,
|
older_than: Optional[timedelta] = None,
|
||||||
@@ -1233,8 +1512,9 @@ class LanceTable(Table):
|
|||||||
This can be run after making several small appends to optimize the table
|
This can be run after making several small appends to optimize the table
|
||||||
for faster reads.
|
for faster reads.
|
||||||
|
|
||||||
Arguments are passed onto :meth:`lance.dataset.DatasetOptimizer.compact_files`.
|
Arguments are passed onto `lance.dataset.DatasetOptimizer.compact_files`.
|
||||||
For most cases, the default should be fine.
|
(see Lance documentation for more details) For most cases, the default
|
||||||
|
should be fine.
|
||||||
"""
|
"""
|
||||||
return self.to_lance().optimize.compact_files(*args, **kwargs)
|
return self.to_lance().optimize.compact_files(*args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import importlib
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
@@ -114,6 +115,25 @@ def join_uri(base: Union[str, pathlib.Path], *parts: str) -> str:
|
|||||||
return "/".join([p.rstrip("/") for p in [base, *parts]])
|
return "/".join([p.rstrip("/") for p in [base, *parts]])
|
||||||
|
|
||||||
|
|
||||||
|
def safe_import(module: str, mitigation=None):
|
||||||
|
"""
|
||||||
|
Import the specified module. If the module is not installed,
|
||||||
|
raise an ImportError with a helpful message.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
module : str
|
||||||
|
The name of the module to import
|
||||||
|
mitigation : Optional[str]
|
||||||
|
The package(s) to install to mitigate the error.
|
||||||
|
If not provided then the module name will be used.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return importlib.import_module(module)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(f"Please install {mitigation or module}")
|
||||||
|
|
||||||
|
|
||||||
def safe_import_pandas():
|
def safe_import_pandas():
|
||||||
try:
|
try:
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.5.1"
|
version = "0.5.4"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"deprecation",
|
"deprecation",
|
||||||
"pylance==0.9.9",
|
"pylance==0.9.15",
|
||||||
"ratelimiter~=1.0",
|
"ratelimiter~=1.0",
|
||||||
"retry>=0.9.2",
|
"retry>=0.9.2",
|
||||||
"tqdm>=4.27.0",
|
"tqdm>=4.27.0",
|
||||||
@@ -48,11 +48,12 @@ classifiers = [
|
|||||||
repository = "https://github.com/lancedb/lancedb"
|
repository = "https://github.com/lancedb/lancedb"
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars"]
|
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars>=0.19"]
|
||||||
dev = ["ruff", "pre-commit"]
|
dev = ["ruff", "pre-commit"]
|
||||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||||
clip = ["torch", "pillow", "open-clip"]
|
clip = ["torch", "pillow", "open-clip"]
|
||||||
embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "InstructorEmbedding"]
|
embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "huggingface_hub",
|
||||||
|
"InstructorEmbedding", "google.generativeai", "boto3>=1.28.57", "awscli>=1.29.57", "botocore>=1.31.57"]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
lancedb = "lancedb.cli.cli:cli"
|
lancedb = "lancedb.cli.cli:cli"
|
||||||
@@ -65,7 +66,8 @@ build-backend = "setuptools.build_meta"
|
|||||||
select = ["F", "E", "W", "I", "G", "TCH", "PERF"]
|
select = ["F", "E", "W", "I", "G", "TCH", "PERF"]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
addopts = "--strict-markers"
|
addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py"
|
||||||
|
|
||||||
markers = [
|
markers = [
|
||||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||||
"asyncio"
|
"asyncio"
|
||||||
|
|||||||
@@ -88,6 +88,7 @@ def test_embedding_function(tmp_path):
|
|||||||
assert np.allclose(actual, expected)
|
assert np.allclose(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
def test_embedding_function_rate_limit(tmp_path):
|
def test_embedding_function_rate_limit(tmp_path):
|
||||||
def _get_schema_from_model(model):
|
def _get_schema_from_model(model):
|
||||||
class Schema(LanceModel):
|
class Schema(LanceModel):
|
||||||
|
|||||||
@@ -10,6 +10,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import importlib
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@@ -202,3 +203,114 @@ def test_gemini_embedding(tmp_path):
|
|||||||
tbl.add(df)
|
tbl.add(df)
|
||||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||||
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
if importlib.util.find_spec("mlx.core") is not None:
|
||||||
|
_mlx = True
|
||||||
|
except ImportError:
|
||||||
|
_mlx = None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
_mlx is None,
|
||||||
|
reason="mlx tests only required for apple users.",
|
||||||
|
)
|
||||||
|
@pytest.mark.slow
|
||||||
|
def test_gte_embedding(tmp_path):
|
||||||
|
import lancedb.embeddings.gte
|
||||||
|
|
||||||
|
model = get_registry().get("gte-text").create()
|
||||||
|
|
||||||
|
class TextModel(LanceModel):
|
||||||
|
text: str = model.SourceField()
|
||||||
|
vector: Vector(model.ndims()) = model.VectorField()
|
||||||
|
|
||||||
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||||
|
|
||||||
|
tbl.add(df)
|
||||||
|
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||||
|
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||||
|
|
||||||
|
|
||||||
|
def aws_setup():
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
sts = boto3.client("sts")
|
||||||
|
sts.get_caller_identity()
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not aws_setup(), reason="AWS credentials not set or libraries not installed"
|
||||||
|
)
|
||||||
|
def test_bedrock_embedding(tmp_path):
|
||||||
|
for name in [
|
||||||
|
"amazon.titan-embed-text-v1",
|
||||||
|
"cohere.embed-english-v3",
|
||||||
|
"cohere.embed-multilingual-v3",
|
||||||
|
]:
|
||||||
|
model = get_registry().get("bedrock-text").create(max_retries=0, name=name)
|
||||||
|
|
||||||
|
class TextModel(LanceModel):
|
||||||
|
text: str = model.SourceField()
|
||||||
|
vector: Vector(model.ndims()) = model.VectorField()
|
||||||
|
|
||||||
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||||
|
|
||||||
|
tbl.add(df)
|
||||||
|
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
||||||
|
)
|
||||||
|
def test_openai_embedding(tmp_path):
|
||||||
|
def _get_table(model):
|
||||||
|
class TextModel(LanceModel):
|
||||||
|
text: str = model.SourceField()
|
||||||
|
vector: Vector(model.ndims()) = model.VectorField()
|
||||||
|
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||||
|
|
||||||
|
return tbl
|
||||||
|
|
||||||
|
model = get_registry().get("openai").create(max_retries=0)
|
||||||
|
tbl = _get_table(model)
|
||||||
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||||
|
|
||||||
|
tbl.add(df)
|
||||||
|
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||||
|
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||||
|
|
||||||
|
model = (
|
||||||
|
get_registry()
|
||||||
|
.get("openai")
|
||||||
|
.create(max_retries=0, name="text-embedding-3-large")
|
||||||
|
)
|
||||||
|
tbl = _get_table(model)
|
||||||
|
|
||||||
|
tbl.add(df)
|
||||||
|
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||||
|
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||||
|
|
||||||
|
model = (
|
||||||
|
get_registry()
|
||||||
|
.get("openai")
|
||||||
|
.create(max_retries=0, name="text-embedding-3-large", dim=1024)
|
||||||
|
)
|
||||||
|
tbl = _get_table(model)
|
||||||
|
|
||||||
|
tbl.add(df)
|
||||||
|
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||||
|
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||||
|
|||||||
@@ -29,6 +29,9 @@ class FakeLanceDBClient:
|
|||||||
def post(self, path: str):
|
def post(self, path: str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def mount_retry_adapter_for_table(self, table_name: str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def test_remote_db():
|
def test_remote_db():
|
||||||
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
|
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
|
||||||
|
|||||||
259
python/tests/test_rerankers.py
Normal file
259
python/tests/test_rerankers.py
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import lancedb
|
||||||
|
from lancedb.conftest import MockTextEmbeddingFunction # noqa
|
||||||
|
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||||
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
from lancedb.rerankers import (
|
||||||
|
CohereReranker,
|
||||||
|
ColbertReranker,
|
||||||
|
CrossEncoderReranker,
|
||||||
|
OpenaiReranker,
|
||||||
|
)
|
||||||
|
from lancedb.table import LanceTable
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_table(tmp_path):
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
# Create a LanceDB table schema with a vector and a text column
|
||||||
|
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||||
|
|
||||||
|
class MyTable(LanceModel):
|
||||||
|
text: str = emb.SourceField()
|
||||||
|
vector: Vector(emb.ndims()) = emb.VectorField()
|
||||||
|
|
||||||
|
# Initialize the table using the schema
|
||||||
|
table = LanceTable.create(
|
||||||
|
db,
|
||||||
|
"my_table",
|
||||||
|
schema=MyTable,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Need to test with a bunch of phrases to make sure sorting is consistent
|
||||||
|
phrases = [
|
||||||
|
"great kid don't get cocky",
|
||||||
|
"now that's a name I haven't heard in a long time",
|
||||||
|
"if you strike me down I shall become more powerful than you imagine",
|
||||||
|
"I find your lack of faith disturbing",
|
||||||
|
"I've got a bad feeling about this",
|
||||||
|
"never tell me the odds",
|
||||||
|
"I am your father",
|
||||||
|
"somebody has to save our skins",
|
||||||
|
"New strategy R2 let the wookiee win",
|
||||||
|
"Arrrrggghhhhhhh",
|
||||||
|
"I see a mansard roof through the trees",
|
||||||
|
"I see a salty message written in the eves",
|
||||||
|
"the ground beneath my feet",
|
||||||
|
"the hot garbage and concrete",
|
||||||
|
"and now the tops of buildings",
|
||||||
|
"everybody with a worried mind could never forgive the sight",
|
||||||
|
"of wicked snakes inside a place you thought was dignified",
|
||||||
|
"I don't wanna live like this",
|
||||||
|
"but I don't wanna die",
|
||||||
|
"The templars want control",
|
||||||
|
"the brotherhood of assassins want freedom",
|
||||||
|
"if only they could both see the world as it really is",
|
||||||
|
"there would be peace",
|
||||||
|
"but the war goes on",
|
||||||
|
"altair's legacy was a warning",
|
||||||
|
"Kratos had a son",
|
||||||
|
"he was a god",
|
||||||
|
"the god of war",
|
||||||
|
"but his son was mortal",
|
||||||
|
"there hasn't been a good battlefield game since 2142",
|
||||||
|
"I wish they would make another one",
|
||||||
|
"campains are not as good as they used to be",
|
||||||
|
"Multiplayer and open world games have destroyed the single player experience",
|
||||||
|
"Maybe the future is console games",
|
||||||
|
"I don't know",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add the phrases and vectors to the table
|
||||||
|
table.add([{"text": p} for p in phrases])
|
||||||
|
|
||||||
|
# Create a fts index
|
||||||
|
table.create_fts_index("text")
|
||||||
|
|
||||||
|
return table, MyTable
|
||||||
|
|
||||||
|
|
||||||
|
def test_linear_combination(tmp_path):
|
||||||
|
table, schema = get_test_table(tmp_path)
|
||||||
|
# The default reranker
|
||||||
|
result1 = (
|
||||||
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
|
.rerank(normalize="score")
|
||||||
|
.to_pydantic(schema)
|
||||||
|
)
|
||||||
|
result2 = ( # noqa
|
||||||
|
table.search("Our father who art in heaven.", query_type="hybrid")
|
||||||
|
.rerank(normalize="rank")
|
||||||
|
.to_pydantic(schema)
|
||||||
|
)
|
||||||
|
result3 = table.search(
|
||||||
|
"Our father who art in heaven..", query_type="hybrid"
|
||||||
|
).to_pydantic(schema)
|
||||||
|
|
||||||
|
assert result1 == result3 # 2 & 3 should be the same as they use score as score
|
||||||
|
|
||||||
|
query = "Our father who art in heaven"
|
||||||
|
query_vector = table.to_pandas()["vector"][0]
|
||||||
|
result = (
|
||||||
|
table.search((query_vector, query))
|
||||||
|
.limit(30)
|
||||||
|
.rerank(normalize="score")
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 30
|
||||||
|
|
||||||
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||||
|
"The _relevance_score column of the results returned by the reranker "
|
||||||
|
"represents the relevance of the result to the query & should "
|
||||||
|
"be descending."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||||
|
)
|
||||||
|
def test_cohere_reranker(tmp_path):
|
||||||
|
pytest.importorskip("cohere")
|
||||||
|
table, schema = get_test_table(tmp_path)
|
||||||
|
# The default reranker
|
||||||
|
result1 = (
|
||||||
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
|
.rerank(normalize="score", reranker=CohereReranker())
|
||||||
|
.to_pydantic(schema)
|
||||||
|
)
|
||||||
|
result2 = (
|
||||||
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
|
.rerank(reranker=CohereReranker())
|
||||||
|
.to_pydantic(schema)
|
||||||
|
)
|
||||||
|
assert result1 == result2
|
||||||
|
|
||||||
|
query = "Our father who art in heaven"
|
||||||
|
query_vector = table.to_pandas()["vector"][0]
|
||||||
|
result = (
|
||||||
|
table.search((query_vector, query))
|
||||||
|
.limit(30)
|
||||||
|
.rerank(reranker=CohereReranker())
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 30
|
||||||
|
|
||||||
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||||
|
"The _relevance_score column of the results returned by the reranker "
|
||||||
|
"represents the relevance of the result to the query & should "
|
||||||
|
"be descending."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cross_encoder_reranker(tmp_path):
|
||||||
|
pytest.importorskip("sentence_transformers")
|
||||||
|
table, schema = get_test_table(tmp_path)
|
||||||
|
result1 = (
|
||||||
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
|
.rerank(normalize="score", reranker=CrossEncoderReranker())
|
||||||
|
.to_pydantic(schema)
|
||||||
|
)
|
||||||
|
result2 = (
|
||||||
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
|
.rerank(reranker=CrossEncoderReranker())
|
||||||
|
.to_pydantic(schema)
|
||||||
|
)
|
||||||
|
assert result1 == result2
|
||||||
|
|
||||||
|
# test explicit hybrid query
|
||||||
|
query = "Our father who art in heaven"
|
||||||
|
query_vector = table.to_pandas()["vector"][0]
|
||||||
|
result = (
|
||||||
|
table.search((query_vector, query), query_type="hybrid")
|
||||||
|
.limit(30)
|
||||||
|
.rerank(reranker=CrossEncoderReranker())
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 30
|
||||||
|
|
||||||
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||||
|
"The _relevance_score column of the results returned by the reranker "
|
||||||
|
"represents the relevance of the result to the query & should "
|
||||||
|
"be descending."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_colbert_reranker(tmp_path):
|
||||||
|
pytest.importorskip("transformers")
|
||||||
|
table, schema = get_test_table(tmp_path)
|
||||||
|
result1 = (
|
||||||
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
|
.rerank(normalize="score", reranker=ColbertReranker())
|
||||||
|
.to_pydantic(schema)
|
||||||
|
)
|
||||||
|
result2 = (
|
||||||
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
|
.rerank(reranker=ColbertReranker())
|
||||||
|
.to_pydantic(schema)
|
||||||
|
)
|
||||||
|
assert result1 == result2
|
||||||
|
|
||||||
|
# test explicit hybrid query
|
||||||
|
query = "Our father who art in heaven"
|
||||||
|
query_vector = table.to_pandas()["vector"][0]
|
||||||
|
result = (
|
||||||
|
table.search((query_vector, query))
|
||||||
|
.limit(30)
|
||||||
|
.rerank(reranker=ColbertReranker())
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 30
|
||||||
|
|
||||||
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||||
|
"The _relevance_score column of the results returned by the reranker "
|
||||||
|
"represents the relevance of the result to the query & should "
|
||||||
|
"be descending."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
||||||
|
)
|
||||||
|
def test_openai_reranker(tmp_path):
|
||||||
|
pytest.importorskip("openai")
|
||||||
|
table, schema = get_test_table(tmp_path)
|
||||||
|
result1 = (
|
||||||
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
|
.rerank(normalize="score", reranker=OpenaiReranker())
|
||||||
|
.to_pydantic(schema)
|
||||||
|
)
|
||||||
|
result2 = (
|
||||||
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
|
.rerank(reranker=OpenaiReranker())
|
||||||
|
.to_pydantic(schema)
|
||||||
|
)
|
||||||
|
assert result1 == result2
|
||||||
|
|
||||||
|
# test explicit hybrid query
|
||||||
|
query = "Our father who art in heaven"
|
||||||
|
query_vector = table.to_pandas()["vector"][0]
|
||||||
|
result = (
|
||||||
|
table.search((query_vector, query))
|
||||||
|
.limit(30)
|
||||||
|
.rerank(reranker=OpenaiReranker())
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 30
|
||||||
|
|
||||||
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||||
|
"The _relevance_score column of the results returned by the reranker "
|
||||||
|
"represents the relevance of the result to the query & should "
|
||||||
|
"be descending."
|
||||||
|
)
|
||||||
@@ -12,8 +12,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
from copy import copy
|
||||||
from datetime import date, datetime, timedelta
|
from datetime import date, datetime, timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from time import sleep
|
||||||
from typing import List
|
from typing import List
|
||||||
from unittest.mock import PropertyMock, patch
|
from unittest.mock import PropertyMock, patch
|
||||||
|
|
||||||
@@ -25,6 +27,7 @@ import pyarrow as pa
|
|||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
import lancedb
|
||||||
from lancedb.conftest import MockTextEmbeddingFunction
|
from lancedb.conftest import MockTextEmbeddingFunction
|
||||||
from lancedb.db import LanceDBConnection
|
from lancedb.db import LanceDBConnection
|
||||||
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||||
@@ -35,6 +38,7 @@ from lancedb.table import LanceTable
|
|||||||
class MockDB:
|
class MockDB:
|
||||||
def __init__(self, uri: Path):
|
def __init__(self, uri: Path):
|
||||||
self.uri = uri
|
self.uri = uri
|
||||||
|
self.read_consistency_interval = None
|
||||||
|
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def is_managed_remote(self) -> bool:
|
def is_managed_remote(self) -> bool:
|
||||||
@@ -267,39 +271,38 @@ def test_versioning(db):
|
|||||||
|
|
||||||
|
|
||||||
def test_create_index_method():
|
def test_create_index_method():
|
||||||
with patch.object(LanceTable, "_reset_dataset", return_value=None):
|
with patch.object(
|
||||||
with patch.object(
|
LanceTable, "_dataset_mut", new_callable=PropertyMock
|
||||||
LanceTable, "_dataset", new_callable=PropertyMock
|
) as mock_dataset:
|
||||||
) as mock_dataset:
|
# Setup mock responses
|
||||||
# Setup mock responses
|
mock_dataset.return_value.create_index.return_value = None
|
||||||
mock_dataset.return_value.create_index.return_value = None
|
|
||||||
|
|
||||||
# Create a LanceTable object
|
# Create a LanceTable object
|
||||||
connection = LanceDBConnection(uri="mock.uri")
|
connection = LanceDBConnection(uri="mock.uri")
|
||||||
table = LanceTable(connection, "test_table")
|
table = LanceTable(connection, "test_table")
|
||||||
|
|
||||||
# Call the create_index method
|
# Call the create_index method
|
||||||
table.create_index(
|
table.create_index(
|
||||||
metric="L2",
|
metric="L2",
|
||||||
num_partitions=256,
|
num_partitions=256,
|
||||||
num_sub_vectors=96,
|
num_sub_vectors=96,
|
||||||
vector_column_name="vector",
|
vector_column_name="vector",
|
||||||
replace=True,
|
replace=True,
|
||||||
index_cache_size=256,
|
index_cache_size=256,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check that the _dataset.create_index method was called
|
# Check that the _dataset.create_index method was called
|
||||||
# with the right parameters
|
# with the right parameters
|
||||||
mock_dataset.return_value.create_index.assert_called_once_with(
|
mock_dataset.return_value.create_index.assert_called_once_with(
|
||||||
column="vector",
|
column="vector",
|
||||||
index_type="IVF_PQ",
|
index_type="IVF_PQ",
|
||||||
metric="L2",
|
metric="L2",
|
||||||
num_partitions=256,
|
num_partitions=256,
|
||||||
num_sub_vectors=96,
|
num_sub_vectors=96,
|
||||||
replace=True,
|
replace=True,
|
||||||
accelerator=None,
|
accelerator=None,
|
||||||
index_cache_size=256,
|
index_cache_size=256,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_add_with_nans(db):
|
def test_add_with_nans(db):
|
||||||
@@ -493,6 +496,69 @@ def test_update_types(db):
|
|||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_insert(db):
|
||||||
|
table = LanceTable.create(
|
||||||
|
db,
|
||||||
|
"my_table",
|
||||||
|
data=pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]}),
|
||||||
|
)
|
||||||
|
assert len(table) == 3
|
||||||
|
version = table.version
|
||||||
|
|
||||||
|
new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]})
|
||||||
|
|
||||||
|
# upsert
|
||||||
|
table.merge_insert(
|
||||||
|
"a"
|
||||||
|
).when_matched_update_all().when_not_matched_insert_all().execute(new_data)
|
||||||
|
|
||||||
|
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "x", "y", "z"]})
|
||||||
|
assert table.to_arrow().sort_by("a") == expected
|
||||||
|
|
||||||
|
table.restore(version)
|
||||||
|
|
||||||
|
# conditional update
|
||||||
|
table.merge_insert("a").when_matched_update_all(where="target.b = 'b'").execute(
|
||||||
|
new_data
|
||||||
|
)
|
||||||
|
expected = pa.table({"a": [1, 2, 3], "b": ["a", "x", "c"]})
|
||||||
|
assert table.to_arrow().sort_by("a") == expected
|
||||||
|
|
||||||
|
table.restore(version)
|
||||||
|
|
||||||
|
# insert-if-not-exists
|
||||||
|
table.merge_insert("a").when_not_matched_insert_all().execute(new_data)
|
||||||
|
|
||||||
|
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "b", "c", "z"]})
|
||||||
|
assert table.to_arrow().sort_by("a") == expected
|
||||||
|
|
||||||
|
table.restore(version)
|
||||||
|
|
||||||
|
new_data = pa.table({"a": [2, 4], "b": ["x", "z"]})
|
||||||
|
|
||||||
|
# replace-range
|
||||||
|
table.merge_insert(
|
||||||
|
"a"
|
||||||
|
).when_matched_update_all().when_not_matched_insert_all().when_not_matched_by_source_delete(
|
||||||
|
"a > 2"
|
||||||
|
).execute(new_data)
|
||||||
|
|
||||||
|
expected = pa.table({"a": [1, 2, 4], "b": ["a", "x", "z"]})
|
||||||
|
assert table.to_arrow().sort_by("a") == expected
|
||||||
|
|
||||||
|
table.restore(version)
|
||||||
|
|
||||||
|
# replace-range no condition
|
||||||
|
table.merge_insert(
|
||||||
|
"a"
|
||||||
|
).when_matched_update_all().when_not_matched_insert_all().when_not_matched_by_source_delete().execute(
|
||||||
|
new_data
|
||||||
|
)
|
||||||
|
|
||||||
|
expected = pa.table({"a": [2, 4], "b": ["x", "z"]})
|
||||||
|
assert table.to_arrow().sort_by("a") == expected
|
||||||
|
|
||||||
|
|
||||||
def test_create_with_embedding_function(db):
|
def test_create_with_embedding_function(db):
|
||||||
class MyTable(LanceModel):
|
class MyTable(LanceModel):
|
||||||
text: str
|
text: str
|
||||||
@@ -682,3 +748,102 @@ def test_count_rows(db):
|
|||||||
assert len(table) == 2
|
assert len(table) == 2
|
||||||
assert table.count_rows() == 2
|
assert table.count_rows() == 2
|
||||||
assert table.count_rows(filter="text='bar'") == 1
|
assert table.count_rows(filter="text='bar'") == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_hybrid_search(db):
|
||||||
|
# hardcoding temporarily.. this test is failing with tmp_path mockdb.
|
||||||
|
# Probably not being parsed right by the fts
|
||||||
|
db = MockDB("~/lancedb_")
|
||||||
|
# Create a LanceDB table schema with a vector and a text column
|
||||||
|
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||||
|
|
||||||
|
class MyTable(LanceModel):
|
||||||
|
text: str = emb.SourceField()
|
||||||
|
vector: Vector(emb.ndims()) = emb.VectorField()
|
||||||
|
|
||||||
|
# Initialize the table using the schema
|
||||||
|
table = LanceTable.create(
|
||||||
|
db,
|
||||||
|
"my_table",
|
||||||
|
schema=MyTable,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a list of 10 unique english phrases
|
||||||
|
phrases = [
|
||||||
|
"great kid don't get cocky",
|
||||||
|
"now that's a name I haven't heard in a long time",
|
||||||
|
"if you strike me down I shall become more powerful than you imagine",
|
||||||
|
"I find your lack of faith disturbing",
|
||||||
|
"I've got a bad feeling about this",
|
||||||
|
"never tell me the odds",
|
||||||
|
"I am your father",
|
||||||
|
"somebody has to save our skins",
|
||||||
|
"New strategy R2 let the wookiee win",
|
||||||
|
"Arrrrggghhhhhhh",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add the phrases and vectors to the table
|
||||||
|
table.add([{"text": p} for p in phrases])
|
||||||
|
|
||||||
|
# Create a fts index
|
||||||
|
table.create_fts_index("text")
|
||||||
|
|
||||||
|
result1 = (
|
||||||
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
|
.rerank(normalize="score")
|
||||||
|
.to_pydantic(MyTable)
|
||||||
|
)
|
||||||
|
result2 = ( # noqa
|
||||||
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
|
.rerank(normalize="rank")
|
||||||
|
.to_pydantic(MyTable)
|
||||||
|
)
|
||||||
|
result3 = table.search(
|
||||||
|
"Our father who art in heaven", query_type="hybrid"
|
||||||
|
).to_pydantic(MyTable)
|
||||||
|
assert result1 == result3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"consistency_interval", [None, timedelta(seconds=0), timedelta(seconds=0.1)]
|
||||||
|
)
|
||||||
|
def test_consistency(tmp_path, consistency_interval):
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
table = LanceTable.create(db, "my_table", data=[{"id": 0}])
|
||||||
|
|
||||||
|
db2 = lancedb.connect(tmp_path, read_consistency_interval=consistency_interval)
|
||||||
|
table2 = db2.open_table("my_table")
|
||||||
|
assert table2.version == table.version
|
||||||
|
|
||||||
|
table.add([{"id": 1}])
|
||||||
|
|
||||||
|
if consistency_interval is None:
|
||||||
|
assert table2.version == table.version - 1
|
||||||
|
table2.checkout_latest()
|
||||||
|
assert table2.version == table.version
|
||||||
|
elif consistency_interval == timedelta(seconds=0):
|
||||||
|
assert table2.version == table.version
|
||||||
|
else:
|
||||||
|
# (consistency_interval == timedelta(seconds=0.1)
|
||||||
|
assert table2.version == table.version - 1
|
||||||
|
sleep(0.1)
|
||||||
|
assert table2.version == table.version
|
||||||
|
|
||||||
|
|
||||||
|
def test_restore_consistency(tmp_path):
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
table = LanceTable.create(db, "my_table", data=[{"id": 0}])
|
||||||
|
|
||||||
|
db2 = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0))
|
||||||
|
table2 = db2.open_table("my_table")
|
||||||
|
assert table2.version == table.version
|
||||||
|
|
||||||
|
# If we call checkout, it should lose consistency
|
||||||
|
table_fixed = copy(table2)
|
||||||
|
table_fixed.checkout(table.version)
|
||||||
|
# But if we call checkout_latest, it should be consistent again
|
||||||
|
table_ref_latest = copy(table_fixed)
|
||||||
|
table_ref_latest.checkout_latest()
|
||||||
|
table.add([{"id": 2}])
|
||||||
|
assert table_fixed.version == table.version - 1
|
||||||
|
assert table_ref_latest.version == table.version
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "vectordb-node"
|
name = "vectordb-node"
|
||||||
version = "0.4.4"
|
version = "0.4.8"
|
||||||
description = "Serverless, low-latency vector database for AI applications"
|
description = "Serverless, low-latency vector database for AI applications"
|
||||||
license = "Apache-2.0"
|
license.workspace = true
|
||||||
edition = "2018"
|
edition.workspace = true
|
||||||
|
repository.workspace = true
|
||||||
|
keywords.workspace = true
|
||||||
|
categories.workspace = true
|
||||||
exclude = ["index.node"]
|
exclude = ["index.node"]
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
// Copyright 2023 Lance Developers.
|
// Copyright 2024 Lance Developers.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
@@ -19,33 +19,21 @@ use arrow_array::RecordBatch;
|
|||||||
use arrow_ipc::reader::FileReader;
|
use arrow_ipc::reader::FileReader;
|
||||||
use arrow_ipc::writer::FileWriter;
|
use arrow_ipc::writer::FileWriter;
|
||||||
use arrow_schema::SchemaRef;
|
use arrow_schema::SchemaRef;
|
||||||
use vectordb::table::VECTOR_COLUMN_NAME;
|
|
||||||
|
|
||||||
use crate::error::{MissingColumnSnafu, Result};
|
use crate::error::Result;
|
||||||
use snafu::prelude::*;
|
|
||||||
|
|
||||||
fn validate_vector_column(record_batch: &RecordBatch) -> Result<()> {
|
pub fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBatch>, SchemaRef)> {
|
||||||
record_batch
|
|
||||||
.column_by_name(VECTOR_COLUMN_NAME)
|
|
||||||
.map(|_| ())
|
|
||||||
.context(MissingColumnSnafu {
|
|
||||||
name: VECTOR_COLUMN_NAME,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBatch>, SchemaRef)> {
|
|
||||||
let mut batches: Vec<RecordBatch> = Vec::new();
|
let mut batches: Vec<RecordBatch> = Vec::new();
|
||||||
let file_reader = FileReader::try_new(Cursor::new(slice), None)?;
|
let file_reader = FileReader::try_new(Cursor::new(slice), None)?;
|
||||||
let schema = file_reader.schema();
|
let schema = file_reader.schema();
|
||||||
for b in file_reader {
|
for b in file_reader {
|
||||||
let record_batch = b?;
|
let record_batch = b?;
|
||||||
validate_vector_column(&record_batch)?;
|
|
||||||
batches.push(record_batch);
|
batches.push(record_batch);
|
||||||
}
|
}
|
||||||
Ok((batches, schema))
|
Ok((batches, schema))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn record_batch_to_buffer(batches: Vec<RecordBatch>) -> Result<Vec<u8>> {
|
pub fn record_batch_to_buffer(batches: Vec<RecordBatch>) -> Result<Vec<u8>> {
|
||||||
if batches.is_empty() {
|
if batches.is_empty() {
|
||||||
return Ok(Vec::new());
|
return Ok(Vec::new());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ use neon::types::buffer::TypedArray;
|
|||||||
|
|
||||||
use crate::error::ResultExt;
|
use crate::error::ResultExt;
|
||||||
|
|
||||||
pub(crate) fn vec_str_to_array<'a, C: Context<'a>>(
|
pub fn vec_str_to_array<'a, C: Context<'a>>(
|
||||||
vec: &Vec<String>,
|
vec: &Vec<String>,
|
||||||
cx: &mut C,
|
cx: &mut C,
|
||||||
) -> JsResult<'a, JsArray> {
|
) -> JsResult<'a, JsArray> {
|
||||||
@@ -29,7 +29,7 @@ pub(crate) fn vec_str_to_array<'a, C: Context<'a>>(
|
|||||||
Ok(a)
|
Ok(a)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn js_array_to_vec(array: &JsArray, cx: &mut FunctionContext) -> Vec<f32> {
|
pub fn js_array_to_vec(array: &JsArray, cx: &mut FunctionContext) -> Vec<f32> {
|
||||||
let mut query_vec: Vec<f32> = Vec::new();
|
let mut query_vec: Vec<f32> = Vec::new();
|
||||||
for i in 0..array.len(cx) {
|
for i in 0..array.len(cx) {
|
||||||
let entry: Handle<JsNumber> = array.get(cx, i).unwrap();
|
let entry: Handle<JsNumber> = array.get(cx, i).unwrap();
|
||||||
@@ -39,7 +39,7 @@ pub(crate) fn js_array_to_vec(array: &JsArray, cx: &mut FunctionContext) -> Vec<
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Creates a new JsBuffer from a rust buffer with a special logic for electron
|
// Creates a new JsBuffer from a rust buffer with a special logic for electron
|
||||||
pub(crate) fn new_js_buffer<'a>(
|
pub fn new_js_buffer<'a>(
|
||||||
buffer: Vec<u8>,
|
buffer: Vec<u8>,
|
||||||
cx: &mut TaskContext<'a>,
|
cx: &mut TaskContext<'a>,
|
||||||
is_electron: bool,
|
is_electron: bool,
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ use neon::prelude::NeonResult;
|
|||||||
use snafu::Snafu;
|
use snafu::Snafu;
|
||||||
|
|
||||||
#[derive(Debug, Snafu)]
|
#[derive(Debug, Snafu)]
|
||||||
#[snafu(visibility(pub(crate)))]
|
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
#[snafu(display("column '{name}' is missing"))]
|
#[snafu(display("column '{name}' is missing"))]
|
||||||
MissingColumn { name: String },
|
MissingColumn { name: String },
|
||||||
|
|||||||
@@ -19,8 +19,9 @@ use neon::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
use crate::{error::ResultExt, runtime, table::JsTable};
|
use crate::{error::ResultExt, runtime, table::JsTable};
|
||||||
|
use vectordb::Table;
|
||||||
|
|
||||||
pub(crate) fn table_create_scalar_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub fn table_create_scalar_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||||
let column = cx.argument::<JsString>(0)?.value(&mut cx);
|
let column = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||||
let replace = cx.argument::<JsBoolean>(1)?.value(&mut cx);
|
let replace = cx.argument::<JsBoolean>(1)?.value(&mut cx);
|
||||||
@@ -35,7 +36,9 @@ pub(crate) fn table_create_scalar_index(mut cx: FunctionContext) -> JsResult<JsP
|
|||||||
let idx_result = table
|
let idx_result = table
|
||||||
.as_native()
|
.as_native()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.create_scalar_index(&column, replace)
|
.create_index(&[&column])
|
||||||
|
.replace(replace)
|
||||||
|
.build()
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
deferred.settle_with(&channel, move |mut cx| {
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ use crate::neon_ext::js_object_ext::JsObjectExt;
|
|||||||
use crate::runtime;
|
use crate::runtime;
|
||||||
use crate::table::JsTable;
|
use crate::table::JsTable;
|
||||||
|
|
||||||
pub(crate) fn table_create_vector_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub fn table_create_vector_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||||
let index_params = cx.argument::<JsObject>(0)?;
|
let index_params = cx.argument::<JsObject>(0)?;
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ use tokio::runtime::Runtime;
|
|||||||
|
|
||||||
use vectordb::connection::Database;
|
use vectordb::connection::Database;
|
||||||
use vectordb::table::ReadParams;
|
use vectordb::table::ReadParams;
|
||||||
use vectordb::Connection;
|
use vectordb::{ConnectOptions, Connection};
|
||||||
|
|
||||||
use crate::error::ResultExt;
|
use crate::error::ResultExt;
|
||||||
use crate::query::JsQuery;
|
use crate::query::JsQuery;
|
||||||
@@ -82,13 +82,26 @@ fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> {
|
|||||||
|
|
||||||
fn database_new(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
fn database_new(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let path = cx.argument::<JsString>(0)?.value(&mut cx);
|
let path = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||||
|
let aws_creds = get_aws_creds(&mut cx, 1)?;
|
||||||
|
let region = get_aws_region(&mut cx, 4)?;
|
||||||
|
|
||||||
let rt = runtime(&mut cx)?;
|
let rt = runtime(&mut cx)?;
|
||||||
let channel = cx.channel();
|
let channel = cx.channel();
|
||||||
let (deferred, promise) = cx.promise();
|
let (deferred, promise) = cx.promise();
|
||||||
|
|
||||||
|
let mut conn_options = ConnectOptions::new(&path);
|
||||||
|
if let Some(region) = region {
|
||||||
|
conn_options = conn_options.region(®ion);
|
||||||
|
}
|
||||||
|
if let Some(aws_creds) = aws_creds {
|
||||||
|
conn_options = conn_options.aws_creds(AwsCredential {
|
||||||
|
key_id: aws_creds.key_id,
|
||||||
|
secret_key: aws_creds.secret_key,
|
||||||
|
token: aws_creds.token,
|
||||||
|
});
|
||||||
|
}
|
||||||
rt.spawn(async move {
|
rt.spawn(async move {
|
||||||
let database = Database::connect(&path).await;
|
let database = Database::connect_with_options(&conn_options).await;
|
||||||
|
|
||||||
deferred.settle_with(&channel, move |mut cx| {
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
let db = JsDatabase {
|
let db = JsDatabase {
|
||||||
@@ -127,7 +140,7 @@ fn database_table_names(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
|||||||
fn get_aws_creds(
|
fn get_aws_creds(
|
||||||
cx: &mut FunctionContext,
|
cx: &mut FunctionContext,
|
||||||
arg_starting_location: i32,
|
arg_starting_location: i32,
|
||||||
) -> NeonResult<Option<AwsCredentialProvider>> {
|
) -> NeonResult<Option<AwsCredential>> {
|
||||||
let secret_key_id = cx
|
let secret_key_id = cx
|
||||||
.argument_opt(arg_starting_location)
|
.argument_opt(arg_starting_location)
|
||||||
.filter(|arg| arg.is_a::<JsString, _>(cx))
|
.filter(|arg| arg.is_a::<JsString, _>(cx))
|
||||||
@@ -147,18 +160,26 @@ fn get_aws_creds(
|
|||||||
.map(|v| v.value(cx));
|
.map(|v| v.value(cx));
|
||||||
|
|
||||||
match (secret_key_id, secret_key, temp_token) {
|
match (secret_key_id, secret_key, temp_token) {
|
||||||
(Some(key_id), Some(key), optional_token) => Ok(Some(Arc::new(
|
(Some(key_id), Some(key), optional_token) => Ok(Some(AwsCredential {
|
||||||
StaticCredentialProvider::new(AwsCredential {
|
key_id,
|
||||||
key_id,
|
secret_key: key,
|
||||||
secret_key: key,
|
token: optional_token,
|
||||||
token: optional_token,
|
})),
|
||||||
}),
|
|
||||||
))),
|
|
||||||
(None, None, None) => Ok(None),
|
(None, None, None) => Ok(None),
|
||||||
_ => cx.throw_error("Invalid credentials configuration"),
|
_ => cx.throw_error("Invalid credentials configuration"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_aws_credential_provider(
|
||||||
|
cx: &mut FunctionContext,
|
||||||
|
arg_starting_location: i32,
|
||||||
|
) -> NeonResult<Option<AwsCredentialProvider>> {
|
||||||
|
Ok(get_aws_creds(cx, arg_starting_location)?.map(|aws_cred| {
|
||||||
|
Arc::new(StaticCredentialProvider::new(aws_cred))
|
||||||
|
as Arc<dyn CredentialProvider<Credential = AwsCredential>>
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
/// Get AWS region arguments from the context
|
/// Get AWS region arguments from the context
|
||||||
fn get_aws_region(cx: &mut FunctionContext, arg_location: i32) -> NeonResult<Option<String>> {
|
fn get_aws_region(cx: &mut FunctionContext, arg_location: i32) -> NeonResult<Option<String>> {
|
||||||
let region = cx
|
let region = cx
|
||||||
@@ -179,7 +200,7 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
|||||||
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
|
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
|
||||||
let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
|
let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||||
|
|
||||||
let aws_creds = get_aws_creds(&mut cx, 1)?;
|
let aws_creds = get_aws_credential_provider(&mut cx, 1)?;
|
||||||
|
|
||||||
let aws_region = get_aws_region(&mut cx, 4)?;
|
let aws_region = get_aws_region(&mut cx, 4)?;
|
||||||
|
|
||||||
@@ -239,6 +260,7 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> {
|
|||||||
cx.export_function("tableCountRows", JsTable::js_count_rows)?;
|
cx.export_function("tableCountRows", JsTable::js_count_rows)?;
|
||||||
cx.export_function("tableDelete", JsTable::js_delete)?;
|
cx.export_function("tableDelete", JsTable::js_delete)?;
|
||||||
cx.export_function("tableUpdate", JsTable::js_update)?;
|
cx.export_function("tableUpdate", JsTable::js_update)?;
|
||||||
|
cx.export_function("tableMergeInsert", JsTable::js_merge_insert)?;
|
||||||
cx.export_function("tableCleanupOldVersions", JsTable::js_cleanup)?;
|
cx.export_function("tableCleanupOldVersions", JsTable::js_cleanup)?;
|
||||||
cx.export_function("tableCompactFiles", JsTable::js_compact)?;
|
cx.export_function("tableCompactFiles", JsTable::js_compact)?;
|
||||||
cx.export_function("tableListIndices", JsTable::js_list_indices)?;
|
cx.export_function("tableListIndices", JsTable::js_list_indices)?;
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ use crate::neon_ext::js_object_ext::JsObjectExt;
|
|||||||
use crate::table::JsTable;
|
use crate::table::JsTable;
|
||||||
use crate::{convert, runtime};
|
use crate::{convert, runtime};
|
||||||
|
|
||||||
pub(crate) struct JsQuery {}
|
pub struct JsQuery {}
|
||||||
|
|
||||||
impl JsQuery {
|
impl JsQuery {
|
||||||
pub(crate) fn js_search(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub(crate) fn js_search(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
|
|||||||
@@ -12,10 +12,13 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::ops::Deref;
|
||||||
|
|
||||||
use arrow_array::{RecordBatch, RecordBatchIterator};
|
use arrow_array::{RecordBatch, RecordBatchIterator};
|
||||||
use lance::dataset::optimize::CompactionOptions;
|
use lance::dataset::optimize::CompactionOptions;
|
||||||
use lance::dataset::{WriteMode, WriteParams};
|
use lance::dataset::{WriteMode, WriteParams};
|
||||||
use lance::io::ObjectStoreParams;
|
use lance::io::ObjectStoreParams;
|
||||||
|
use vectordb::table::OptimizeAction;
|
||||||
|
|
||||||
use crate::arrow::{arrow_buffer_to_record_batch, record_batch_to_buffer};
|
use crate::arrow::{arrow_buffer_to_record_batch, record_batch_to_buffer};
|
||||||
use neon::prelude::*;
|
use neon::prelude::*;
|
||||||
@@ -23,9 +26,9 @@ use neon::types::buffer::TypedArray;
|
|||||||
use vectordb::TableRef;
|
use vectordb::TableRef;
|
||||||
|
|
||||||
use crate::error::ResultExt;
|
use crate::error::ResultExt;
|
||||||
use crate::{convert, get_aws_creds, get_aws_region, runtime, JsDatabase};
|
use crate::{convert, get_aws_credential_provider, get_aws_region, runtime, JsDatabase};
|
||||||
|
|
||||||
pub(crate) struct JsTable {
|
pub struct JsTable {
|
||||||
pub table: TableRef,
|
pub table: TableRef,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -33,7 +36,7 @@ impl Finalize for JsTable {}
|
|||||||
|
|
||||||
impl From<TableRef> for JsTable {
|
impl From<TableRef> for JsTable {
|
||||||
fn from(table: TableRef) -> Self {
|
fn from(table: TableRef) -> Self {
|
||||||
JsTable { table }
|
Self { table }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -63,7 +66,7 @@ impl JsTable {
|
|||||||
let (deferred, promise) = cx.promise();
|
let (deferred, promise) = cx.promise();
|
||||||
let database = db.database.clone();
|
let database = db.database.clone();
|
||||||
|
|
||||||
let aws_creds = get_aws_creds(&mut cx, 3)?;
|
let aws_creds = get_aws_credential_provider(&mut cx, 3)?;
|
||||||
let aws_region = get_aws_region(&mut cx, 6)?;
|
let aws_region = get_aws_region(&mut cx, 6)?;
|
||||||
|
|
||||||
let params = WriteParams {
|
let params = WriteParams {
|
||||||
@@ -82,14 +85,14 @@ impl JsTable {
|
|||||||
|
|
||||||
deferred.settle_with(&channel, move |mut cx| {
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
let table = table_rst.or_throw(&mut cx)?;
|
let table = table_rst.or_throw(&mut cx)?;
|
||||||
Ok(cx.boxed(JsTable::from(table)))
|
Ok(cx.boxed(Self::from(table)))
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
Ok(promise)
|
Ok(promise)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn js_add(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub(crate) fn js_add(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||||
let buffer = cx.argument::<JsBuffer>(0)?;
|
let buffer = cx.argument::<JsBuffer>(0)?;
|
||||||
let write_mode = cx.argument::<JsString>(1)?.value(&mut cx);
|
let write_mode = cx.argument::<JsString>(1)?.value(&mut cx);
|
||||||
let (batches, schema) =
|
let (batches, schema) =
|
||||||
@@ -105,7 +108,7 @@ impl JsTable {
|
|||||||
"overwrite" => WriteMode::Overwrite,
|
"overwrite" => WriteMode::Overwrite,
|
||||||
s => return cx.throw_error(format!("invalid write mode {}", s)),
|
s => return cx.throw_error(format!("invalid write mode {}", s)),
|
||||||
};
|
};
|
||||||
let aws_creds = get_aws_creds(&mut cx, 2)?;
|
let aws_creds = get_aws_credential_provider(&mut cx, 2)?;
|
||||||
let aws_region = get_aws_region(&mut cx, 5)?;
|
let aws_region = get_aws_region(&mut cx, 5)?;
|
||||||
|
|
||||||
let params = WriteParams {
|
let params = WriteParams {
|
||||||
@@ -122,21 +125,34 @@ impl JsTable {
|
|||||||
|
|
||||||
deferred.settle_with(&channel, move |mut cx| {
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
add_result.or_throw(&mut cx)?;
|
add_result.or_throw(&mut cx)?;
|
||||||
Ok(cx.boxed(JsTable::from(table)))
|
Ok(cx.boxed(Self::from(table)))
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
Ok(promise)
|
Ok(promise)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn js_count_rows(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub(crate) fn js_count_rows(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||||
|
let filter = cx
|
||||||
|
.argument_opt(0)
|
||||||
|
.and_then(|filt| {
|
||||||
|
if filt.is_a::<JsUndefined, _>(&mut cx) || filt.is_a::<JsNull, _>(&mut cx) {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(
|
||||||
|
filt.downcast_or_throw::<JsString, _>(&mut cx)
|
||||||
|
.map(|js_filt| js_filt.deref().value(&mut cx)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.transpose()?;
|
||||||
let rt = runtime(&mut cx)?;
|
let rt = runtime(&mut cx)?;
|
||||||
let (deferred, promise) = cx.promise();
|
let (deferred, promise) = cx.promise();
|
||||||
let channel = cx.channel();
|
let channel = cx.channel();
|
||||||
let table = js_table.table.clone();
|
let table = js_table.table.clone();
|
||||||
|
|
||||||
rt.spawn(async move {
|
rt.spawn(async move {
|
||||||
let num_rows_result = table.count_rows().await;
|
let num_rows_result = table.count_rows(filter).await;
|
||||||
|
|
||||||
deferred.settle_with(&channel, move |mut cx| {
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
let num_rows = num_rows_result.or_throw(&mut cx)?;
|
let num_rows = num_rows_result.or_throw(&mut cx)?;
|
||||||
@@ -147,7 +163,7 @@ impl JsTable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn js_delete(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub(crate) fn js_delete(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||||
let rt = runtime(&mut cx)?;
|
let rt = runtime(&mut cx)?;
|
||||||
let (deferred, promise) = cx.promise();
|
let (deferred, promise) = cx.promise();
|
||||||
let predicate = cx.argument::<JsString>(0)?.value(&mut cx);
|
let predicate = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||||
@@ -159,14 +175,67 @@ impl JsTable {
|
|||||||
|
|
||||||
deferred.settle_with(&channel, move |mut cx| {
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
delete_result.or_throw(&mut cx)?;
|
delete_result.or_throw(&mut cx)?;
|
||||||
Ok(cx.boxed(JsTable::from(table)))
|
Ok(cx.boxed(Self::from(table)))
|
||||||
|
})
|
||||||
|
});
|
||||||
|
Ok(promise)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn js_merge_insert(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
|
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||||
|
let rt = runtime(&mut cx)?;
|
||||||
|
let (deferred, promise) = cx.promise();
|
||||||
|
let channel = cx.channel();
|
||||||
|
let table = js_table.table.clone();
|
||||||
|
|
||||||
|
let key = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||||
|
let mut builder = table.merge_insert(&[&key]);
|
||||||
|
if cx.argument::<JsBoolean>(1)?.value(&mut cx) {
|
||||||
|
let filter = cx.argument_opt(2).unwrap();
|
||||||
|
if filter.is_a::<JsNull, _>(&mut cx) {
|
||||||
|
builder.when_matched_update_all(None);
|
||||||
|
} else {
|
||||||
|
let filter = filter
|
||||||
|
.downcast_or_throw::<JsString, _>(&mut cx)?
|
||||||
|
.deref()
|
||||||
|
.value(&mut cx);
|
||||||
|
builder.when_matched_update_all(Some(filter));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cx.argument::<JsBoolean>(3)?.value(&mut cx) {
|
||||||
|
builder.when_not_matched_insert_all();
|
||||||
|
}
|
||||||
|
if cx.argument::<JsBoolean>(4)?.value(&mut cx) {
|
||||||
|
let filter = cx.argument_opt(5).unwrap();
|
||||||
|
if filter.is_a::<JsNull, _>(&mut cx) {
|
||||||
|
builder.when_not_matched_by_source_delete(None);
|
||||||
|
} else {
|
||||||
|
let filter = filter
|
||||||
|
.downcast_or_throw::<JsString, _>(&mut cx)?
|
||||||
|
.deref()
|
||||||
|
.value(&mut cx);
|
||||||
|
builder.when_not_matched_by_source_delete(Some(filter));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let buffer = cx.argument::<JsBuffer>(6)?;
|
||||||
|
let (batches, schema) =
|
||||||
|
arrow_buffer_to_record_batch(buffer.as_slice(&cx)).or_throw(&mut cx)?;
|
||||||
|
|
||||||
|
rt.spawn(async move {
|
||||||
|
let new_data = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
|
||||||
|
let merge_insert_result = builder.execute(Box::new(new_data)).await;
|
||||||
|
|
||||||
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
|
merge_insert_result.or_throw(&mut cx)?;
|
||||||
|
Ok(cx.boxed(Self::from(table)))
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
Ok(promise)
|
Ok(promise)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn js_update(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub(crate) fn js_update(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||||
let table = js_table.table.clone();
|
let table = js_table.table.clone();
|
||||||
|
|
||||||
let rt = runtime(&mut cx)?;
|
let rt = runtime(&mut cx)?;
|
||||||
@@ -225,7 +294,7 @@ impl JsTable {
|
|||||||
.await;
|
.await;
|
||||||
deferred.settle_with(&channel, move |mut cx| {
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
update_result.or_throw(&mut cx)?;
|
update_result.or_throw(&mut cx)?;
|
||||||
Ok(cx.boxed(JsTable::from(table)))
|
Ok(cx.boxed(Self::from(table)))
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -233,7 +302,7 @@ impl JsTable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn js_cleanup(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub(crate) fn js_cleanup(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||||
let rt = runtime(&mut cx)?;
|
let rt = runtime(&mut cx)?;
|
||||||
let (deferred, promise) = cx.promise();
|
let (deferred, promise) = cx.promise();
|
||||||
let table = js_table.table.clone();
|
let table = js_table.table.clone();
|
||||||
@@ -245,30 +314,33 @@ impl JsTable {
|
|||||||
.map(|val| val.value(&mut cx) as i64)
|
.map(|val| val.value(&mut cx) as i64)
|
||||||
.unwrap_or_else(|| 2 * 7 * 24 * 60); // 2 weeks
|
.unwrap_or_else(|| 2 * 7 * 24 * 60); // 2 weeks
|
||||||
let older_than = chrono::Duration::minutes(older_than);
|
let older_than = chrono::Duration::minutes(older_than);
|
||||||
let delete_unverified: bool = cx
|
let delete_unverified: Option<bool> = Some(
|
||||||
.argument_opt(1)
|
cx.argument_opt(1)
|
||||||
.and_then(|val| val.downcast::<JsBoolean, _>(&mut cx).ok())
|
.and_then(|val| val.downcast::<JsBoolean, _>(&mut cx).ok())
|
||||||
.map(|val| val.value(&mut cx))
|
.map(|val| val.value(&mut cx))
|
||||||
.unwrap_or_default();
|
.unwrap_or_default(),
|
||||||
|
);
|
||||||
|
|
||||||
rt.spawn(async move {
|
rt.spawn(async move {
|
||||||
let stats = table
|
let stats = table
|
||||||
.as_native()
|
.optimize(OptimizeAction::Prune {
|
||||||
.unwrap()
|
older_than,
|
||||||
.cleanup_old_versions(older_than, Some(delete_unverified))
|
delete_unverified,
|
||||||
|
})
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
deferred.settle_with(&channel, move |mut cx| {
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
let stats = stats.or_throw(&mut cx)?;
|
let stats = stats.or_throw(&mut cx)?;
|
||||||
|
|
||||||
|
let prune_stats = stats.prune.as_ref().expect("Prune stats missing");
|
||||||
let output_metrics = JsObject::new(&mut cx);
|
let output_metrics = JsObject::new(&mut cx);
|
||||||
let bytes_removed = cx.number(stats.bytes_removed as f64);
|
let bytes_removed = cx.number(prune_stats.bytes_removed as f64);
|
||||||
output_metrics.set(&mut cx, "bytesRemoved", bytes_removed)?;
|
output_metrics.set(&mut cx, "bytesRemoved", bytes_removed)?;
|
||||||
|
|
||||||
let old_versions = cx.number(stats.old_versions as f64);
|
let old_versions = cx.number(prune_stats.old_versions as f64);
|
||||||
output_metrics.set(&mut cx, "oldVersions", old_versions)?;
|
output_metrics.set(&mut cx, "oldVersions", old_versions)?;
|
||||||
|
|
||||||
let output_table = cx.boxed(JsTable::from(table));
|
let output_table = cx.boxed(Self::from(table));
|
||||||
|
|
||||||
let output = JsObject::new(&mut cx);
|
let output = JsObject::new(&mut cx);
|
||||||
output.set(&mut cx, "metrics", output_metrics)?;
|
output.set(&mut cx, "metrics", output_metrics)?;
|
||||||
@@ -281,7 +353,7 @@ impl JsTable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn js_compact(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub(crate) fn js_compact(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||||
let rt = runtime(&mut cx)?;
|
let rt = runtime(&mut cx)?;
|
||||||
let (deferred, promise) = cx.promise();
|
let (deferred, promise) = cx.promise();
|
||||||
let table = js_table.table.clone();
|
let table = js_table.table.clone();
|
||||||
@@ -317,13 +389,15 @@ impl JsTable {
|
|||||||
|
|
||||||
rt.spawn(async move {
|
rt.spawn(async move {
|
||||||
let stats = table
|
let stats = table
|
||||||
.as_native()
|
.optimize(OptimizeAction::Compact {
|
||||||
.unwrap()
|
options,
|
||||||
.compact_files(options, None)
|
remap_options: None,
|
||||||
|
})
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
deferred.settle_with(&channel, move |mut cx| {
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
let stats = stats.or_throw(&mut cx)?;
|
let stats = stats.or_throw(&mut cx)?;
|
||||||
|
let stats = stats.compaction.as_ref().expect("Compact stats missing");
|
||||||
|
|
||||||
let output_metrics = JsObject::new(&mut cx);
|
let output_metrics = JsObject::new(&mut cx);
|
||||||
let fragments_removed = cx.number(stats.fragments_removed as f64);
|
let fragments_removed = cx.number(stats.fragments_removed as f64);
|
||||||
@@ -338,7 +412,7 @@ impl JsTable {
|
|||||||
let files_added = cx.number(stats.files_added as f64);
|
let files_added = cx.number(stats.files_added as f64);
|
||||||
output_metrics.set(&mut cx, "filesAdded", files_added)?;
|
output_metrics.set(&mut cx, "filesAdded", files_added)?;
|
||||||
|
|
||||||
let output_table = cx.boxed(JsTable::from(table));
|
let output_table = cx.boxed(Self::from(table));
|
||||||
|
|
||||||
let output = JsObject::new(&mut cx);
|
let output = JsObject::new(&mut cx);
|
||||||
output.set(&mut cx, "metrics", output_metrics)?;
|
output.set(&mut cx, "metrics", output_metrics)?;
|
||||||
@@ -351,7 +425,7 @@ impl JsTable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn js_list_indices(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub(crate) fn js_list_indices(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||||
let rt = runtime(&mut cx)?;
|
let rt = runtime(&mut cx)?;
|
||||||
let (deferred, promise) = cx.promise();
|
let (deferred, promise) = cx.promise();
|
||||||
// let predicate = cx.argument::<JsString>(0)?.value(&mut cx);
|
// let predicate = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||||
@@ -390,7 +464,7 @@ impl JsTable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn js_index_stats(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub(crate) fn js_index_stats(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||||
let rt = runtime(&mut cx)?;
|
let rt = runtime(&mut cx)?;
|
||||||
let (deferred, promise) = cx.promise();
|
let (deferred, promise) = cx.promise();
|
||||||
let index_uuid = cx.argument::<JsString>(0)?.value(&mut cx);
|
let index_uuid = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||||
@@ -438,7 +512,7 @@ impl JsTable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn js_schema(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub(crate) fn js_schema(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||||
let rt = runtime(&mut cx)?;
|
let rt = runtime(&mut cx)?;
|
||||||
let (deferred, promise) = cx.promise();
|
let (deferred, promise) = cx.promise();
|
||||||
let channel = cx.channel();
|
let channel = cx.channel();
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "vectordb"
|
name = "vectordb"
|
||||||
version = "0.4.4"
|
version = "0.4.8"
|
||||||
edition = "2021"
|
edition.workspace = true
|
||||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||||
license = "Apache-2.0"
|
license.workspace = true
|
||||||
repository = "https://github.com/lancedb/lancedb"
|
repository.workspace = true
|
||||||
keywords = ["lancedb", "lance", "database", "search"]
|
keywords.workspace = true
|
||||||
categories = ["database-implementations"]
|
categories.workspace = true
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
|||||||
168
rust/vectordb/examples/simple.rs
Normal file
168
rust/vectordb/examples/simple.rs
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
// Copyright 2024 Lance Developers.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use arrow_array::types::Float32Type;
|
||||||
|
use arrow_array::{FixedSizeListArray, Int32Array, RecordBatch, RecordBatchIterator};
|
||||||
|
use arrow_schema::{DataType, Field, Schema};
|
||||||
|
use futures::TryStreamExt;
|
||||||
|
|
||||||
|
use vectordb::Connection;
|
||||||
|
use vectordb::{connect, Result, Table, TableRef};
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<()> {
|
||||||
|
if std::path::Path::new("data").exists() {
|
||||||
|
std::fs::remove_dir_all("data").unwrap();
|
||||||
|
}
|
||||||
|
// --8<-- [start:connect]
|
||||||
|
let uri = "data/sample-lancedb";
|
||||||
|
let db = connect(uri).await?;
|
||||||
|
// --8<-- [end:connect]
|
||||||
|
|
||||||
|
// --8<-- [start:list_names]
|
||||||
|
println!("{:?}", db.table_names().await?);
|
||||||
|
// --8<-- [end:list_names]
|
||||||
|
let tbl = create_table(db.clone()).await?;
|
||||||
|
create_index(tbl.as_ref()).await?;
|
||||||
|
let batches = search(tbl.as_ref()).await?;
|
||||||
|
println!("{:?}", batches);
|
||||||
|
|
||||||
|
create_empty_table(db.clone()).await.unwrap();
|
||||||
|
|
||||||
|
// --8<-- [start:delete]
|
||||||
|
tbl.delete("id > 24").await.unwrap();
|
||||||
|
// --8<-- [end:delete]
|
||||||
|
|
||||||
|
// --8<-- [start:drop_table]
|
||||||
|
db.drop_table("my_table").await.unwrap();
|
||||||
|
// --8<-- [end:drop_table]
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
async fn open_with_existing_tbl() -> Result<()> {
|
||||||
|
let uri = "data/sample-lancedb";
|
||||||
|
let db = connect(uri).await?;
|
||||||
|
// --8<-- [start:open_with_existing_file]
|
||||||
|
let _ = db
|
||||||
|
.open_table_with_params("my_table", Default::default())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
// --8<-- [end:open_with_existing_file]
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_table(db: Arc<dyn Connection>) -> Result<TableRef> {
|
||||||
|
// --8<-- [start:create_table]
|
||||||
|
const TOTAL: usize = 1000;
|
||||||
|
const DIM: usize = 128;
|
||||||
|
|
||||||
|
let schema = Arc::new(Schema::new(vec![
|
||||||
|
Field::new("id", DataType::Int32, false),
|
||||||
|
Field::new(
|
||||||
|
"vector",
|
||||||
|
DataType::FixedSizeList(
|
||||||
|
Arc::new(Field::new("item", DataType::Float32, true)),
|
||||||
|
DIM as i32,
|
||||||
|
),
|
||||||
|
true,
|
||||||
|
),
|
||||||
|
]));
|
||||||
|
|
||||||
|
// Create a RecordBatch stream.
|
||||||
|
let batches = RecordBatchIterator::new(
|
||||||
|
vec![RecordBatch::try_new(
|
||||||
|
schema.clone(),
|
||||||
|
vec![
|
||||||
|
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
|
||||||
|
Arc::new(
|
||||||
|
FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||||
|
(0..TOTAL).map(|_| Some(vec![Some(1.0); DIM])),
|
||||||
|
DIM as i32,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.unwrap()]
|
||||||
|
.into_iter()
|
||||||
|
.map(Ok),
|
||||||
|
schema.clone(),
|
||||||
|
);
|
||||||
|
let tbl = db
|
||||||
|
.create_table("my_table", Box::new(batches), None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
// --8<-- [end:create_table]
|
||||||
|
|
||||||
|
let new_batches = RecordBatchIterator::new(
|
||||||
|
vec![RecordBatch::try_new(
|
||||||
|
schema.clone(),
|
||||||
|
vec![
|
||||||
|
Arc::new(Int32Array::from_iter_values(0..TOTAL as i32)),
|
||||||
|
Arc::new(
|
||||||
|
FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
|
||||||
|
(0..TOTAL).map(|_| Some(vec![Some(1.0); DIM])),
|
||||||
|
DIM as i32,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.unwrap()]
|
||||||
|
.into_iter()
|
||||||
|
.map(Ok),
|
||||||
|
schema.clone(),
|
||||||
|
);
|
||||||
|
// --8<-- [start:add]
|
||||||
|
tbl.add(Box::new(new_batches), None).await.unwrap();
|
||||||
|
// --8<-- [end:add]
|
||||||
|
|
||||||
|
Ok(tbl)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_empty_table(db: Arc<dyn Connection>) -> Result<TableRef> {
|
||||||
|
// --8<-- [start:create_empty_table]
|
||||||
|
let schema = Arc::new(Schema::new(vec![
|
||||||
|
Field::new("id", DataType::Int32, false),
|
||||||
|
Field::new("item", DataType::Utf8, true),
|
||||||
|
]));
|
||||||
|
let batches = RecordBatchIterator::new(vec![], schema.clone());
|
||||||
|
db.create_table("empty_table", Box::new(batches), None)
|
||||||
|
.await
|
||||||
|
// --8<-- [end:create_empty_table]
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_index(table: &dyn Table) -> Result<()> {
|
||||||
|
// --8<-- [start:create_index]
|
||||||
|
table
|
||||||
|
.create_index(&["vector"])
|
||||||
|
.ivf_pq()
|
||||||
|
.num_partitions(8)
|
||||||
|
.build()
|
||||||
|
.await
|
||||||
|
// --8<-- [end:create_index]
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn search(table: &dyn Table) -> Result<Vec<RecordBatch>> {
|
||||||
|
// --8<-- [start:search]
|
||||||
|
Ok(table
|
||||||
|
.search(&[1.0; 128])
|
||||||
|
.limit(2)
|
||||||
|
.execute_stream()
|
||||||
|
.await?
|
||||||
|
.try_collect::<Vec<_>>()
|
||||||
|
.await?)
|
||||||
|
// --8<-- [end:search]
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user