mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 13:29:57 +00:00
Compare commits
63 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
245786fed7 | ||
|
|
edd9a043f8 | ||
|
|
38c09fc294 | ||
|
|
ebaa2dede5 | ||
|
|
ba7618a026 | ||
|
|
a6bcbd007b | ||
|
|
5af74b5aca | ||
|
|
8a52619bc0 | ||
|
|
314d4c93e5 | ||
|
|
c5471ee694 | ||
|
|
4605359d3b | ||
|
|
f1596122e6 | ||
|
|
3aa0c40168 | ||
|
|
677b7c1fcc | ||
|
|
8303a7197b | ||
|
|
5fa9bfc4a8 | ||
|
|
bf2e9d0088 | ||
|
|
f04590ddad | ||
|
|
62c5117def | ||
|
|
22c196b3e3 | ||
|
|
1f4ac71fa3 | ||
|
|
b5aad2d856 | ||
|
|
ca6f55b160 | ||
|
|
6f8cf1e068 | ||
|
|
e0277383a5 | ||
|
|
d6b408e26f | ||
|
|
2447372c1f | ||
|
|
f0298d8372 | ||
|
|
54693e6bec | ||
|
|
73b2977bff | ||
|
|
aec85f7875 | ||
|
|
51f92ecb3d | ||
|
|
5b60412d66 | ||
|
|
53d63966a9 | ||
|
|
5ba87575e7 | ||
|
|
cc5f2136a6 | ||
|
|
78e5fb5451 | ||
|
|
8104c5c18e | ||
|
|
4fbabdeec3 | ||
|
|
eb31d95fef | ||
|
|
3169c36525 | ||
|
|
1b990983b3 | ||
|
|
0c21f91c16 | ||
|
|
7e50c239eb | ||
|
|
24e8043150 | ||
|
|
990440385d | ||
|
|
a693a9d897 | ||
|
|
82936c77ef | ||
|
|
dddcddcaf9 | ||
|
|
a9727eb318 | ||
|
|
48d55bf952 | ||
|
|
d2e71c8b08 | ||
|
|
f53aace89c | ||
|
|
d982ee934a | ||
|
|
57605a2d86 | ||
|
|
738511c5f2 | ||
|
|
0b0f42537e | ||
|
|
e412194008 | ||
|
|
a9088224c5 | ||
|
|
688c57a0d8 | ||
|
|
12a98deded | ||
|
|
e4bb042918 | ||
|
|
04e1662681 |
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.4.7
|
current_version = 0.4.11
|
||||||
commit = True
|
commit = True
|
||||||
message = Bump version: {current_version} → {new_version}
|
message = Bump version: {current_version} → {new_version}
|
||||||
tag = True
|
tag = True
|
||||||
@@ -9,4 +9,4 @@ tag_name = v{new_version}
|
|||||||
|
|
||||||
[bumpversion:file:rust/ffi/node/Cargo.toml]
|
[bumpversion:file:rust/ffi/node/Cargo.toml]
|
||||||
|
|
||||||
[bumpversion:file:rust/vectordb/Cargo.toml]
|
[bumpversion:file:rust/lancedb/Cargo.toml]
|
||||||
|
|||||||
40
.cargo/config.toml
Normal file
40
.cargo/config.toml
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
[profile.release]
|
||||||
|
lto = "fat"
|
||||||
|
codegen-units = 1
|
||||||
|
|
||||||
|
[profile.release-with-debug]
|
||||||
|
inherits = "release"
|
||||||
|
debug = true
|
||||||
|
# Prioritize compile time over runtime performance
|
||||||
|
codegen-units = 16
|
||||||
|
lto = "thin"
|
||||||
|
|
||||||
|
[target.'cfg(all())']
|
||||||
|
rustflags = [
|
||||||
|
"-Wclippy::all",
|
||||||
|
"-Wclippy::style",
|
||||||
|
"-Wclippy::fallible_impl_from",
|
||||||
|
"-Wclippy::manual_let_else",
|
||||||
|
"-Wclippy::redundant_pub_crate",
|
||||||
|
"-Wclippy::string_add_assign",
|
||||||
|
"-Wclippy::string_add",
|
||||||
|
"-Wclippy::string_lit_as_bytes",
|
||||||
|
"-Wclippy::string_to_string",
|
||||||
|
"-Wclippy::use_self",
|
||||||
|
"-Dclippy::cargo",
|
||||||
|
"-Dclippy::dbg_macro",
|
||||||
|
# not too much we can do to avoid multiple crate versions
|
||||||
|
"-Aclippy::multiple-crate-versions",
|
||||||
|
"-Aclippy::wildcard_dependencies",
|
||||||
|
]
|
||||||
|
|
||||||
|
[target.x86_64-unknown-linux-gnu]
|
||||||
|
rustflags = ["-C", "target-cpu=haswell", "-C", "target-feature=+avx2,+fma,+f16c"]
|
||||||
|
|
||||||
|
[target.aarch64-apple-darwin]
|
||||||
|
rustflags = ["-C", "target-cpu=apple-m1", "-C", "target-feature=+neon,+fp16,+fhm,+dotprod"]
|
||||||
|
|
||||||
|
# Not all Windows systems have the C runtime installed, so this avoids library
|
||||||
|
# not found errors on systems that are missing it.
|
||||||
|
[target.x86_64-pc-windows-msvc]
|
||||||
|
rustflags = ["-Ctarget-feature=+crt-static"]
|
||||||
58
.github/workflows/build_linux_wheel/action.yml
vendored
Normal file
58
.github/workflows/build_linux_wheel/action.yml
vendored
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
# We create a composite action to be re-used both for testing and for releasing
|
||||||
|
name: build-linux-wheel
|
||||||
|
description: "Build a manylinux wheel for lance"
|
||||||
|
inputs:
|
||||||
|
python-minor-version:
|
||||||
|
description: "8, 9, 10, 11, 12"
|
||||||
|
required: true
|
||||||
|
args:
|
||||||
|
description: "--release"
|
||||||
|
required: false
|
||||||
|
default: ""
|
||||||
|
arm-build:
|
||||||
|
description: "Build for arm64 instead of x86_64"
|
||||||
|
# Note: this does *not* mean the host is arm64, since we might be cross-compiling.
|
||||||
|
required: false
|
||||||
|
default: "false"
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: CONFIRM ARM BUILD
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
echo "ARM BUILD: ${{ inputs.arm-build }}"
|
||||||
|
- name: Build x86_64 Manylinux wheel
|
||||||
|
if: ${{ inputs.arm-build == 'false' }}
|
||||||
|
uses: PyO3/maturin-action@v1
|
||||||
|
with:
|
||||||
|
command: build
|
||||||
|
working-directory: python
|
||||||
|
target: x86_64-unknown-linux-gnu
|
||||||
|
manylinux: "2_17"
|
||||||
|
args: ${{ inputs.args }}
|
||||||
|
before-script-linux: |
|
||||||
|
set -e
|
||||||
|
yum install -y openssl-devel \
|
||||||
|
&& curl -L https://github.com/protocolbuffers/protobuf/releases/download/v24.4/protoc-24.4-linux-$(uname -m).zip > /tmp/protoc.zip \
|
||||||
|
&& unzip /tmp/protoc.zip -d /usr/local \
|
||||||
|
&& rm /tmp/protoc.zip
|
||||||
|
- name: Build Arm Manylinux Wheel
|
||||||
|
if: ${{ inputs.arm-build == 'true' }}
|
||||||
|
uses: PyO3/maturin-action@v1
|
||||||
|
with:
|
||||||
|
command: build
|
||||||
|
working-directory: python
|
||||||
|
target: aarch64-unknown-linux-gnu
|
||||||
|
manylinux: "2_24"
|
||||||
|
args: ${{ inputs.args }}
|
||||||
|
before-script-linux: |
|
||||||
|
set -e
|
||||||
|
apt install -y unzip
|
||||||
|
if [ $(uname -m) = "x86_64" ]; then
|
||||||
|
PROTOC_ARCH="x86_64"
|
||||||
|
else
|
||||||
|
PROTOC_ARCH="aarch_64"
|
||||||
|
fi
|
||||||
|
curl -L https://github.com/protocolbuffers/protobuf/releases/download/v24.4/protoc-24.4-linux-$PROTOC_ARCH.zip > /tmp/protoc.zip \
|
||||||
|
&& unzip /tmp/protoc.zip -d /usr/local \
|
||||||
|
&& rm /tmp/protoc.zip
|
||||||
25
.github/workflows/build_mac_wheel/action.yml
vendored
Normal file
25
.github/workflows/build_mac_wheel/action.yml
vendored
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
# We create a composite action to be re-used both for testing and for releasing
|
||||||
|
name: build_wheel
|
||||||
|
description: "Build a lance wheel"
|
||||||
|
inputs:
|
||||||
|
python-minor-version:
|
||||||
|
description: "8, 9, 10, 11"
|
||||||
|
required: true
|
||||||
|
args:
|
||||||
|
description: "--release"
|
||||||
|
required: false
|
||||||
|
default: ""
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Install macos dependency
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
brew install protobuf
|
||||||
|
- name: Build wheel
|
||||||
|
uses: PyO3/maturin-action@v1
|
||||||
|
with:
|
||||||
|
command: build
|
||||||
|
args: ${{ inputs.args }}
|
||||||
|
working-directory: python
|
||||||
|
interpreter: 3.${{ inputs.python-minor-version }}
|
||||||
33
.github/workflows/build_windows_wheel/action.yml
vendored
Normal file
33
.github/workflows/build_windows_wheel/action.yml
vendored
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
# We create a composite action to be re-used both for testing and for releasing
|
||||||
|
name: build_wheel
|
||||||
|
description: "Build a lance wheel"
|
||||||
|
inputs:
|
||||||
|
python-minor-version:
|
||||||
|
description: "8, 9, 10, 11"
|
||||||
|
required: true
|
||||||
|
args:
|
||||||
|
description: "--release"
|
||||||
|
required: false
|
||||||
|
default: ""
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Install Protoc v21.12
|
||||||
|
working-directory: C:\
|
||||||
|
run: |
|
||||||
|
New-Item -Path 'C:\protoc' -ItemType Directory
|
||||||
|
Set-Location C:\protoc
|
||||||
|
Invoke-WebRequest https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip -OutFile C:\protoc\protoc.zip
|
||||||
|
7z x protoc.zip
|
||||||
|
Add-Content $env:GITHUB_PATH "C:\protoc\bin"
|
||||||
|
shell: powershell
|
||||||
|
- name: Build wheel
|
||||||
|
uses: PyO3/maturin-action@v1
|
||||||
|
with:
|
||||||
|
command: build
|
||||||
|
args: ${{ inputs.args }}
|
||||||
|
working-directory: python
|
||||||
|
- uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: windows-wheels
|
||||||
|
path: python\target\wheels
|
||||||
2
.github/workflows/cargo-publish.yml
vendored
2
.github/workflows/cargo-publish.yml
vendored
@@ -26,4 +26,4 @@ jobs:
|
|||||||
sudo apt install -y protobuf-compiler libssl-dev
|
sudo apt install -y protobuf-compiler libssl-dev
|
||||||
- name: Publish the package
|
- name: Publish the package
|
||||||
run: |
|
run: |
|
||||||
cargo publish -p vectordb --all-features --token ${{ secrets.CARGO_REGISTRY_TOKEN }}
|
cargo publish -p lancedb --all-features --token ${{ secrets.CARGO_REGISTRY_TOKEN }}
|
||||||
|
|||||||
9
.github/workflows/docs_test.yml
vendored
9
.github/workflows/docs_test.yml
vendored
@@ -49,6 +49,9 @@ jobs:
|
|||||||
test-node:
|
test-node:
|
||||||
name: Test doc nodejs code
|
name: Test doc nodejs code
|
||||||
runs-on: "ubuntu-latest"
|
runs-on: "ubuntu-latest"
|
||||||
|
timeout-minutes: 45
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -66,6 +69,12 @@ jobs:
|
|||||||
uses: swatinem/rust-cache@v2
|
uses: swatinem/rust-cache@v2
|
||||||
- name: Install node dependencies
|
- name: Install node dependencies
|
||||||
run: |
|
run: |
|
||||||
|
sudo swapoff -a
|
||||||
|
sudo fallocate -l 8G /swapfile
|
||||||
|
sudo chmod 600 /swapfile
|
||||||
|
sudo mkswap /swapfile
|
||||||
|
sudo swapon /swapfile
|
||||||
|
sudo swapon --show
|
||||||
cd node
|
cd node
|
||||||
npm ci
|
npm ci
|
||||||
npm run build-release
|
npm run build-release
|
||||||
|
|||||||
17
.github/workflows/npm-publish.yml
vendored
17
.github/workflows/npm-publish.yml
vendored
@@ -80,10 +80,25 @@ jobs:
|
|||||||
- arch: x86_64
|
- arch: x86_64
|
||||||
runner: ubuntu-latest
|
runner: ubuntu-latest
|
||||||
- arch: aarch64
|
- arch: aarch64
|
||||||
runner: buildjet-4vcpu-ubuntu-2204-arm
|
# For successful fat LTO builds, we need a large runner to avoid OOM errors.
|
||||||
|
runner: buildjet-16vcpu-ubuntu-2204-arm
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
# Buildjet aarch64 runners have only 1.5 GB RAM per core, vs 3.5 GB per core for
|
||||||
|
# x86_64 runners. To avoid OOM errors on ARM, we create a swap file.
|
||||||
|
- name: Configure aarch64 build
|
||||||
|
if: ${{ matrix.config.arch == 'aarch64' }}
|
||||||
|
run: |
|
||||||
|
free -h
|
||||||
|
sudo fallocate -l 16G /swapfile
|
||||||
|
sudo chmod 600 /swapfile
|
||||||
|
sudo mkswap /swapfile
|
||||||
|
sudo swapon /swapfile
|
||||||
|
echo "/swapfile swap swap defaults 0 0" >> sudo /etc/fstab
|
||||||
|
# print info
|
||||||
|
swapon --show
|
||||||
|
free -h
|
||||||
- name: Build Linux Artifacts
|
- name: Build Linux Artifacts
|
||||||
run: |
|
run: |
|
||||||
bash ci/build_linux_artifacts.sh ${{ matrix.config.arch }}
|
bash ci/build_linux_artifacts.sh ${{ matrix.config.arch }}
|
||||||
|
|||||||
101
.github/workflows/pypi-publish.yml
vendored
101
.github/workflows/pypi-publish.yml
vendored
@@ -2,30 +2,91 @@ name: PyPI Publish
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
release:
|
release:
|
||||||
types: [ published ]
|
types: [published]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
publish:
|
linux:
|
||||||
runs-on: ubuntu-latest
|
timeout-minutes: 60
|
||||||
# Only runs on tags that matches the python-make-release action
|
strategy:
|
||||||
if: startsWith(github.ref, 'refs/tags/python-v')
|
matrix:
|
||||||
defaults:
|
python-minor-version: ["8"]
|
||||||
run:
|
platform:
|
||||||
shell: bash
|
- x86_64
|
||||||
working-directory: python
|
- aarch64
|
||||||
|
runs-on: "ubuntu-22.04"
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
lfs: true
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: "3.8"
|
python-version: 3.${{ matrix.python-minor-version }}
|
||||||
- name: Build distribution
|
- uses: ./.github/workflows/build_linux_wheel
|
||||||
run: |
|
|
||||||
ls -la
|
|
||||||
pip install wheel setuptools --upgrade
|
|
||||||
python setup.py sdist bdist_wheel
|
|
||||||
- name: Publish
|
|
||||||
uses: pypa/gh-action-pypi-publish@v1.8.5
|
|
||||||
with:
|
with:
|
||||||
password: ${{ secrets.LANCEDB_PYPI_API_TOKEN }}
|
python-minor-version: ${{ matrix.python-minor-version }}
|
||||||
packages-dir: python/dist
|
args: "--release --strip"
|
||||||
|
arm-build: ${{ matrix.platform == 'aarch64' }}
|
||||||
|
- uses: ./.github/workflows/upload_wheel
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.LANCEDB_PYPI_API_TOKEN }}
|
||||||
|
repo: "pypi"
|
||||||
|
mac:
|
||||||
|
timeout-minutes: 60
|
||||||
|
runs-on: ${{ matrix.config.runner }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-minor-version: ["8"]
|
||||||
|
config:
|
||||||
|
- target: x86_64-apple-darwin
|
||||||
|
runner: macos-13
|
||||||
|
- target: aarch64-apple-darwin
|
||||||
|
runner: macos-14
|
||||||
|
env:
|
||||||
|
MACOSX_DEPLOYMENT_TARGET: 10.15
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
fetch-depth: 0
|
||||||
|
lfs: true
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: 3.12
|
||||||
|
- uses: ./.github/workflows/build_mac_wheel
|
||||||
|
with:
|
||||||
|
python-minor-version: ${{ matrix.python-minor-version }}
|
||||||
|
args: "--release --strip --target ${{ matrix.config.target }}"
|
||||||
|
- uses: ./.github/workflows/upload_wheel
|
||||||
|
with:
|
||||||
|
python-minor-version: ${{ matrix.python-minor-version }}
|
||||||
|
token: ${{ secrets.LANCEDB_PYPI_API_TOKEN }}
|
||||||
|
repo: "pypi"
|
||||||
|
windows:
|
||||||
|
timeout-minutes: 60
|
||||||
|
runs-on: windows-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-minor-version: ["8"]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
fetch-depth: 0
|
||||||
|
lfs: true
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: 3.${{ matrix.python-minor-version }}
|
||||||
|
- uses: ./.github/workflows/build_windows_wheel
|
||||||
|
with:
|
||||||
|
python-minor-version: ${{ matrix.python-minor-version }}
|
||||||
|
args: "--release --strip"
|
||||||
|
vcpkg_token: ${{ secrets.VCPKG_GITHUB_PACKAGES }}
|
||||||
|
- uses: ./.github/workflows/upload_wheel
|
||||||
|
with:
|
||||||
|
python-minor-version: ${{ matrix.python-minor-version }}
|
||||||
|
token: ${{ secrets.LANCEDB_PYPI_API_TOKEN }}
|
||||||
|
repo: "pypi"
|
||||||
|
|||||||
210
.github/workflows/python.yml
vendored
210
.github/workflows/python.yml
vendored
@@ -14,49 +14,133 @@ concurrency:
|
|||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
linux:
|
lint:
|
||||||
|
name: "Lint"
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
python-minor-version: [ "8", "11" ]
|
|
||||||
runs-on: "ubuntu-22.04"
|
runs-on: "ubuntu-22.04"
|
||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
shell: bash
|
shell: bash
|
||||||
working-directory: python
|
working-directory: python
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: 3.${{ matrix.python-minor-version }}
|
python-version: "3.11"
|
||||||
- name: Install lancedb
|
- name: Install ruff
|
||||||
run: |
|
run: |
|
||||||
pip install -e .[tests]
|
pip install ruff
|
||||||
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
|
- name: Format check
|
||||||
pip install pytest pytest-mock ruff
|
run: ruff format --check .
|
||||||
- name: Format check
|
- name: Lint
|
||||||
run: ruff format --check .
|
run: ruff .
|
||||||
- name: Lint
|
doctest:
|
||||||
run: ruff .
|
name: "Doctest"
|
||||||
- name: Run tests
|
timeout-minutes: 30
|
||||||
run: pytest -m "not slow" -x -v --durations=30 tests
|
runs-on: "ubuntu-22.04"
|
||||||
- name: doctest
|
defaults:
|
||||||
run: pytest --doctest-modules lancedb
|
run:
|
||||||
|
shell: bash
|
||||||
|
working-directory: python
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
lfs: true
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
|
cache: "pip"
|
||||||
|
- name: Install protobuf
|
||||||
|
run: |
|
||||||
|
sudo apt update
|
||||||
|
sudo apt install -y protobuf-compiler
|
||||||
|
- uses: Swatinem/rust-cache@v2
|
||||||
|
with:
|
||||||
|
workspaces: python
|
||||||
|
- name: Install
|
||||||
|
run: |
|
||||||
|
pip install -e .[tests,dev,embeddings]
|
||||||
|
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
|
||||||
|
pip install mlx
|
||||||
|
- name: Doctest
|
||||||
|
run: pytest --doctest-modules python/lancedb
|
||||||
|
linux:
|
||||||
|
name: "Linux: python-3.${{ matrix.python-minor-version }}"
|
||||||
|
timeout-minutes: 30
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-minor-version: ["8", "11"]
|
||||||
|
runs-on: "ubuntu-22.04"
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
shell: bash
|
||||||
|
working-directory: python
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
lfs: true
|
||||||
|
- name: Install protobuf
|
||||||
|
run: |
|
||||||
|
sudo apt update
|
||||||
|
sudo apt install -y protobuf-compiler
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: 3.${{ matrix.python-minor-version }}
|
||||||
|
- uses: Swatinem/rust-cache@v2
|
||||||
|
with:
|
||||||
|
workspaces: python
|
||||||
|
- uses: ./.github/workflows/build_linux_wheel
|
||||||
|
- uses: ./.github/workflows/run_tests
|
||||||
|
# Make sure wheels are not included in the Rust cache
|
||||||
|
- name: Delete wheels
|
||||||
|
run: rm -rf target/wheels
|
||||||
platform:
|
platform:
|
||||||
name: "Platform: ${{ matrix.config.name }}"
|
name: "Mac: ${{ matrix.config.name }}"
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
config:
|
config:
|
||||||
- name: x86 Mac
|
- name: x86
|
||||||
runner: macos-13
|
runner: macos-13
|
||||||
- name: Arm Mac
|
- name: Arm
|
||||||
runner: macos-14
|
runner: macos-14
|
||||||
- name: x86 Windows
|
runs-on: "${{ matrix.config.runner }}"
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
shell: bash
|
||||||
|
working-directory: python
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
lfs: true
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
|
- uses: Swatinem/rust-cache@v2
|
||||||
|
with:
|
||||||
|
workspaces: python
|
||||||
|
- uses: ./.github/workflows/build_mac_wheel
|
||||||
|
- uses: ./.github/workflows/run_tests
|
||||||
|
# Make sure wheels are not included in the Rust cache
|
||||||
|
- name: Delete wheels
|
||||||
|
run: rm -rf target/wheels
|
||||||
|
windows:
|
||||||
|
name: "Windows: ${{ matrix.config.name }}"
|
||||||
|
timeout-minutes: 30
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
config:
|
||||||
|
- name: x86
|
||||||
runner: windows-latest
|
runner: windows-latest
|
||||||
runs-on: "${{ matrix.config.runner }}"
|
runs-on: "${{ matrix.config.runner }}"
|
||||||
defaults:
|
defaults:
|
||||||
@@ -64,21 +148,22 @@ jobs:
|
|||||||
shell: bash
|
shell: bash
|
||||||
working-directory: python
|
working-directory: python
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
- name: Install lancedb
|
- uses: Swatinem/rust-cache@v2
|
||||||
run: |
|
with:
|
||||||
pip install -e .[tests]
|
workspaces: python
|
||||||
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
|
- uses: ./.github/workflows/build_windows_wheel
|
||||||
pip install pytest pytest-mock
|
- uses: ./.github/workflows/run_tests
|
||||||
- name: Run tests
|
# Make sure wheels are not included in the Rust cache
|
||||||
run: pytest -m "not slow" -x -v --durations=30 tests
|
- name: Delete wheels
|
||||||
|
run: rm -rf target/wheels
|
||||||
pydantic1x:
|
pydantic1x:
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
runs-on: "ubuntu-22.04"
|
runs-on: "ubuntu-22.04"
|
||||||
@@ -87,21 +172,22 @@ jobs:
|
|||||||
shell: bash
|
shell: bash
|
||||||
working-directory: python
|
working-directory: python
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
lfs: true
|
lfs: true
|
||||||
- name: Set up Python
|
- name: Install dependencies
|
||||||
uses: actions/setup-python@v5
|
run: |
|
||||||
with:
|
sudo apt update
|
||||||
python-version: 3.9
|
sudo apt install -y protobuf-compiler
|
||||||
- name: Install lancedb
|
- name: Set up Python
|
||||||
run: |
|
uses: actions/setup-python@v5
|
||||||
pip install "pydantic<2"
|
with:
|
||||||
pip install -e .[tests]
|
python-version: 3.9
|
||||||
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
|
- name: Install lancedb
|
||||||
pip install pytest pytest-mock
|
run: |
|
||||||
- name: Run tests
|
pip install "pydantic<2"
|
||||||
run: pytest -m "not slow" -x -v --durations=30 tests
|
pip install -e .[tests]
|
||||||
- name: doctest
|
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
|
||||||
run: pytest --doctest-modules lancedb
|
- name: Run tests
|
||||||
|
run: pytest -m "not slow" -x -v --durations=30 python/tests
|
||||||
|
|||||||
17
.github/workflows/run_tests/action.yml
vendored
Normal file
17
.github/workflows/run_tests/action.yml
vendored
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
name: run-tests
|
||||||
|
|
||||||
|
description: "Install lance wheel and run unit tests"
|
||||||
|
inputs:
|
||||||
|
python-minor-version:
|
||||||
|
required: true
|
||||||
|
description: "8 9 10 11 12"
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Install lancedb
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
pip3 install $(ls target/wheels/lancedb-*.whl)[tests,dev,embeddings]
|
||||||
|
- name: pytest
|
||||||
|
shell: bash
|
||||||
|
run: pytest -m "not slow" -x -v --durations=30 python/python/tests
|
||||||
29
.github/workflows/upload_wheel/action.yml
vendored
Normal file
29
.github/workflows/upload_wheel/action.yml
vendored
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
name: upload-wheel
|
||||||
|
|
||||||
|
description: "Upload wheels to Pypi"
|
||||||
|
inputs:
|
||||||
|
os:
|
||||||
|
required: true
|
||||||
|
description: "ubuntu-22.04 or macos-13"
|
||||||
|
repo:
|
||||||
|
required: false
|
||||||
|
description: "pypi or testpypi"
|
||||||
|
default: "pypi"
|
||||||
|
token:
|
||||||
|
required: true
|
||||||
|
description: "release token for the repo"
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Install dependencies
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install twine
|
||||||
|
- name: Publish wheel
|
||||||
|
env:
|
||||||
|
TWINE_USERNAME: __token__
|
||||||
|
TWINE_PASSWORD: ${{ inputs.token }}
|
||||||
|
shell: bash
|
||||||
|
run: twine upload --repository ${{ inputs.repo }} target/wheels/lancedb-*.whl
|
||||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -22,6 +22,11 @@ python/dist
|
|||||||
|
|
||||||
**/.hypothesis
|
**/.hypothesis
|
||||||
|
|
||||||
|
# Compiled Dynamic libraries
|
||||||
|
*.so
|
||||||
|
*.dylib
|
||||||
|
*.dll
|
||||||
|
|
||||||
## Javascript
|
## Javascript
|
||||||
*.node
|
*.node
|
||||||
**/node_modules
|
**/node_modules
|
||||||
|
|||||||
15
Cargo.toml
15
Cargo.toml
@@ -1,20 +1,23 @@
|
|||||||
[workspace]
|
[workspace]
|
||||||
members = ["rust/ffi/node", "rust/vectordb", "nodejs"]
|
members = ["rust/ffi/node", "rust/lancedb", "nodejs", "python"]
|
||||||
# Python package needs to be built by maturin.
|
# Python package needs to be built by maturin.
|
||||||
exclude = ["python"]
|
exclude = ["python"]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Lance Devs <dev@lancedb.com>"]
|
authors = ["LanceDB Devs <dev@lancedb.com>"]
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
repository = "https://github.com/lancedb/lancedb"
|
repository = "https://github.com/lancedb/lancedb"
|
||||||
|
description = "Serverless, low-latency vector database for AI applications"
|
||||||
|
keywords = ["lancedb", "lance", "database", "vector", "search"]
|
||||||
|
categories = ["database-implementations"]
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
lance = { "version" = "=0.9.12", "features" = ["dynamodb"] }
|
lance = { "version" = "=0.10.1", "features" = ["dynamodb"] }
|
||||||
lance-index = { "version" = "=0.9.12" }
|
lance-index = { "version" = "=0.10.1" }
|
||||||
lance-linalg = { "version" = "=0.9.12" }
|
lance-linalg = { "version" = "=0.10.1" }
|
||||||
lance-testing = { "version" = "=0.9.12" }
|
lance-testing = { "version" = "=0.10.1" }
|
||||||
# Note that this one does not include pyarrow
|
# Note that this one does not include pyarrow
|
||||||
arrow = { version = "50.0", optional = false }
|
arrow = { version = "50.0", optional = false }
|
||||||
arrow-array = "50.0"
|
arrow-array = "50.0"
|
||||||
|
|||||||
@@ -13,7 +13,9 @@ docker build \
|
|||||||
.
|
.
|
||||||
popd
|
popd
|
||||||
|
|
||||||
|
# We turn on memory swap to avoid OOM killer
|
||||||
docker run \
|
docker run \
|
||||||
-v $(pwd):/io -w /io \
|
-v $(pwd):/io -w /io \
|
||||||
|
--memory-swap=-1 \
|
||||||
lancedb-node-manylinux \
|
lancedb-node-manylinux \
|
||||||
bash ci/manylinux_node/build.sh $ARCH
|
bash ci/manylinux_node/build.sh $ARCH
|
||||||
|
|||||||
27
dockerfiles/Dockerfile
Normal file
27
dockerfiles/Dockerfile
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
#Simple base dockerfile that supports basic dependencies required to run lance with FTS and Hybrid Search
|
||||||
|
#Usage docker build -t lancedb:latest -f Dockerfile .
|
||||||
|
FROM python:3.10-slim-buster
|
||||||
|
|
||||||
|
# Install Rust
|
||||||
|
RUN apt-get update && apt-get install -y curl build-essential && \
|
||||||
|
curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||||
|
|
||||||
|
# Set the environment variable for Rust
|
||||||
|
ENV PATH="/root/.cargo/bin:${PATH}"
|
||||||
|
|
||||||
|
# Install protobuf compiler
|
||||||
|
RUN apt-get install -y protobuf-compiler && \
|
||||||
|
apt-get clean && \
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
RUN apt-get -y update &&\
|
||||||
|
apt-get -y upgrade && \
|
||||||
|
apt-get -y install git
|
||||||
|
|
||||||
|
|
||||||
|
# Verify installations
|
||||||
|
RUN python --version && \
|
||||||
|
rustc --version && \
|
||||||
|
protoc --version
|
||||||
|
|
||||||
|
RUN pip install tantivy lancedb
|
||||||
@@ -57,6 +57,16 @@ plugins:
|
|||||||
- https://arrow.apache.org/docs/objects.inv
|
- https://arrow.apache.org/docs/objects.inv
|
||||||
- https://pandas.pydata.org/docs/objects.inv
|
- https://pandas.pydata.org/docs/objects.inv
|
||||||
- mkdocs-jupyter
|
- mkdocs-jupyter
|
||||||
|
- ultralytics:
|
||||||
|
verbose: True
|
||||||
|
enabled: True
|
||||||
|
default_image: "assets/lancedb_and_lance.png" # Default image for all pages
|
||||||
|
add_image: True # Automatically add meta image
|
||||||
|
add_keywords: True # Add page keywords in the header tag
|
||||||
|
add_share_buttons: True # Add social share buttons
|
||||||
|
add_authors: False # Display page authors
|
||||||
|
add_desc: False
|
||||||
|
add_dates: False
|
||||||
|
|
||||||
markdown_extensions:
|
markdown_extensions:
|
||||||
- admonition
|
- admonition
|
||||||
@@ -90,16 +100,18 @@ nav:
|
|||||||
- Building an ANN index: ann_indexes.md
|
- Building an ANN index: ann_indexes.md
|
||||||
- Vector Search: search.md
|
- Vector Search: search.md
|
||||||
- Full-text search: fts.md
|
- Full-text search: fts.md
|
||||||
- Hybrid search: hybrid_search.md
|
- Hybrid search:
|
||||||
|
- Overview: hybrid_search/hybrid_search.md
|
||||||
|
- Comparing Rerankers: hybrid_search/eval.md
|
||||||
|
- Airbnb financial data example: notebooks/hybrid_search.ipynb
|
||||||
- Filtering: sql.md
|
- Filtering: sql.md
|
||||||
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||||
- Configuring Storage: guides/storage.md
|
- Configuring Storage: guides/storage.md
|
||||||
- 🧬 Managing embeddings:
|
- 🧬 Managing embeddings:
|
||||||
- Overview: embeddings/index.md
|
- Overview: embeddings/index.md
|
||||||
- Explicit management: embeddings/embedding_explicit.md
|
- Embedding functions: embeddings/embedding_functions.md
|
||||||
- Implicit management: embeddings/embedding_functions.md
|
- Available models: embeddings/default_embedding_functions.md
|
||||||
- Available Functions: embeddings/default_embedding_functions.md
|
- User-defined embedding functions: embeddings/custom_embedding_function.md
|
||||||
- Custom Embedding Functions: embeddings/api.md
|
|
||||||
- "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb
|
- "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb
|
||||||
- "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb
|
- "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb
|
||||||
- 🔌 Integrations:
|
- 🔌 Integrations:
|
||||||
@@ -152,16 +164,18 @@ nav:
|
|||||||
- Building an ANN index: ann_indexes.md
|
- Building an ANN index: ann_indexes.md
|
||||||
- Vector Search: search.md
|
- Vector Search: search.md
|
||||||
- Full-text search: fts.md
|
- Full-text search: fts.md
|
||||||
- Hybrid search: hybrid_search.md
|
- Hybrid search:
|
||||||
|
- Overview: hybrid_search/hybrid_search.md
|
||||||
|
- Comparing Rerankers: hybrid_search/eval.md
|
||||||
|
- Airbnb financial data example: notebooks/hybrid_search.ipynb
|
||||||
- Filtering: sql.md
|
- Filtering: sql.md
|
||||||
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||||
- Configuring Storage: guides/storage.md
|
- Configuring Storage: guides/storage.md
|
||||||
- Managing Embeddings:
|
- Managing Embeddings:
|
||||||
- Overview: embeddings/index.md
|
- Overview: embeddings/index.md
|
||||||
- Explicit management: embeddings/embedding_explicit.md
|
- Embedding functions: embeddings/embedding_functions.md
|
||||||
- Implicit management: embeddings/embedding_functions.md
|
- Available models: embeddings/default_embedding_functions.md
|
||||||
- Available Functions: embeddings/default_embedding_functions.md
|
- User-defined embedding functions: embeddings/custom_embedding_function.md
|
||||||
- Custom Embedding Functions: embeddings/api.md
|
|
||||||
- "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb
|
- "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb
|
||||||
- "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb
|
- "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb
|
||||||
- Integrations:
|
- Integrations:
|
||||||
|
|||||||
@@ -2,4 +2,5 @@ mkdocs==1.5.3
|
|||||||
mkdocs-jupyter==0.24.1
|
mkdocs-jupyter==0.24.1
|
||||||
mkdocs-material==9.5.3
|
mkdocs-material==9.5.3
|
||||||
mkdocstrings[python]==0.20.0
|
mkdocstrings[python]==0.20.0
|
||||||
pydantic
|
pydantic
|
||||||
|
mkdocs-ultralytics-plugin==0.0.44
|
||||||
@@ -17,6 +17,7 @@ Let's implement `SentenceTransformerEmbeddings` class. All you need to do is imp
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
from lancedb.embeddings import register
|
from lancedb.embeddings import register
|
||||||
|
from lancedb.util import attempt_import_or_raise
|
||||||
|
|
||||||
@register("sentence-transformers")
|
@register("sentence-transformers")
|
||||||
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||||
@@ -81,7 +82,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
open_clip = self.safe_import("open_clip", "open-clip") # EmbeddingFunction util to import external libs and raise if not found
|
open_clip = attempt_import_or_raise("open_clip", "open-clip") # EmbeddingFunction util to import external libs and raise if not found
|
||||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||||
self.name, pretrained=self.pretrained
|
self.name, pretrained=self.pretrained
|
||||||
)
|
)
|
||||||
@@ -109,14 +110,14 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
|||||||
if isinstance(query, str):
|
if isinstance(query, str):
|
||||||
return [self.generate_text_embeddings(query)]
|
return [self.generate_text_embeddings(query)]
|
||||||
else:
|
else:
|
||||||
PIL = self.safe_import("PIL", "pillow")
|
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||||
if isinstance(query, PIL.Image.Image):
|
if isinstance(query, PIL.Image.Image):
|
||||||
return [self.generate_image_embedding(query)]
|
return [self.generate_image_embedding(query)]
|
||||||
else:
|
else:
|
||||||
raise TypeError("OpenClip supports str or PIL Image as query")
|
raise TypeError("OpenClip supports str or PIL Image as query")
|
||||||
|
|
||||||
def generate_text_embeddings(self, text: str) -> np.ndarray:
|
def generate_text_embeddings(self, text: str) -> np.ndarray:
|
||||||
torch = self.safe_import("torch")
|
torch = attempt_import_or_raise("torch")
|
||||||
text = self.sanitize_input(text)
|
text = self.sanitize_input(text)
|
||||||
text = self._tokenizer(text)
|
text = self._tokenizer(text)
|
||||||
text.to(self.device)
|
text.to(self.device)
|
||||||
@@ -175,7 +176,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
|||||||
The image to embed. If the image is a str, it is treated as a uri.
|
The image to embed. If the image is a str, it is treated as a uri.
|
||||||
If the image is bytes, it is treated as the raw image bytes.
|
If the image is bytes, it is treated as the raw image bytes.
|
||||||
"""
|
"""
|
||||||
torch = self.safe_import("torch")
|
torch = attempt_import_or_raise("torch")
|
||||||
# TODO handle retry and errors for https
|
# TODO handle retry and errors for https
|
||||||
image = self._to_pil(image)
|
image = self._to_pil(image)
|
||||||
image = self._preprocess(image).unsqueeze(0)
|
image = self._preprocess(image).unsqueeze(0)
|
||||||
@@ -183,7 +184,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
|||||||
return self._encode_and_normalize_image(image)
|
return self._encode_and_normalize_image(image)
|
||||||
|
|
||||||
def _to_pil(self, image: Union[str, bytes]):
|
def _to_pil(self, image: Union[str, bytes]):
|
||||||
PIL = self.safe_import("PIL", "pillow")
|
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||||
if isinstance(image, bytes):
|
if isinstance(image, bytes):
|
||||||
return PIL.Image.open(io.BytesIO(image))
|
return PIL.Image.open(io.BytesIO(image))
|
||||||
if isinstance(image, PIL.Image.Image):
|
if isinstance(image, PIL.Image.Image):
|
||||||
@@ -9,6 +9,9 @@ Contains the text embedding functions registered by default.
|
|||||||
### Sentence transformers
|
### Sentence transformers
|
||||||
Allows you to set parameters when registering a `sentence-transformers` object.
|
Allows you to set parameters when registering a `sentence-transformers` object.
|
||||||
|
|
||||||
|
!!! info
|
||||||
|
Sentence transformer embeddings are normalized by default. It is recommended to use normalized embeddings for similarity search.
|
||||||
|
|
||||||
| Parameter | Type | Default Value | Description |
|
| Parameter | Type | Default Value | Description |
|
||||||
|---|---|---|---|
|
|---|---|---|---|
|
||||||
| `name` | `str` | `all-MiniLM-L6-v2` | The name of the model |
|
| `name` | `str` | `all-MiniLM-L6-v2` | The name of the model |
|
||||||
|
|||||||
@@ -1,141 +0,0 @@
|
|||||||
In this workflow, you define your own embedding function and pass it as a callable to LanceDB, invoking it in your code to generate the embeddings. Let's look at some examples.
|
|
||||||
|
|
||||||
### Hugging Face
|
|
||||||
|
|
||||||
!!! note
|
|
||||||
Currently, the Hugging Face method is only supported in the Python SDK.
|
|
||||||
|
|
||||||
=== "Python"
|
|
||||||
The most popular open source option is to use the [sentence-transformers](https://www.sbert.net/)
|
|
||||||
library, which can be installed via pip.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install sentence-transformers
|
|
||||||
```
|
|
||||||
|
|
||||||
The example below shows how to use the `paraphrase-albert-small-v2` model to generate embeddings
|
|
||||||
for a given document.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
|
|
||||||
name="paraphrase-albert-small-v2"
|
|
||||||
model = SentenceTransformer(name)
|
|
||||||
|
|
||||||
# used for both training and querying
|
|
||||||
def embed_func(batch):
|
|
||||||
return [model.encode(sentence) for sentence in batch]
|
|
||||||
```
|
|
||||||
|
|
||||||
### OpenAI
|
|
||||||
|
|
||||||
Another popular alternative is to use an external API like OpenAI's [embeddings API](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings).
|
|
||||||
|
|
||||||
=== "Python"
|
|
||||||
```python
|
|
||||||
import openai
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Configuring the environment variable OPENAI_API_KEY
|
|
||||||
if "OPENAI_API_KEY" not in os.environ:
|
|
||||||
# OR set the key here as a variable
|
|
||||||
openai.api_key = "sk-..."
|
|
||||||
|
|
||||||
# verify that the API key is working
|
|
||||||
assert len(openai.Model.list()["data"]) > 0
|
|
||||||
|
|
||||||
def embed_func(c):
|
|
||||||
rs = openai.Embedding.create(input=c, engine="text-embedding-ada-002")
|
|
||||||
return [record["embedding"] for record in rs["data"]]
|
|
||||||
```
|
|
||||||
|
|
||||||
=== "JavaScript"
|
|
||||||
```javascript
|
|
||||||
const lancedb = require("vectordb");
|
|
||||||
|
|
||||||
// You need to provide an OpenAI API key
|
|
||||||
const apiKey = "sk-..."
|
|
||||||
// The embedding function will create embeddings for the 'text' column
|
|
||||||
const embedding = new lancedb.OpenAIEmbeddingFunction('text', apiKey)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Applying an embedding function to data
|
|
||||||
|
|
||||||
=== "Python"
|
|
||||||
Using an embedding function, you can apply it to raw data
|
|
||||||
to generate embeddings for each record.
|
|
||||||
|
|
||||||
Say you have a pandas DataFrame with a `text` column that you want embedded,
|
|
||||||
you can use the `with_embeddings` function to generate embeddings and add them to
|
|
||||||
an existing table.
|
|
||||||
|
|
||||||
```python
|
|
||||||
import pandas as pd
|
|
||||||
from lancedb.embeddings import with_embeddings
|
|
||||||
|
|
||||||
df = pd.DataFrame(
|
|
||||||
[
|
|
||||||
{"text": "pepperoni"},
|
|
||||||
{"text": "pineapple"}
|
|
||||||
]
|
|
||||||
)
|
|
||||||
data = with_embeddings(embed_func, df)
|
|
||||||
|
|
||||||
# The output is used to create / append to a table
|
|
||||||
# db.create_table("my_table", data=data)
|
|
||||||
```
|
|
||||||
|
|
||||||
If your data is in a different column, you can specify the `column` kwarg to `with_embeddings`.
|
|
||||||
|
|
||||||
By default, LanceDB calls the function with batches of 1000 rows. This can be configured
|
|
||||||
using the `batch_size` parameter to `with_embeddings`.
|
|
||||||
|
|
||||||
LanceDB automatically wraps the function with retry and rate-limit logic to ensure the OpenAI
|
|
||||||
API call is reliable.
|
|
||||||
|
|
||||||
=== "JavaScript"
|
|
||||||
Using an embedding function, you can apply it to raw data
|
|
||||||
to generate embeddings for each record.
|
|
||||||
|
|
||||||
Simply pass the embedding function created above and LanceDB will use it to generate
|
|
||||||
embeddings for your data.
|
|
||||||
|
|
||||||
```javascript
|
|
||||||
const db = await lancedb.connect("data/sample-lancedb");
|
|
||||||
const data = [
|
|
||||||
{ text: "pepperoni"},
|
|
||||||
{ text: "pineapple"}
|
|
||||||
]
|
|
||||||
|
|
||||||
const table = await db.createTable("vectors", data, embedding)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Querying using an embedding function
|
|
||||||
|
|
||||||
!!! warning
|
|
||||||
At query time, you **must** use the same embedding function you used to vectorize your data.
|
|
||||||
If you use a different embedding function, the embeddings will not reside in the same vector
|
|
||||||
space and the results will be nonsensical.
|
|
||||||
|
|
||||||
=== "Python"
|
|
||||||
```python
|
|
||||||
query = "What's the best pizza topping?"
|
|
||||||
query_vector = embed_func([query])[0]
|
|
||||||
results = (
|
|
||||||
tbl.search(query_vector)
|
|
||||||
.limit(10)
|
|
||||||
.to_pandas()
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
The above snippet returns a pandas DataFrame with the 10 closest vectors to the query.
|
|
||||||
|
|
||||||
=== "JavaScript"
|
|
||||||
```javascript
|
|
||||||
const results = await table
|
|
||||||
.search("What's the best pizza topping?")
|
|
||||||
.limit(10)
|
|
||||||
.execute()
|
|
||||||
```
|
|
||||||
|
|
||||||
The above snippet returns an array of records with the top 10 nearest neighbors to the query.
|
|
||||||
@@ -3,61 +3,126 @@ Representing multi-modal data as vector embeddings is becoming a standard practi
|
|||||||
For this purpose, LanceDB introduces an **embedding functions API**, that allow you simply set up once, during the configuration stage of your project. After this, the table remembers it, effectively making the embedding functions *disappear in the background* so you don't have to worry about manually passing callables, and instead, simply focus on the rest of your data engineering pipeline.
|
For this purpose, LanceDB introduces an **embedding functions API**, that allow you simply set up once, during the configuration stage of your project. After this, the table remembers it, effectively making the embedding functions *disappear in the background* so you don't have to worry about manually passing callables, and instead, simply focus on the rest of your data engineering pipeline.
|
||||||
|
|
||||||
!!! warning
|
!!! warning
|
||||||
Using the implicit embeddings management approach means that you can forget about the manually passing around embedding
|
Using the embedding function registry means that you don't have to explicitly generate the embeddings yourself.
|
||||||
functions in your code, as long as you don't intend to change it at a later time. If your embedding function changes,
|
However, if your embedding function changes, you'll have to re-configure your table with the new embedding function
|
||||||
you'll have to re-configure your table with the new embedding function and regenerate the embeddings.
|
and regenerate the embeddings. In the future, we plan to support the ability to change the embedding function via
|
||||||
|
table metadata and have LanceDB automatically take care of regenerating the embeddings.
|
||||||
|
|
||||||
|
|
||||||
## 1. Define the embedding function
|
## 1. Define the embedding function
|
||||||
We have some pre-defined embedding functions in the global registry, with more coming soon. Here's let's an implementation of CLIP as example.
|
|
||||||
```
|
|
||||||
registry = EmbeddingFunctionRegistry.get_instance()
|
|
||||||
clip = registry.get("open-clip").create()
|
|
||||||
|
|
||||||
```
|
=== "Python"
|
||||||
You can also define your own embedding function by implementing the `EmbeddingFunction` abstract base interface. It subclasses Pydantic Model which can be utilized to write complex schemas simply as we'll see next!
|
In the LanceDB python SDK, we define a global embedding function registry with
|
||||||
|
many different embedding models and even more coming soon.
|
||||||
|
Here's let's an implementation of CLIP as example.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lancedb.embeddings import get_registry
|
||||||
|
|
||||||
|
registry = get_registry()
|
||||||
|
clip = registry.get("open-clip").create()
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also define your own embedding function by implementing the `EmbeddingFunction`
|
||||||
|
abstract base interface. It subclasses Pydantic Model which can be utilized to write complex schemas simply as we'll see next!
|
||||||
|
|
||||||
|
=== "JavaScript""
|
||||||
|
In the TypeScript SDK, the choices are more limited. For now, only the OpenAI
|
||||||
|
embedding function is available.
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
const lancedb = require("vectordb");
|
||||||
|
|
||||||
|
// You need to provide an OpenAI API key
|
||||||
|
const apiKey = "sk-..."
|
||||||
|
// The embedding function will create embeddings for the 'text' column
|
||||||
|
const embedding = new lancedb.OpenAIEmbeddingFunction('text', apiKey)
|
||||||
|
```
|
||||||
|
|
||||||
## 2. Define the data model or schema
|
## 2. Define the data model or schema
|
||||||
The embedding function defined above abstracts away all the details about the models and dimensions required to define the schema. You can simply set a field as **source** or **vector** column. Here's how:
|
|
||||||
|
|
||||||
```python
|
=== "Python"
|
||||||
class Pets(LanceModel):
|
The embedding function defined above abstracts away all the details about the models and dimensions required to define the schema. You can simply set a field as **source** or **vector** column. Here's how:
|
||||||
vector: Vector(clip.ndims) = clip.VectorField()
|
|
||||||
image_uri: str = clip.SourceField()
|
|
||||||
```
|
|
||||||
|
|
||||||
`VectorField` tells LanceDB to use the clip embedding function to generate query embeddings for the `vector` column and `SourceField` ensures that when adding data, we automatically use the specified embedding function to encode `image_uri`.
|
```python
|
||||||
|
class Pets(LanceModel):
|
||||||
|
vector: Vector(clip.ndims) = clip.VectorField()
|
||||||
|
image_uri: str = clip.SourceField()
|
||||||
|
```
|
||||||
|
|
||||||
## 3. Create LanceDB table
|
`VectorField` tells LanceDB to use the clip embedding function to generate query embeddings for the `vector` column and `SourceField` ensures that when adding data, we automatically use the specified embedding function to encode `image_uri`.
|
||||||
Now that we have chosen/defined our embedding function and the schema, we can create the table:
|
|
||||||
|
|
||||||
```python
|
=== "JavaScript"
|
||||||
db = lancedb.connect("~/lancedb")
|
|
||||||
table = db.create_table("pets", schema=Pets)
|
|
||||||
|
|
||||||
```
|
For the TypeScript SDK, a schema can be inferred from input data, or an explicit
|
||||||
|
Arrow schema can be provided.
|
||||||
|
|
||||||
That's it! We've provided all the information needed to embed the source and query inputs. We can now forget about the model and dimension details and start to build our VectorDB pipeline.
|
## 3. Create table and add data
|
||||||
|
|
||||||
## 4. Ingest lots of data and query your table
|
Now that we have chosen/defined our embedding function and the schema,
|
||||||
Any new or incoming data can just be added and it'll be vectorized automatically.
|
we can create the table and ingest data without needing to explicitly generate
|
||||||
|
the embeddings at all:
|
||||||
|
|
||||||
```python
|
=== "Python"
|
||||||
table.add([{"image_uri": u} for u in uris])
|
```python
|
||||||
```
|
db = lancedb.connect("~/lancedb")
|
||||||
|
table = db.create_table("pets", schema=Pets)
|
||||||
|
|
||||||
Our OpenCLIP query embedding function supports querying via both text and images:
|
table.add([{"image_uri": u} for u in uris])
|
||||||
|
```
|
||||||
|
|
||||||
```python
|
=== "JavaScript"
|
||||||
result = table.search("dog")
|
|
||||||
```
|
|
||||||
|
|
||||||
Let's query an image:
|
```javascript
|
||||||
|
const db = await lancedb.connect("data/sample-lancedb");
|
||||||
|
const data = [
|
||||||
|
{ text: "pepperoni"},
|
||||||
|
{ text: "pineapple"}
|
||||||
|
]
|
||||||
|
|
||||||
```python
|
const table = await db.createTable("vectors", data, embedding)
|
||||||
p = Path("path/to/images/samoyed_100.jpg")
|
```
|
||||||
query_image = Image.open(p)
|
|
||||||
table.search(query_image)
|
## 4. Querying your table
|
||||||
```
|
Not only can you forget about the embeddings during ingestion, you also don't
|
||||||
|
need to worry about it when you query the table:
|
||||||
|
|
||||||
|
=== "Python"
|
||||||
|
|
||||||
|
Our OpenCLIP query embedding function supports querying via both text and images:
|
||||||
|
|
||||||
|
```python
|
||||||
|
results = (
|
||||||
|
table.search("dog")
|
||||||
|
.limit(10)
|
||||||
|
.to_pandas()
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Or we can search using an image:
|
||||||
|
|
||||||
|
```python
|
||||||
|
p = Path("path/to/images/samoyed_100.jpg")
|
||||||
|
query_image = Image.open(p)
|
||||||
|
results = (
|
||||||
|
table.search(query_image)
|
||||||
|
.limit(10)
|
||||||
|
.to_pandas()
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Both of the above snippet returns a pandas DataFrame with the 10 closest vectors to the query.
|
||||||
|
|
||||||
|
=== "JavaScript"
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
const results = await table
|
||||||
|
.search("What's the best pizza topping?")
|
||||||
|
.limit(10)
|
||||||
|
.execute()
|
||||||
|
```
|
||||||
|
|
||||||
|
The above snippet returns an array of records with the top 10 nearest neighbors to the query.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -100,4 +165,5 @@ rs[2].image
|
|||||||
|
|
||||||

|

|
||||||
|
|
||||||
Now that you have the basic idea about implicit management via embedding functions, let's dive deeper into a [custom API](./api.md) that you can use to implement your own embedding functions.
|
Now that you have the basic idea about LanceDB embedding functions and the embedding function registry,
|
||||||
|
let's dive deeper into defining your own [custom functions](./custom_embedding_function.md).
|
||||||
@@ -1,8 +1,14 @@
|
|||||||
Due to the nature of vector embeddings, they can be used to represent any kind of data, from text to images to audio. This makes them a very powerful tool for machine learning practitioners. However, there's no one-size-fits-all solution for generating embeddings - there are many different libraries and APIs (both commercial and open source) that can be used to generate embeddings from structured/unstructured data.
|
Due to the nature of vector embeddings, they can be used to represent any kind of data, from text to images to audio.
|
||||||
|
This makes them a very powerful tool for machine learning practitioners.
|
||||||
|
However, there's no one-size-fits-all solution for generating embeddings - there are many different libraries and APIs
|
||||||
|
(both commercial and open source) that can be used to generate embeddings from structured/unstructured data.
|
||||||
|
|
||||||
LanceDB supports 2 methods of vectorizing your raw data into embeddings.
|
LanceDB supports 3 methods of working with embeddings.
|
||||||
|
|
||||||
1. **Explicit**: By manually calling LanceDB's `with_embedding` function to vectorize your data via an `embed_func` of your choice
|
1. You can manually generate embeddings for the data and queries. This is done outside of LanceDB.
|
||||||
2. **Implicit**: Allow LanceDB to embed the data and queries in the background as they come in, by using the table's `EmbeddingRegistry` information
|
2. You can use the built-in [embedding functions](./embedding_functions.md) to embed the data and queries in the background.
|
||||||
|
3. For python users, you can define your own [custom embedding function](./custom_embedding_function.md)
|
||||||
|
that extends the default embedding functions.
|
||||||
|
|
||||||
See the [explicit](embedding_explicit.md) and [implicit](embedding_functions.md) embedding sections for more details.
|
For python users, there is also a legacy [with_embeddings API](./legacy.md).
|
||||||
|
It is retained for compatibility and will be removed in a future version.
|
||||||
99
docs/src/embeddings/legacy.md
Normal file
99
docs/src/embeddings/legacy.md
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
The legacy `with_embeddings` API is for Python only and is deprecated.
|
||||||
|
|
||||||
|
### Hugging Face
|
||||||
|
|
||||||
|
The most popular open source option is to use the [sentence-transformers](https://www.sbert.net/)
|
||||||
|
library, which can be installed via pip.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install sentence-transformers
|
||||||
|
```
|
||||||
|
|
||||||
|
The example below shows how to use the `paraphrase-albert-small-v2` model to generate embeddings
|
||||||
|
for a given document.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
name="paraphrase-albert-small-v2"
|
||||||
|
model = SentenceTransformer(name)
|
||||||
|
|
||||||
|
# used for both training and querying
|
||||||
|
def embed_func(batch):
|
||||||
|
return [model.encode(sentence) for sentence in batch]
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### OpenAI
|
||||||
|
|
||||||
|
Another popular alternative is to use an external API like OpenAI's [embeddings API](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings).
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Configuring the environment variable OPENAI_API_KEY
|
||||||
|
if "OPENAI_API_KEY" not in os.environ:
|
||||||
|
# OR set the key here as a variable
|
||||||
|
openai.api_key = "sk-..."
|
||||||
|
|
||||||
|
client = openai.OpenAI()
|
||||||
|
|
||||||
|
def embed_func(c):
|
||||||
|
rs = client.embeddings.create(input=c, model="text-embedding-ada-002")
|
||||||
|
return [record.embedding for record in rs["data"]]
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Applying an embedding function to data
|
||||||
|
|
||||||
|
Using an embedding function, you can apply it to raw data
|
||||||
|
to generate embeddings for each record.
|
||||||
|
|
||||||
|
Say you have a pandas DataFrame with a `text` column that you want embedded,
|
||||||
|
you can use the `with_embeddings` function to generate embeddings and add them to
|
||||||
|
an existing table.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import pandas as pd
|
||||||
|
from lancedb.embeddings import with_embeddings
|
||||||
|
|
||||||
|
df = pd.DataFrame(
|
||||||
|
[
|
||||||
|
{"text": "pepperoni"},
|
||||||
|
{"text": "pineapple"}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
data = with_embeddings(embed_func, df)
|
||||||
|
|
||||||
|
# The output is used to create / append to a table
|
||||||
|
tbl = db.create_table("my_table", data=data)
|
||||||
|
```
|
||||||
|
|
||||||
|
If your data is in a different column, you can specify the `column` kwarg to `with_embeddings`.
|
||||||
|
|
||||||
|
By default, LanceDB calls the function with batches of 1000 rows. This can be configured
|
||||||
|
using the `batch_size` parameter to `with_embeddings`.
|
||||||
|
|
||||||
|
LanceDB automatically wraps the function with retry and rate-limit logic to ensure the OpenAI
|
||||||
|
API call is reliable.
|
||||||
|
|
||||||
|
## Querying using an embedding function
|
||||||
|
|
||||||
|
!!! warning
|
||||||
|
At query time, you **must** use the same embedding function you used to vectorize your data.
|
||||||
|
If you use a different embedding function, the embeddings will not reside in the same vector
|
||||||
|
space and the results will be nonsensical.
|
||||||
|
|
||||||
|
=== "Python"
|
||||||
|
```python
|
||||||
|
query = "What's the best pizza topping?"
|
||||||
|
query_vector = embed_func([query])[0]
|
||||||
|
results = (
|
||||||
|
tbl.search(query_vector)
|
||||||
|
.limit(10)
|
||||||
|
.to_pandas()
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
The above snippet returns a pandas DataFrame with the 10 closest vectors to the query.
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
import pickle
|
import pickle
|
||||||
import re
|
import re
|
||||||
import sys
|
|
||||||
import zipfile
|
import zipfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|||||||
@@ -69,3 +69,19 @@ MinIO supports an S3 compatible API. In order to connect to a MinIO instance, yo
|
|||||||
- Set the envvar `AWS_ENDPOINT` to the URL of your MinIO API
|
- Set the envvar `AWS_ENDPOINT` to the URL of your MinIO API
|
||||||
- Set the envvars `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` with your MinIO credential
|
- Set the envvars `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` with your MinIO credential
|
||||||
- Call `lancedb.connect("s3://minio_bucket_name")`
|
- Call `lancedb.connect("s3://minio_bucket_name")`
|
||||||
|
|
||||||
|
### Where can I find benchmarks for LanceDB?
|
||||||
|
|
||||||
|
Refer to this [post](https://blog.lancedb.com/benchmarking-lancedb-92b01032874a) for recent benchmarks.
|
||||||
|
|
||||||
|
### How much data can LanceDB practically manage without effecting performance?
|
||||||
|
|
||||||
|
We target good performance on ~10-50 billion rows and ~10-30 TB of data.
|
||||||
|
|
||||||
|
### Does LanceDB support concurrent operations?
|
||||||
|
|
||||||
|
LanceDB can handle concurrent reads very well, and can scale horizontally. The main constraint is how well the [storage layer](https://lancedb.github.io/lancedb/concepts/storage/) you've chosen scales. For writes, we support concurrent writing, though too many concurrent writers can lead to failing writes as there is a limited number of times a writer retries a commit
|
||||||
|
|
||||||
|
!!! info "Multiprocessing with LanceDB"
|
||||||
|
|
||||||
|
For multiprocessing you should probably not use ```fork``` as lance is multi-threaded internally and ```fork``` and multi-thread do not work well.[Refer to this discussion](https://discuss.python.org/t/concerns-regarding-deprecation-of-fork-with-alive-threads/33555)
|
||||||
|
|||||||
@@ -636,6 +636,70 @@ The `values` parameter is used to provide the new values for the columns as lite
|
|||||||
|
|
||||||
When rows are updated, they are moved out of the index. The row will still show up in ANN queries, but the query will not be as fast as it would be if the row was in the index. If you update a large proportion of rows, consider rebuilding the index afterwards.
|
When rows are updated, they are moved out of the index. The row will still show up in ANN queries, but the query will not be as fast as it would be if the row was in the index. If you update a large proportion of rows, consider rebuilding the index afterwards.
|
||||||
|
|
||||||
|
## Consistency
|
||||||
|
|
||||||
|
In LanceDB OSS, users can set the `read_consistency_interval` parameter on connections to achieve different levels of read consistency. This parameter determines how frequently the database synchronizes with the underlying storage system to check for updates made by other processes. If another process updates a table, the database will not see the changes until the next synchronization.
|
||||||
|
|
||||||
|
There are three possible settings for `read_consistency_interval`:
|
||||||
|
|
||||||
|
1. **Unset (default)**: The database does not check for updates to tables made by other processes. This provides the best query performance, but means that clients may not see the most up-to-date data. This setting is suitable for applications where the data does not change during the lifetime of the table reference.
|
||||||
|
2. **Zero seconds (Strong consistency)**: The database checks for updates on every read. This provides the strongest consistency guarantees, ensuring that all clients see the latest committed data. However, it has the most overhead. This setting is suitable when consistency matters more than having high QPS.
|
||||||
|
3. **Custom interval (Eventual consistency)**: The database checks for updates at a custom interval, such as every 5 seconds. This provides eventual consistency, allowing for some lag between write and read operations. Performance wise, this is a middle ground between strong consistency and no consistency check. This setting is suitable for applications where immediate consistency is not critical, but clients should see updated data eventually.
|
||||||
|
|
||||||
|
!!! tip "Consistency in LanceDB Cloud"
|
||||||
|
|
||||||
|
This is only tune-able in LanceDB OSS. In LanceDB Cloud, readers are always eventually consistent.
|
||||||
|
|
||||||
|
=== "Python"
|
||||||
|
|
||||||
|
To set strong consistency, use `timedelta(0)`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from datetime import timedelta
|
||||||
|
db = lancedb.connect("./.lancedb",. read_consistency_interval=timedelta(0))
|
||||||
|
table = db.open_table("my_table")
|
||||||
|
```
|
||||||
|
|
||||||
|
For eventual consistency, use a custom `timedelta`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from datetime import timedelta
|
||||||
|
db = lancedb.connect("./.lancedb", read_consistency_interval=timedelta(seconds=5))
|
||||||
|
table = db.open_table("my_table")
|
||||||
|
```
|
||||||
|
|
||||||
|
By default, a `Table` will never check for updates from other writers. To manually check for updates you can use `checkout_latest`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
db = lancedb.connect("./.lancedb")
|
||||||
|
table = db.open_table("my_table")
|
||||||
|
|
||||||
|
# (Other writes happen to my_table from another process)
|
||||||
|
|
||||||
|
# Check for updates
|
||||||
|
table.checkout_latest()
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "JavaScript/Typescript"
|
||||||
|
|
||||||
|
To set strong consistency, use `0`:
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
const db = await lancedb.connect({ uri: "./.lancedb", readConsistencyInterval: 0 });
|
||||||
|
const table = await db.openTable("my_table");
|
||||||
|
```
|
||||||
|
|
||||||
|
For eventual consistency, specify the update interval as seconds:
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
const db = await lancedb.connect({ uri: "./.lancedb", readConsistencyInterval: 5 });
|
||||||
|
const table = await db.openTable("my_table");
|
||||||
|
```
|
||||||
|
|
||||||
|
<!-- Node doesn't yet support the version time travel: https://github.com/lancedb/lancedb/issues/1007
|
||||||
|
Once it does, we can show manual consistency check for Node as well.
|
||||||
|
-->
|
||||||
|
|
||||||
## What's next?
|
## What's next?
|
||||||
|
|
||||||
Learn the best practices on creating an ANN index and getting the most out of it.
|
Learn the best practices on creating an ANN index and getting the most out of it.
|
||||||
49
docs/src/hybrid_search/eval.md
Normal file
49
docs/src/hybrid_search/eval.md
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
# Hybrid Search
|
||||||
|
|
||||||
|
Hybrid Search is a broad (often misused) term. It can mean anything from combining multiple methods for searching, to applying ranking methods to better sort the results. In this blog, we use the definition of "hybrid search" to mean using a combination of keyword-based and vector search.
|
||||||
|
|
||||||
|
## The challenge of (re)ranking search results
|
||||||
|
Once you have a group of the most relevant search results from multiple search sources, you'd likely standardize the score and rank them accordingly. This process can also be seen as another independent step - reranking.
|
||||||
|
There are two approaches for reranking search results from multiple sources.
|
||||||
|
* <b>Score-based</b>: Calculate final relevance scores based on a weighted linear combination of individual search algorithm scores. Example - Weighted linear combination of semantic search & keyword-based search results.
|
||||||
|
* <b>Relevance-based</b>: Discards the existing scores and calculates the relevance of each search result - query pair. Example - Cross Encoder models
|
||||||
|
|
||||||
|
Even though there are many strategies for reranking search results, none works for all cases. Moreover, evaluating them itself is a challenge. Also, reranking can be dataset, application specific so it's hard to generalize.
|
||||||
|
|
||||||
|
### Example evaluation of hybrid search with Reranking
|
||||||
|
|
||||||
|
Here's some evaluation numbers from experiment comparing these re-rankers on about 800 queries. It is modified version of an evaluation script from [llama-index](https://github.com/run-llama/finetune-embedding/blob/main/evaluate.ipynb) that measures hit-rate at top-k.
|
||||||
|
|
||||||
|
<b> With OpenAI ada2 embedding </b>
|
||||||
|
|
||||||
|
Vector Search baseline - `0.64`
|
||||||
|
|
||||||
|
| Reranker | Top-3 | Top-5 | Top-10 |
|
||||||
|
| --- | --- | --- | --- |
|
||||||
|
| Linear Combination | `0.73` | `0.74` | `0.85` |
|
||||||
|
| Cross Encoder | `0.71` | `0.70` | `0.77` |
|
||||||
|
| Cohere | `0.81` | `0.81` | `0.85` |
|
||||||
|
| ColBERT | `0.68` | `0.68` | `0.73` |
|
||||||
|
|
||||||
|
<p>
|
||||||
|
<img src="https://github.com/AyushExel/assets/assets/15766192/d57b1780-ef27-414c-a5c3-73bee7808a45">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<b> With OpenAI embedding-v3-small </b>
|
||||||
|
|
||||||
|
Vector Search baseline - `0.59`
|
||||||
|
|
||||||
|
| Reranker | Top-3 | Top-5 | Top-10 |
|
||||||
|
| --- | --- | --- | --- |
|
||||||
|
| Linear Combination | `0.68` | `0.70` | `0.84` |
|
||||||
|
| Cross Encoder | `0.72` | `0.72` | `0.79` |
|
||||||
|
| Cohere | `0.79` | `0.79` | `0.84` |
|
||||||
|
| ColBERT | `0.70` | `0.70` | `0.76` |
|
||||||
|
|
||||||
|
<p>
|
||||||
|
<img src="https://github.com/AyushExel/assets/assets/15766192/259adfd2-6ec6-4df6-a77d-1456598970dd">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
### Conclusion
|
||||||
|
|
||||||
|
The results show that the reranking methods are able to improve the search results. However, the improvement is not consistent across all rerankers. The choice of reranker depends on the dataset and the application. It is also important to note that the reranking methods are not a replacement for the search methods. They are complementary and should be used together to get the best results. The speed to recall tradeoff is also an important factor to consider when choosing the reranker.
|
||||||
@@ -1,22 +1,29 @@
|
|||||||
# Hybrid Search
|
# Hybrid Search
|
||||||
|
|
||||||
LanceDB supports both semantic and keyword-based search. In real world applications, it is often useful to combine these two approaches to get the best best results. For example, you may want to search for a document that is semantically similar to a query document, but also contains a specific keyword. This is an example of *hybrid search*, a search algorithm that combines multiple search techniques.
|
LanceDB supports both semantic and keyword-based search (also termed full-text search, or FTS). In real world applications, it is often useful to combine these two approaches to get the best best results. For example, you may want to search for a document that is semantically similar to a query document, but also contains a specific keyword. This is an example of *hybrid search*, a search algorithm that combines multiple search techniques.
|
||||||
|
|
||||||
## Hybrid search in LanceDB
|
## Hybrid search in LanceDB
|
||||||
You can perform hybrid search in LanceDB by combining the results of semantic and full-text search via a reranking algorithm of your choice. LanceDB provides multiple rerankers out of the box. However, you can always write a custom reranker if your use case need more sophisticated logic .
|
You can perform hybrid search in LanceDB by combining the results of semantic and full-text search via a reranking algorithm of your choice. LanceDB provides multiple rerankers out of the box. However, you can always write a custom reranker if your use case need more sophisticated logic .
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
import os
|
||||||
|
|
||||||
import lancedb
|
import lancedb
|
||||||
|
import openai
|
||||||
from lancedb.embeddings import get_registry
|
from lancedb.embeddings import get_registry
|
||||||
from lancedb.pydanatic import LanceModel, Vector
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
|
||||||
db = lancedb.connect("~/.lancedb")
|
db = lancedb.connect("~/.lancedb")
|
||||||
|
|
||||||
# Ingest embedding function in LanceDB table
|
# Ingest embedding function in LanceDB table
|
||||||
|
# Configuring the environment variable OPENAI_API_KEY
|
||||||
|
if "OPENAI_API_KEY" not in os.environ:
|
||||||
|
# OR set the key here as a variable
|
||||||
|
openai.api_key = "sk-..."
|
||||||
embeddings = get_registry().get("openai").create()
|
embeddings = get_registry().get("openai").create()
|
||||||
|
|
||||||
class Documents(LanceModel):
|
class Documents(LanceModel):
|
||||||
vector: Vector(embeddings.ndims) = embeddings.VectorField()
|
vector: Vector(embeddings.ndims()) = embeddings.VectorField()
|
||||||
text: str = embeddings.SourceField()
|
text: str = embeddings.SourceField()
|
||||||
|
|
||||||
table = db.create_table("documents", schema=Documents)
|
table = db.create_table("documents", schema=Documents)
|
||||||
@@ -31,17 +38,19 @@ data = [
|
|||||||
# ingest docs with auto-vectorization
|
# ingest docs with auto-vectorization
|
||||||
table.add(data)
|
table.add(data)
|
||||||
|
|
||||||
|
# Create a fts index before the hybrid search
|
||||||
|
table.create_fts_index("text")
|
||||||
# hybrid search with default re-ranker
|
# hybrid search with default re-ranker
|
||||||
results = table.search("flower moon", query_type="hybrid").to_pandas()
|
results = table.search("flower moon", query_type="hybrid").to_pandas()
|
||||||
```
|
```
|
||||||
|
|
||||||
By default, LanceDB uses `LinearCombinationReranker(weights=0.7)` to combine and rerank the results of semantic and full-text search. You can customize the hyperparameters as needed or write your own custom reranker. Here's how you can use any of the available rerankers:
|
By default, LanceDB uses `LinearCombinationReranker(weight=0.7)` to combine and rerank the results of semantic and full-text search. You can customize the hyperparameters as needed or write your own custom reranker. Here's how you can use any of the available rerankers:
|
||||||
|
|
||||||
|
|
||||||
### `rerank()` arguments
|
### `rerank()` arguments
|
||||||
* `normalize`: `str`, default `"score"`:
|
* `normalize`: `str`, default `"score"`:
|
||||||
The method to normalize the scores. Can be "rank" or "score". If "rank", the scores are converted to ranks and then normalized. If "score", the scores are normalized directly.
|
The method to normalize the scores. Can be "rank" or "score". If "rank", the scores are converted to ranks and then normalized. If "score", the scores are normalized directly.
|
||||||
* `reranker`: `Reranker`, default `LinearCombinationReranker(weights=0.7)`.
|
* `reranker`: `Reranker`, default `LinearCombinationReranker(weight=0.7)`.
|
||||||
The reranker to use. If not specified, the default reranker is used.
|
The reranker to use. If not specified, the default reranker is used.
|
||||||
|
|
||||||
|
|
||||||
@@ -55,12 +64,12 @@ This is the default re-ranker used by LanceDB. It combines the results of semant
|
|||||||
```python
|
```python
|
||||||
from lancedb.rerankers import LinearCombinationReranker
|
from lancedb.rerankers import LinearCombinationReranker
|
||||||
|
|
||||||
reranker = LinearCombinationReranker(weights=0.3) # Use 0.3 as the weight for vector search
|
reranker = LinearCombinationReranker(weight=0.3) # Use 0.3 as the weight for vector search
|
||||||
|
|
||||||
results = table.search("rebel", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
results = table.search("rebel", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||||
```
|
```
|
||||||
|
|
||||||
Arguments
|
### Arguments
|
||||||
----------------
|
----------------
|
||||||
* `weight`: `float`, default `0.7`:
|
* `weight`: `float`, default `0.7`:
|
||||||
The weight to use for the semantic search score. The weight for the full-text search score is `1 - weights`.
|
The weight to use for the semantic search score. The weight for the full-text search score is `1 - weights`.
|
||||||
@@ -82,9 +91,9 @@ reranker = CohereReranker()
|
|||||||
results = table.search("vampire weekend", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
results = table.search("vampire weekend", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||||
```
|
```
|
||||||
|
|
||||||
Arguments
|
### Arguments
|
||||||
----------------
|
----------------
|
||||||
* `model_name`` : str, default `"rerank-english-v2.0"``
|
* `model_name` : str, default `"rerank-english-v2.0"`
|
||||||
The name of the cross encoder model to use. Available cohere models are:
|
The name of the cross encoder model to use. Available cohere models are:
|
||||||
- rerank-english-v2.0
|
- rerank-english-v2.0
|
||||||
- rerank-multilingual-v2.0
|
- rerank-multilingual-v2.0
|
||||||
@@ -108,7 +117,7 @@ results = table.search("harmony hall", query_type="hybrid").rerank(reranker=rera
|
|||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
Arguments
|
### Arguments
|
||||||
----------------
|
----------------
|
||||||
* `model` : str, default `"cross-encoder/ms-marco-TinyBERT-L-6"`
|
* `model` : str, default `"cross-encoder/ms-marco-TinyBERT-L-6"`
|
||||||
The name of the cross encoder model to use. Available cross encoder models can be found [here](https://www.sbert.net/docs/pretrained_cross-encoders.html)
|
The name of the cross encoder model to use. Available cross encoder models can be found [here](https://www.sbert.net/docs/pretrained_cross-encoders.html)
|
||||||
@@ -121,6 +130,61 @@ Arguments
|
|||||||
Only returns `_relevance_score`. Does not support `return_score = "all"`.
|
Only returns `_relevance_score`. Does not support `return_score = "all"`.
|
||||||
|
|
||||||
|
|
||||||
|
### ColBERT Reranker
|
||||||
|
This reranker uses the ColBERT model to combine the results of semantic and full-text search. You can use it by passing `ColbertrReranker()` to the `rerank()` method.
|
||||||
|
|
||||||
|
ColBERT reranker model calculates relevance of given docs against the query and don't take existing fts and vector search scores into account, so it currently only supports `return_score="relevance"`. By default, it looks for `text` column to rerank the results. But you can specify the column name to use as input to the cross encoder model as described below.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lancedb.rerankers import ColbertReranker
|
||||||
|
|
||||||
|
reranker = ColbertReranker()
|
||||||
|
|
||||||
|
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Arguments
|
||||||
|
----------------
|
||||||
|
* `model_name` : `str`, default `"colbert-ir/colbertv2.0"`
|
||||||
|
The name of the cross encoder model to use.
|
||||||
|
* `column` : `str`, default `"text"`
|
||||||
|
The name of the column to use as input to the cross encoder model.
|
||||||
|
* `return_score` : `str`, default `"relevance"`
|
||||||
|
options are `"relevance"` or `"all"`. Only `"relevance"` is supported for now.
|
||||||
|
|
||||||
|
!!! Note
|
||||||
|
Only returns `_relevance_score`. Does not support `return_score = "all"`.
|
||||||
|
|
||||||
|
### OpenAI Reranker
|
||||||
|
This reranker uses the OpenAI API to combine the results of semantic and full-text search. You can use it by passing `OpenaiReranker()` to the `rerank()` method.
|
||||||
|
|
||||||
|
!!! Note
|
||||||
|
This prompts chat model to rerank results which is not a dedicated reranker model. This should be treated as experimental.
|
||||||
|
|
||||||
|
!!! Tip
|
||||||
|
- You might run out of token limit so set the search `limits` based on your token limit.
|
||||||
|
- It is recommended to use gpt-4-turbo-preview, the default model, older models might lead to undesired behaviour
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lancedb.rerankers import OpenaiReranker
|
||||||
|
|
||||||
|
reranker = OpenaiReranker()
|
||||||
|
|
||||||
|
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Arguments
|
||||||
|
----------------
|
||||||
|
* `model_name` : `str`, default `"gpt-4-turbo-preview"`
|
||||||
|
The name of the cross encoder model to use.
|
||||||
|
* `column` : `str`, default `"text"`
|
||||||
|
The name of the column to use as input to the cross encoder model.
|
||||||
|
* `return_score` : `str`, default `"relevance"`
|
||||||
|
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||||
|
* `api_key` : `str`, default `None`
|
||||||
|
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
|
||||||
|
|
||||||
|
|
||||||
## Building Custom Rerankers
|
## Building Custom Rerankers
|
||||||
You can build your own custom reranker by subclassing the `Reranker` class and implementing the `rerank_hybrid()` method. Here's an example of a custom reranker that combines the results of semantic and full-text search using a linear combination of the scores.
|
You can build your own custom reranker by subclassing the `Reranker` class and implementing the `rerank_hybrid()` method. Here's an example of a custom reranker that combines the results of semantic and full-text search using a linear combination of the scores.
|
||||||
|
|
||||||
@@ -137,7 +201,7 @@ class MyReranker(Reranker):
|
|||||||
self.param1 = param1
|
self.param1 = param1
|
||||||
self.param2 = param2
|
self.param2 = param2
|
||||||
|
|
||||||
def rerank_hybrid(self, vector_results: pa.Table, fts_results: pa.Table):
|
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table):
|
||||||
# Use the built-in merging function
|
# Use the built-in merging function
|
||||||
combined_result = self.merge_results(vector_results, fts_results)
|
combined_result = self.merge_results(vector_results, fts_results)
|
||||||
|
|
||||||
@@ -149,24 +213,30 @@ class MyReranker(Reranker):
|
|||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
You can also accept additional arguments like a filter along with fts and vector search results
|
### Example of a Custom Reranker
|
||||||
|
For the sake of simplicity let's build custom reranker that just enchances the Cohere Reranker by accepting a filter query, and accept other CohereReranker params as kwags.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|
||||||
from lancedb.rerankers import Reranker
|
from typing import List, Union
|
||||||
import pyarrow as pa
|
import pandas as pd
|
||||||
|
from lancedb.rerankers import CohereReranker
|
||||||
|
|
||||||
class MyReranker(Reranker):
|
class MofidifiedCohereReranker(CohereReranker):
|
||||||
...
|
def __init__(self, filters: Union[str, List[str]], **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
def rerank_hybrid(self, vector_results: pa.Table, fts_results: pa.Table, filter: str):
|
filters = filters if isinstance(filters, list) else [filters]
|
||||||
# Use the built-in merging function
|
self.filters = filters
|
||||||
combined_result = self.merge_results(vector_results, fts_results)
|
|
||||||
|
|
||||||
# Do something with the combined results & filter
|
|
||||||
# ...
|
|
||||||
|
|
||||||
# Return the combined results
|
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table)-> pa.Table:
|
||||||
return combined_result
|
combined_result = super().rerank_hybrid(query, vector_results, fts_results)
|
||||||
|
df = combined_result.to_pandas()
|
||||||
|
for filter in self.filters:
|
||||||
|
df = df.query("not text.str.contains(@filter)")
|
||||||
|
|
||||||
|
return pa.Table.from_pandas(df)
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
!!! tip
|
||||||
|
The `vector_results` and `fts_results` are pyarrow tables. You can convert them to pandas dataframes using `to_pandas()` method and perform any operations you want. After you are done, you can convert the dataframe back to pyarrow table using `pa.Table.from_pandas()` method and return it.
|
||||||
@@ -290,7 +290,7 @@
|
|||||||
"from lancedb.pydantic import LanceModel, Vector\n",
|
"from lancedb.pydantic import LanceModel, Vector\n",
|
||||||
"\n",
|
"\n",
|
||||||
"class Pets(LanceModel):\n",
|
"class Pets(LanceModel):\n",
|
||||||
" vector: Vector(clip.ndims) = clip.VectorField()\n",
|
" vector: Vector(clip.ndims()) = clip.VectorField()\n",
|
||||||
" image_uri: str = clip.SourceField()\n",
|
" image_uri: str = clip.SourceField()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" @property\n",
|
" @property\n",
|
||||||
@@ -360,7 +360,7 @@
|
|||||||
" table = db.create_table(\"pets\", schema=Pets)\n",
|
" table = db.create_table(\"pets\", schema=Pets)\n",
|
||||||
" # use a sampling of 1000 images\n",
|
" # use a sampling of 1000 images\n",
|
||||||
" p = Path(\"~/Downloads/images\").expanduser()\n",
|
" p = Path(\"~/Downloads/images\").expanduser()\n",
|
||||||
" uris = [str(f) for f in p.iterdir()]\n",
|
" uris = [str(f) for f in p.glob(\"*.jpg\")]\n",
|
||||||
" uris = sample(uris, 1000)\n",
|
" uris = sample(uris, 1000)\n",
|
||||||
" table.add(pd.DataFrame({\"image_uri\": uris}))"
|
" table.add(pd.DataFrame({\"image_uri\": uris}))"
|
||||||
]
|
]
|
||||||
@@ -543,7 +543,7 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from PIL import Image\n",
|
"from PIL import Image\n",
|
||||||
"p = Path(\"/Users/changshe/Downloads/images/samoyed_100.jpg\")\n",
|
"p = Path(\"~/Downloads/images/samoyed_100.jpg\").expanduser()\n",
|
||||||
"query_image = Image.open(p)\n",
|
"query_image = Image.open(p)\n",
|
||||||
"query_image"
|
"query_image"
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -23,10 +23,8 @@ from multiprocessing import Pool
|
|||||||
import lance
|
import lance
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from PIL import Image
|
|
||||||
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast
|
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast
|
||||||
|
|
||||||
import lancedb
|
|
||||||
|
|
||||||
MODEL_ID = "openai/clip-vit-base-patch32"
|
MODEL_ID = "openai/clip-vit-base-patch32"
|
||||||
|
|
||||||
|
|||||||
1122
docs/src/notebooks/hybrid_search.ipynb
Normal file
1122
docs/src/notebooks/hybrid_search.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,9 @@
|
|||||||
# DuckDB
|
# DuckDB
|
||||||
|
|
||||||
LanceDB is very well-integrated with [DuckDB](https://duckdb.org/), an in-process SQL OLAP database. This integration is done via [Arrow](https://duckdb.org/docs/guides/python/sql_on_arrow) .
|
In Python, LanceDB tables can also be queried with [DuckDB](https://duckdb.org/), an in-process SQL OLAP database. This means you can write complex SQL queries to analyze your data in LanceDB.
|
||||||
|
|
||||||
|
This integration is done via [Apache Arrow](https://duckdb.org/docs/guides/python/sql_on_arrow), which provides zero-copy data sharing between LanceDB and DuckDB. DuckDB is capable of passing down column selections and basic filters to LanceDB, reducing the amount of data that needs to be scanned to perform your query. Finally, the integration allows streaming data from LanceDB tables, allowing you to aggregate tables that won't fit into memory. All of this uses the same mechanism described in DuckDB's blog post *[DuckDB quacks Arrow](https://duckdb.org/2021/12/03/duck-arrow.html)*.
|
||||||
|
|
||||||
|
|
||||||
We can demonstrate this by first installing `duckdb` and `lancedb`.
|
We can demonstrate this by first installing `duckdb` and `lancedb`.
|
||||||
|
|
||||||
@@ -19,14 +22,15 @@ data = [
|
|||||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}
|
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}
|
||||||
]
|
]
|
||||||
table = db.create_table("pd_table", data=data)
|
table = db.create_table("pd_table", data=data)
|
||||||
arrow_table = table.to_arrow()
|
|
||||||
```
|
```
|
||||||
|
|
||||||
DuckDB can directly query the `pyarrow.Table` object:
|
To query the table, first call `to_lance` to convert the table to a "dataset", which is an object that can be queried by DuckDB. Then all you need to do is reference that dataset by the same name in your SQL query.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import duckdb
|
import duckdb
|
||||||
|
|
||||||
|
arrow_table = table.to_lance()
|
||||||
|
|
||||||
duckdb.query("SELECT * FROM arrow_table")
|
duckdb.query("SELECT * FROM arrow_table")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ excluded_globs = [
|
|||||||
"../src/concepts/*.md",
|
"../src/concepts/*.md",
|
||||||
"../src/ann_indexes.md",
|
"../src/ann_indexes.md",
|
||||||
"../src/basic.md",
|
"../src/basic.md",
|
||||||
"../src/hybrid_search.md",
|
"../src/hybrid_search/hybrid_search.md",
|
||||||
]
|
]
|
||||||
|
|
||||||
python_prefix = "py"
|
python_prefix = "py"
|
||||||
|
|||||||
44
node/package-lock.json
generated
44
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.4.7",
|
"version": "0.4.11",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.4.7",
|
"version": "0.4.11",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64",
|
"x64",
|
||||||
"arm64"
|
"arm64"
|
||||||
@@ -53,11 +53,11 @@
|
|||||||
"uuid": "^9.0.0"
|
"uuid": "^9.0.0"
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
"optionalDependencies": {
|
||||||
"@lancedb/vectordb-darwin-arm64": "0.4.7",
|
"@lancedb/vectordb-darwin-arm64": "0.4.11",
|
||||||
"@lancedb/vectordb-darwin-x64": "0.4.7",
|
"@lancedb/vectordb-darwin-x64": "0.4.11",
|
||||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.7",
|
"@lancedb/vectordb-linux-arm64-gnu": "0.4.11",
|
||||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.7",
|
"@lancedb/vectordb-linux-x64-gnu": "0.4.11",
|
||||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.7"
|
"@lancedb/vectordb-win32-x64-msvc": "0.4.11"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/@75lb/deep-merge": {
|
"node_modules/@75lb/deep-merge": {
|
||||||
@@ -329,9 +329,9 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||||
"version": "0.4.7",
|
"version": "0.4.11",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.4.7.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.4.11.tgz",
|
||||||
"integrity": "sha512-kACOIytgjBfX8NRwjPKe311XRN3lbSN13B7avT5htMd3kYm3AnnMag9tZhlwoO7lIuvGaXhy7mApygJrjhfJ4g==",
|
"integrity": "sha512-JDOKmFnuJPFkA7ZmrzBJolROwSjWr7yMvAbi40uLBc25YbbVezodd30u2EFtIwWwtk1GqNYRZ49FZOElKYeC/Q==",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"arm64"
|
"arm64"
|
||||||
],
|
],
|
||||||
@@ -341,9 +341,9 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-darwin-x64": {
|
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||||
"version": "0.4.7",
|
"version": "0.4.11",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.4.7.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.4.11.tgz",
|
||||||
"integrity": "sha512-vb74iK5uPWCwz5E60r3yWp/R/HSg54/Z9AZWYckYXqsPv4w/nfbkM5iZhfRqqR/9uE6JClWJKOtjbk7b8CFRFg==",
|
"integrity": "sha512-iy6r+8tp2v1EFgJV52jusXtxgO6NY6SkpOdX41xPqN2mQWMkfUAR9Xtks1mgknjPOIKH4MRc8ZS0jcW/UWmilQ==",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64"
|
"x64"
|
||||||
],
|
],
|
||||||
@@ -353,9 +353,9 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||||
"version": "0.4.7",
|
"version": "0.4.11",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.4.7.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.4.11.tgz",
|
||||||
"integrity": "sha512-jHp7THm6S9sB8RaCxGoZXLAwGAUHnawUUilB1K3mvQsRdfB2bBs0f7wDehW+PDhr+Iog4LshaWbcnoQEUJWR+Q==",
|
"integrity": "sha512-5K6IVcTMuH0SZBjlqB5Gg39WC889FpTwIWKufxzQMMXrzxo5J3lKUHVoR28RRlNhDF2d9kZXBEyCpIfDFsV9iQ==",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"arm64"
|
"arm64"
|
||||||
],
|
],
|
||||||
@@ -365,9 +365,9 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||||
"version": "0.4.7",
|
"version": "0.4.11",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.4.7.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.4.11.tgz",
|
||||||
"integrity": "sha512-LKbVe6Wrp/AGqCCjKliNDmYoeTNgY/wfb2DTLjrx41Jko/04ywLrJ6xSEAn3XD5RDCO5u3fyUdXHHHv5a3VAAQ==",
|
"integrity": "sha512-hF9ZChsdqKqqnivOzd9mE7lC3PmhZadXtwThi2RrsPiOLoEaGDfmr6Ni3amVQnB3bR8YEJtTxdQxe0NC4uW/8g==",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64"
|
"x64"
|
||||||
],
|
],
|
||||||
@@ -377,9 +377,9 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||||
"version": "0.4.7",
|
"version": "0.4.11",
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.4.7.tgz",
|
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.4.11.tgz",
|
||||||
"integrity": "sha512-C5ln4+wafeY1Sm4PeV0Ios9lUaQVVip5Mjl9XU7ngioSEMEuXI/XMVfIdVfDPppVNXPeQxg33wLA272uw88D1Q==",
|
"integrity": "sha512-0+9ut1ccKoqIyGxsVixwx3771Z+DXpl5WfSmOeA8kf3v3jlOg2H+0YUahiXLDid2ju+yeLPrAUYm7A1gKHVhew==",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64"
|
"x64"
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.4.7",
|
"version": "0.4.11",
|
||||||
"description": " Serverless, low-latency vector database for AI applications",
|
"description": " Serverless, low-latency vector database for AI applications",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"types": "dist/index.d.ts",
|
"types": "dist/index.d.ts",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"tsc": "tsc -b",
|
"tsc": "tsc -b",
|
||||||
"build": "npm run tsc && cargo-cp-artifact --artifact cdylib vectordb-node index.node -- cargo build --message-format=json",
|
"build": "npm run tsc && cargo-cp-artifact --artifact cdylib lancedb-node index.node -- cargo build --message-format=json",
|
||||||
"build-release": "npm run build -- --release",
|
"build-release": "npm run build -- --release",
|
||||||
"test": "npm run tsc && mocha -recursive dist/test",
|
"test": "npm run tsc && mocha -recursive dist/test",
|
||||||
"integration-test": "npm run tsc && mocha -recursive dist/integration_test",
|
"integration-test": "npm run tsc && mocha -recursive dist/integration_test",
|
||||||
@@ -61,11 +61,13 @@
|
|||||||
"uuid": "^9.0.0"
|
"uuid": "^9.0.0"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@apache-arrow/ts": "^14.0.2",
|
|
||||||
"@neon-rs/load": "^0.0.74",
|
"@neon-rs/load": "^0.0.74",
|
||||||
"apache-arrow": "^14.0.2",
|
|
||||||
"axios": "^1.4.0"
|
"axios": "^1.4.0"
|
||||||
},
|
},
|
||||||
|
"peerDependencies": {
|
||||||
|
"@apache-arrow/ts": "^14.0.2",
|
||||||
|
"apache-arrow": "^14.0.2"
|
||||||
|
},
|
||||||
"os": [
|
"os": [
|
||||||
"darwin",
|
"darwin",
|
||||||
"linux",
|
"linux",
|
||||||
@@ -85,10 +87,10 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
"optionalDependencies": {
|
||||||
"@lancedb/vectordb-darwin-arm64": "0.4.7",
|
"@lancedb/vectordb-darwin-arm64": "0.4.11",
|
||||||
"@lancedb/vectordb-darwin-x64": "0.4.7",
|
"@lancedb/vectordb-darwin-x64": "0.4.11",
|
||||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.7",
|
"@lancedb/vectordb-linux-arm64-gnu": "0.4.11",
|
||||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.7",
|
"@lancedb/vectordb-linux-x64-gnu": "0.4.11",
|
||||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.7"
|
"@lancedb/vectordb-win32-x64-msvc": "0.4.11"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -14,8 +14,6 @@
|
|||||||
|
|
||||||
import {
|
import {
|
||||||
Field,
|
Field,
|
||||||
type FixedSizeListBuilder,
|
|
||||||
Float32,
|
|
||||||
makeBuilder,
|
makeBuilder,
|
||||||
RecordBatchFileWriter,
|
RecordBatchFileWriter,
|
||||||
Utf8,
|
Utf8,
|
||||||
@@ -26,14 +24,19 @@ import {
|
|||||||
Table as ArrowTable,
|
Table as ArrowTable,
|
||||||
RecordBatchStreamWriter,
|
RecordBatchStreamWriter,
|
||||||
List,
|
List,
|
||||||
Float64,
|
|
||||||
RecordBatch,
|
RecordBatch,
|
||||||
makeData,
|
makeData,
|
||||||
Struct,
|
Struct,
|
||||||
type Float
|
type Float,
|
||||||
|
DataType,
|
||||||
|
Binary,
|
||||||
|
Float32
|
||||||
} from 'apache-arrow'
|
} from 'apache-arrow'
|
||||||
import { type EmbeddingFunction } from './index'
|
import { type EmbeddingFunction } from './index'
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Options to control how a column should be converted to a vector array
|
||||||
|
*/
|
||||||
export class VectorColumnOptions {
|
export class VectorColumnOptions {
|
||||||
/** Vector column type. */
|
/** Vector column type. */
|
||||||
type: Float = new Float32()
|
type: Float = new Float32()
|
||||||
@@ -45,14 +48,50 @@ export class VectorColumnOptions {
|
|||||||
|
|
||||||
/** Options to control the makeArrowTable call. */
|
/** Options to control the makeArrowTable call. */
|
||||||
export class MakeArrowTableOptions {
|
export class MakeArrowTableOptions {
|
||||||
/** Provided schema. */
|
/*
|
||||||
|
* Schema of the data.
|
||||||
|
*
|
||||||
|
* If this is not provided then the data type will be inferred from the
|
||||||
|
* JS type. Integer numbers will become int64, floating point numbers
|
||||||
|
* will become float64 and arrays will become variable sized lists with
|
||||||
|
* the data type inferred from the first element in the array.
|
||||||
|
*
|
||||||
|
* The schema must be specified if there are no records (e.g. to make
|
||||||
|
* an empty table)
|
||||||
|
*/
|
||||||
schema?: Schema
|
schema?: Schema
|
||||||
|
|
||||||
/** Vector columns */
|
/*
|
||||||
|
* Mapping from vector column name to expected type
|
||||||
|
*
|
||||||
|
* Lance expects vector columns to be fixed size list arrays (i.e. tensors)
|
||||||
|
* However, `makeArrowTable` will not infer this by default (it creates
|
||||||
|
* variable size list arrays). This field can be used to indicate that a column
|
||||||
|
* should be treated as a vector column and converted to a fixed size list.
|
||||||
|
*
|
||||||
|
* The keys should be the names of the vector columns. The value specifies the
|
||||||
|
* expected data type of the vector columns.
|
||||||
|
*
|
||||||
|
* If `schema` is provided then this field is ignored.
|
||||||
|
*
|
||||||
|
* By default, the column named "vector" will be assumed to be a float32
|
||||||
|
* vector column.
|
||||||
|
*/
|
||||||
vectorColumns: Record<string, VectorColumnOptions> = {
|
vectorColumns: Record<string, VectorColumnOptions> = {
|
||||||
vector: new VectorColumnOptions()
|
vector: new VectorColumnOptions()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* If true then string columns will be encoded with dictionary encoding
|
||||||
|
*
|
||||||
|
* Set this to true if your string columns tend to repeat the same values
|
||||||
|
* often. For more precise control use the `schema` property to specify the
|
||||||
|
* data type for individual columns.
|
||||||
|
*
|
||||||
|
* If `schema` is provided then this property is ignored.
|
||||||
|
*/
|
||||||
|
dictionaryEncodeStrings: boolean = false
|
||||||
|
|
||||||
constructor (values?: Partial<MakeArrowTableOptions>) {
|
constructor (values?: Partial<MakeArrowTableOptions>) {
|
||||||
Object.assign(this, values)
|
Object.assign(this, values)
|
||||||
}
|
}
|
||||||
@@ -62,8 +101,29 @@ export class MakeArrowTableOptions {
|
|||||||
* An enhanced version of the {@link makeTable} function from Apache Arrow
|
* An enhanced version of the {@link makeTable} function from Apache Arrow
|
||||||
* that supports nested fields and embeddings columns.
|
* that supports nested fields and embeddings columns.
|
||||||
*
|
*
|
||||||
|
* This function converts an array of Record<String, any> (row-major JS objects)
|
||||||
|
* to an Arrow Table (a columnar structure)
|
||||||
|
*
|
||||||
* Note that it currently does not support nulls.
|
* Note that it currently does not support nulls.
|
||||||
*
|
*
|
||||||
|
* If a schema is provided then it will be used to determine the resulting array
|
||||||
|
* types. Fields will also be reordered to fit the order defined by the schema.
|
||||||
|
*
|
||||||
|
* If a schema is not provided then the types will be inferred and the field order
|
||||||
|
* will be controlled by the order of properties in the first record.
|
||||||
|
*
|
||||||
|
* If the input is empty then a schema must be provided to create an empty table.
|
||||||
|
*
|
||||||
|
* When a schema is not specified then data types will be inferred. The inference
|
||||||
|
* rules are as follows:
|
||||||
|
*
|
||||||
|
* - boolean => Bool
|
||||||
|
* - number => Float64
|
||||||
|
* - String => Utf8
|
||||||
|
* - Buffer => Binary
|
||||||
|
* - Record<String, any> => Struct
|
||||||
|
* - Array<any> => List
|
||||||
|
*
|
||||||
* @param data input data
|
* @param data input data
|
||||||
* @param options options to control the makeArrowTable call.
|
* @param options options to control the makeArrowTable call.
|
||||||
*
|
*
|
||||||
@@ -86,8 +146,10 @@ export class MakeArrowTableOptions {
|
|||||||
* ], { schema });
|
* ], { schema });
|
||||||
* ```
|
* ```
|
||||||
*
|
*
|
||||||
* It guesses the vector columns if the schema is not provided. For example,
|
* By default it assumes that the column named `vector` is a vector column
|
||||||
* by default it assumes that the column named `vector` is a vector column.
|
* and it will be converted into a fixed size list array of type float32.
|
||||||
|
* The `vectorColumns` option can be used to support other vector column
|
||||||
|
* names and data types.
|
||||||
*
|
*
|
||||||
* ```ts
|
* ```ts
|
||||||
*
|
*
|
||||||
@@ -134,211 +196,304 @@ export function makeArrowTable (
|
|||||||
data: Array<Record<string, any>>,
|
data: Array<Record<string, any>>,
|
||||||
options?: Partial<MakeArrowTableOptions>
|
options?: Partial<MakeArrowTableOptions>
|
||||||
): ArrowTable {
|
): ArrowTable {
|
||||||
if (data.length === 0) {
|
if (data.length === 0 && (options?.schema === undefined || options?.schema === null)) {
|
||||||
throw new Error('At least one record needs to be provided')
|
throw new Error('At least one record or a schema needs to be provided')
|
||||||
}
|
}
|
||||||
|
|
||||||
const opt = new MakeArrowTableOptions(options !== undefined ? options : {})
|
const opt = new MakeArrowTableOptions(options !== undefined ? options : {})
|
||||||
const columns: Record<string, Vector> = {}
|
const columns: Record<string, Vector> = {}
|
||||||
// TODO: sample dataset to find missing columns
|
// TODO: sample dataset to find missing columns
|
||||||
const columnNames = Object.keys(data[0])
|
// Prefer the field ordering of the schema, if present
|
||||||
|
const columnNames = ((options?.schema) != null) ? (options?.schema?.names as string[]) : Object.keys(data[0])
|
||||||
for (const colName of columnNames) {
|
for (const colName of columnNames) {
|
||||||
const values = data.map((datum) => datum[colName])
|
if (data.length !== 0 && !Object.prototype.hasOwnProperty.call(data[0], colName)) {
|
||||||
let vector: Vector
|
// The field is present in the schema, but not in the data, skip it
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Extract a single column from the records (transpose from row-major to col-major)
|
||||||
|
let values = data.map((datum) => datum[colName])
|
||||||
|
|
||||||
|
// By default (type === undefined) arrow will infer the type from the JS type
|
||||||
|
let type
|
||||||
if (opt.schema !== undefined) {
|
if (opt.schema !== undefined) {
|
||||||
// Explicit schema is provided, highest priority
|
// If there is a schema provided, then use that for the type instead
|
||||||
vector = vectorFromArray(
|
type = opt.schema?.fields.filter((f) => f.name === colName)[0]?.type
|
||||||
values,
|
if (DataType.isInt(type) && type.bitWidth === 64) {
|
||||||
opt.schema?.fields.filter((f) => f.name === colName)[0]?.type
|
// wrap in BigInt to avoid bug: https://github.com/apache/arrow/issues/40051
|
||||||
)
|
values = values.map((v) => {
|
||||||
|
if (v === null) {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return BigInt(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
|
// Otherwise, check to see if this column is one of the vector columns
|
||||||
|
// defined by opt.vectorColumns and, if so, use the fixed size list type
|
||||||
const vectorColumnOptions = opt.vectorColumns[colName]
|
const vectorColumnOptions = opt.vectorColumns[colName]
|
||||||
if (vectorColumnOptions !== undefined) {
|
if (vectorColumnOptions !== undefined) {
|
||||||
const fslType = new FixedSizeList(
|
type = newVectorType(values[0].length, vectorColumnOptions.type)
|
||||||
values[0].length,
|
|
||||||
new Field('item', vectorColumnOptions.type, false)
|
|
||||||
)
|
|
||||||
vector = vectorFromArray(values, fslType)
|
|
||||||
} else {
|
|
||||||
// Normal case
|
|
||||||
vector = vectorFromArray(values)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
columns[colName] = vector
|
|
||||||
|
try {
|
||||||
|
// Convert an Array of JS values to an arrow vector
|
||||||
|
columns[colName] = makeVector(values, type, opt.dictionaryEncodeStrings)
|
||||||
|
} catch (error: unknown) {
|
||||||
|
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||||
|
throw Error(`Could not convert column "${colName}" to Arrow: ${error}`)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return new ArrowTable(columns)
|
if (opt.schema != null) {
|
||||||
|
// `new ArrowTable(columns)` infers a schema which may sometimes have
|
||||||
|
// incorrect nullability (it assumes nullable=true if there are 0 rows)
|
||||||
|
//
|
||||||
|
// `new ArrowTable(schema, columns)` will also fail because it will create a
|
||||||
|
// batch with an inferred schema and then complain that the batch schema
|
||||||
|
// does not match the provided schema.
|
||||||
|
//
|
||||||
|
// To work around this we first create a table with the wrong schema and
|
||||||
|
// then patch the schema of the batches so we can use
|
||||||
|
// `new ArrowTable(schema, batches)` which does not do any schema inference
|
||||||
|
const firstTable = new ArrowTable(columns)
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||||
|
const batchesFixed = firstTable.batches.map(batch => new RecordBatch(opt.schema!, batch.data))
|
||||||
|
return new ArrowTable(opt.schema, batchesFixed)
|
||||||
|
} else {
|
||||||
|
return new ArrowTable(columns)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts an Array of records into an Arrow Table, optionally applying an embeddings function to it.
|
/**
|
||||||
|
* Create an empty Arrow table with the provided schema
|
||||||
|
*/
|
||||||
|
export function makeEmptyTable (schema: Schema): ArrowTable {
|
||||||
|
return makeArrowTable([], { schema })
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to convert Array<Array<any>> to a variable sized list array
|
||||||
|
function makeListVector (lists: any[][]): Vector<any> {
|
||||||
|
if (lists.length === 0 || lists[0].length === 0) {
|
||||||
|
throw Error('Cannot infer list vector from empty array or empty list')
|
||||||
|
}
|
||||||
|
const sampleList = lists[0]
|
||||||
|
let inferredType
|
||||||
|
try {
|
||||||
|
const sampleVector = makeVector(sampleList)
|
||||||
|
inferredType = sampleVector.type
|
||||||
|
} catch (error: unknown) {
|
||||||
|
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||||
|
throw Error(`Cannot infer list vector. Cannot infer inner type: ${error}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
const listBuilder = makeBuilder({
|
||||||
|
type: new List(new Field('item', inferredType, true))
|
||||||
|
})
|
||||||
|
for (const list of lists) {
|
||||||
|
listBuilder.append(list)
|
||||||
|
}
|
||||||
|
return listBuilder.finish().toVector()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to convert an Array of JS values to an Arrow Vector
|
||||||
|
function makeVector (values: any[], type?: DataType, stringAsDictionary?: boolean): Vector<any> {
|
||||||
|
if (type !== undefined) {
|
||||||
|
// No need for inference, let Arrow create it
|
||||||
|
return vectorFromArray(values, type)
|
||||||
|
}
|
||||||
|
if (values.length === 0) {
|
||||||
|
throw Error('makeVector requires at least one value or the type must be specfied')
|
||||||
|
}
|
||||||
|
const sampleValue = values.find(val => val !== null && val !== undefined)
|
||||||
|
if (sampleValue === undefined) {
|
||||||
|
throw Error('makeVector cannot infer the type if all values are null or undefined')
|
||||||
|
}
|
||||||
|
if (Array.isArray(sampleValue)) {
|
||||||
|
// Default Arrow inference doesn't handle list types
|
||||||
|
return makeListVector(values)
|
||||||
|
} else if (Buffer.isBuffer(sampleValue)) {
|
||||||
|
// Default Arrow inference doesn't handle Buffer
|
||||||
|
return vectorFromArray(values, new Binary())
|
||||||
|
} else if (!(stringAsDictionary ?? false) && (typeof sampleValue === 'string' || sampleValue instanceof String)) {
|
||||||
|
// If the type is string then don't use Arrow's default inference unless dictionaries are requested
|
||||||
|
// because it will always use dictionary encoding for strings
|
||||||
|
return vectorFromArray(values, new Utf8())
|
||||||
|
} else {
|
||||||
|
// Convert a JS array of values to an arrow vector
|
||||||
|
return vectorFromArray(values)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function applyEmbeddings<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<ArrowTable> {
|
||||||
|
if (embeddings == null) {
|
||||||
|
return table
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert from ArrowTable to Record<String, Vector>
|
||||||
|
const colEntries = [...Array(table.numCols).keys()].map((_, idx) => {
|
||||||
|
const name = table.schema.fields[idx].name
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||||
|
const vec = table.getChildAt(idx)!
|
||||||
|
return [name, vec]
|
||||||
|
})
|
||||||
|
const newColumns = Object.fromEntries(colEntries)
|
||||||
|
|
||||||
|
const sourceColumn = newColumns[embeddings.sourceColumn]
|
||||||
|
const destColumn = embeddings.destColumn ?? 'vector'
|
||||||
|
const innerDestType = embeddings.embeddingDataType ?? new Float32()
|
||||||
|
if (sourceColumn === undefined) {
|
||||||
|
throw new Error(`Cannot apply embedding function because the source column '${embeddings.sourceColumn}' was not present in the data`)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (table.numRows === 0) {
|
||||||
|
if (Object.prototype.hasOwnProperty.call(newColumns, destColumn)) {
|
||||||
|
// We have an empty table and it already has the embedding column so no work needs to be done
|
||||||
|
// Note: we don't return an error like we did below because this is a common occurrence. For example,
|
||||||
|
// if we call convertToTable with 0 records and a schema that includes the embedding
|
||||||
|
return table
|
||||||
|
}
|
||||||
|
if (embeddings.embeddingDimension !== undefined) {
|
||||||
|
const destType = newVectorType(embeddings.embeddingDimension, innerDestType)
|
||||||
|
newColumns[destColumn] = makeVector([], destType)
|
||||||
|
} else if (schema != null) {
|
||||||
|
const destField = schema.fields.find(f => f.name === destColumn)
|
||||||
|
if (destField != null) {
|
||||||
|
newColumns[destColumn] = makeVector([], destField.type)
|
||||||
|
} else {
|
||||||
|
throw new Error(`Attempt to apply embeddings to an empty table failed because schema was missing embedding column '${destColumn}'`)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw new Error('Attempt to apply embeddings to an empty table when the embeddings function does not specify `embeddingDimension`')
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (Object.prototype.hasOwnProperty.call(newColumns, destColumn)) {
|
||||||
|
throw new Error(`Attempt to apply embeddings to table failed because column ${destColumn} already existed`)
|
||||||
|
}
|
||||||
|
if (table.batches.length > 1) {
|
||||||
|
throw new Error('Internal error: `makeArrowTable` unexpectedly created a table with more than one batch')
|
||||||
|
}
|
||||||
|
const values = sourceColumn.toArray()
|
||||||
|
const vectors = await embeddings.embed(values as T[])
|
||||||
|
if (vectors.length !== values.length) {
|
||||||
|
throw new Error('Embedding function did not return an embedding for each input element')
|
||||||
|
}
|
||||||
|
const destType = newVectorType(vectors[0].length, innerDestType)
|
||||||
|
newColumns[destColumn] = makeVector(vectors, destType)
|
||||||
|
}
|
||||||
|
|
||||||
|
const newTable = new ArrowTable(newColumns)
|
||||||
|
if (schema != null) {
|
||||||
|
if (schema.fields.find(f => f.name === destColumn) === undefined) {
|
||||||
|
throw new Error(`When using embedding functions and specifying a schema the schema should include the embedding column but the column ${destColumn} was missing`)
|
||||||
|
}
|
||||||
|
return alignTable(newTable, schema)
|
||||||
|
}
|
||||||
|
return newTable
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Convert an Array of records into an Arrow Table, optionally applying an
|
||||||
|
* embeddings function to it.
|
||||||
|
*
|
||||||
|
* This function calls `makeArrowTable` first to create the Arrow Table.
|
||||||
|
* Any provided `makeTableOptions` (e.g. a schema) will be passed on to
|
||||||
|
* that call.
|
||||||
|
*
|
||||||
|
* The embedding function will be passed a column of values (based on the
|
||||||
|
* `sourceColumn` of the embedding function) and expects to receive back
|
||||||
|
* number[][] which will be converted into a fixed size list column. By
|
||||||
|
* default this will be a fixed size list of Float32 but that can be
|
||||||
|
* customized by the `embeddingDataType` property of the embedding function.
|
||||||
|
*
|
||||||
|
* If a schema is provided in `makeTableOptions` then it should include the
|
||||||
|
* embedding columns. If no schema is provded then embedding columns will
|
||||||
|
* be placed at the end of the table, after all of the input columns.
|
||||||
|
*/
|
||||||
export async function convertToTable<T> (
|
export async function convertToTable<T> (
|
||||||
data: Array<Record<string, unknown>>,
|
data: Array<Record<string, unknown>>,
|
||||||
embeddings?: EmbeddingFunction<T>
|
embeddings?: EmbeddingFunction<T>,
|
||||||
|
makeTableOptions?: Partial<MakeArrowTableOptions>
|
||||||
): Promise<ArrowTable> {
|
): Promise<ArrowTable> {
|
||||||
if (data.length === 0) {
|
const table = makeArrowTable(data, makeTableOptions)
|
||||||
throw new Error('At least one record needs to be provided')
|
return await applyEmbeddings(table, embeddings, makeTableOptions?.schema)
|
||||||
}
|
|
||||||
|
|
||||||
const columns = Object.keys(data[0])
|
|
||||||
const records: Record<string, Vector> = {}
|
|
||||||
|
|
||||||
for (const columnsKey of columns) {
|
|
||||||
if (columnsKey === 'vector') {
|
|
||||||
const vectorSize = (data[0].vector as any[]).length
|
|
||||||
const listBuilder = newVectorBuilder(vectorSize)
|
|
||||||
for (const datum of data) {
|
|
||||||
if ((datum[columnsKey] as any[]).length !== vectorSize) {
|
|
||||||
throw new Error(`Invalid vector size, expected ${vectorSize}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
listBuilder.append(datum[columnsKey])
|
|
||||||
}
|
|
||||||
records[columnsKey] = listBuilder.finish().toVector()
|
|
||||||
} else {
|
|
||||||
const values = []
|
|
||||||
for (const datum of data) {
|
|
||||||
values.push(datum[columnsKey])
|
|
||||||
}
|
|
||||||
|
|
||||||
if (columnsKey === embeddings?.sourceColumn) {
|
|
||||||
const vectors = await embeddings.embed(values as T[])
|
|
||||||
records.vector = vectorFromArray(
|
|
||||||
vectors,
|
|
||||||
newVectorType(vectors[0].length)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if (typeof values[0] === 'string') {
|
|
||||||
// `vectorFromArray` converts strings into dictionary vectors, forcing it back to a string column
|
|
||||||
records[columnsKey] = vectorFromArray(values, new Utf8())
|
|
||||||
} else if (Array.isArray(values[0])) {
|
|
||||||
const elementType = getElementType(values[0])
|
|
||||||
let innerType
|
|
||||||
if (elementType === 'string') {
|
|
||||||
innerType = new Utf8()
|
|
||||||
} else if (elementType === 'number') {
|
|
||||||
innerType = new Float64()
|
|
||||||
} else {
|
|
||||||
// TODO: pass in schema if it exists, else keep going to the next element
|
|
||||||
throw new Error(`Unsupported array element type ${elementType}`)
|
|
||||||
}
|
|
||||||
const listBuilder = makeBuilder({
|
|
||||||
type: new List(new Field('item', innerType, true))
|
|
||||||
})
|
|
||||||
for (const value of values) {
|
|
||||||
listBuilder.append(value)
|
|
||||||
}
|
|
||||||
records[columnsKey] = listBuilder.finish().toVector()
|
|
||||||
} else {
|
|
||||||
// TODO if this is a struct field then recursively align the subfields
|
|
||||||
records[columnsKey] = vectorFromArray(values)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return new ArrowTable(records)
|
|
||||||
}
|
|
||||||
|
|
||||||
function getElementType (arr: any[]): string {
|
|
||||||
if (arr.length === 0) {
|
|
||||||
return 'undefined'
|
|
||||||
}
|
|
||||||
|
|
||||||
return typeof arr[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Creates a new Arrow ListBuilder that stores a Vector column
|
|
||||||
function newVectorBuilder (dim: number): FixedSizeListBuilder<Float32> {
|
|
||||||
return makeBuilder({
|
|
||||||
type: newVectorType(dim)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates the Arrow Type for a Vector column with dimension `dim`
|
// Creates the Arrow Type for a Vector column with dimension `dim`
|
||||||
function newVectorType (dim: number): FixedSizeList<Float32> {
|
function newVectorType <T extends Float> (dim: number, innerType: T): FixedSizeList<T> {
|
||||||
// Somewhere we always default to have the elements nullable, so we need to set it to true
|
// Somewhere we always default to have the elements nullable, so we need to set it to true
|
||||||
// otherwise we often get schema mismatches because the stored data always has schema with nullable elements
|
// otherwise we often get schema mismatches because the stored data always has schema with nullable elements
|
||||||
const children = new Field<Float32>('item', new Float32(), true)
|
const children = new Field<T>('item', innerType, true)
|
||||||
return new FixedSizeList(dim, children)
|
return new FixedSizeList(dim, children)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts an Array of records into Arrow IPC format
|
/**
|
||||||
|
* Serialize an Array of records into a buffer using the Arrow IPC File serialization
|
||||||
|
*
|
||||||
|
* This function will call `convertToTable` and pass on `embeddings` and `schema`
|
||||||
|
*
|
||||||
|
* `schema` is required if data is empty
|
||||||
|
*/
|
||||||
export async function fromRecordsToBuffer<T> (
|
export async function fromRecordsToBuffer<T> (
|
||||||
data: Array<Record<string, unknown>>,
|
data: Array<Record<string, unknown>>,
|
||||||
embeddings?: EmbeddingFunction<T>,
|
embeddings?: EmbeddingFunction<T>,
|
||||||
schema?: Schema
|
schema?: Schema
|
||||||
): Promise<Buffer> {
|
): Promise<Buffer> {
|
||||||
let table = await convertToTable(data, embeddings)
|
const table = await convertToTable(data, embeddings, { schema })
|
||||||
if (schema !== undefined) {
|
|
||||||
table = alignTable(table, schema)
|
|
||||||
}
|
|
||||||
const writer = RecordBatchFileWriter.writeAll(table)
|
const writer = RecordBatchFileWriter.writeAll(table)
|
||||||
return Buffer.from(await writer.toUint8Array())
|
return Buffer.from(await writer.toUint8Array())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts an Array of records into Arrow IPC stream format
|
/**
|
||||||
|
* Serialize an Array of records into a buffer using the Arrow IPC Stream serialization
|
||||||
|
*
|
||||||
|
* This function will call `convertToTable` and pass on `embeddings` and `schema`
|
||||||
|
*
|
||||||
|
* `schema` is required if data is empty
|
||||||
|
*/
|
||||||
export async function fromRecordsToStreamBuffer<T> (
|
export async function fromRecordsToStreamBuffer<T> (
|
||||||
data: Array<Record<string, unknown>>,
|
data: Array<Record<string, unknown>>,
|
||||||
embeddings?: EmbeddingFunction<T>,
|
embeddings?: EmbeddingFunction<T>,
|
||||||
schema?: Schema
|
schema?: Schema
|
||||||
): Promise<Buffer> {
|
): Promise<Buffer> {
|
||||||
let table = await convertToTable(data, embeddings)
|
const table = await convertToTable(data, embeddings, { schema })
|
||||||
if (schema !== undefined) {
|
|
||||||
table = alignTable(table, schema)
|
|
||||||
}
|
|
||||||
const writer = RecordBatchStreamWriter.writeAll(table)
|
const writer = RecordBatchStreamWriter.writeAll(table)
|
||||||
return Buffer.from(await writer.toUint8Array())
|
return Buffer.from(await writer.toUint8Array())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts an Arrow Table into Arrow IPC format
|
/**
|
||||||
|
* Serialize an Arrow Table into a buffer using the Arrow IPC File serialization
|
||||||
|
*
|
||||||
|
* This function will apply `embeddings` to the table in a manner similar to
|
||||||
|
* `convertToTable`.
|
||||||
|
*
|
||||||
|
* `schema` is required if the table is empty
|
||||||
|
*/
|
||||||
export async function fromTableToBuffer<T> (
|
export async function fromTableToBuffer<T> (
|
||||||
table: ArrowTable,
|
table: ArrowTable,
|
||||||
embeddings?: EmbeddingFunction<T>,
|
embeddings?: EmbeddingFunction<T>,
|
||||||
schema?: Schema
|
schema?: Schema
|
||||||
): Promise<Buffer> {
|
): Promise<Buffer> {
|
||||||
if (embeddings !== undefined) {
|
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema)
|
||||||
const source = table.getChild(embeddings.sourceColumn)
|
const writer = RecordBatchFileWriter.writeAll(tableWithEmbeddings)
|
||||||
|
|
||||||
if (source === null) {
|
|
||||||
throw new Error(
|
|
||||||
`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
const vectors = await embeddings.embed(source.toArray() as T[])
|
|
||||||
const column = vectorFromArray(vectors, newVectorType(vectors[0].length))
|
|
||||||
table = table.assign(new ArrowTable({ vector: column }))
|
|
||||||
}
|
|
||||||
if (schema !== undefined) {
|
|
||||||
table = alignTable(table, schema)
|
|
||||||
}
|
|
||||||
const writer = RecordBatchFileWriter.writeAll(table)
|
|
||||||
return Buffer.from(await writer.toUint8Array())
|
return Buffer.from(await writer.toUint8Array())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts an Arrow Table into Arrow IPC stream format
|
/**
|
||||||
|
* Serialize an Arrow Table into a buffer using the Arrow IPC Stream serialization
|
||||||
|
*
|
||||||
|
* This function will apply `embeddings` to the table in a manner similar to
|
||||||
|
* `convertToTable`.
|
||||||
|
*
|
||||||
|
* `schema` is required if the table is empty
|
||||||
|
*/
|
||||||
export async function fromTableToStreamBuffer<T> (
|
export async function fromTableToStreamBuffer<T> (
|
||||||
table: ArrowTable,
|
table: ArrowTable,
|
||||||
embeddings?: EmbeddingFunction<T>,
|
embeddings?: EmbeddingFunction<T>,
|
||||||
schema?: Schema
|
schema?: Schema
|
||||||
): Promise<Buffer> {
|
): Promise<Buffer> {
|
||||||
if (embeddings !== undefined) {
|
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema)
|
||||||
const source = table.getChild(embeddings.sourceColumn)
|
const writer = RecordBatchStreamWriter.writeAll(tableWithEmbeddings)
|
||||||
|
|
||||||
if (source === null) {
|
|
||||||
throw new Error(
|
|
||||||
`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
const vectors = await embeddings.embed(source.toArray() as T[])
|
|
||||||
const column = vectorFromArray(vectors, newVectorType(vectors[0].length))
|
|
||||||
table = table.assign(new ArrowTable({ vector: column }))
|
|
||||||
}
|
|
||||||
if (schema !== undefined) {
|
|
||||||
table = alignTable(table, schema)
|
|
||||||
}
|
|
||||||
const writer = RecordBatchStreamWriter.writeAll(table)
|
|
||||||
return Buffer.from(await writer.toUint8Array())
|
return Buffer.from(await writer.toUint8Array())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,18 +12,53 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
import { type Float } from 'apache-arrow'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An embedding function that automatically creates vector representation for a given column.
|
* An embedding function that automatically creates vector representation for a given column.
|
||||||
*/
|
*/
|
||||||
export interface EmbeddingFunction<T> {
|
export interface EmbeddingFunction<T> {
|
||||||
/**
|
/**
|
||||||
* The name of the column that will be used as input for the Embedding Function.
|
* The name of the column that will be used as input for the Embedding Function.
|
||||||
*/
|
*/
|
||||||
sourceColumn: string
|
sourceColumn: string
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a vector representation for the given values.
|
* The data type of the embedding
|
||||||
*/
|
*
|
||||||
|
* The embedding function should return `number`. This will be converted into
|
||||||
|
* an Arrow float array. By default this will be Float32 but this property can
|
||||||
|
* be used to control the conversion.
|
||||||
|
*/
|
||||||
|
embeddingDataType?: Float
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The dimension of the embedding
|
||||||
|
*
|
||||||
|
* This is optional, normally this can be determined by looking at the results of
|
||||||
|
* `embed`. If this is not specified, and there is an attempt to apply the embedding
|
||||||
|
* to an empty table, then that process will fail.
|
||||||
|
*/
|
||||||
|
embeddingDimension?: number
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The name of the column that will contain the embedding
|
||||||
|
*
|
||||||
|
* By default this is "vector"
|
||||||
|
*/
|
||||||
|
destColumn?: string
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Should the source column be excluded from the resulting table
|
||||||
|
*
|
||||||
|
* By default the source column is included. Set this to true and
|
||||||
|
* only the embedding will be stored.
|
||||||
|
*/
|
||||||
|
excludeSource?: boolean
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a vector representation for the given values.
|
||||||
|
*/
|
||||||
embed: (data: T[]) => Promise<number[][]>
|
embed: (data: T[]) => Promise<number[][]>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -42,14 +42,17 @@ const {
|
|||||||
tableCompactFiles,
|
tableCompactFiles,
|
||||||
tableListIndices,
|
tableListIndices,
|
||||||
tableIndexStats,
|
tableIndexStats,
|
||||||
tableSchema
|
tableSchema,
|
||||||
|
tableAddColumns,
|
||||||
|
tableAlterColumns,
|
||||||
|
tableDropColumns
|
||||||
// eslint-disable-next-line @typescript-eslint/no-var-requires
|
// eslint-disable-next-line @typescript-eslint/no-var-requires
|
||||||
} = require('../native.js')
|
} = require('../native.js')
|
||||||
|
|
||||||
export { Query }
|
export { Query }
|
||||||
export type { EmbeddingFunction }
|
export type { EmbeddingFunction }
|
||||||
export { OpenAIEmbeddingFunction } from './embedding/openai'
|
export { OpenAIEmbeddingFunction } from './embedding/openai'
|
||||||
export { makeArrowTable, type MakeArrowTableOptions } from './arrow'
|
export { convertToTable, makeArrowTable, type MakeArrowTableOptions } from './arrow'
|
||||||
|
|
||||||
const defaultAwsRegion = 'us-west-2'
|
const defaultAwsRegion = 'us-west-2'
|
||||||
|
|
||||||
@@ -96,6 +99,19 @@ export interface ConnectionOptions {
|
|||||||
* This is useful for local testing.
|
* This is useful for local testing.
|
||||||
*/
|
*/
|
||||||
hostOverride?: string
|
hostOverride?: string
|
||||||
|
|
||||||
|
/**
|
||||||
|
* (For LanceDB OSS only): The interval, in seconds, at which to check for
|
||||||
|
* updates to the table from other processes. If None, then consistency is not
|
||||||
|
* checked. For performance reasons, this is the default. For strong
|
||||||
|
* consistency, set this to zero seconds. Then every read will check for
|
||||||
|
* updates from other processes. As a compromise, you can set this to a
|
||||||
|
* non-zero value for eventual consistency. If more than that interval
|
||||||
|
* has passed since the last check, then the table will be checked for updates.
|
||||||
|
* Note: this consistency only applies to read operations. Write operations are
|
||||||
|
* always consistent.
|
||||||
|
*/
|
||||||
|
readConsistencyInterval?: number
|
||||||
}
|
}
|
||||||
|
|
||||||
function getAwsArgs (opts: ConnectionOptions): any[] {
|
function getAwsArgs (opts: ConnectionOptions): any[] {
|
||||||
@@ -181,7 +197,8 @@ export async function connect (
|
|||||||
opts.awsCredentials?.accessKeyId,
|
opts.awsCredentials?.accessKeyId,
|
||||||
opts.awsCredentials?.secretKey,
|
opts.awsCredentials?.secretKey,
|
||||||
opts.awsCredentials?.sessionToken,
|
opts.awsCredentials?.sessionToken,
|
||||||
opts.awsRegion
|
opts.awsRegion,
|
||||||
|
opts.readConsistencyInterval
|
||||||
)
|
)
|
||||||
return new LocalConnection(db, opts)
|
return new LocalConnection(db, opts)
|
||||||
}
|
}
|
||||||
@@ -372,7 +389,7 @@ export interface Table<T = number[]> {
|
|||||||
/**
|
/**
|
||||||
* Returns the number of rows in this table.
|
* Returns the number of rows in this table.
|
||||||
*/
|
*/
|
||||||
countRows: () => Promise<number>
|
countRows: (filter?: string) => Promise<number>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Delete rows from this table.
|
* Delete rows from this table.
|
||||||
@@ -486,6 +503,59 @@ export interface Table<T = number[]> {
|
|||||||
filter(value: string): Query<T>
|
filter(value: string): Query<T>
|
||||||
|
|
||||||
schema: Promise<Schema>
|
schema: Promise<Schema>
|
||||||
|
|
||||||
|
// TODO: Support BatchUDF
|
||||||
|
/**
|
||||||
|
* Add new columns with defined values.
|
||||||
|
*
|
||||||
|
* @param newColumnTransforms pairs of column names and the SQL expression to use
|
||||||
|
* to calculate the value of the new column. These
|
||||||
|
* expressions will be evaluated for each row in the
|
||||||
|
* table, and can reference existing columns in the table.
|
||||||
|
*/
|
||||||
|
addColumns(newColumnTransforms: Array<{ name: string, valueSql: string }>): Promise<void>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Alter the name or nullability of columns.
|
||||||
|
*
|
||||||
|
* @param columnAlterations One or more alterations to apply to columns.
|
||||||
|
*/
|
||||||
|
alterColumns(columnAlterations: ColumnAlteration[]): Promise<void>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Drop one or more columns from the dataset
|
||||||
|
*
|
||||||
|
* This is a metadata-only operation and does not remove the data from the
|
||||||
|
* underlying storage. In order to remove the data, you must subsequently
|
||||||
|
* call ``compact_files`` to rewrite the data without the removed columns and
|
||||||
|
* then call ``cleanup_files`` to remove the old files.
|
||||||
|
*
|
||||||
|
* @param columnNames The names of the columns to drop. These can be nested
|
||||||
|
* column references (e.g. "a.b.c") or top-level column
|
||||||
|
* names (e.g. "a").
|
||||||
|
*/
|
||||||
|
dropColumns(columnNames: string[]): Promise<void>
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A definition of a column alteration. The alteration changes the column at
|
||||||
|
* `path` to have the new name `name`, to be nullable if `nullable` is true,
|
||||||
|
* and to have the data type `data_type`. At least one of `rename` or `nullable`
|
||||||
|
* must be provided.
|
||||||
|
*/
|
||||||
|
export interface ColumnAlteration {
|
||||||
|
/**
|
||||||
|
* The path to the column to alter. This is a dot-separated path to the column.
|
||||||
|
* If it is a top-level column then it is just the name of the column. If it is
|
||||||
|
* a nested column then it is the path to the column, e.g. "a.b.c" for a column
|
||||||
|
* `c` nested inside a column `b` nested inside a column `a`.
|
||||||
|
*/
|
||||||
|
path: string
|
||||||
|
rename?: string
|
||||||
|
/**
|
||||||
|
* Set the new nullability. Note that a nullable column cannot be made non-nullable.
|
||||||
|
*/
|
||||||
|
nullable?: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface UpdateArgs {
|
export interface UpdateArgs {
|
||||||
@@ -525,8 +595,19 @@ export interface MergeInsertArgs {
|
|||||||
* If there are multiple matches then the behavior is undefined.
|
* If there are multiple matches then the behavior is undefined.
|
||||||
* Currently this causes multiple copies of the row to be created
|
* Currently this causes multiple copies of the row to be created
|
||||||
* but that behavior is subject to change.
|
* but that behavior is subject to change.
|
||||||
|
*
|
||||||
|
* Optionally, a filter can be specified. This should be an SQL
|
||||||
|
* filter where fields with the prefix "target." refer to fields
|
||||||
|
* in the target table (old data) and fields with the prefix
|
||||||
|
* "source." refer to fields in the source table (new data). For
|
||||||
|
* example, the filter "target.lastUpdated < source.lastUpdated" will
|
||||||
|
* only update matched rows when the incoming `lastUpdated` value is
|
||||||
|
* newer.
|
||||||
|
*
|
||||||
|
* Rows that do not match the filter will not be updated. Rows that
|
||||||
|
* do not match the filter do become "not matched" rows.
|
||||||
*/
|
*/
|
||||||
whenMatchedUpdateAll?: boolean
|
whenMatchedUpdateAll?: string | boolean
|
||||||
/**
|
/**
|
||||||
* If true then rows that exist only in the source table (new data)
|
* If true then rows that exist only in the source table (new data)
|
||||||
* will be inserted into the target table.
|
* will be inserted into the target table.
|
||||||
@@ -840,8 +921,8 @@ export class LocalTable<T = number[]> implements Table<T> {
|
|||||||
/**
|
/**
|
||||||
* Returns the number of rows in this table.
|
* Returns the number of rows in this table.
|
||||||
*/
|
*/
|
||||||
async countRows (): Promise<number> {
|
async countRows (filter?: string): Promise<number> {
|
||||||
return tableCountRows.call(this._tbl)
|
return tableCountRows.call(this._tbl, filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -885,7 +966,14 @@ export class LocalTable<T = number[]> implements Table<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
|
async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
|
||||||
const whenMatchedUpdateAll = args.whenMatchedUpdateAll ?? false
|
let whenMatchedUpdateAll = false
|
||||||
|
let whenMatchedUpdateAllFilt = null
|
||||||
|
if (args.whenMatchedUpdateAll !== undefined && args.whenMatchedUpdateAll !== null) {
|
||||||
|
whenMatchedUpdateAll = true
|
||||||
|
if (args.whenMatchedUpdateAll !== true) {
|
||||||
|
whenMatchedUpdateAllFilt = args.whenMatchedUpdateAll
|
||||||
|
}
|
||||||
|
}
|
||||||
const whenNotMatchedInsertAll = args.whenNotMatchedInsertAll ?? false
|
const whenNotMatchedInsertAll = args.whenNotMatchedInsertAll ?? false
|
||||||
let whenNotMatchedBySourceDelete = false
|
let whenNotMatchedBySourceDelete = false
|
||||||
let whenNotMatchedBySourceDeleteFilt = null
|
let whenNotMatchedBySourceDeleteFilt = null
|
||||||
@@ -909,6 +997,7 @@ export class LocalTable<T = number[]> implements Table<T> {
|
|||||||
this._tbl,
|
this._tbl,
|
||||||
on,
|
on,
|
||||||
whenMatchedUpdateAll,
|
whenMatchedUpdateAll,
|
||||||
|
whenMatchedUpdateAllFilt,
|
||||||
whenNotMatchedInsertAll,
|
whenNotMatchedInsertAll,
|
||||||
whenNotMatchedBySourceDelete,
|
whenNotMatchedBySourceDelete,
|
||||||
whenNotMatchedBySourceDeleteFilt,
|
whenNotMatchedBySourceDeleteFilt,
|
||||||
@@ -995,6 +1084,18 @@ export class LocalTable<T = number[]> implements Table<T> {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async addColumns (newColumnTransforms: Array<{ name: string, valueSql: string }>): Promise<void> {
|
||||||
|
return tableAddColumns.call(this._tbl, newColumnTransforms)
|
||||||
|
}
|
||||||
|
|
||||||
|
async alterColumns (columnAlterations: ColumnAlteration[]): Promise<void> {
|
||||||
|
return tableAlterColumns.call(this._tbl, columnAlterations)
|
||||||
|
}
|
||||||
|
|
||||||
|
async dropColumns (columnNames: string[]): Promise<void> {
|
||||||
|
return tableDropColumns.call(this._tbl, columnNames)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface CleanupStats {
|
export interface CleanupStats {
|
||||||
|
|||||||
@@ -25,7 +25,8 @@ import {
|
|||||||
type UpdateArgs,
|
type UpdateArgs,
|
||||||
type UpdateSqlArgs,
|
type UpdateSqlArgs,
|
||||||
makeArrowTable,
|
makeArrowTable,
|
||||||
type MergeInsertArgs
|
type MergeInsertArgs,
|
||||||
|
type ColumnAlteration
|
||||||
} from '../index'
|
} from '../index'
|
||||||
import { Query } from '../query'
|
import { Query } from '../query'
|
||||||
|
|
||||||
@@ -286,8 +287,11 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
|||||||
const queryParams: any = {
|
const queryParams: any = {
|
||||||
on
|
on
|
||||||
}
|
}
|
||||||
if (args.whenMatchedUpdateAll ?? false) {
|
if (args.whenMatchedUpdateAll !== false && args.whenMatchedUpdateAll !== null && args.whenMatchedUpdateAll !== undefined) {
|
||||||
queryParams.when_matched_update_all = 'true'
|
queryParams.when_matched_update_all = 'true'
|
||||||
|
if (typeof args.whenMatchedUpdateAll === 'string') {
|
||||||
|
queryParams.when_matched_update_all_filt = args.whenMatchedUpdateAll
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
queryParams.when_matched_update_all = 'false'
|
queryParams.when_matched_update_all = 'false'
|
||||||
}
|
}
|
||||||
@@ -471,4 +475,16 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
|||||||
numUnindexedRows: results.data.num_unindexed_rows
|
numUnindexedRows: results.data.num_unindexed_rows
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async addColumns (newColumnTransforms: Array<{ name: string, valueSql: string }>): Promise<void> {
|
||||||
|
throw new Error('Add columns is not yet supported in LanceDB Cloud.')
|
||||||
|
}
|
||||||
|
|
||||||
|
async alterColumns (columnAlterations: ColumnAlteration[]): Promise<void> {
|
||||||
|
throw new Error('Alter columns is not yet supported in LanceDB Cloud.')
|
||||||
|
}
|
||||||
|
|
||||||
|
async dropColumns (columnNames: string[]): Promise<void> {
|
||||||
|
throw new Error('Drop columns is not yet supported in LanceDB Cloud.')
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,9 +13,10 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
import { describe } from 'mocha'
|
import { describe } from 'mocha'
|
||||||
import { assert } from 'chai'
|
import { assert, expect, use as chaiUse } from 'chai'
|
||||||
|
import * as chaiAsPromised from 'chai-as-promised'
|
||||||
|
|
||||||
import { fromTableToBuffer, makeArrowTable } from '../arrow'
|
import { convertToTable, fromTableToBuffer, makeArrowTable, makeEmptyTable } from '../arrow'
|
||||||
import {
|
import {
|
||||||
Field,
|
Field,
|
||||||
FixedSizeList,
|
FixedSizeList,
|
||||||
@@ -24,21 +25,79 @@ import {
|
|||||||
Int32,
|
Int32,
|
||||||
tableFromIPC,
|
tableFromIPC,
|
||||||
Schema,
|
Schema,
|
||||||
Float64
|
Float64,
|
||||||
|
type Table,
|
||||||
|
Binary,
|
||||||
|
Bool,
|
||||||
|
Utf8,
|
||||||
|
Struct,
|
||||||
|
List,
|
||||||
|
DataType,
|
||||||
|
Dictionary,
|
||||||
|
Int64
|
||||||
} from 'apache-arrow'
|
} from 'apache-arrow'
|
||||||
|
import { type EmbeddingFunction } from '../embedding/embedding_function'
|
||||||
|
|
||||||
describe('Apache Arrow tables', function () {
|
chaiUse(chaiAsPromised)
|
||||||
it('customized schema', async function () {
|
|
||||||
|
function sampleRecords (): Array<Record<string, any>> {
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
binary: Buffer.alloc(5),
|
||||||
|
boolean: false,
|
||||||
|
number: 7,
|
||||||
|
string: 'hello',
|
||||||
|
struct: { x: 0, y: 0 },
|
||||||
|
list: ['anime', 'action', 'comedy']
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper method to verify various ways to create a table
|
||||||
|
async function checkTableCreation (tableCreationMethod: (records: any, recordsReversed: any, schema: Schema) => Promise<Table>): Promise<void> {
|
||||||
|
const records = sampleRecords()
|
||||||
|
const recordsReversed = [{
|
||||||
|
list: ['anime', 'action', 'comedy'],
|
||||||
|
struct: { x: 0, y: 0 },
|
||||||
|
string: 'hello',
|
||||||
|
number: 7,
|
||||||
|
boolean: false,
|
||||||
|
binary: Buffer.alloc(5)
|
||||||
|
}]
|
||||||
|
const schema = new Schema([
|
||||||
|
new Field('binary', new Binary(), false),
|
||||||
|
new Field('boolean', new Bool(), false),
|
||||||
|
new Field('number', new Float64(), false),
|
||||||
|
new Field('string', new Utf8(), false),
|
||||||
|
new Field('struct', new Struct([
|
||||||
|
new Field('x', new Float64(), false),
|
||||||
|
new Field('y', new Float64(), false)
|
||||||
|
])),
|
||||||
|
new Field('list', new List(new Field('item', new Utf8(), false)), false)
|
||||||
|
])
|
||||||
|
|
||||||
|
const table = await tableCreationMethod(records, recordsReversed, schema)
|
||||||
|
schema.fields.forEach((field, idx) => {
|
||||||
|
const actualField = table.schema.fields[idx]
|
||||||
|
assert.isFalse(actualField.nullable)
|
||||||
|
assert.equal(table.getChild(field.name)?.type.toString(), field.type.toString())
|
||||||
|
assert.equal(table.getChildAt(idx)?.type.toString(), field.type.toString())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('The function makeArrowTable', function () {
|
||||||
|
it('will use data types from a provided schema instead of inference', async function () {
|
||||||
const schema = new Schema([
|
const schema = new Schema([
|
||||||
new Field('a', new Int32()),
|
new Field('a', new Int32()),
|
||||||
new Field('b', new Float32()),
|
new Field('b', new Float32()),
|
||||||
new Field('c', new FixedSizeList(3, new Field('item', new Float16())))
|
new Field('c', new FixedSizeList(3, new Field('item', new Float16()))),
|
||||||
|
new Field('d', new Int64())
|
||||||
])
|
])
|
||||||
const table = makeArrowTable(
|
const table = makeArrowTable(
|
||||||
[
|
[
|
||||||
{ a: 1, b: 2, c: [1, 2, 3] },
|
{ a: 1, b: 2, c: [1, 2, 3], d: 9 },
|
||||||
{ a: 4, b: 5, c: [4, 5, 6] },
|
{ a: 4, b: 5, c: [4, 5, 6], d: 10 },
|
||||||
{ a: 7, b: 8, c: [7, 8, 9] }
|
{ a: 7, b: 8, c: [7, 8, 9], d: null }
|
||||||
],
|
],
|
||||||
{ schema }
|
{ schema }
|
||||||
)
|
)
|
||||||
@@ -52,13 +111,13 @@ describe('Apache Arrow tables', function () {
|
|||||||
assert.deepEqual(actualSchema, schema)
|
assert.deepEqual(actualSchema, schema)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('default vector column', async function () {
|
it('will assume the column `vector` is FixedSizeList<Float32> by default', async function () {
|
||||||
const schema = new Schema([
|
const schema = new Schema([
|
||||||
new Field('a', new Float64()),
|
new Field('a', new Float64()),
|
||||||
new Field('b', new Float64()),
|
new Field('b', new Float64()),
|
||||||
new Field(
|
new Field(
|
||||||
'vector',
|
'vector',
|
||||||
new FixedSizeList(3, new Field('item', new Float32()))
|
new FixedSizeList(3, new Field('item', new Float32(), true))
|
||||||
)
|
)
|
||||||
])
|
])
|
||||||
const table = makeArrowTable([
|
const table = makeArrowTable([
|
||||||
@@ -76,12 +135,12 @@ describe('Apache Arrow tables', function () {
|
|||||||
assert.deepEqual(actualSchema, schema)
|
assert.deepEqual(actualSchema, schema)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('2 vector columns', async function () {
|
it('can support multiple vector columns', async function () {
|
||||||
const schema = new Schema([
|
const schema = new Schema([
|
||||||
new Field('a', new Float64()),
|
new Field('a', new Float64()),
|
||||||
new Field('b', new Float64()),
|
new Field('b', new Float64()),
|
||||||
new Field('vec1', new FixedSizeList(3, new Field('item', new Float16()))),
|
new Field('vec1', new FixedSizeList(3, new Field('item', new Float16(), true))),
|
||||||
new Field('vec2', new FixedSizeList(3, new Field('item', new Float16())))
|
new Field('vec2', new FixedSizeList(3, new Field('item', new Float16(), true)))
|
||||||
])
|
])
|
||||||
const table = makeArrowTable(
|
const table = makeArrowTable(
|
||||||
[
|
[
|
||||||
@@ -105,4 +164,157 @@ describe('Apache Arrow tables', function () {
|
|||||||
const actualSchema = actual.schema
|
const actualSchema = actual.schema
|
||||||
assert.deepEqual(actualSchema, schema)
|
assert.deepEqual(actualSchema, schema)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('will allow different vector column types', async function () {
|
||||||
|
const table = makeArrowTable(
|
||||||
|
[
|
||||||
|
{ fp16: [1], fp32: [1], fp64: [1] }
|
||||||
|
],
|
||||||
|
{
|
||||||
|
vectorColumns: {
|
||||||
|
fp16: { type: new Float16() },
|
||||||
|
fp32: { type: new Float32() },
|
||||||
|
fp64: { type: new Float64() }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.equal(table.getChild('fp16')?.type.children[0].type.toString(), new Float16().toString())
|
||||||
|
assert.equal(table.getChild('fp32')?.type.children[0].type.toString(), new Float32().toString())
|
||||||
|
assert.equal(table.getChild('fp64')?.type.children[0].type.toString(), new Float64().toString())
|
||||||
|
})
|
||||||
|
|
||||||
|
it('will use dictionary encoded strings if asked', async function () {
|
||||||
|
const table = makeArrowTable([{ str: 'hello' }])
|
||||||
|
assert.isTrue(DataType.isUtf8(table.getChild('str')?.type))
|
||||||
|
|
||||||
|
const tableWithDict = makeArrowTable([{ str: 'hello' }], { dictionaryEncodeStrings: true })
|
||||||
|
assert.isTrue(DataType.isDictionary(tableWithDict.getChild('str')?.type))
|
||||||
|
|
||||||
|
const schema = new Schema([
|
||||||
|
new Field('str', new Dictionary(new Utf8(), new Int32()))
|
||||||
|
])
|
||||||
|
|
||||||
|
const tableWithDict2 = makeArrowTable([{ str: 'hello' }], { schema })
|
||||||
|
assert.isTrue(DataType.isDictionary(tableWithDict2.getChild('str')?.type))
|
||||||
|
})
|
||||||
|
|
||||||
|
it('will infer data types correctly', async function () {
|
||||||
|
await checkTableCreation(async (records) => makeArrowTable(records))
|
||||||
|
})
|
||||||
|
|
||||||
|
it('will allow a schema to be provided', async function () {
|
||||||
|
await checkTableCreation(async (records, _, schema) => makeArrowTable(records, { schema }))
|
||||||
|
})
|
||||||
|
|
||||||
|
it('will use the field order of any provided schema', async function () {
|
||||||
|
await checkTableCreation(async (_, recordsReversed, schema) => makeArrowTable(recordsReversed, { schema }))
|
||||||
|
})
|
||||||
|
|
||||||
|
it('will make an empty table', async function () {
|
||||||
|
await checkTableCreation(async (_, __, schema) => makeArrowTable([], { schema }))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
class DummyEmbedding implements EmbeddingFunction<string> {
|
||||||
|
public readonly sourceColumn = 'string'
|
||||||
|
public readonly embeddingDimension = 2
|
||||||
|
public readonly embeddingDataType = new Float16()
|
||||||
|
|
||||||
|
async embed (data: string[]): Promise<number[][]> {
|
||||||
|
return data.map(
|
||||||
|
() => [0.0, 0.0]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class DummyEmbeddingWithNoDimension implements EmbeddingFunction<string> {
|
||||||
|
public readonly sourceColumn = 'string'
|
||||||
|
|
||||||
|
async embed (data: string[]): Promise<number[][]> {
|
||||||
|
return data.map(
|
||||||
|
() => [0.0, 0.0]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('convertToTable', function () {
|
||||||
|
it('will infer data types correctly', async function () {
|
||||||
|
await checkTableCreation(async (records) => await convertToTable(records))
|
||||||
|
})
|
||||||
|
|
||||||
|
it('will allow a schema to be provided', async function () {
|
||||||
|
await checkTableCreation(async (records, _, schema) => await convertToTable(records, undefined, { schema }))
|
||||||
|
})
|
||||||
|
|
||||||
|
it('will use the field order of any provided schema', async function () {
|
||||||
|
await checkTableCreation(async (_, recordsReversed, schema) => await convertToTable(recordsReversed, undefined, { schema }))
|
||||||
|
})
|
||||||
|
|
||||||
|
it('will make an empty table', async function () {
|
||||||
|
await checkTableCreation(async (_, __, schema) => await convertToTable([], undefined, { schema }))
|
||||||
|
})
|
||||||
|
|
||||||
|
it('will apply embeddings', async function () {
|
||||||
|
const records = sampleRecords()
|
||||||
|
const table = await convertToTable(records, new DummyEmbedding())
|
||||||
|
assert.isTrue(DataType.isFixedSizeList(table.getChild('vector')?.type))
|
||||||
|
assert.equal(table.getChild('vector')?.type.children[0].type.toString(), new Float16().toString())
|
||||||
|
})
|
||||||
|
|
||||||
|
it('will fail if missing the embedding source column', async function () {
|
||||||
|
return await expect(convertToTable([{ id: 1 }], new DummyEmbedding())).to.be.rejectedWith("'string' was not present")
|
||||||
|
})
|
||||||
|
|
||||||
|
it('use embeddingDimension if embedding missing from table', async function () {
|
||||||
|
const schema = new Schema([
|
||||||
|
new Field('string', new Utf8(), false)
|
||||||
|
])
|
||||||
|
// Simulate getting an empty Arrow table (minus embedding) from some other source
|
||||||
|
// In other words, we aren't starting with records
|
||||||
|
const table = makeEmptyTable(schema)
|
||||||
|
|
||||||
|
// If the embedding specifies the dimension we are fine
|
||||||
|
await fromTableToBuffer(table, new DummyEmbedding())
|
||||||
|
|
||||||
|
// We can also supply a schema and should be ok
|
||||||
|
const schemaWithEmbedding = new Schema([
|
||||||
|
new Field('string', new Utf8(), false),
|
||||||
|
new Field('vector', new FixedSizeList(2, new Field('item', new Float16(), false)), false)
|
||||||
|
])
|
||||||
|
await fromTableToBuffer(table, new DummyEmbeddingWithNoDimension(), schemaWithEmbedding)
|
||||||
|
|
||||||
|
// Otherwise we will get an error
|
||||||
|
return await expect(fromTableToBuffer(table, new DummyEmbeddingWithNoDimension())).to.be.rejectedWith('does not specify `embeddingDimension`')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('will apply embeddings to an empty table', async function () {
|
||||||
|
const schema = new Schema([
|
||||||
|
new Field('string', new Utf8(), false),
|
||||||
|
new Field('vector', new FixedSizeList(2, new Field('item', new Float16(), false)), false)
|
||||||
|
])
|
||||||
|
const table = await convertToTable([], new DummyEmbedding(), { schema })
|
||||||
|
assert.isTrue(DataType.isFixedSizeList(table.getChild('vector')?.type))
|
||||||
|
assert.equal(table.getChild('vector')?.type.children[0].type.toString(), new Float16().toString())
|
||||||
|
})
|
||||||
|
|
||||||
|
it('will complain if embeddings present but schema missing embedding column', async function () {
|
||||||
|
const schema = new Schema([
|
||||||
|
new Field('string', new Utf8(), false)
|
||||||
|
])
|
||||||
|
return await expect(convertToTable([], new DummyEmbedding(), { schema })).to.be.rejectedWith('column vector was missing')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('will provide a nice error if run twice', async function () {
|
||||||
|
const records = sampleRecords()
|
||||||
|
const table = await convertToTable(records, new DummyEmbedding())
|
||||||
|
// fromTableToBuffer will try and apply the embeddings again
|
||||||
|
return await expect(fromTableToBuffer(table, new DummyEmbedding())).to.be.rejectedWith('already existed')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('makeEmptyTable', function () {
|
||||||
|
it('will make an empty table', async function () {
|
||||||
|
await checkTableCreation(async (_, __, schema) => makeEmptyTable(schema))
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -37,8 +37,10 @@ import {
|
|||||||
Utf8,
|
Utf8,
|
||||||
Table as ArrowTable,
|
Table as ArrowTable,
|
||||||
vectorFromArray,
|
vectorFromArray,
|
||||||
|
Float64,
|
||||||
Float32,
|
Float32,
|
||||||
Float16
|
Float16,
|
||||||
|
Int64
|
||||||
} from 'apache-arrow'
|
} from 'apache-arrow'
|
||||||
|
|
||||||
const expect = chai.expect
|
const expect = chai.expect
|
||||||
@@ -196,7 +198,7 @@ describe('LanceDB client', function () {
|
|||||||
const table = await con.openTable('vectors')
|
const table = await con.openTable('vectors')
|
||||||
const results = await table
|
const results = await table
|
||||||
.search([0.1, 0.1])
|
.search([0.1, 0.1])
|
||||||
.select(['is_active'])
|
.select(['is_active', 'vector'])
|
||||||
.execute()
|
.execute()
|
||||||
assert.equal(results.length, 2)
|
assert.equal(results.length, 2)
|
||||||
// vector and _distance are always returned
|
// vector and _distance are always returned
|
||||||
@@ -294,6 +296,7 @@ describe('LanceDB client', function () {
|
|||||||
})
|
})
|
||||||
assert.equal(table.name, 'vectors')
|
assert.equal(table.name, 'vectors')
|
||||||
assert.equal(await table.countRows(), 10)
|
assert.equal(await table.countRows(), 10)
|
||||||
|
assert.equal(await table.countRows('vector IS NULL'), 0)
|
||||||
assert.deepEqual(await con.tableNames(), ['vectors'])
|
assert.deepEqual(await con.tableNames(), ['vectors'])
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -369,6 +372,7 @@ describe('LanceDB client', function () {
|
|||||||
const table = await con.createTable('f16', data)
|
const table = await con.createTable('f16', data)
|
||||||
assert.equal(table.name, 'f16')
|
assert.equal(table.name, 'f16')
|
||||||
assert.equal(await table.countRows(), total)
|
assert.equal(await table.countRows(), total)
|
||||||
|
assert.equal(await table.countRows('id < 5'), 5)
|
||||||
assert.deepEqual(await con.tableNames(), ['f16'])
|
assert.deepEqual(await con.tableNames(), ['f16'])
|
||||||
assert.deepEqual(await table.schema, schema)
|
assert.deepEqual(await table.schema, schema)
|
||||||
|
|
||||||
@@ -538,26 +542,36 @@ describe('LanceDB client', function () {
|
|||||||
const data = [{ id: 1, age: 1 }, { id: 2, age: 1 }]
|
const data = [{ id: 1, age: 1 }, { id: 2, age: 1 }]
|
||||||
const table = await con.createTable('my_table', data)
|
const table = await con.createTable('my_table', data)
|
||||||
|
|
||||||
|
// insert if not exists
|
||||||
let newData = [{ id: 2, age: 2 }, { id: 3, age: 2 }]
|
let newData = [{ id: 2, age: 2 }, { id: 3, age: 2 }]
|
||||||
await table.mergeInsert('id', newData, {
|
await table.mergeInsert('id', newData, {
|
||||||
whenNotMatchedInsertAll: true
|
whenNotMatchedInsertAll: true
|
||||||
})
|
})
|
||||||
assert.equal(await table.countRows(), 3)
|
assert.equal(await table.countRows(), 3)
|
||||||
assert.equal((await table.filter('age = 2').execute()).length, 1)
|
assert.equal(await table.countRows('age = 2'), 1)
|
||||||
|
|
||||||
newData = [{ id: 3, age: 3 }, { id: 4, age: 3 }]
|
// conditional update
|
||||||
|
newData = [{ id: 2, age: 3 }, { id: 3, age: 3 }]
|
||||||
|
await table.mergeInsert('id', newData, {
|
||||||
|
whenMatchedUpdateAll: 'target.age = 1'
|
||||||
|
})
|
||||||
|
assert.equal(await table.countRows(), 3)
|
||||||
|
assert.equal(await table.countRows('age = 1'), 1)
|
||||||
|
assert.equal(await table.countRows('age = 3'), 1)
|
||||||
|
|
||||||
|
newData = [{ id: 3, age: 4 }, { id: 4, age: 4 }]
|
||||||
await table.mergeInsert('id', newData, {
|
await table.mergeInsert('id', newData, {
|
||||||
whenNotMatchedInsertAll: true,
|
whenNotMatchedInsertAll: true,
|
||||||
whenMatchedUpdateAll: true
|
whenMatchedUpdateAll: true
|
||||||
})
|
})
|
||||||
assert.equal(await table.countRows(), 4)
|
assert.equal(await table.countRows(), 4)
|
||||||
assert.equal((await table.filter('age = 3').execute()).length, 2)
|
assert.equal((await table.filter('age = 4').execute()).length, 2)
|
||||||
|
|
||||||
newData = [{ id: 5, age: 4 }]
|
newData = [{ id: 5, age: 5 }]
|
||||||
await table.mergeInsert('id', newData, {
|
await table.mergeInsert('id', newData, {
|
||||||
whenNotMatchedInsertAll: true,
|
whenNotMatchedInsertAll: true,
|
||||||
whenMatchedUpdateAll: true,
|
whenMatchedUpdateAll: true,
|
||||||
whenNotMatchedBySourceDelete: 'age < 3'
|
whenNotMatchedBySourceDelete: 'age < 4'
|
||||||
})
|
})
|
||||||
assert.equal(await table.countRows(), 3)
|
assert.equal(await table.countRows(), 3)
|
||||||
|
|
||||||
@@ -1045,3 +1059,63 @@ describe('Compact and cleanup', function () {
|
|||||||
assert.equal(await table.countRows(), 3)
|
assert.equal(await table.countRows(), 3)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
describe('schema evolution', function () {
|
||||||
|
// Create a new sample table
|
||||||
|
it('can add a new column to the schema', async function () {
|
||||||
|
const dir = await track().mkdir('lancejs')
|
||||||
|
const con = await lancedb.connect(dir)
|
||||||
|
const table = await con.createTable('vectors', [
|
||||||
|
{ id: 1n, vector: [0.1, 0.2] }
|
||||||
|
])
|
||||||
|
|
||||||
|
await table.addColumns([{ name: 'price', valueSql: 'cast(10.0 as float)' }])
|
||||||
|
|
||||||
|
const expectedSchema = new Schema([
|
||||||
|
new Field('id', new Int64()),
|
||||||
|
new Field('vector', new FixedSizeList(2, new Field('item', new Float32(), true))),
|
||||||
|
new Field('price', new Float32())
|
||||||
|
])
|
||||||
|
expect(await table.schema).to.deep.equal(expectedSchema)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('can alter the columns in the schema', async function () {
|
||||||
|
const dir = await track().mkdir('lancejs')
|
||||||
|
const con = await lancedb.connect(dir)
|
||||||
|
const schema = new Schema([
|
||||||
|
new Field('id', new Int64(), false),
|
||||||
|
new Field('vector', new FixedSizeList(2, new Field('item', new Float32(), true))),
|
||||||
|
new Field('price', new Float64(), false)
|
||||||
|
])
|
||||||
|
const table = await con.createTable('vectors', [
|
||||||
|
{ id: 1n, vector: [0.1, 0.2], price: 10.0 }
|
||||||
|
])
|
||||||
|
expect(await table.schema).to.deep.equal(schema)
|
||||||
|
|
||||||
|
await table.alterColumns([
|
||||||
|
{ path: 'id', rename: 'new_id' },
|
||||||
|
{ path: 'price', nullable: true }
|
||||||
|
])
|
||||||
|
|
||||||
|
const expectedSchema = new Schema([
|
||||||
|
new Field('new_id', new Int64(), false),
|
||||||
|
new Field('vector', new FixedSizeList(2, new Field('item', new Float32(), true))),
|
||||||
|
new Field('price', new Float64(), true)
|
||||||
|
])
|
||||||
|
expect(await table.schema).to.deep.equal(expectedSchema)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('can drop a column from the schema', async function () {
|
||||||
|
const dir = await track().mkdir('lancejs')
|
||||||
|
const con = await lancedb.connect(dir)
|
||||||
|
const table = await con.createTable('vectors', [
|
||||||
|
{ id: 1n, vector: [0.1, 0.2] }
|
||||||
|
])
|
||||||
|
await table.dropColumns(['vector'])
|
||||||
|
|
||||||
|
const expectedSchema = new Schema([
|
||||||
|
new Field('id', new Int64(), false)
|
||||||
|
])
|
||||||
|
expect(await table.schema).to.deep.equal(expectedSchema)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|||||||
@@ -9,6 +9,6 @@
|
|||||||
"declaration": true,
|
"declaration": true,
|
||||||
"outDir": "./dist",
|
"outDir": "./dist",
|
||||||
"strict": true,
|
"strict": true,
|
||||||
// "esModuleInterop": true,
|
"sourceMap": true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -18,5 +18,5 @@ module.exports = {
|
|||||||
"@typescript-eslint/method-signature-style": "off",
|
"@typescript-eslint/method-signature-style": "off",
|
||||||
"@typescript-eslint/no-explicit-any": "off",
|
"@typescript-eslint/no-explicit-any": "off",
|
||||||
},
|
},
|
||||||
ignorePatterns: ["node_modules/", "dist/", "build/", "vectordb/native.*"],
|
ignorePatterns: ["node_modules/", "dist/", "build/", "lancedb/native.*"],
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "vectordb-nodejs"
|
name = "lancedb-nodejs"
|
||||||
edition = "2021"
|
edition.workspace = true
|
||||||
version = "0.0.0"
|
version = "0.0.0"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
description.workspace = true
|
||||||
repository.workspace = true
|
repository.workspace = true
|
||||||
|
keywords.workspace = true
|
||||||
|
categories.workspace = true
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
crate-type = ["cdylib"]
|
crate-type = ["cdylib"]
|
||||||
@@ -13,16 +16,15 @@ arrow-ipc.workspace = true
|
|||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
lance-linalg.workspace = true
|
lance-linalg.workspace = true
|
||||||
lance.workspace = true
|
lance.workspace = true
|
||||||
vectordb = { path = "../rust/vectordb" }
|
lancedb = { path = "../rust/lancedb" }
|
||||||
napi = { version = "2.14", default-features = false, features = [
|
napi = { version = "2.15", default-features = false, features = [
|
||||||
"napi7",
|
"napi7",
|
||||||
"async"
|
"async"
|
||||||
] }
|
] }
|
||||||
napi-derive = "2.14"
|
napi-derive = "2"
|
||||||
|
|
||||||
|
# Prevent dynamic linking of lzma, which comes from datafusion
|
||||||
|
lzma-sys = { version = "*", features = ["static"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
napi-build = "2.1"
|
napi-build = "2.1"
|
||||||
|
|
||||||
[profile.release]
|
|
||||||
lto = true
|
|
||||||
strip = "symbols"
|
|
||||||
|
|||||||
@@ -12,8 +12,9 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
import { makeArrowTable, toBuffer } from "../vectordb/arrow";
|
import { makeArrowTable, toBuffer } from "../lancedb/arrow";
|
||||||
import {
|
import {
|
||||||
|
Int64,
|
||||||
Field,
|
Field,
|
||||||
FixedSizeList,
|
FixedSizeList,
|
||||||
Float16,
|
Float16,
|
||||||
@@ -104,3 +105,16 @@ test("2 vector columns", function () {
|
|||||||
const actualSchema = actual.schema;
|
const actualSchema = actual.schema;
|
||||||
expect(actualSchema.toString()).toEqual(schema.toString());
|
expect(actualSchema.toString()).toEqual(schema.toString());
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test("handles int64", function() {
|
||||||
|
// https://github.com/lancedb/lancedb/issues/960
|
||||||
|
const schema = new Schema([
|
||||||
|
new Field("x", new Int64(), true)
|
||||||
|
]);
|
||||||
|
const table = makeArrowTable([
|
||||||
|
{ x: 1 },
|
||||||
|
{ x: 2 },
|
||||||
|
{ x: 3 }
|
||||||
|
], { schema });
|
||||||
|
expect(table.schema).toEqual(schema);
|
||||||
|
})
|
||||||
@@ -29,6 +29,6 @@ test("open database", async () => {
|
|||||||
const tbl = await db.createTable("test", [{ id: 1 }, { id: 2 }]);
|
const tbl = await db.createTable("test", [{ id: 1 }, { id: 2 }]);
|
||||||
expect(await db.tableNames()).toStrictEqual(["test"]);
|
expect(await db.tableNames()).toStrictEqual(["test"]);
|
||||||
|
|
||||||
const schema = tbl.schema;
|
const schema = await tbl.schema();
|
||||||
expect(schema).toEqual(new Schema([new Field("id", new Float64(), true)]));
|
expect(schema).toEqual(new Schema([new Field("id", new Float64(), true)]));
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import * as path from "path";
|
|||||||
import * as fs from "fs";
|
import * as fs from "fs";
|
||||||
|
|
||||||
import { connect } from "../dist";
|
import { connect } from "../dist";
|
||||||
import { Schema, Field, Float32, Int32, FixedSizeList } from "apache-arrow";
|
import { Schema, Field, Float32, Int32, FixedSizeList, Int64, Float64 } from "apache-arrow";
|
||||||
import { makeArrowTable } from "../dist/arrow";
|
import { makeArrowTable } from "../dist/arrow";
|
||||||
|
|
||||||
describe("Test creating index", () => {
|
describe("Test creating index", () => {
|
||||||
@@ -181,3 +181,102 @@ describe("Test creating index", () => {
|
|||||||
// TODO: check index type.
|
// TODO: check index type.
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe("Read consistency interval", () => {
|
||||||
|
let tmpDir: string;
|
||||||
|
beforeEach(() => {
|
||||||
|
tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "read-consistency-"));
|
||||||
|
});
|
||||||
|
|
||||||
|
// const intervals = [undefined, 0, 0.1];
|
||||||
|
const intervals = [0];
|
||||||
|
test.each(intervals)("read consistency interval %p", async (interval) => {
|
||||||
|
const db = await connect({ uri: tmpDir });
|
||||||
|
const table = await db.createTable("my_table", [{ id: 1 }]);
|
||||||
|
|
||||||
|
const db2 = await connect({ uri: tmpDir, readConsistencyInterval: interval });
|
||||||
|
const table2 = await db2.openTable("my_table");
|
||||||
|
expect(await table2.countRows()).toEqual(await table.countRows());
|
||||||
|
|
||||||
|
await table.add([{ id: 2 }]);
|
||||||
|
|
||||||
|
if (interval === undefined) {
|
||||||
|
expect(await table2.countRows()).toEqual(1n);
|
||||||
|
// TODO: once we implement time travel we can uncomment this part of the test.
|
||||||
|
// await table2.checkout_latest();
|
||||||
|
// expect(await table2.countRows()).toEqual(2);
|
||||||
|
} else if (interval === 0) {
|
||||||
|
expect(await table2.countRows()).toEqual(2n);
|
||||||
|
} else {
|
||||||
|
// interval == 0.1
|
||||||
|
expect(await table2.countRows()).toEqual(1n);
|
||||||
|
await new Promise(r => setTimeout(r, 100));
|
||||||
|
expect(await table2.countRows()).toEqual(2n);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
describe('schema evolution', function () {
|
||||||
|
let tmpDir: string;
|
||||||
|
beforeEach(() => {
|
||||||
|
tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), "schema-evolution-"));
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create a new sample table
|
||||||
|
it('can add a new column to the schema', async function () {
|
||||||
|
const con = await connect(tmpDir)
|
||||||
|
const table = await con.createTable('vectors', [
|
||||||
|
{ id: 1n, vector: [0.1, 0.2] }
|
||||||
|
])
|
||||||
|
|
||||||
|
await table.addColumns([{ name: 'price', valueSql: 'cast(10.0 as float)' }])
|
||||||
|
|
||||||
|
const expectedSchema = new Schema([
|
||||||
|
new Field('id', new Int64(), true),
|
||||||
|
new Field('vector', new FixedSizeList(2, new Field('item', new Float32(), true)), true),
|
||||||
|
new Field('price', new Float32(), false)
|
||||||
|
])
|
||||||
|
expect(await table.schema()).toEqual(expectedSchema)
|
||||||
|
});
|
||||||
|
|
||||||
|
it('can alter the columns in the schema', async function () {
|
||||||
|
const con = await connect(tmpDir)
|
||||||
|
const schema = new Schema([
|
||||||
|
new Field('id', new Int64(), true),
|
||||||
|
new Field('vector', new FixedSizeList(2, new Field('item', new Float32(), true)), true),
|
||||||
|
new Field('price', new Float64(), false)
|
||||||
|
])
|
||||||
|
const table = await con.createTable('vectors', [
|
||||||
|
{ id: 1n, vector: [0.1, 0.2] }
|
||||||
|
])
|
||||||
|
// Can create a non-nullable column only through addColumns at the moment.
|
||||||
|
await table.addColumns([{ name: 'price', valueSql: 'cast(10.0 as double)' }])
|
||||||
|
expect(await table.schema()).toEqual(schema)
|
||||||
|
|
||||||
|
await table.alterColumns([
|
||||||
|
{ path: 'id', rename: 'new_id' },
|
||||||
|
{ path: 'price', nullable: true }
|
||||||
|
])
|
||||||
|
|
||||||
|
const expectedSchema = new Schema([
|
||||||
|
new Field('new_id', new Int64(), true),
|
||||||
|
new Field('vector', new FixedSizeList(2, new Field('item', new Float32(), true)), true),
|
||||||
|
new Field('price', new Float64(), true)
|
||||||
|
])
|
||||||
|
expect(await table.schema()).toEqual(expectedSchema)
|
||||||
|
});
|
||||||
|
|
||||||
|
it('can drop a column from the schema', async function () {
|
||||||
|
const con = await connect(tmpDir)
|
||||||
|
const table = await con.createTable('vectors', [
|
||||||
|
{ id: 1n, vector: [0.1, 0.2] }
|
||||||
|
])
|
||||||
|
await table.dropColumns(['vector'])
|
||||||
|
|
||||||
|
const expectedSchema = new Schema([
|
||||||
|
new Field('id', new Int64(), true)
|
||||||
|
])
|
||||||
|
expect(await table.schema()).toEqual(expectedSchema)
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -2,4 +2,6 @@
|
|||||||
module.exports = {
|
module.exports = {
|
||||||
preset: 'ts-jest',
|
preset: 'ts-jest',
|
||||||
testEnvironment: 'node',
|
testEnvironment: 'node',
|
||||||
};
|
moduleDirectories: ["node_modules", "./dist"],
|
||||||
|
moduleFileExtensions: ["js", "ts"],
|
||||||
|
};
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
import {
|
import {
|
||||||
|
Int64,
|
||||||
Field,
|
Field,
|
||||||
FixedSizeList,
|
FixedSizeList,
|
||||||
Float,
|
Float,
|
||||||
@@ -23,6 +24,7 @@ import {
|
|||||||
Vector,
|
Vector,
|
||||||
vectorFromArray,
|
vectorFromArray,
|
||||||
tableToIPC,
|
tableToIPC,
|
||||||
|
DataType,
|
||||||
} from "apache-arrow";
|
} from "apache-arrow";
|
||||||
|
|
||||||
/** Data type accepted by NodeJS SDK */
|
/** Data type accepted by NodeJS SDK */
|
||||||
@@ -137,15 +139,18 @@ export function makeArrowTable(
|
|||||||
const columnNames = Object.keys(data[0]);
|
const columnNames = Object.keys(data[0]);
|
||||||
for (const colName of columnNames) {
|
for (const colName of columnNames) {
|
||||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-return
|
// eslint-disable-next-line @typescript-eslint/no-unsafe-return
|
||||||
const values = data.map((datum) => datum[colName]);
|
let values = data.map((datum) => datum[colName]);
|
||||||
let vector: Vector;
|
let vector: Vector;
|
||||||
|
|
||||||
if (opt.schema !== undefined) {
|
if (opt.schema !== undefined) {
|
||||||
// Explicit schema is provided, highest priority
|
// Explicit schema is provided, highest priority
|
||||||
vector = vectorFromArray(
|
const fieldType: DataType | undefined = opt.schema.fields.filter((f) => f.name === colName)[0]?.type as DataType;
|
||||||
values,
|
if (fieldType instanceof Int64) {
|
||||||
opt.schema?.fields.filter((f) => f.name === colName)[0]?.type
|
// wrap in BigInt to avoid bug: https://github.com/apache/arrow/issues/40051
|
||||||
);
|
// eslint-disable-next-line @typescript-eslint/no-unsafe-argument
|
||||||
|
values = values.map((v) => BigInt(v));
|
||||||
|
}
|
||||||
|
vector = vectorFromArray(values, fieldType);
|
||||||
} else {
|
} else {
|
||||||
const vectorColumnOptions = opt.vectorColumns[colName];
|
const vectorColumnOptions = opt.vectorColumns[colName];
|
||||||
if (vectorColumnOptions !== undefined) {
|
if (vectorColumnOptions !== undefined) {
|
||||||
@@ -53,12 +53,12 @@ export async function connect(
|
|||||||
opts = Object.assign(
|
opts = Object.assign(
|
||||||
{
|
{
|
||||||
uri: "",
|
uri: "",
|
||||||
apiKey: "",
|
apiKey: undefined,
|
||||||
hostOverride: "",
|
hostOverride: undefined,
|
||||||
},
|
},
|
||||||
args
|
args
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
const nativeConn = await NativeConnection.new(opts.uri);
|
const nativeConn = await NativeConnection.new(opts);
|
||||||
return new Connection(nativeConn);
|
return new Connection(nativeConn);
|
||||||
}
|
}
|
||||||
@@ -12,10 +12,54 @@ export const enum MetricType {
|
|||||||
Cosine = 1,
|
Cosine = 1,
|
||||||
Dot = 2
|
Dot = 2
|
||||||
}
|
}
|
||||||
|
/**
|
||||||
|
* A definition of a column alteration. The alteration changes the column at
|
||||||
|
* `path` to have the new name `name`, to be nullable if `nullable` is true,
|
||||||
|
* and to have the data type `data_type`. At least one of `rename` or `nullable`
|
||||||
|
* must be provided.
|
||||||
|
*/
|
||||||
|
export interface ColumnAlteration {
|
||||||
|
/**
|
||||||
|
* The path to the column to alter. This is a dot-separated path to the column.
|
||||||
|
* If it is a top-level column then it is just the name of the column. If it is
|
||||||
|
* a nested column then it is the path to the column, e.g. "a.b.c" for a column
|
||||||
|
* `c` nested inside a column `b` nested inside a column `a`.
|
||||||
|
*/
|
||||||
|
path: string
|
||||||
|
/**
|
||||||
|
* The new name of the column. If not provided then the name will not be changed.
|
||||||
|
* This must be distinct from the names of all other columns in the table.
|
||||||
|
*/
|
||||||
|
rename?: string
|
||||||
|
/** Set the new nullability. Note that a nullable column cannot be made non-nullable. */
|
||||||
|
nullable?: boolean
|
||||||
|
}
|
||||||
|
/** A definition of a new column to add to a table. */
|
||||||
|
export interface AddColumnsSql {
|
||||||
|
/** The name of the new column. */
|
||||||
|
name: string
|
||||||
|
/**
|
||||||
|
* The values to populate the new column with, as a SQL expression.
|
||||||
|
* The expression can reference other columns in the table.
|
||||||
|
*/
|
||||||
|
valueSql: string
|
||||||
|
}
|
||||||
export interface ConnectionOptions {
|
export interface ConnectionOptions {
|
||||||
uri: string
|
uri: string
|
||||||
apiKey?: string
|
apiKey?: string
|
||||||
hostOverride?: string
|
hostOverride?: string
|
||||||
|
/**
|
||||||
|
* (For LanceDB OSS only): The interval, in seconds, at which to check for
|
||||||
|
* updates to the table from other processes. If None, then consistency is not
|
||||||
|
* checked. For performance reasons, this is the default. For strong
|
||||||
|
* consistency, set this to zero seconds. Then every read will check for
|
||||||
|
* updates from other processes. As a compromise, you can set this to a
|
||||||
|
* non-zero value for eventual consistency. If more than that interval
|
||||||
|
* has passed since the last check, then the table will be checked for updates.
|
||||||
|
* Note: this consistency only applies to read operations. Write operations are
|
||||||
|
* always consistent.
|
||||||
|
*/
|
||||||
|
readConsistencyInterval?: number
|
||||||
}
|
}
|
||||||
/** Write mode for writing a table. */
|
/** Write mode for writing a table. */
|
||||||
export const enum WriteMode {
|
export const enum WriteMode {
|
||||||
@@ -30,7 +74,7 @@ export interface WriteOptions {
|
|||||||
export function connect(options: ConnectionOptions): Promise<Connection>
|
export function connect(options: ConnectionOptions): Promise<Connection>
|
||||||
export class Connection {
|
export class Connection {
|
||||||
/** Create a new Connection instance from the given URI. */
|
/** Create a new Connection instance from the given URI. */
|
||||||
static new(uri: string): Promise<Connection>
|
static new(options: ConnectionOptions): Promise<Connection>
|
||||||
/** List all tables in the dataset. */
|
/** List all tables in the dataset. */
|
||||||
tableNames(): Promise<Array<string>>
|
tableNames(): Promise<Array<string>>
|
||||||
/**
|
/**
|
||||||
@@ -71,10 +115,13 @@ export class Query {
|
|||||||
}
|
}
|
||||||
export class Table {
|
export class Table {
|
||||||
/** Return Schema as empty Arrow IPC file. */
|
/** Return Schema as empty Arrow IPC file. */
|
||||||
schema(): Buffer
|
schema(): Promise<Buffer>
|
||||||
add(buf: Buffer): Promise<void>
|
add(buf: Buffer): Promise<void>
|
||||||
countRows(): Promise<bigint>
|
countRows(filter?: string | undefined | null): Promise<bigint>
|
||||||
delete(predicate: string): Promise<void>
|
delete(predicate: string): Promise<void>
|
||||||
createIndex(): IndexBuilder
|
createIndex(): IndexBuilder
|
||||||
query(): Query
|
query(): Query
|
||||||
|
addColumns(transforms: Array<AddColumnsSql>): Promise<void>
|
||||||
|
alterColumns(alterations: Array<ColumnAlteration>): Promise<void>
|
||||||
|
dropColumns(columns: Array<string>): Promise<void>
|
||||||
}
|
}
|
||||||
@@ -32,24 +32,24 @@ switch (platform) {
|
|||||||
case 'android':
|
case 'android':
|
||||||
switch (arch) {
|
switch (arch) {
|
||||||
case 'arm64':
|
case 'arm64':
|
||||||
localFileExisted = existsSync(join(__dirname, 'vectordb-nodejs.android-arm64.node'))
|
localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.android-arm64.node'))
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./vectordb-nodejs.android-arm64.node')
|
nativeBinding = require('./lancedb-nodejs.android-arm64.node')
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('vectordb-android-arm64')
|
nativeBinding = require('lancedb-android-arm64')
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
case 'arm':
|
case 'arm':
|
||||||
localFileExisted = existsSync(join(__dirname, 'vectordb-nodejs.android-arm-eabi.node'))
|
localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.android-arm-eabi.node'))
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./vectordb-nodejs.android-arm-eabi.node')
|
nativeBinding = require('./lancedb-nodejs.android-arm-eabi.node')
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('vectordb-android-arm-eabi')
|
nativeBinding = require('lancedb-android-arm-eabi')
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e
|
||||||
@@ -63,13 +63,13 @@ switch (platform) {
|
|||||||
switch (arch) {
|
switch (arch) {
|
||||||
case 'x64':
|
case 'x64':
|
||||||
localFileExisted = existsSync(
|
localFileExisted = existsSync(
|
||||||
join(__dirname, 'vectordb-nodejs.win32-x64-msvc.node')
|
join(__dirname, 'lancedb-nodejs.win32-x64-msvc.node')
|
||||||
)
|
)
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./vectordb-nodejs.win32-x64-msvc.node')
|
nativeBinding = require('./lancedb-nodejs.win32-x64-msvc.node')
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('vectordb-win32-x64-msvc')
|
nativeBinding = require('lancedb-win32-x64-msvc')
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e
|
||||||
@@ -77,13 +77,13 @@ switch (platform) {
|
|||||||
break
|
break
|
||||||
case 'ia32':
|
case 'ia32':
|
||||||
localFileExisted = existsSync(
|
localFileExisted = existsSync(
|
||||||
join(__dirname, 'vectordb-nodejs.win32-ia32-msvc.node')
|
join(__dirname, 'lancedb-nodejs.win32-ia32-msvc.node')
|
||||||
)
|
)
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./vectordb-nodejs.win32-ia32-msvc.node')
|
nativeBinding = require('./lancedb-nodejs.win32-ia32-msvc.node')
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('vectordb-win32-ia32-msvc')
|
nativeBinding = require('lancedb-win32-ia32-msvc')
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e
|
||||||
@@ -91,13 +91,13 @@ switch (platform) {
|
|||||||
break
|
break
|
||||||
case 'arm64':
|
case 'arm64':
|
||||||
localFileExisted = existsSync(
|
localFileExisted = existsSync(
|
||||||
join(__dirname, 'vectordb-nodejs.win32-arm64-msvc.node')
|
join(__dirname, 'lancedb-nodejs.win32-arm64-msvc.node')
|
||||||
)
|
)
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./vectordb-nodejs.win32-arm64-msvc.node')
|
nativeBinding = require('./lancedb-nodejs.win32-arm64-msvc.node')
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('vectordb-win32-arm64-msvc')
|
nativeBinding = require('lancedb-win32-arm64-msvc')
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e
|
||||||
@@ -108,23 +108,23 @@ switch (platform) {
|
|||||||
}
|
}
|
||||||
break
|
break
|
||||||
case 'darwin':
|
case 'darwin':
|
||||||
localFileExisted = existsSync(join(__dirname, 'vectordb-nodejs.darwin-universal.node'))
|
localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.darwin-universal.node'))
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./vectordb-nodejs.darwin-universal.node')
|
nativeBinding = require('./lancedb-nodejs.darwin-universal.node')
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('vectordb-darwin-universal')
|
nativeBinding = require('lancedb-darwin-universal')
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
} catch {}
|
} catch {}
|
||||||
switch (arch) {
|
switch (arch) {
|
||||||
case 'x64':
|
case 'x64':
|
||||||
localFileExisted = existsSync(join(__dirname, 'vectordb-nodejs.darwin-x64.node'))
|
localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.darwin-x64.node'))
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./vectordb-nodejs.darwin-x64.node')
|
nativeBinding = require('./lancedb-nodejs.darwin-x64.node')
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('vectordb-darwin-x64')
|
nativeBinding = require('lancedb-darwin-x64')
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e
|
||||||
@@ -132,13 +132,13 @@ switch (platform) {
|
|||||||
break
|
break
|
||||||
case 'arm64':
|
case 'arm64':
|
||||||
localFileExisted = existsSync(
|
localFileExisted = existsSync(
|
||||||
join(__dirname, 'vectordb-nodejs.darwin-arm64.node')
|
join(__dirname, 'lancedb-nodejs.darwin-arm64.node')
|
||||||
)
|
)
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./vectordb-nodejs.darwin-arm64.node')
|
nativeBinding = require('./lancedb-nodejs.darwin-arm64.node')
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('vectordb-darwin-arm64')
|
nativeBinding = require('lancedb-darwin-arm64')
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e
|
||||||
@@ -152,12 +152,12 @@ switch (platform) {
|
|||||||
if (arch !== 'x64') {
|
if (arch !== 'x64') {
|
||||||
throw new Error(`Unsupported architecture on FreeBSD: ${arch}`)
|
throw new Error(`Unsupported architecture on FreeBSD: ${arch}`)
|
||||||
}
|
}
|
||||||
localFileExisted = existsSync(join(__dirname, 'vectordb-nodejs.freebsd-x64.node'))
|
localFileExisted = existsSync(join(__dirname, 'lancedb-nodejs.freebsd-x64.node'))
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./vectordb-nodejs.freebsd-x64.node')
|
nativeBinding = require('./lancedb-nodejs.freebsd-x64.node')
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('vectordb-freebsd-x64')
|
nativeBinding = require('lancedb-freebsd-x64')
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e
|
||||||
@@ -168,26 +168,26 @@ switch (platform) {
|
|||||||
case 'x64':
|
case 'x64':
|
||||||
if (isMusl()) {
|
if (isMusl()) {
|
||||||
localFileExisted = existsSync(
|
localFileExisted = existsSync(
|
||||||
join(__dirname, 'vectordb-nodejs.linux-x64-musl.node')
|
join(__dirname, 'lancedb-nodejs.linux-x64-musl.node')
|
||||||
)
|
)
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./vectordb-nodejs.linux-x64-musl.node')
|
nativeBinding = require('./lancedb-nodejs.linux-x64-musl.node')
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('vectordb-linux-x64-musl')
|
nativeBinding = require('lancedb-linux-x64-musl')
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
localFileExisted = existsSync(
|
localFileExisted = existsSync(
|
||||||
join(__dirname, 'vectordb-nodejs.linux-x64-gnu.node')
|
join(__dirname, 'lancedb-nodejs.linux-x64-gnu.node')
|
||||||
)
|
)
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./vectordb-nodejs.linux-x64-gnu.node')
|
nativeBinding = require('./lancedb-nodejs.linux-x64-gnu.node')
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('vectordb-linux-x64-gnu')
|
nativeBinding = require('lancedb-linux-x64-gnu')
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e
|
||||||
@@ -197,26 +197,26 @@ switch (platform) {
|
|||||||
case 'arm64':
|
case 'arm64':
|
||||||
if (isMusl()) {
|
if (isMusl()) {
|
||||||
localFileExisted = existsSync(
|
localFileExisted = existsSync(
|
||||||
join(__dirname, 'vectordb-nodejs.linux-arm64-musl.node')
|
join(__dirname, 'lancedb-nodejs.linux-arm64-musl.node')
|
||||||
)
|
)
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./vectordb-nodejs.linux-arm64-musl.node')
|
nativeBinding = require('./lancedb-nodejs.linux-arm64-musl.node')
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('vectordb-linux-arm64-musl')
|
nativeBinding = require('lancedb-linux-arm64-musl')
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
localFileExisted = existsSync(
|
localFileExisted = existsSync(
|
||||||
join(__dirname, 'vectordb-nodejs.linux-arm64-gnu.node')
|
join(__dirname, 'lancedb-nodejs.linux-arm64-gnu.node')
|
||||||
)
|
)
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./vectordb-nodejs.linux-arm64-gnu.node')
|
nativeBinding = require('./lancedb-nodejs.linux-arm64-gnu.node')
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('vectordb-linux-arm64-gnu')
|
nativeBinding = require('lancedb-linux-arm64-gnu')
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e
|
||||||
@@ -225,13 +225,13 @@ switch (platform) {
|
|||||||
break
|
break
|
||||||
case 'arm':
|
case 'arm':
|
||||||
localFileExisted = existsSync(
|
localFileExisted = existsSync(
|
||||||
join(__dirname, 'vectordb-nodejs.linux-arm-gnueabihf.node')
|
join(__dirname, 'lancedb-nodejs.linux-arm-gnueabihf.node')
|
||||||
)
|
)
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./vectordb-nodejs.linux-arm-gnueabihf.node')
|
nativeBinding = require('./lancedb-nodejs.linux-arm-gnueabihf.node')
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('vectordb-linux-arm-gnueabihf')
|
nativeBinding = require('lancedb-linux-arm-gnueabihf')
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e
|
||||||
@@ -240,26 +240,26 @@ switch (platform) {
|
|||||||
case 'riscv64':
|
case 'riscv64':
|
||||||
if (isMusl()) {
|
if (isMusl()) {
|
||||||
localFileExisted = existsSync(
|
localFileExisted = existsSync(
|
||||||
join(__dirname, 'vectordb-nodejs.linux-riscv64-musl.node')
|
join(__dirname, 'lancedb-nodejs.linux-riscv64-musl.node')
|
||||||
)
|
)
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./vectordb-nodejs.linux-riscv64-musl.node')
|
nativeBinding = require('./lancedb-nodejs.linux-riscv64-musl.node')
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('vectordb-linux-riscv64-musl')
|
nativeBinding = require('lancedb-linux-riscv64-musl')
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
localFileExisted = existsSync(
|
localFileExisted = existsSync(
|
||||||
join(__dirname, 'vectordb-nodejs.linux-riscv64-gnu.node')
|
join(__dirname, 'lancedb-nodejs.linux-riscv64-gnu.node')
|
||||||
)
|
)
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./vectordb-nodejs.linux-riscv64-gnu.node')
|
nativeBinding = require('./lancedb-nodejs.linux-riscv64-gnu.node')
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('vectordb-linux-riscv64-gnu')
|
nativeBinding = require('lancedb-linux-riscv64-gnu')
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e
|
||||||
@@ -268,13 +268,13 @@ switch (platform) {
|
|||||||
break
|
break
|
||||||
case 's390x':
|
case 's390x':
|
||||||
localFileExisted = existsSync(
|
localFileExisted = existsSync(
|
||||||
join(__dirname, 'vectordb-nodejs.linux-s390x-gnu.node')
|
join(__dirname, 'lancedb-nodejs.linux-s390x-gnu.node')
|
||||||
)
|
)
|
||||||
try {
|
try {
|
||||||
if (localFileExisted) {
|
if (localFileExisted) {
|
||||||
nativeBinding = require('./vectordb-nodejs.linux-s390x-gnu.node')
|
nativeBinding = require('./lancedb-nodejs.linux-s390x-gnu.node')
|
||||||
} else {
|
} else {
|
||||||
nativeBinding = require('vectordb-linux-s390x-gnu')
|
nativeBinding = require('lancedb-linux-s390x-gnu')
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
loadError = e
|
loadError = e
|
||||||
@@ -13,7 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
import { Schema, tableFromIPC } from "apache-arrow";
|
import { Schema, tableFromIPC } from "apache-arrow";
|
||||||
import { Table as _NativeTable } from "./native";
|
import { AddColumnsSql, ColumnAlteration, Table as _NativeTable } from "./native";
|
||||||
import { toBuffer, Data } from "./arrow";
|
import { toBuffer, Data } from "./arrow";
|
||||||
import { Query } from "./query";
|
import { Query } from "./query";
|
||||||
import { IndexBuilder } from "./indexer";
|
import { IndexBuilder } from "./indexer";
|
||||||
@@ -32,8 +32,8 @@ export class Table {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** Get the schema of the table. */
|
/** Get the schema of the table. */
|
||||||
get schema(): Schema {
|
async schema(): Promise<Schema> {
|
||||||
const schemaBuf = this.inner.schema();
|
const schemaBuf = await this.inner.schema();
|
||||||
const tbl = tableFromIPC(schemaBuf);
|
const tbl = tableFromIPC(schemaBuf);
|
||||||
return tbl.schema;
|
return tbl.schema;
|
||||||
}
|
}
|
||||||
@@ -50,8 +50,8 @@ export class Table {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** Count the total number of rows in the dataset. */
|
/** Count the total number of rows in the dataset. */
|
||||||
async countRows(): Promise<bigint> {
|
async countRows(filter?: string): Promise<bigint> {
|
||||||
return await this.inner.countRows();
|
return await this.inner.countRows(filter);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Delete the rows that satisfy the predicate. */
|
/** Delete the rows that satisfy the predicate. */
|
||||||
@@ -150,4 +150,42 @@ export class Table {
|
|||||||
}
|
}
|
||||||
return q;
|
return q;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Support BatchUDF
|
||||||
|
/**
|
||||||
|
* Add new columns with defined values.
|
||||||
|
*
|
||||||
|
* @param newColumnTransforms pairs of column names and the SQL expression to use
|
||||||
|
* to calculate the value of the new column. These
|
||||||
|
* expressions will be evaluated for each row in the
|
||||||
|
* table, and can reference existing columns in the table.
|
||||||
|
*/
|
||||||
|
async addColumns(newColumnTransforms: AddColumnsSql[]): Promise<void> {
|
||||||
|
await this.inner.addColumns(newColumnTransforms);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Alter the name or nullability of columns.
|
||||||
|
*
|
||||||
|
* @param columnAlterations One or more alterations to apply to columns.
|
||||||
|
*/
|
||||||
|
async alterColumns(columnAlterations: ColumnAlteration[]): Promise<void> {
|
||||||
|
await this.inner.alterColumns(columnAlterations);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Drop one or more columns from the dataset
|
||||||
|
*
|
||||||
|
* This is a metadata-only operation and does not remove the data from the
|
||||||
|
* underlying storage. In order to remove the data, you must subsequently
|
||||||
|
* call ``compact_files`` to rewrite the data without the removed columns and
|
||||||
|
* then call ``cleanup_files`` to remove the old files.
|
||||||
|
*
|
||||||
|
* @param columnNames The names of the columns to drop. These can be nested
|
||||||
|
* column references (e.g. "a.b.c") or top-level column
|
||||||
|
* names (e.g. "a").
|
||||||
|
*/
|
||||||
|
async dropColumns(columnNames: string[]): Promise<void> {
|
||||||
|
await this.inner.dropColumns(columnNames);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -1,3 +1,3 @@
|
|||||||
# `vectordb-darwin-arm64`
|
# `lancedb-darwin-arm64`
|
||||||
|
|
||||||
This is the **aarch64-apple-darwin** binary for `vectordb`
|
This is the **aarch64-apple-darwin** binary for `lancedb`
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb-darwin-arm64",
|
"name": "lancedb-darwin-arm64",
|
||||||
"version": "0.4.3",
|
"version": "0.4.3",
|
||||||
"os": [
|
"os": [
|
||||||
"darwin"
|
"darwin"
|
||||||
@@ -7,9 +7,9 @@
|
|||||||
"cpu": [
|
"cpu": [
|
||||||
"arm64"
|
"arm64"
|
||||||
],
|
],
|
||||||
"main": "vectordb.darwin-arm64.node",
|
"main": "lancedb.darwin-arm64.node",
|
||||||
"files": [
|
"files": [
|
||||||
"vectordb.darwin-arm64.node"
|
"lancedb.darwin-arm64.node"
|
||||||
],
|
],
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"engines": {
|
"engines": {
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
# `vectordb-darwin-x64`
|
# `lancedb-darwin-x64`
|
||||||
|
|
||||||
This is the **x86_64-apple-darwin** binary for `vectordb`
|
This is the **x86_64-apple-darwin** binary for `lancedb`
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb-darwin-x64",
|
"name": "lancedb-darwin-x64",
|
||||||
"version": "0.4.3",
|
"version": "0.4.3",
|
||||||
"os": [
|
"os": [
|
||||||
"darwin"
|
"darwin"
|
||||||
@@ -7,9 +7,9 @@
|
|||||||
"cpu": [
|
"cpu": [
|
||||||
"x64"
|
"x64"
|
||||||
],
|
],
|
||||||
"main": "vectordb.darwin-x64.node",
|
"main": "lancedb.darwin-x64.node",
|
||||||
"files": [
|
"files": [
|
||||||
"vectordb.darwin-x64.node"
|
"lancedb.darwin-x64.node"
|
||||||
],
|
],
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"engines": {
|
"engines": {
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
# `vectordb-linux-arm64-gnu`
|
# `lancedb-linux-arm64-gnu`
|
||||||
|
|
||||||
This is the **aarch64-unknown-linux-gnu** binary for `vectordb`
|
This is the **aarch64-unknown-linux-gnu** binary for `lancedb`
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb-linux-arm64-gnu",
|
"name": "lancedb-linux-arm64-gnu",
|
||||||
"version": "0.4.3",
|
"version": "0.4.3",
|
||||||
"os": [
|
"os": [
|
||||||
"linux"
|
"linux"
|
||||||
@@ -7,9 +7,9 @@
|
|||||||
"cpu": [
|
"cpu": [
|
||||||
"arm64"
|
"arm64"
|
||||||
],
|
],
|
||||||
"main": "vectordb.linux-arm64-gnu.node",
|
"main": "lancedb.linux-arm64-gnu.node",
|
||||||
"files": [
|
"files": [
|
||||||
"vectordb.linux-arm64-gnu.node"
|
"lancedb.linux-arm64-gnu.node"
|
||||||
],
|
],
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"engines": {
|
"engines": {
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
# `vectordb-linux-x64-gnu`
|
# `lancedb-linux-x64-gnu`
|
||||||
|
|
||||||
This is the **x86_64-unknown-linux-gnu** binary for `vectordb`
|
This is the **x86_64-unknown-linux-gnu** binary for `lancedb`
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb-linux-x64-gnu",
|
"name": "lancedb-linux-x64-gnu",
|
||||||
"version": "0.4.3",
|
"version": "0.4.3",
|
||||||
"os": [
|
"os": [
|
||||||
"linux"
|
"linux"
|
||||||
@@ -7,9 +7,9 @@
|
|||||||
"cpu": [
|
"cpu": [
|
||||||
"x64"
|
"x64"
|
||||||
],
|
],
|
||||||
"main": "vectordb.linux-x64-gnu.node",
|
"main": "lancedb.linux-x64-gnu.node",
|
||||||
"files": [
|
"files": [
|
||||||
"vectordb.linux-x64-gnu.node"
|
"lancedb.linux-x64-gnu.node"
|
||||||
],
|
],
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"engines": {
|
"engines": {
|
||||||
|
|||||||
1087
nodejs/package-lock.json
generated
1087
nodejs/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -1,10 +1,10 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "lancedb",
|
||||||
"version": "0.4.3",
|
"version": "0.4.3",
|
||||||
"main": "./dist/index.js",
|
"main": "./dist/index.js",
|
||||||
"types": "./dist/index.d.ts",
|
"types": "./dist/index.d.ts",
|
||||||
"napi": {
|
"napi": {
|
||||||
"name": "vectordb-nodejs",
|
"name": "lancedb-nodejs",
|
||||||
"triples": {
|
"triples": {
|
||||||
"defaults": false,
|
"defaults": false,
|
||||||
"additional": [
|
"additional": [
|
||||||
@@ -18,7 +18,7 @@
|
|||||||
"license": "Apache 2.0",
|
"license": "Apache 2.0",
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@napi-rs/cli": "^2.18.0",
|
"@napi-rs/cli": "^2.18.0",
|
||||||
"@types/jest": "^29.5.11",
|
"@types/jest": "^29.1.2",
|
||||||
"@typescript-eslint/eslint-plugin": "^6.19.0",
|
"@typescript-eslint/eslint-plugin": "^6.19.0",
|
||||||
"@typescript-eslint/parser": "^6.19.0",
|
"@typescript-eslint/parser": "^6.19.0",
|
||||||
"eslint": "^8.56.0",
|
"eslint": "^8.56.0",
|
||||||
@@ -45,23 +45,24 @@
|
|||||||
],
|
],
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"artifacts": "napi artifacts",
|
"artifacts": "napi artifacts",
|
||||||
"build:native": "napi build --platform --release --js vectordb/native.js --dts vectordb/native.d.ts dist/",
|
"build:native": "napi build --platform --release --js lancedb/native.js --dts lancedb/native.d.ts dist/",
|
||||||
"build:debug": "napi build --platform --dts ../vectordb/native.d.ts --js ../vectordb/native.js dist/",
|
"build:debug": "napi build --platform --dts ../lancedb/native.d.ts --js ../lancedb/native.js dist/",
|
||||||
"build": "npm run build:debug && tsc -b",
|
"build": "npm run build:debug && tsc -b",
|
||||||
"docs": "typedoc --plugin typedoc-plugin-markdown vectordb/index.ts",
|
"docs": "typedoc --plugin typedoc-plugin-markdown lancedb/index.ts",
|
||||||
"lint": "eslint vectordb --ext .js,.ts",
|
"lint": "eslint lancedb --ext .js,.ts",
|
||||||
"prepublishOnly": "napi prepublish -t npm",
|
"prepublishOnly": "napi prepublish -t npm",
|
||||||
"test": "npm run build && jest",
|
"//": "maxWorkers=1 is workaround for bigint issue in jest: https://github.com/jestjs/jest/issues/11617#issuecomment-1068732414",
|
||||||
|
"test": "npm run build && jest --maxWorkers=1",
|
||||||
"universal": "napi universal",
|
"universal": "napi universal",
|
||||||
"version": "napi version"
|
"version": "napi version"
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
"optionalDependencies": {
|
||||||
"vectordb-darwin-arm64": "0.4.3",
|
"lancedb-darwin-arm64": "0.4.3",
|
||||||
"vectordb-darwin-x64": "0.4.3",
|
"lancedb-darwin-x64": "0.4.3",
|
||||||
"vectordb-linux-arm64-gnu": "0.4.3",
|
"lancedb-linux-arm64-gnu": "0.4.3",
|
||||||
"vectordb-linux-x64-gnu": "0.4.3"
|
"lancedb-linux-x64-gnu": "0.4.3"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"peerDependencies": {
|
||||||
"apache-arrow": "^15.0.0"
|
"apache-arrow": "^15.0.0"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,29 +12,40 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use napi::bindgen_prelude::*;
|
use napi::bindgen_prelude::*;
|
||||||
use napi_derive::*;
|
use napi_derive::*;
|
||||||
|
|
||||||
use crate::table::Table;
|
use crate::table::Table;
|
||||||
use vectordb::connection::{Connection as LanceDBConnection, Database};
|
use crate::ConnectionOptions;
|
||||||
use vectordb::ipc::ipc_file_to_batches;
|
use lancedb::connection::{ConnectBuilder, Connection as LanceDBConnection};
|
||||||
|
use lancedb::ipc::ipc_file_to_batches;
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub struct Connection {
|
pub struct Connection {
|
||||||
conn: Arc<dyn LanceDBConnection>,
|
conn: LanceDBConnection,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
impl Connection {
|
impl Connection {
|
||||||
/// Create a new Connection instance from the given URI.
|
/// Create a new Connection instance from the given URI.
|
||||||
#[napi(factory)]
|
#[napi(factory)]
|
||||||
pub async fn new(uri: String) -> napi::Result<Self> {
|
pub async fn new(options: ConnectionOptions) -> napi::Result<Self> {
|
||||||
|
let mut builder = ConnectBuilder::new(&options.uri);
|
||||||
|
if let Some(api_key) = options.api_key {
|
||||||
|
builder = builder.api_key(&api_key);
|
||||||
|
}
|
||||||
|
if let Some(host_override) = options.host_override {
|
||||||
|
builder = builder.host_override(&host_override);
|
||||||
|
}
|
||||||
|
if let Some(interval) = options.read_consistency_interval {
|
||||||
|
builder =
|
||||||
|
builder.read_consistency_interval(std::time::Duration::from_secs_f64(interval));
|
||||||
|
}
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
conn: Arc::new(Database::connect(&uri).await.map_err(|e| {
|
conn: builder
|
||||||
napi::Error::from_reason(format!("Failed to connect to database: {}", e))
|
.execute()
|
||||||
})?),
|
.await
|
||||||
|
.map_err(|e| napi::Error::from_reason(format!("{}", e)))?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -59,7 +70,8 @@ impl Connection {
|
|||||||
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
|
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
|
||||||
let tbl = self
|
let tbl = self
|
||||||
.conn
|
.conn
|
||||||
.create_table(&name, Box::new(batches), None)
|
.create_table(&name, Box::new(batches))
|
||||||
|
.execute()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| napi::Error::from_reason(format!("{}", e)))?;
|
.map_err(|e| napi::Error::from_reason(format!("{}", e)))?;
|
||||||
Ok(Table::new(tbl))
|
Ok(Table::new(tbl))
|
||||||
@@ -70,6 +82,7 @@ impl Connection {
|
|||||||
let tbl = self
|
let tbl = self
|
||||||
.conn
|
.conn
|
||||||
.open_table(&name)
|
.open_table(&name)
|
||||||
|
.execute()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| napi::Error::from_reason(format!("{}", e)))?;
|
.map_err(|e| napi::Error::from_reason(format!("{}", e)))?;
|
||||||
Ok(Table::new(tbl))
|
Ok(Table::new(tbl))
|
||||||
|
|||||||
@@ -40,12 +40,12 @@ impl From<MetricType> for LanceMetricType {
|
|||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub struct IndexBuilder {
|
pub struct IndexBuilder {
|
||||||
inner: vectordb::index::IndexBuilder,
|
inner: lancedb::index::IndexBuilder,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
impl IndexBuilder {
|
impl IndexBuilder {
|
||||||
pub fn new(tbl: &dyn vectordb::Table) -> Self {
|
pub fn new(tbl: &dyn lancedb::Table) -> Self {
|
||||||
let inner = tbl.create_index(&[]);
|
let inner = tbl.create_index(&[]);
|
||||||
Self { inner }
|
Self { inner }
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,9 +14,9 @@
|
|||||||
|
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use lance::io::RecordBatchStream;
|
use lance::io::RecordBatchStream;
|
||||||
|
use lancedb::ipc::batches_to_ipc_file;
|
||||||
use napi::bindgen_prelude::*;
|
use napi::bindgen_prelude::*;
|
||||||
use napi_derive::napi;
|
use napi_derive::napi;
|
||||||
use vectordb::ipc::batches_to_ipc_file;
|
|
||||||
|
|
||||||
/** Typescript-style Async Iterator over RecordBatches */
|
/** Typescript-style Async Iterator over RecordBatches */
|
||||||
#[napi]
|
#[napi]
|
||||||
|
|||||||
@@ -22,10 +22,21 @@ mod query;
|
|||||||
mod table;
|
mod table;
|
||||||
|
|
||||||
#[napi(object)]
|
#[napi(object)]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct ConnectionOptions {
|
pub struct ConnectionOptions {
|
||||||
pub uri: String,
|
pub uri: String,
|
||||||
pub api_key: Option<String>,
|
pub api_key: Option<String>,
|
||||||
pub host_override: Option<String>,
|
pub host_override: Option<String>,
|
||||||
|
/// (For LanceDB OSS only): The interval, in seconds, at which to check for
|
||||||
|
/// updates to the table from other processes. If None, then consistency is not
|
||||||
|
/// checked. For performance reasons, this is the default. For strong
|
||||||
|
/// consistency, set this to zero seconds. Then every read will check for
|
||||||
|
/// updates from other processes. As a compromise, you can set this to a
|
||||||
|
/// non-zero value for eventual consistency. If more than that interval
|
||||||
|
/// has passed since the last check, then the table will be checked for updates.
|
||||||
|
/// Note: this consistency only applies to read operations. Write operations are
|
||||||
|
/// always consistent.
|
||||||
|
pub read_consistency_interval: Option<f64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Write mode for writing a table.
|
/// Write mode for writing a table.
|
||||||
@@ -44,5 +55,5 @@ pub struct WriteOptions {
|
|||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub async fn connect(options: ConnectionOptions) -> napi::Result<Connection> {
|
pub async fn connect(options: ConnectionOptions) -> napi::Result<Connection> {
|
||||||
Connection::new(options.uri.clone()).await
|
Connection::new(options).await
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,9 +12,9 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use lancedb::query::Query as LanceDBQuery;
|
||||||
use napi::bindgen_prelude::*;
|
use napi::bindgen_prelude::*;
|
||||||
use napi_derive::napi;
|
use napi_derive::napi;
|
||||||
use vectordb::query::Query as LanceDBQuery;
|
|
||||||
|
|
||||||
use crate::{iterator::RecordBatchIterator, table::Table};
|
use crate::{iterator::RecordBatchIterator, table::Table};
|
||||||
|
|
||||||
|
|||||||
@@ -13,9 +13,13 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use arrow_ipc::writer::FileWriter;
|
use arrow_ipc::writer::FileWriter;
|
||||||
|
use lance::dataset::ColumnAlteration as LanceColumnAlteration;
|
||||||
|
use lancedb::{
|
||||||
|
ipc::ipc_file_to_batches,
|
||||||
|
table::{AddDataOptions, TableRef},
|
||||||
|
};
|
||||||
use napi::bindgen_prelude::*;
|
use napi::bindgen_prelude::*;
|
||||||
use napi_derive::napi;
|
use napi_derive::napi;
|
||||||
use vectordb::{ipc::ipc_file_to_batches, table::TableRef};
|
|
||||||
|
|
||||||
use crate::index::IndexBuilder;
|
use crate::index::IndexBuilder;
|
||||||
use crate::query::Query;
|
use crate::query::Query;
|
||||||
@@ -33,8 +37,12 @@ impl Table {
|
|||||||
|
|
||||||
/// Return Schema as empty Arrow IPC file.
|
/// Return Schema as empty Arrow IPC file.
|
||||||
#[napi]
|
#[napi]
|
||||||
pub fn schema(&self) -> napi::Result<Buffer> {
|
pub async fn schema(&self) -> napi::Result<Buffer> {
|
||||||
let mut writer = FileWriter::try_new(vec![], &self.table.schema())
|
let schema =
|
||||||
|
self.table.schema().await.map_err(|e| {
|
||||||
|
napi::Error::from_reason(format!("Failed to create IPC file: {}", e))
|
||||||
|
})?;
|
||||||
|
let mut writer = FileWriter::try_new(vec![], &schema)
|
||||||
.map_err(|e| napi::Error::from_reason(format!("Failed to create IPC file: {}", e)))?;
|
.map_err(|e| napi::Error::from_reason(format!("Failed to create IPC file: {}", e)))?;
|
||||||
writer
|
writer
|
||||||
.finish()
|
.finish()
|
||||||
@@ -48,17 +56,20 @@ impl Table {
|
|||||||
pub async fn add(&self, buf: Buffer) -> napi::Result<()> {
|
pub async fn add(&self, buf: Buffer) -> napi::Result<()> {
|
||||||
let batches = ipc_file_to_batches(buf.to_vec())
|
let batches = ipc_file_to_batches(buf.to_vec())
|
||||||
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
|
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
|
||||||
self.table.add(Box::new(batches), None).await.map_err(|e| {
|
self.table
|
||||||
napi::Error::from_reason(format!(
|
.add(Box::new(batches), AddDataOptions::default())
|
||||||
"Failed to add batches to table {}: {}",
|
.await
|
||||||
self.table, e
|
.map_err(|e| {
|
||||||
))
|
napi::Error::from_reason(format!(
|
||||||
})
|
"Failed to add batches to table {}: {}",
|
||||||
|
self.table, e
|
||||||
|
))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub async fn count_rows(&self) -> napi::Result<usize> {
|
pub async fn count_rows(&self, filter: Option<String>) -> napi::Result<usize> {
|
||||||
self.table.count_rows().await.map_err(|e| {
|
self.table.count_rows(filter).await.map_err(|e| {
|
||||||
napi::Error::from_reason(format!(
|
napi::Error::from_reason(format!(
|
||||||
"Failed to count rows in table {}: {}",
|
"Failed to count rows in table {}: {}",
|
||||||
self.table, e
|
self.table, e
|
||||||
@@ -85,4 +96,106 @@ impl Table {
|
|||||||
pub fn query(&self) -> Query {
|
pub fn query(&self) -> Query {
|
||||||
Query::new(self)
|
Query::new(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub async fn add_columns(&self, transforms: Vec<AddColumnsSql>) -> napi::Result<()> {
|
||||||
|
let transforms = transforms
|
||||||
|
.into_iter()
|
||||||
|
.map(|sql| (sql.name, sql.value_sql))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let transforms = lance::dataset::NewColumnTransform::SqlExpressions(transforms);
|
||||||
|
self.table
|
||||||
|
.add_columns(transforms, None)
|
||||||
|
.await
|
||||||
|
.map_err(|err| {
|
||||||
|
napi::Error::from_reason(format!(
|
||||||
|
"Failed to add columns to table {}: {}",
|
||||||
|
self.table, err
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub async fn alter_columns(&self, alterations: Vec<ColumnAlteration>) -> napi::Result<()> {
|
||||||
|
for alteration in &alterations {
|
||||||
|
if alteration.rename.is_none() && alteration.nullable.is_none() {
|
||||||
|
return Err(napi::Error::from_reason(
|
||||||
|
"Alteration must have a 'rename' or 'nullable' field.",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let alterations = alterations
|
||||||
|
.into_iter()
|
||||||
|
.map(LanceColumnAlteration::from)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
self.table
|
||||||
|
.alter_columns(&alterations)
|
||||||
|
.await
|
||||||
|
.map_err(|err| {
|
||||||
|
napi::Error::from_reason(format!(
|
||||||
|
"Failed to alter columns in table {}: {}",
|
||||||
|
self.table, err
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[napi]
|
||||||
|
pub async fn drop_columns(&self, columns: Vec<String>) -> napi::Result<()> {
|
||||||
|
let col_refs = columns.iter().map(String::as_str).collect::<Vec<_>>();
|
||||||
|
self.table.drop_columns(&col_refs).await.map_err(|err| {
|
||||||
|
napi::Error::from_reason(format!(
|
||||||
|
"Failed to drop columns from table {}: {}",
|
||||||
|
self.table, err
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A definition of a column alteration. The alteration changes the column at
|
||||||
|
/// `path` to have the new name `name`, to be nullable if `nullable` is true,
|
||||||
|
/// and to have the data type `data_type`. At least one of `rename` or `nullable`
|
||||||
|
/// must be provided.
|
||||||
|
#[napi(object)]
|
||||||
|
pub struct ColumnAlteration {
|
||||||
|
/// The path to the column to alter. This is a dot-separated path to the column.
|
||||||
|
/// If it is a top-level column then it is just the name of the column. If it is
|
||||||
|
/// a nested column then it is the path to the column, e.g. "a.b.c" for a column
|
||||||
|
/// `c` nested inside a column `b` nested inside a column `a`.
|
||||||
|
pub path: String,
|
||||||
|
/// The new name of the column. If not provided then the name will not be changed.
|
||||||
|
/// This must be distinct from the names of all other columns in the table.
|
||||||
|
pub rename: Option<String>,
|
||||||
|
/// Set the new nullability. Note that a nullable column cannot be made non-nullable.
|
||||||
|
pub nullable: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ColumnAlteration> for LanceColumnAlteration {
|
||||||
|
fn from(js: ColumnAlteration) -> Self {
|
||||||
|
let ColumnAlteration {
|
||||||
|
path,
|
||||||
|
rename,
|
||||||
|
nullable,
|
||||||
|
} = js;
|
||||||
|
Self {
|
||||||
|
path,
|
||||||
|
rename,
|
||||||
|
nullable,
|
||||||
|
// TODO: wire up this field
|
||||||
|
data_type: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A definition of a new column to add to a table.
|
||||||
|
#[napi(object)]
|
||||||
|
pub struct AddColumnsSql {
|
||||||
|
/// The name of the new column.
|
||||||
|
pub name: String,
|
||||||
|
/// The values to populate the new column with, as a SQL expression.
|
||||||
|
/// The expression can reference other columns in the table.
|
||||||
|
pub value_sql: String,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
{
|
{
|
||||||
"include": [
|
"include": [
|
||||||
"vectordb/*.ts",
|
"lancedb/*.ts",
|
||||||
"vectordb/**/*.ts",
|
"lancedb/**/*.ts",
|
||||||
"vectordb/*.js",
|
"lancedb/*.js",
|
||||||
],
|
],
|
||||||
"compilerOptions": {
|
"compilerOptions": {
|
||||||
"target": "es2022",
|
"target": "es2022",
|
||||||
@@ -18,7 +18,7 @@
|
|||||||
],
|
],
|
||||||
"typedocOptions": {
|
"typedocOptions": {
|
||||||
"entryPoints": [
|
"entryPoints": [
|
||||||
"vectordb/index.ts"
|
"lancedb/index.ts"
|
||||||
],
|
],
|
||||||
"out": "../docs/src/javascript/",
|
"out": "../docs/src/javascript/",
|
||||||
"visibilityFilters": {
|
"visibilityFilters": {
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.5.2
|
current_version = 0.6.0
|
||||||
commit = True
|
commit = True
|
||||||
message = [python] Bump version: {current_version} → {new_version}
|
message = [python] Bump version: {current_version} → {new_version}
|
||||||
tag = True
|
tag = True
|
||||||
|
|||||||
26
python/Cargo.toml
Normal file
26
python/Cargo.toml
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
[package]
|
||||||
|
name = "lancedb-python"
|
||||||
|
version = "0.4.10"
|
||||||
|
edition.workspace = true
|
||||||
|
description = "Python bindings for LanceDB"
|
||||||
|
license.workspace = true
|
||||||
|
repository.workspace = true
|
||||||
|
keywords.workspace = true
|
||||||
|
categories.workspace = true
|
||||||
|
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
name = "_lancedb"
|
||||||
|
crate-type = ["cdylib"]
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
lancedb = { path = "../rust/lancedb" }
|
||||||
|
env_logger = "0.10"
|
||||||
|
pyo3 = { version = "0.20", features = ["extension-module", "abi3-py38"] }
|
||||||
|
pyo3-asyncio = { version = "0.20", features = ["attributes", "tokio-runtime"] }
|
||||||
|
|
||||||
|
# Prevent dynamic linking of lzma, which comes from datafusion
|
||||||
|
lzma-sys = { version = "*", features = ["static"] }
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
pyo3-build-config = { version = "0.20.3", features = ["extension-module", "abi3-py38"] }
|
||||||
@@ -20,10 +20,10 @@ results = table.search([0.1, 0.3]).limit(20).to_list()
|
|||||||
print(results)
|
print(results)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Development
|
## Development
|
||||||
|
|
||||||
Create a virtual environment and activate it:
|
LanceDb is based on the rust crate `lancedb` and is built with maturin. In order to build with maturin
|
||||||
|
you will either need a conda environment or a virtual environment (venv).
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m venv venv
|
python -m venv venv
|
||||||
@@ -33,7 +33,15 @@ python -m venv venv
|
|||||||
Install the necessary packages:
|
Install the necessary packages:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m pip install .
|
python -m pip install .[tests,dev]
|
||||||
|
```
|
||||||
|
|
||||||
|
To build the python package you can use maturin:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# This will build the rust bindings and place them in the appropriate place
|
||||||
|
# in your venv or conda environment
|
||||||
|
matruin develop
|
||||||
```
|
```
|
||||||
|
|
||||||
To run the unit tests:
|
To run the unit tests:
|
||||||
@@ -42,6 +50,12 @@ To run the unit tests:
|
|||||||
pytest
|
pytest
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To run the doc tests:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pytest --doctest-modules python/lancedb
|
||||||
|
```
|
||||||
|
|
||||||
To run linter and automatically fix all errors:
|
To run linter and automatically fix all errors:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -55,31 +69,27 @@ If any packages are missing, install them with:
|
|||||||
pip install <PACKAGE_NAME>
|
pip install <PACKAGE_NAME>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
___
|
___
|
||||||
For **Windows** users, there may be errors when installing packages, so these commands may be helpful:
|
For **Windows** users, there may be errors when installing packages, so these commands may be helpful:
|
||||||
|
|
||||||
Activate the virtual environment:
|
Activate the virtual environment:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
. .\venv\Scripts\activate
|
. .\venv\Scripts\activate
|
||||||
```
|
```
|
||||||
|
|
||||||
You may need to run the installs separately:
|
You may need to run the installs separately:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install -e .[tests]
|
pip install -e .[tests]
|
||||||
pip install -e .[dev]
|
pip install -e .[dev]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
`tantivy` requires `rust` to be installed, so install it with `conda`, as it doesn't support windows installation:
|
`tantivy` requires `rust` to be installed, so install it with `conda`, as it doesn't support windows installation:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install wheel
|
pip install wheel
|
||||||
pip install cargo
|
pip install cargo
|
||||||
conda install rust
|
conda install rust
|
||||||
pip install tantivy
|
pip install tantivy
|
||||||
```
|
```
|
||||||
|
|
||||||
To run the unit tests:
|
|
||||||
```bash
|
|
||||||
pytest
|
|
||||||
```
|
|
||||||
|
|||||||
3
python/build.rs
Normal file
3
python/build.rs
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
fn main() {
|
||||||
|
pyo3_build_config::add_extension_module_link_args();
|
||||||
|
}
|
||||||
@@ -1,76 +0,0 @@
|
|||||||
# Copyright 2023 LanceDB Developers
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import importlib.metadata
|
|
||||||
import os
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
__version__ = importlib.metadata.version("lancedb")
|
|
||||||
|
|
||||||
from .common import URI
|
|
||||||
from .db import DBConnection, LanceDBConnection
|
|
||||||
from .remote.db import RemoteDBConnection
|
|
||||||
from .schema import vector # noqa: F401
|
|
||||||
from .utils import sentry_log # noqa: F401
|
|
||||||
|
|
||||||
|
|
||||||
def connect(
|
|
||||||
uri: URI,
|
|
||||||
*,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
region: str = "us-east-1",
|
|
||||||
host_override: Optional[str] = None,
|
|
||||||
) -> DBConnection:
|
|
||||||
"""Connect to a LanceDB database.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
uri: str or Path
|
|
||||||
The uri of the database.
|
|
||||||
api_key: str, optional
|
|
||||||
If presented, connect to LanceDB cloud.
|
|
||||||
Otherwise, connect to a database on file system or cloud storage.
|
|
||||||
Can be set via environment variable `LANCEDB_API_KEY`.
|
|
||||||
region: str, default "us-east-1"
|
|
||||||
The region to use for LanceDB Cloud.
|
|
||||||
host_override: str, optional
|
|
||||||
The override url for LanceDB Cloud.
|
|
||||||
|
|
||||||
Examples
|
|
||||||
--------
|
|
||||||
|
|
||||||
For a local directory, provide a path for the database:
|
|
||||||
|
|
||||||
>>> import lancedb
|
|
||||||
>>> db = lancedb.connect("~/.lancedb")
|
|
||||||
|
|
||||||
For object storage, use a URI prefix:
|
|
||||||
|
|
||||||
>>> db = lancedb.connect("s3://my-bucket/lancedb")
|
|
||||||
|
|
||||||
Connect to LancdDB cloud:
|
|
||||||
|
|
||||||
>>> db = lancedb.connect("db://my_database", api_key="ldb_...")
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
conn : DBConnection
|
|
||||||
A connection to a LanceDB database.
|
|
||||||
"""
|
|
||||||
if isinstance(uri, str) and uri.startswith("db://"):
|
|
||||||
if api_key is None:
|
|
||||||
api_key = os.environ.get("LANCEDB_API_KEY")
|
|
||||||
if api_key is None:
|
|
||||||
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
|
|
||||||
return RemoteDBConnection(uri, api_key, region, host_override)
|
|
||||||
return LanceDBConnection(uri)
|
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.5.2"
|
version = "0.6.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"deprecation",
|
"deprecation",
|
||||||
"pylance==0.9.12",
|
"pylance==0.10.1",
|
||||||
"ratelimiter~=1.0",
|
"ratelimiter~=1.0",
|
||||||
"retry>=0.9.2",
|
"retry>=0.9.2",
|
||||||
"tqdm>=4.27.0",
|
"tqdm>=4.27.0",
|
||||||
@@ -14,7 +14,7 @@ dependencies = [
|
|||||||
"pyyaml>=6.0",
|
"pyyaml>=6.0",
|
||||||
"click>=8.1.7",
|
"click>=8.1.7",
|
||||||
"requests>=2.31.0",
|
"requests>=2.31.0",
|
||||||
"overrides>=0.7"
|
"overrides>=0.7",
|
||||||
]
|
]
|
||||||
description = "lancedb"
|
description = "lancedb"
|
||||||
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
|
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
|
||||||
@@ -26,7 +26,7 @@ keywords = [
|
|||||||
"data-science",
|
"data-science",
|
||||||
"machine-learning",
|
"machine-learning",
|
||||||
"arrow",
|
"arrow",
|
||||||
"data-analytics"
|
"data-analytics",
|
||||||
]
|
]
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Development Status :: 3 - Alpha",
|
"Development Status :: 3 - Alpha",
|
||||||
@@ -48,21 +48,53 @@ classifiers = [
|
|||||||
repository = "https://github.com/lancedb/lancedb"
|
repository = "https://github.com/lancedb/lancedb"
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars"]
|
tests = [
|
||||||
|
"aiohttp",
|
||||||
|
"pandas>=1.4",
|
||||||
|
"pytest",
|
||||||
|
"pytest-mock",
|
||||||
|
"pytest-asyncio",
|
||||||
|
"duckdb",
|
||||||
|
"pytz",
|
||||||
|
"polars>=0.19",
|
||||||
|
]
|
||||||
dev = ["ruff", "pre-commit"]
|
dev = ["ruff", "pre-commit"]
|
||||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
docs = [
|
||||||
|
"mkdocs",
|
||||||
|
"mkdocs-jupyter",
|
||||||
|
"mkdocs-material",
|
||||||
|
"mkdocstrings[python]",
|
||||||
|
"mkdocs-ultralytics-plugin==0.0.44",
|
||||||
|
]
|
||||||
clip = ["torch", "pillow", "open-clip"]
|
clip = ["torch", "pillow", "open-clip"]
|
||||||
embeddings = ["openai>=1.6.1", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "huggingface_hub",
|
embeddings = [
|
||||||
"InstructorEmbedding", "google.generativeai", "boto3>=1.28.57", "awscli>=1.29.57", "botocore>=1.31.57"]
|
"openai>=1.6.1",
|
||||||
|
"sentence-transformers",
|
||||||
|
"torch",
|
||||||
|
"pillow",
|
||||||
|
"open-clip-torch",
|
||||||
|
"cohere",
|
||||||
|
"huggingface_hub",
|
||||||
|
"InstructorEmbedding",
|
||||||
|
"google.generativeai",
|
||||||
|
"boto3>=1.28.57",
|
||||||
|
"awscli>=1.29.57",
|
||||||
|
"botocore>=1.31.57",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.maturin]
|
||||||
|
python-source = "python"
|
||||||
|
module-name = "lancedb._lancedb"
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
lancedb = "lancedb.cli.cli:cli"
|
lancedb = "lancedb.cli.cli:cli"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["setuptools", "wheel"]
|
requires = ["maturin>=1.4"]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "maturin"
|
||||||
|
|
||||||
[tool.ruff]
|
|
||||||
|
[tool.ruff.lint]
|
||||||
select = ["F", "E", "W", "I", "G", "TCH", "PERF"]
|
select = ["F", "E", "W", "I", "G", "TCH", "PERF"]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
@@ -70,5 +102,5 @@ addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py"
|
|||||||
|
|
||||||
markers = [
|
markers = [
|
||||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||||
"asyncio"
|
"asyncio",
|
||||||
]
|
]
|
||||||
|
|||||||
175
python/python/lancedb/__init__.py
Normal file
175
python/python/lancedb/__init__.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
# Copyright 2023 LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import importlib.metadata
|
||||||
|
import os
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from datetime import timedelta
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
__version__ = importlib.metadata.version("lancedb")
|
||||||
|
|
||||||
|
from ._lancedb import connect as lancedb_connect
|
||||||
|
from .common import URI, sanitize_uri
|
||||||
|
from .db import AsyncConnection, AsyncLanceDBConnection, DBConnection, LanceDBConnection
|
||||||
|
from .remote.db import RemoteDBConnection
|
||||||
|
from .schema import vector # noqa: F401
|
||||||
|
from .utils import sentry_log # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
def connect(
|
||||||
|
uri: URI,
|
||||||
|
*,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
region: str = "us-east-1",
|
||||||
|
host_override: Optional[str] = None,
|
||||||
|
read_consistency_interval: Optional[timedelta] = None,
|
||||||
|
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
|
||||||
|
) -> DBConnection:
|
||||||
|
"""Connect to a LanceDB database.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
uri: str or Path
|
||||||
|
The uri of the database.
|
||||||
|
api_key: str, optional
|
||||||
|
If presented, connect to LanceDB cloud.
|
||||||
|
Otherwise, connect to a database on file system or cloud storage.
|
||||||
|
Can be set via environment variable `LANCEDB_API_KEY`.
|
||||||
|
region: str, default "us-east-1"
|
||||||
|
The region to use for LanceDB Cloud.
|
||||||
|
host_override: str, optional
|
||||||
|
The override url for LanceDB Cloud.
|
||||||
|
read_consistency_interval: timedelta, default None
|
||||||
|
(For LanceDB OSS only)
|
||||||
|
The interval at which to check for updates to the table from other
|
||||||
|
processes. If None, then consistency is not checked. For performance
|
||||||
|
reasons, this is the default. For strong consistency, set this to
|
||||||
|
zero seconds. Then every read will check for updates from other
|
||||||
|
processes. As a compromise, you can set this to a non-zero timedelta
|
||||||
|
for eventual consistency. If more than that interval has passed since
|
||||||
|
the last check, then the table will be checked for updates. Note: this
|
||||||
|
consistency only applies to read operations. Write operations are
|
||||||
|
always consistent.
|
||||||
|
request_thread_pool: int or ThreadPoolExecutor, optional
|
||||||
|
The thread pool to use for making batch requests to the LanceDB Cloud API.
|
||||||
|
If an integer, then a ThreadPoolExecutor will be created with that
|
||||||
|
number of threads. If None, then a ThreadPoolExecutor will be created
|
||||||
|
with the default number of threads. If a ThreadPoolExecutor, then that
|
||||||
|
executor will be used for making requests. This is for LanceDB Cloud
|
||||||
|
only and is only used when making batch requests (i.e., passing in
|
||||||
|
multiple queries to the search method at once).
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
|
||||||
|
For a local directory, provide a path for the database:
|
||||||
|
|
||||||
|
>>> import lancedb
|
||||||
|
>>> db = lancedb.connect("~/.lancedb")
|
||||||
|
|
||||||
|
For object storage, use a URI prefix:
|
||||||
|
|
||||||
|
>>> db = lancedb.connect("s3://my-bucket/lancedb")
|
||||||
|
|
||||||
|
Connect to LancdDB cloud:
|
||||||
|
|
||||||
|
>>> db = lancedb.connect("db://my_database", api_key="ldb_...")
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
conn : DBConnection
|
||||||
|
A connection to a LanceDB database.
|
||||||
|
"""
|
||||||
|
if isinstance(uri, str) and uri.startswith("db://"):
|
||||||
|
if api_key is None:
|
||||||
|
api_key = os.environ.get("LANCEDB_API_KEY")
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
|
||||||
|
if isinstance(request_thread_pool, int):
|
||||||
|
request_thread_pool = ThreadPoolExecutor(request_thread_pool)
|
||||||
|
return RemoteDBConnection(
|
||||||
|
uri, api_key, region, host_override, request_thread_pool=request_thread_pool
|
||||||
|
)
|
||||||
|
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)
|
||||||
|
|
||||||
|
|
||||||
|
async def connect_async(
|
||||||
|
uri: URI,
|
||||||
|
*,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
region: str = "us-east-1",
|
||||||
|
host_override: Optional[str] = None,
|
||||||
|
read_consistency_interval: Optional[timedelta] = None,
|
||||||
|
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
|
||||||
|
) -> AsyncConnection:
|
||||||
|
"""Connect to a LanceDB database.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
uri: str or Path
|
||||||
|
The uri of the database.
|
||||||
|
api_key: str, optional
|
||||||
|
If present, connect to LanceDB cloud.
|
||||||
|
Otherwise, connect to a database on file system or cloud storage.
|
||||||
|
Can be set via environment variable `LANCEDB_API_KEY`.
|
||||||
|
region: str, default "us-east-1"
|
||||||
|
The region to use for LanceDB Cloud.
|
||||||
|
host_override: str, optional
|
||||||
|
The override url for LanceDB Cloud.
|
||||||
|
read_consistency_interval: timedelta, default None
|
||||||
|
(For LanceDB OSS only)
|
||||||
|
The interval at which to check for updates to the table from other
|
||||||
|
processes. If None, then consistency is not checked. For performance
|
||||||
|
reasons, this is the default. For strong consistency, set this to
|
||||||
|
zero seconds. Then every read will check for updates from other
|
||||||
|
processes. As a compromise, you can set this to a non-zero timedelta
|
||||||
|
for eventual consistency. If more than that interval has passed since
|
||||||
|
the last check, then the table will be checked for updates. Note: this
|
||||||
|
consistency only applies to read operations. Write operations are
|
||||||
|
always consistent.
|
||||||
|
request_thread_pool: int or ThreadPoolExecutor, optional
|
||||||
|
The thread pool to use for making batch requests to the LanceDB Cloud API.
|
||||||
|
If an integer, then a ThreadPoolExecutor will be created with that
|
||||||
|
number of threads. If None, then a ThreadPoolExecutor will be created
|
||||||
|
with the default number of threads. If a ThreadPoolExecutor, then that
|
||||||
|
executor will be used for making requests. This is for LanceDB Cloud
|
||||||
|
only and is only used when making batch requests (i.e., passing in
|
||||||
|
multiple queries to the search method at once).
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
|
||||||
|
For a local directory, provide a path for the database:
|
||||||
|
|
||||||
|
>>> import lancedb
|
||||||
|
>>> db = lancedb.connect("~/.lancedb")
|
||||||
|
|
||||||
|
For object storage, use a URI prefix:
|
||||||
|
|
||||||
|
>>> db = lancedb.connect("s3://my-bucket/lancedb")
|
||||||
|
|
||||||
|
Connect to LancdDB cloud:
|
||||||
|
|
||||||
|
>>> db = lancedb.connect("db://my_database", api_key="ldb_...")
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
conn : DBConnection
|
||||||
|
A connection to a LanceDB database.
|
||||||
|
"""
|
||||||
|
return AsyncLanceDBConnection(
|
||||||
|
await lancedb_connect(
|
||||||
|
sanitize_uri(uri), api_key, region, host_override, read_consistency_interval
|
||||||
|
)
|
||||||
|
)
|
||||||
12
python/python/lancedb/_lancedb.pyi
Normal file
12
python/python/lancedb/_lancedb.pyi
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
class Connection(object):
|
||||||
|
async def table_names(self) -> list[str]: ...
|
||||||
|
|
||||||
|
async def connect(
|
||||||
|
uri: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
region: Optional[str],
|
||||||
|
host_override: Optional[str],
|
||||||
|
read_consistency_interval: Optional[float],
|
||||||
|
) -> Connection: ...
|
||||||
@@ -16,9 +16,9 @@ from typing import Iterable, List, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
from .util import safe_import
|
from .util import safe_import_pandas
|
||||||
|
|
||||||
pd = safe_import("pandas")
|
pd = safe_import_pandas()
|
||||||
|
|
||||||
DATA = Union[List[dict], dict, "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]]
|
DATA = Union[List[dict], dict, "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]]
|
||||||
VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray]
|
VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray]
|
||||||
@@ -34,3 +34,7 @@ class Credential(str):
|
|||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return "********"
|
return "********"
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_uri(uri: URI) -> str:
|
||||||
|
return str(uri)
|
||||||
@@ -16,9 +16,9 @@ import deprecation
|
|||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
from .exceptions import MissingColumnError, MissingValueError
|
from .exceptions import MissingColumnError, MissingValueError
|
||||||
from .util import safe_import
|
from .util import safe_import_pandas
|
||||||
|
|
||||||
pd = safe_import("pandas")
|
pd = safe_import_pandas()
|
||||||
|
|
||||||
|
|
||||||
def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
|
def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
|
||||||
@@ -26,6 +26,9 @@ from .table import LanceTable, Table
|
|||||||
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
|
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
from ._lancedb import Connection as LanceDbConnection
|
||||||
from .common import DATA, URI
|
from .common import DATA, URI
|
||||||
from .embeddings import EmbeddingFunctionConfig
|
from .embeddings import EmbeddingFunctionConfig
|
||||||
from .pydantic import LanceModel
|
from .pydantic import LanceModel
|
||||||
@@ -38,14 +41,21 @@ class DBConnection(EnforceOverrides):
|
|||||||
def table_names(
|
def table_names(
|
||||||
self, page_token: Optional[str] = None, limit: int = 10
|
self, page_token: Optional[str] = None, limit: int = 10
|
||||||
) -> Iterable[str]:
|
) -> Iterable[str]:
|
||||||
"""List all table in this database
|
"""List all tables in this database, in sorted order
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
page_token: str, optional
|
page_token: str, optional
|
||||||
The token to use for pagination. If not present, start from the beginning.
|
The token to use for pagination. If not present, start from the beginning.
|
||||||
|
Typically, this token is last table name from the previous page.
|
||||||
|
Only supported by LanceDb Cloud.
|
||||||
limit: int, default 10
|
limit: int, default 10
|
||||||
The size of the page to return.
|
The size of the page to return.
|
||||||
|
Only supported by LanceDb Cloud.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Iterable of str
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -118,7 +128,7 @@ class DBConnection(EnforceOverrides):
|
|||||||
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
||||||
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
||||||
>>> db.create_table("my_table", data)
|
>>> db.create_table("my_table", data)
|
||||||
LanceTable(my_table)
|
LanceTable(connection=..., name="my_table")
|
||||||
>>> db["my_table"].head()
|
>>> db["my_table"].head()
|
||||||
pyarrow.Table
|
pyarrow.Table
|
||||||
vector: fixed_size_list<item: float>[2]
|
vector: fixed_size_list<item: float>[2]
|
||||||
@@ -139,7 +149,7 @@ class DBConnection(EnforceOverrides):
|
|||||||
... "long": [-122.7, -74.1]
|
... "long": [-122.7, -74.1]
|
||||||
... })
|
... })
|
||||||
>>> db.create_table("table2", data)
|
>>> db.create_table("table2", data)
|
||||||
LanceTable(table2)
|
LanceTable(connection=..., name="table2")
|
||||||
>>> db["table2"].head()
|
>>> db["table2"].head()
|
||||||
pyarrow.Table
|
pyarrow.Table
|
||||||
vector: fixed_size_list<item: float>[2]
|
vector: fixed_size_list<item: float>[2]
|
||||||
@@ -161,7 +171,7 @@ class DBConnection(EnforceOverrides):
|
|||||||
... pa.field("long", pa.float32())
|
... pa.field("long", pa.float32())
|
||||||
... ])
|
... ])
|
||||||
>>> db.create_table("table3", data, schema = custom_schema)
|
>>> db.create_table("table3", data, schema = custom_schema)
|
||||||
LanceTable(table3)
|
LanceTable(connection=..., name="table3")
|
||||||
>>> db["table3"].head()
|
>>> db["table3"].head()
|
||||||
pyarrow.Table
|
pyarrow.Table
|
||||||
vector: fixed_size_list<item: float>[2]
|
vector: fixed_size_list<item: float>[2]
|
||||||
@@ -195,7 +205,7 @@ class DBConnection(EnforceOverrides):
|
|||||||
... pa.field("price", pa.float32()),
|
... pa.field("price", pa.float32()),
|
||||||
... ])
|
... ])
|
||||||
>>> db.create_table("table4", make_batches(), schema=schema)
|
>>> db.create_table("table4", make_batches(), schema=schema)
|
||||||
LanceTable(table4)
|
LanceTable(connection=..., name="table4")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -243,6 +253,16 @@ class LanceDBConnection(DBConnection):
|
|||||||
----------
|
----------
|
||||||
uri: str or Path
|
uri: str or Path
|
||||||
The root uri of the database.
|
The root uri of the database.
|
||||||
|
read_consistency_interval: timedelta, default None
|
||||||
|
The interval at which to check for updates to the table from other
|
||||||
|
processes. If None, then consistency is not checked. For performance
|
||||||
|
reasons, this is the default. For strong consistency, set this to
|
||||||
|
zero seconds. Then every read will check for updates from other
|
||||||
|
processes. As a compromise, you can set this to a non-zero timedelta
|
||||||
|
for eventual consistency. If more than that interval has passed since
|
||||||
|
the last check, then the table will be checked for updates. Note: this
|
||||||
|
consistency only applies to read operations. Write operations are
|
||||||
|
always consistent.
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
@@ -250,22 +270,24 @@ class LanceDBConnection(DBConnection):
|
|||||||
>>> db = lancedb.connect("./.lancedb")
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
>>> db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2},
|
>>> db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2},
|
||||||
... {"vector": [0.5, 1.3], "b": 4}])
|
... {"vector": [0.5, 1.3], "b": 4}])
|
||||||
LanceTable(my_table)
|
LanceTable(connection=..., name="my_table")
|
||||||
>>> db.create_table("another_table", data=[{"vector": [0.4, 0.4], "b": 6}])
|
>>> db.create_table("another_table", data=[{"vector": [0.4, 0.4], "b": 6}])
|
||||||
LanceTable(another_table)
|
LanceTable(connection=..., name="another_table")
|
||||||
>>> sorted(db.table_names())
|
>>> sorted(db.table_names())
|
||||||
['another_table', 'my_table']
|
['another_table', 'my_table']
|
||||||
>>> len(db)
|
>>> len(db)
|
||||||
2
|
2
|
||||||
>>> db["my_table"]
|
>>> db["my_table"]
|
||||||
LanceTable(my_table)
|
LanceTable(connection=..., name="my_table")
|
||||||
>>> "my_table" in db
|
>>> "my_table" in db
|
||||||
True
|
True
|
||||||
>>> db.drop_table("my_table")
|
>>> db.drop_table("my_table")
|
||||||
>>> db.drop_table("another_table")
|
>>> db.drop_table("another_table")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, uri: URI):
|
def __init__(
|
||||||
|
self, uri: URI, *, read_consistency_interval: Optional[timedelta] = None
|
||||||
|
):
|
||||||
if not isinstance(uri, Path):
|
if not isinstance(uri, Path):
|
||||||
scheme = get_uri_scheme(uri)
|
scheme = get_uri_scheme(uri)
|
||||||
is_local = isinstance(uri, Path) or scheme == "file"
|
is_local = isinstance(uri, Path) or scheme == "file"
|
||||||
@@ -277,6 +299,14 @@ class LanceDBConnection(DBConnection):
|
|||||||
self._uri = str(uri)
|
self._uri = str(uri)
|
||||||
|
|
||||||
self._entered = False
|
self._entered = False
|
||||||
|
self.read_consistency_interval = read_consistency_interval
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
val = f"{self.__class__.__name__}({self._uri}"
|
||||||
|
if self.read_consistency_interval is not None:
|
||||||
|
val += f", read_consistency_interval={repr(self.read_consistency_interval)}"
|
||||||
|
val += ")"
|
||||||
|
return val
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def uri(self) -> str:
|
def uri(self) -> str:
|
||||||
@@ -390,3 +420,254 @@ class LanceDBConnection(DBConnection):
|
|||||||
def drop_database(self):
|
def drop_database(self):
|
||||||
filesystem, path = fs_from_uri(self.uri)
|
filesystem, path = fs_from_uri(self.uri)
|
||||||
filesystem.delete_dir(path)
|
filesystem.delete_dir(path)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncConnection(EnforceOverrides):
|
||||||
|
"""An active LanceDB connection interface."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def table_names(
|
||||||
|
self, *, page_token: Optional[str] = None, limit: int = 10
|
||||||
|
) -> Iterable[str]:
|
||||||
|
"""List all tables in this database, in sorted order
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
page_token: str, optional
|
||||||
|
The token to use for pagination. If not present, start from the beginning.
|
||||||
|
Typically, this token is last table name from the previous page.
|
||||||
|
Only supported by LanceDb Cloud.
|
||||||
|
limit: int, default 10
|
||||||
|
The size of the page to return.
|
||||||
|
Only supported by LanceDb Cloud.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Iterable of str
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def create_table(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
data: Optional[DATA] = None,
|
||||||
|
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||||
|
mode: str = "create",
|
||||||
|
exist_ok: bool = False,
|
||||||
|
on_bad_vectors: str = "error",
|
||||||
|
fill_value: float = 0.0,
|
||||||
|
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||||
|
) -> Table:
|
||||||
|
"""Create a [Table][lancedb.table.Table] in the database.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name: str
|
||||||
|
The name of the table.
|
||||||
|
data: The data to initialize the table, *optional*
|
||||||
|
User must provide at least one of `data` or `schema`.
|
||||||
|
Acceptable types are:
|
||||||
|
|
||||||
|
- dict or list-of-dict
|
||||||
|
|
||||||
|
- pandas.DataFrame
|
||||||
|
|
||||||
|
- pyarrow.Table or pyarrow.RecordBatch
|
||||||
|
schema: The schema of the table, *optional*
|
||||||
|
Acceptable types are:
|
||||||
|
|
||||||
|
- pyarrow.Schema
|
||||||
|
|
||||||
|
- [LanceModel][lancedb.pydantic.LanceModel]
|
||||||
|
mode: str; default "create"
|
||||||
|
The mode to use when creating the table.
|
||||||
|
Can be either "create" or "overwrite".
|
||||||
|
By default, if the table already exists, an exception is raised.
|
||||||
|
If you want to overwrite the table, use mode="overwrite".
|
||||||
|
exist_ok: bool, default False
|
||||||
|
If a table by the same name already exists, then raise an exception
|
||||||
|
if exist_ok=False. If exist_ok=True, then open the existing table;
|
||||||
|
it will not add the provided data but will validate against any
|
||||||
|
schema that's specified.
|
||||||
|
on_bad_vectors: str, default "error"
|
||||||
|
What to do if any of the vectors are not the same size or contains NaNs.
|
||||||
|
One of "error", "drop", "fill".
|
||||||
|
fill_value: float
|
||||||
|
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
LanceTable
|
||||||
|
A reference to the newly created table.
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
|
||||||
|
The vector index won't be created by default.
|
||||||
|
To create the index, call the `create_index` method on the table.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
|
||||||
|
Can create with list of tuples or dictionaries:
|
||||||
|
|
||||||
|
>>> import lancedb
|
||||||
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
|
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
||||||
|
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
||||||
|
>>> db.create_table("my_table", data)
|
||||||
|
LanceTable(connection=..., name="my_table")
|
||||||
|
>>> db["my_table"].head()
|
||||||
|
pyarrow.Table
|
||||||
|
vector: fixed_size_list<item: float>[2]
|
||||||
|
child 0, item: float
|
||||||
|
lat: double
|
||||||
|
long: double
|
||||||
|
----
|
||||||
|
vector: [[[1.1,1.2],[0.2,1.8]]]
|
||||||
|
lat: [[45.5,40.1]]
|
||||||
|
long: [[-122.7,-74.1]]
|
||||||
|
|
||||||
|
You can also pass a pandas DataFrame:
|
||||||
|
|
||||||
|
>>> import pandas as pd
|
||||||
|
>>> data = pd.DataFrame({
|
||||||
|
... "vector": [[1.1, 1.2], [0.2, 1.8]],
|
||||||
|
... "lat": [45.5, 40.1],
|
||||||
|
... "long": [-122.7, -74.1]
|
||||||
|
... })
|
||||||
|
>>> db.create_table("table2", data)
|
||||||
|
LanceTable(connection=..., name="table2")
|
||||||
|
>>> db["table2"].head()
|
||||||
|
pyarrow.Table
|
||||||
|
vector: fixed_size_list<item: float>[2]
|
||||||
|
child 0, item: float
|
||||||
|
lat: double
|
||||||
|
long: double
|
||||||
|
----
|
||||||
|
vector: [[[1.1,1.2],[0.2,1.8]]]
|
||||||
|
lat: [[45.5,40.1]]
|
||||||
|
long: [[-122.7,-74.1]]
|
||||||
|
|
||||||
|
Data is converted to Arrow before being written to disk. For maximum
|
||||||
|
control over how data is saved, either provide the PyArrow schema to
|
||||||
|
convert to or else provide a [PyArrow Table](pyarrow.Table) directly.
|
||||||
|
|
||||||
|
>>> custom_schema = pa.schema([
|
||||||
|
... pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||||
|
... pa.field("lat", pa.float32()),
|
||||||
|
... pa.field("long", pa.float32())
|
||||||
|
... ])
|
||||||
|
>>> db.create_table("table3", data, schema = custom_schema)
|
||||||
|
LanceTable(connection=..., name="table3")
|
||||||
|
>>> db["table3"].head()
|
||||||
|
pyarrow.Table
|
||||||
|
vector: fixed_size_list<item: float>[2]
|
||||||
|
child 0, item: float
|
||||||
|
lat: float
|
||||||
|
long: float
|
||||||
|
----
|
||||||
|
vector: [[[1.1,1.2],[0.2,1.8]]]
|
||||||
|
lat: [[45.5,40.1]]
|
||||||
|
long: [[-122.7,-74.1]]
|
||||||
|
|
||||||
|
|
||||||
|
It is also possible to create an table from `[Iterable[pa.RecordBatch]]`:
|
||||||
|
|
||||||
|
|
||||||
|
>>> import pyarrow as pa
|
||||||
|
>>> def make_batches():
|
||||||
|
... for i in range(5):
|
||||||
|
... yield pa.RecordBatch.from_arrays(
|
||||||
|
... [
|
||||||
|
... pa.array([[3.1, 4.1], [5.9, 26.5]],
|
||||||
|
... pa.list_(pa.float32(), 2)),
|
||||||
|
... pa.array(["foo", "bar"]),
|
||||||
|
... pa.array([10.0, 20.0]),
|
||||||
|
... ],
|
||||||
|
... ["vector", "item", "price"],
|
||||||
|
... )
|
||||||
|
>>> schema=pa.schema([
|
||||||
|
... pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||||
|
... pa.field("item", pa.utf8()),
|
||||||
|
... pa.field("price", pa.float32()),
|
||||||
|
... ])
|
||||||
|
>>> db.create_table("table4", make_batches(), schema=schema)
|
||||||
|
LanceTable(connection=..., name="table4")
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def open_table(self, name: str) -> Table:
|
||||||
|
"""Open a Lance Table in the database.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name: str
|
||||||
|
The name of the table.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A LanceTable object representing the table.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def drop_table(self, name: str):
|
||||||
|
"""Drop a table from the database.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name: str
|
||||||
|
The name of the table.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def drop_database(self):
|
||||||
|
"""
|
||||||
|
Drop database
|
||||||
|
This is the same thing as dropping all the tables
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncLanceDBConnection(AsyncConnection):
|
||||||
|
def __init__(self, connection: LanceDbConnection):
|
||||||
|
self._inner = connection
|
||||||
|
|
||||||
|
async def __repr__(self) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def table_names(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
page_token=None,
|
||||||
|
limit=None,
|
||||||
|
) -> Iterable[str]:
|
||||||
|
return await self._inner.table_names()
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def create_table(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
data: Optional[DATA] = None,
|
||||||
|
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||||
|
mode: str = "create",
|
||||||
|
exist_ok: bool = False,
|
||||||
|
on_bad_vectors: str = "error",
|
||||||
|
fill_value: float = 0.0,
|
||||||
|
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||||
|
) -> LanceTable:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def open_table(self, name: str) -> LanceTable:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def drop_table(self, name: str, ignore_missing: bool = False):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def drop_database(self):
|
||||||
|
raise NotImplementedError
|
||||||
@@ -10,7 +10,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import importlib
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
@@ -91,25 +90,6 @@ class EmbeddingFunction(BaseModel, ABC):
|
|||||||
texts = texts.combine_chunks().to_pylist()
|
texts = texts.combine_chunks().to_pylist()
|
||||||
return texts
|
return texts
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def safe_import(cls, module: str, mitigation=None):
|
|
||||||
"""
|
|
||||||
Import the specified module. If the module is not installed,
|
|
||||||
raise an ImportError with a helpful message.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
module : str
|
|
||||||
The name of the module to import
|
|
||||||
mitigation : Optional[str]
|
|
||||||
The package(s) to install to mitigate the error.
|
|
||||||
If not provided then the module name will be used.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return importlib.import_module(module)
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(f"Please install {mitigation or module}")
|
|
||||||
|
|
||||||
def safe_model_dump(self):
|
def safe_model_dump(self):
|
||||||
from ..pydantic import PYDANTIC_VERSION
|
from ..pydantic import PYDANTIC_VERSION
|
||||||
|
|
||||||
@@ -19,6 +19,7 @@ import numpy as np
|
|||||||
|
|
||||||
from lancedb.pydantic import PYDANTIC_VERSION
|
from lancedb.pydantic import PYDANTIC_VERSION
|
||||||
|
|
||||||
|
from ..util import attempt_import_or_raise
|
||||||
from .base import TextEmbeddingFunction
|
from .base import TextEmbeddingFunction
|
||||||
from .registry import register
|
from .registry import register
|
||||||
from .utils import TEXT
|
from .utils import TEXT
|
||||||
@@ -183,8 +184,8 @@ class BedRockText(TextEmbeddingFunction):
|
|||||||
boto3.client
|
boto3.client
|
||||||
The boto3 client for Amazon Bedrock service
|
The boto3 client for Amazon Bedrock service
|
||||||
"""
|
"""
|
||||||
botocore = self.safe_import("botocore")
|
botocore = attempt_import_or_raise("botocore")
|
||||||
boto3 = self.safe_import("boto3")
|
boto3 = attempt_import_or_raise("boto3")
|
||||||
|
|
||||||
session_kwargs = {"region_name": self.region}
|
session_kwargs = {"region_name": self.region}
|
||||||
client_kwargs = {**session_kwargs}
|
client_kwargs = {**session_kwargs}
|
||||||
@@ -16,6 +16,7 @@ from typing import ClassVar, List, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from ..util import attempt_import_or_raise
|
||||||
from .base import TextEmbeddingFunction
|
from .base import TextEmbeddingFunction
|
||||||
from .registry import register
|
from .registry import register
|
||||||
from .utils import api_key_not_found_help
|
from .utils import api_key_not_found_help
|
||||||
@@ -84,7 +85,7 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
|
|||||||
return [emb for emb in rs.embeddings]
|
return [emb for emb in rs.embeddings]
|
||||||
|
|
||||||
def _init_client(self):
|
def _init_client(self):
|
||||||
cohere = self.safe_import("cohere")
|
cohere = attempt_import_or_raise("cohere")
|
||||||
if CohereEmbeddingFunction.client is None:
|
if CohereEmbeddingFunction.client is None:
|
||||||
if os.environ.get("COHERE_API_KEY") is None:
|
if os.environ.get("COHERE_API_KEY") is None:
|
||||||
api_key_not_found_help("cohere")
|
api_key_not_found_help("cohere")
|
||||||
@@ -19,6 +19,7 @@ import numpy as np
|
|||||||
|
|
||||||
from lancedb.pydantic import PYDANTIC_VERSION
|
from lancedb.pydantic import PYDANTIC_VERSION
|
||||||
|
|
||||||
|
from ..util import attempt_import_or_raise
|
||||||
from .base import TextEmbeddingFunction
|
from .base import TextEmbeddingFunction
|
||||||
from .registry import register
|
from .registry import register
|
||||||
from .utils import TEXT, api_key_not_found_help
|
from .utils import TEXT, api_key_not_found_help
|
||||||
@@ -134,7 +135,7 @@ class GeminiText(TextEmbeddingFunction):
|
|||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def client(self):
|
def client(self):
|
||||||
genai = self.safe_import("google.generativeai", "google.generativeai")
|
genai = attempt_import_or_raise("google.generativeai", "google.generativeai")
|
||||||
|
|
||||||
if not os.environ.get("GOOGLE_API_KEY"):
|
if not os.environ.get("GOOGLE_API_KEY"):
|
||||||
api_key_not_found_help("google")
|
api_key_not_found_help("google")
|
||||||
@@ -14,6 +14,7 @@ from typing import List, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from ..util import attempt_import_or_raise
|
||||||
from .base import TextEmbeddingFunction
|
from .base import TextEmbeddingFunction
|
||||||
from .registry import register
|
from .registry import register
|
||||||
from .utils import weak_lru
|
from .utils import weak_lru
|
||||||
@@ -122,7 +123,7 @@ class GteEmbeddings(TextEmbeddingFunction):
|
|||||||
|
|
||||||
return Model()
|
return Model()
|
||||||
else:
|
else:
|
||||||
sentence_transformers = self.safe_import(
|
sentence_transformers = attempt_import_or_raise(
|
||||||
"sentence_transformers", "sentence-transformers"
|
"sentence_transformers", "sentence-transformers"
|
||||||
)
|
)
|
||||||
return sentence_transformers.SentenceTransformer(
|
return sentence_transformers.SentenceTransformer(
|
||||||
172
python/python/lancedb/embeddings/imagebind.py
Normal file
172
python/python/lancedb/embeddings/imagebind.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
# Copyright (c) 2023. LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
from ..util import attempt_import_or_raise
|
||||||
|
from .base import EmbeddingFunction
|
||||||
|
from .registry import register
|
||||||
|
from .utils import AUDIO, IMAGES, TEXT
|
||||||
|
|
||||||
|
|
||||||
|
@register("imagebind")
|
||||||
|
class ImageBindEmbeddings(EmbeddingFunction):
|
||||||
|
"""
|
||||||
|
An embedding function that uses the ImageBind API
|
||||||
|
For generating multi-modal embeddings across
|
||||||
|
six different modalities: images, text, audio, depth, thermal, and IMU data
|
||||||
|
|
||||||
|
to download package, run :
|
||||||
|
`pip install imagebind@git+https://github.com/raghavdixit99/ImageBind`
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "imagebind_huge"
|
||||||
|
device: str = "cpu"
|
||||||
|
normalize: bool = False
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._ndims = 1024
|
||||||
|
self._audio_extensions = (".mp3", ".wav", ".flac", ".ogg", ".aac")
|
||||||
|
self._image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp")
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def embedding_model(self):
|
||||||
|
"""
|
||||||
|
Get the embedding model. This is cached so that the model is only loaded
|
||||||
|
once per process.
|
||||||
|
"""
|
||||||
|
return self.get_embedding_model()
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _data(self):
|
||||||
|
"""
|
||||||
|
Get the data module from imagebind
|
||||||
|
"""
|
||||||
|
data = attempt_import_or_raise("imagebind.data", "imagebind")
|
||||||
|
return data
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _ModalityType(self):
|
||||||
|
"""
|
||||||
|
Get the ModalityType from imagebind
|
||||||
|
"""
|
||||||
|
imagebind = attempt_import_or_raise("imagebind", "imagebind")
|
||||||
|
return imagebind.imagebind_model.ModalityType
|
||||||
|
|
||||||
|
def ndims(self):
|
||||||
|
return self._ndims
|
||||||
|
|
||||||
|
def compute_query_embeddings(
|
||||||
|
self, query: Union[str], *args, **kwargs
|
||||||
|
) -> List[np.ndarray]:
|
||||||
|
"""
|
||||||
|
Compute the embeddings for a given user query
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
query : Union[str]
|
||||||
|
The query to embed. A query can be either text, image paths or audio paths.
|
||||||
|
"""
|
||||||
|
query = self.sanitize_input(query)
|
||||||
|
if query[0].endswith(self._audio_extensions):
|
||||||
|
return [self.generate_audio_embeddings(query)]
|
||||||
|
elif query[0].endswith(self._image_extensions):
|
||||||
|
return [self.generate_image_embeddings(query)]
|
||||||
|
else:
|
||||||
|
return [self.generate_text_embeddings(query)]
|
||||||
|
|
||||||
|
def generate_image_embeddings(self, image: IMAGES) -> np.ndarray:
|
||||||
|
torch = attempt_import_or_raise("torch")
|
||||||
|
inputs = {
|
||||||
|
self._ModalityType.VISION: self._data.load_and_transform_vision_data(
|
||||||
|
image, self.device
|
||||||
|
)
|
||||||
|
}
|
||||||
|
with torch.no_grad():
|
||||||
|
image_features = self.embedding_model(inputs)[self._ModalityType.VISION]
|
||||||
|
if self.normalize:
|
||||||
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||||
|
return image_features.cpu().numpy().squeeze()
|
||||||
|
|
||||||
|
def generate_audio_embeddings(self, audio: AUDIO) -> np.ndarray:
|
||||||
|
torch = attempt_import_or_raise("torch")
|
||||||
|
inputs = {
|
||||||
|
self._ModalityType.AUDIO: self._data.load_and_transform_audio_data(
|
||||||
|
audio, self.device
|
||||||
|
)
|
||||||
|
}
|
||||||
|
with torch.no_grad():
|
||||||
|
audio_features = self.embedding_model(inputs)[self._ModalityType.AUDIO]
|
||||||
|
if self.normalize:
|
||||||
|
audio_features /= audio_features.norm(dim=-1, keepdim=True)
|
||||||
|
return audio_features.cpu().numpy().squeeze()
|
||||||
|
|
||||||
|
def generate_text_embeddings(self, text: TEXT) -> np.ndarray:
|
||||||
|
torch = attempt_import_or_raise("torch")
|
||||||
|
inputs = {
|
||||||
|
self._ModalityType.TEXT: self._data.load_and_transform_text(
|
||||||
|
text, self.device
|
||||||
|
)
|
||||||
|
}
|
||||||
|
with torch.no_grad():
|
||||||
|
text_features = self.embedding_model(inputs)[self._ModalityType.TEXT]
|
||||||
|
if self.normalize:
|
||||||
|
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||||
|
return text_features.cpu().numpy().squeeze()
|
||||||
|
|
||||||
|
def compute_source_embeddings(
|
||||||
|
self, source: Union[IMAGES, AUDIO], *args, **kwargs
|
||||||
|
) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Get the embeddings for the given sourcefield column in the pydantic model.
|
||||||
|
"""
|
||||||
|
source = self.sanitize_input(source)
|
||||||
|
embeddings = []
|
||||||
|
if source[0].endswith(self._audio_extensions):
|
||||||
|
embeddings.extend(self.generate_audio_embeddings(source))
|
||||||
|
return embeddings
|
||||||
|
elif source[0].endswith(self._image_extensions):
|
||||||
|
embeddings.extend(self.generate_image_embeddings(source))
|
||||||
|
return embeddings
|
||||||
|
else:
|
||||||
|
embeddings.extend(self.generate_text_embeddings(source))
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def sanitize_input(
|
||||||
|
self, input: Union[IMAGES, AUDIO]
|
||||||
|
) -> Union[List[bytes], np.ndarray]:
|
||||||
|
"""
|
||||||
|
Sanitize the input to the embedding function.
|
||||||
|
"""
|
||||||
|
if isinstance(input, (str, bytes)):
|
||||||
|
input = [input]
|
||||||
|
elif isinstance(input, pa.Array):
|
||||||
|
input = input.to_pylist()
|
||||||
|
elif isinstance(input, pa.ChunkedArray):
|
||||||
|
input = input.combine_chunks().to_pylist()
|
||||||
|
return input
|
||||||
|
|
||||||
|
def get_embedding_model(self):
|
||||||
|
"""
|
||||||
|
fetches the imagebind embedding model
|
||||||
|
"""
|
||||||
|
imagebind = attempt_import_or_raise("imagebind", "imagebind")
|
||||||
|
model = imagebind.imagebind_model.imagebind_huge(pretrained=True)
|
||||||
|
model.eval()
|
||||||
|
model.to(self.device)
|
||||||
|
return model
|
||||||
@@ -14,6 +14,7 @@ from typing import List
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from ..util import attempt_import_or_raise
|
||||||
from .base import TextEmbeddingFunction
|
from .base import TextEmbeddingFunction
|
||||||
from .registry import register
|
from .registry import register
|
||||||
from .utils import TEXT, weak_lru
|
from .utils import TEXT, weak_lru
|
||||||
@@ -102,9 +103,9 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
|||||||
# convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly
|
# convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly
|
||||||
|
|
||||||
source_instruction: str = "represent the document for retrieval"
|
source_instruction: str = "represent the document for retrieval"
|
||||||
query_instruction: str = (
|
query_instruction: (
|
||||||
"represent the document for retrieving the most similar documents"
|
str
|
||||||
)
|
) = "represent the document for retrieving the most similar documents"
|
||||||
|
|
||||||
@weak_lru(maxsize=1)
|
@weak_lru(maxsize=1)
|
||||||
def ndims(self):
|
def ndims(self):
|
||||||
@@ -131,10 +132,10 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
|||||||
|
|
||||||
@weak_lru(maxsize=1)
|
@weak_lru(maxsize=1)
|
||||||
def get_model(self):
|
def get_model(self):
|
||||||
instructor_embedding = self.safe_import(
|
instructor_embedding = attempt_import_or_raise(
|
||||||
"InstructorEmbedding", "InstructorEmbedding"
|
"InstructorEmbedding", "InstructorEmbedding"
|
||||||
)
|
)
|
||||||
torch = self.safe_import("torch", "torch")
|
torch = attempt_import_or_raise("torch", "torch")
|
||||||
|
|
||||||
model = instructor_embedding.INSTRUCTOR(self.name)
|
model = instructor_embedding.INSTRUCTOR(self.name)
|
||||||
if self.quantize:
|
if self.quantize:
|
||||||
@@ -21,6 +21,7 @@ import pyarrow as pa
|
|||||||
from pydantic import PrivateAttr
|
from pydantic import PrivateAttr
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from ..util import attempt_import_or_raise
|
||||||
from .base import EmbeddingFunction
|
from .base import EmbeddingFunction
|
||||||
from .registry import register
|
from .registry import register
|
||||||
from .utils import IMAGES, url_retrieve
|
from .utils import IMAGES, url_retrieve
|
||||||
@@ -50,7 +51,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
open_clip = self.safe_import("open_clip", "open-clip")
|
open_clip = attempt_import_or_raise("open_clip", "open-clip")
|
||||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||||
self.name, pretrained=self.pretrained
|
self.name, pretrained=self.pretrained
|
||||||
)
|
)
|
||||||
@@ -78,14 +79,14 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
|||||||
if isinstance(query, str):
|
if isinstance(query, str):
|
||||||
return [self.generate_text_embeddings(query)]
|
return [self.generate_text_embeddings(query)]
|
||||||
else:
|
else:
|
||||||
PIL = self.safe_import("PIL", "pillow")
|
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||||
if isinstance(query, PIL.Image.Image):
|
if isinstance(query, PIL.Image.Image):
|
||||||
return [self.generate_image_embedding(query)]
|
return [self.generate_image_embedding(query)]
|
||||||
else:
|
else:
|
||||||
raise TypeError("OpenClip supports str or PIL Image as query")
|
raise TypeError("OpenClip supports str or PIL Image as query")
|
||||||
|
|
||||||
def generate_text_embeddings(self, text: str) -> np.ndarray:
|
def generate_text_embeddings(self, text: str) -> np.ndarray:
|
||||||
torch = self.safe_import("torch")
|
torch = attempt_import_or_raise("torch")
|
||||||
text = self.sanitize_input(text)
|
text = self.sanitize_input(text)
|
||||||
text = self._tokenizer(text)
|
text = self._tokenizer(text)
|
||||||
text.to(self.device)
|
text.to(self.device)
|
||||||
@@ -144,7 +145,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
|||||||
The image to embed. If the image is a str, it is treated as a uri.
|
The image to embed. If the image is a str, it is treated as a uri.
|
||||||
If the image is bytes, it is treated as the raw image bytes.
|
If the image is bytes, it is treated as the raw image bytes.
|
||||||
"""
|
"""
|
||||||
torch = self.safe_import("torch")
|
torch = attempt_import_or_raise("torch")
|
||||||
# TODO handle retry and errors for https
|
# TODO handle retry and errors for https
|
||||||
image = self._to_pil(image)
|
image = self._to_pil(image)
|
||||||
image = self._preprocess(image).unsqueeze(0)
|
image = self._preprocess(image).unsqueeze(0)
|
||||||
@@ -152,7 +153,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
|||||||
return self._encode_and_normalize_image(image)
|
return self._encode_and_normalize_image(image)
|
||||||
|
|
||||||
def _to_pil(self, image: Union[str, bytes]):
|
def _to_pil(self, image: Union[str, bytes]):
|
||||||
PIL = self.safe_import("PIL", "pillow")
|
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||||
if isinstance(image, bytes):
|
if isinstance(image, bytes):
|
||||||
return PIL.Image.open(io.BytesIO(image))
|
return PIL.Image.open(io.BytesIO(image))
|
||||||
if isinstance(image, PIL.Image.Image):
|
if isinstance(image, PIL.Image.Image):
|
||||||
@@ -12,10 +12,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import List, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from ..util import attempt_import_or_raise
|
||||||
from .base import TextEmbeddingFunction
|
from .base import TextEmbeddingFunction
|
||||||
from .registry import register
|
from .registry import register
|
||||||
from .utils import api_key_not_found_help
|
from .utils import api_key_not_found_help
|
||||||
@@ -30,10 +31,21 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
name: str = "text-embedding-ada-002"
|
name: str = "text-embedding-ada-002"
|
||||||
|
dim: Optional[int] = None
|
||||||
|
|
||||||
def ndims(self):
|
def ndims(self):
|
||||||
# TODO don't hardcode this
|
return self._ndims
|
||||||
return 1536
|
|
||||||
|
@cached_property
|
||||||
|
def _ndims(self):
|
||||||
|
if self.name == "text-embedding-ada-002":
|
||||||
|
return 1536
|
||||||
|
elif self.name == "text-embedding-3-large":
|
||||||
|
return self.dim or 3072
|
||||||
|
elif self.name == "text-embedding-3-small":
|
||||||
|
return self.dim or 1536
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model name {self.name}")
|
||||||
|
|
||||||
def generate_embeddings(
|
def generate_embeddings(
|
||||||
self, texts: Union[List[str], np.ndarray]
|
self, texts: Union[List[str], np.ndarray]
|
||||||
@@ -47,12 +59,17 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
|||||||
The texts to embed
|
The texts to embed
|
||||||
"""
|
"""
|
||||||
# TODO retry, rate limit, token limit
|
# TODO retry, rate limit, token limit
|
||||||
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
|
if self.name == "text-embedding-ada-002":
|
||||||
|
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
|
||||||
|
else:
|
||||||
|
rs = self._openai_client.embeddings.create(
|
||||||
|
input=texts, model=self.name, dimensions=self.ndims()
|
||||||
|
)
|
||||||
return [v.embedding for v in rs.data]
|
return [v.embedding for v in rs.data]
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def _openai_client(self):
|
def _openai_client(self):
|
||||||
openai = self.safe_import("openai")
|
openai = attempt_import_or_raise("openai")
|
||||||
|
|
||||||
if not os.environ.get("OPENAI_API_KEY"):
|
if not os.environ.get("OPENAI_API_KEY"):
|
||||||
api_key_not_found_help("openai")
|
api_key_not_found_help("openai")
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user