Compare commits

...

73 Commits

Author SHA1 Message Date
Lance Release
1884fe8a3e Bump version: 0.9.0-beta.7 → 0.9.0-beta.8 2025-02-26 15:03:57 +00:00
Ryan Green
d8111b259c Merge remote-tracking branch 'origin/python-v0.9.4-patch' into python-v0.9.4-patch 2025-02-26 11:31:34 -03:30
Ryan Green
3c74bf5c7a Pin chrono version 2025-02-26 11:31:29 -03:30
Lance Release
b64bb75a82 Bump version: 0.9.0-beta.6 → 0.9.0-beta.7 2025-02-26 13:29:54 +00:00
Ryan Green
93e03ec702 revert worfklow 2025-02-26 09:56:08 -03:30
Ryan Green
7a94a7e171 Merge remote-tracking branch 'origin/python-v0.9.4-patch' into python-v0.9.4-patch 2025-02-26 09:52:55 -03:30
Ryan Green
acae6522fb workaround "edition2024" issue 2025-02-26 09:52:48 -03:30
Lance Release
005d5b64ac Bump version: 0.5.2 → 0.5.2-final.1 2025-02-26 13:05:01 +00:00
Lance Release
1e89d07fe2 Bump version: 0.9.0-beta.5 → 0.9.0-beta.6 2025-02-26 13:04:48 +00:00
Ryan Green
1da55719e7 fix windows workflow 2025-02-26 09:33:42 -03:30
Ryan Green
9d0ca5a823 merge PyPi Publish workflow from main 2025-02-26 09:31:18 -03:30
Lance Release
1e0cc69401 Bump version: 0.9.0-beta.4 → 0.9.0-beta.5 2025-02-26 12:46:00 +00:00
Ryan Green
f31e0c749d hotfix: add support for scalar index type in remote table 2025-02-26 09:13:30 -03:30
Lance Release
7a3ef68306 Bump version: 0.9.0-beta.3 → 0.9.0-beta.4 2024-12-20 16:02:53 +00:00
Ryan Green
43952e01d7 bump version 2024-12-20 09:44:46 -06:00
Ryan Green
495c335831 Fix fast_search 2024-12-20 09:43:39 -06:00
Ryan Green
77707db543 Backport fast_search and empty query builder for remote table 2024-12-20 09:21:05 -06:00
Ryan Green
d6d7ad3b06 bump version 2024-12-18 10:21:04 -06:00
Ryan Green
e58d64c286 Remove unsupported Retry params 2024-12-18 10:08:38 -06:00
Ryan Green
76cbd18c46 bump version 2024-12-18 09:38:36 -06:00
Ryan Green
4abb38ac70 bump version 2024-12-18 09:37:58 -06:00
Ryan Green
cc7bc5011d Merge remote-tracking branch 'origin/python-v0.9.0-patch' into python-v0.9.0-patch
# Conflicts:
#	python/pyproject.toml
2024-12-18 08:59:35 -06:00
Ryan Green
8193183304 override urllib3 version 2024-12-18 08:59:24 -06:00
Ryan Green
cf28b58b7d override urllib3 version 2024-12-18 08:58:41 -06:00
Lance Release
e3b7ee47b9 Bump version: 0.9.0 → 0.9.0-final.1 2024-12-13 01:16:24 +00:00
Lu Qiu
97c9c906e4 Fix version test 2024-12-12 17:10:07 -08:00
Lu Qiu
358f86b9c6 fix 2024-12-12 16:44:24 -08:00
Lu Qiu
5489e215a3 Support storage options and folder prefix 2024-12-12 16:17:34 -08:00
Lance Release
bc0814767b Bump version: 0.9.0-beta.0 → 0.9.0 2024-06-25 00:25:27 +00:00
Lance Release
8960a8e535 Bump version: 0.8.2 → 0.9.0-beta.0 2024-06-25 00:25:27 +00:00
Weston Pace
a8568ddc72 feat: upgrade to lance 0.13.0 (#1404) 2024-06-24 17:22:57 -07:00
Cory Grinstead
55f88346d0 feat(nodejs): table.indexStats (#1361)
closes https://github.com/lancedb/lancedb/issues/1359
2024-06-21 17:06:52 -05:00
Will Jones
dfb9a28795 ci(node): add description and keywords for lancedb package (#1398) 2024-06-21 14:43:35 -07:00
Cory Grinstead
a797f5fe59 feat(nodejs): feature parity [5/N] - add query.filter() alias (#1391)
to make the transition from `vectordb` to `@lancedb/lancedb` as seamless
as possible, this adds `query.filter` with a deprecated tag.


depends on https://github.com/lancedb/lancedb/pull/1390
see actual diff here
https://github.com/universalmind303/lancedb/compare/list-indices-name...universalmind303:query-filter
2024-06-21 16:03:58 -05:00
Cory Grinstead
3cd84c9375 feat(nodejs): feature parity [4/N] - add 'name' to 'IndexConfig' for 'listIndices' (#1390)
depends on https://github.com/lancedb/lancedb/pull/1386

see actual diff here
https://github.com/universalmind303/lancedb/compare/create-table-args...universalmind303:list-indices-name
2024-06-21 15:45:02 -05:00
Cory Grinstead
5ca83fdc99 fix(node): node build (#1396)
i have no idea why this fixes the build.
2024-06-21 15:42:22 -05:00
Cory Grinstead
33cc9b682f feat(nodejs): feature parity [3/N] - createTable({name, data, ...options}) (#1386)
adds support for the `vectordb` syntax of `createTable({name, data,
...options})`.


depends on https://github.com/lancedb/lancedb/pull/1380
see actual diff here
https://github.com/universalmind303/lancedb/compare/table-name...universalmind303:create-table-args
2024-06-21 12:17:39 -05:00
Cory Grinstead
b3e5ac6d2a feat(nodejs): feature parity [2/N] - add table.name and lancedb.connect({args}) (#1380)
depends on https://github.com/lancedb/lancedb/pull/1378

see proper diff here
https://github.com/universalmind303/lancedb/compare/remote-table-node...universalmind303:lancedb:table-name
2024-06-21 11:38:26 -05:00
josca42
0fe844034d feat: enable stemming (#1356)
Added the ability to specify tokenizer_name, when creating a full text
search index using tantivy. This enables the use of language specific
stemming.

Also updated the [guide on full text
search](https://lancedb.github.io/lancedb/fts/) with a short section on
choosing tokenizer.

Fixes #1315
2024-06-20 14:23:55 -07:00
Cory Grinstead
f41eb899dc chore(rust): lock toolchain & fix clippy (#1389)
- fix some clippy errors from ci running a different toolchain. 
- add some saftey notes about some unsafe blocks. 

- locks the toolchain so that it is consistent across dev and CI.
2024-06-20 12:13:03 -05:00
Cory Grinstead
e7022b990e feat(nodejs): feature parity [1/N] - remote table (#1378)
closes https://github.com/lancedb/lancedb/issues/1362
2024-06-17 15:23:27 -05:00
Weston Pace
ea86dad4b7 feat: upgrade lance to 0.12.2-beta.2 (#1381) 2024-06-14 05:43:26 -07:00
harsha-mangena
a45656b8b6 docs: remove code-block:: python from docs (#1366)
- refer #1264
- fixed minor documentation issue
2024-06-11 13:13:02 -07:00
Cory Grinstead
bc19a75f65 feat(nodejs): merge insert (#1351)
closes https://github.com/lancedb/lancedb/issues/1349
2024-06-11 15:05:15 -05:00
Ryan Green
8e348ab4bd fix: use JS naming convention in new index stats fields (#1377)
Changes new index stats fields in node client from snake case to camel
case.
2024-06-10 16:41:31 -02:30
Raghav Dixit
96914a619b docs: llama-index integration (#1347)
Updated api refrence and usage for llama index integration.
2024-06-09 23:52:18 +05:30
Beinan
3c62806b6a fix(java): the JVM crash when using jdk 8 (#1372)
The Optional::isEmpty does not exist in java 8, so we should use
isPresent instead
2024-06-08 22:43:41 -07:00
Ayush Chaurasia
72f339a0b3 docs: add note about embedding api not being available on cloud (#1371) 2024-06-09 03:57:23 +05:30
QianZhu
b9e3cfbdca fix: add status to remote listIndices return (#1364)
expose `status` returned by remote listIndices
2024-06-08 09:52:35 -07:00
Ayush Chaurasia
5e30648f45 docs: fix example path (#1367) 2024-06-07 19:40:50 -07:00
Ayush Chaurasia
76fc16c7a1 docs: add retriever guide, address minor onboarding feedbacks & enhancement (#1326)
- Tried to address some onboarding feedbacks listed in
https://github.com/lancedb/lancedb/issues/1224
- Improve visibility of pydantic integration and embedding API. (Based
on onboarding feedback - Many ways of ingesting data, defining schema
but not sure what to use in a specific use-case)
- Add a guide that takes users through testing and improving retriever
performance using built-in utilities like hybrid-search and reranking
- Add some benchmarks for the above
- Add missing cohere docs

---------

Co-authored-by: Weston Pace <weston.pace@gmail.com>
2024-06-08 06:25:31 +05:30
Weston Pace
007f9c1af8 chore: change build machine for linux arm (#1360) 2024-06-06 13:22:58 -07:00
Lance Release
27e4ad3f11 Updating package-lock.json 2024-06-05 13:47:44 +00:00
Lance Release
df42943ccf Bump version: 0.5.2-beta.0 → 0.5.2 2024-06-05 13:47:28 +00:00
Lance Release
3eec9ea740 Bump version: 0.5.1 → 0.5.2-beta.0 2024-06-05 13:47:27 +00:00
Lance Release
11fcdb1194 Bump version: 0.8.2-beta.0 → 0.8.2 2024-06-05 13:47:16 +00:00
Lance Release
95a5a0d713 Bump version: 0.8.1 → 0.8.2-beta.0 2024-06-05 13:47:16 +00:00
Weston Pace
c3043a54c6 feat: bump lance dependency to 0.12.1 (#1357) 2024-06-05 06:07:11 -07:00
Weston Pace
d5586c9c32 feat: make it possible to opt in to using the v2 format (#1352)
This also exposed the max_batch_length configuration option in
python/node (it was needed to verify if we are actually in v2 mode or
not)
2024-06-04 21:52:14 -07:00
Rob Meng
d39e7d23f4 feat: fast path for checkout_latest (#1355)
similar to https://github.com/lancedb/lancedb/pull/1354
do locked IO less frequently
2024-06-04 23:01:28 -04:00
Rob Meng
ddceda4ff7 feat: add fast path to dataset reload (#1354)
most of the time we don't need to reload. Locking the write lock and
performing IO is not an ideal pattern.

This PR tries to make the critical section of `.write()` happen less
frequently.

This isn't the most ideal solution. The most ideal solution should not
lock until the new dataset has been loaded. But that would require too
much refactoring.
2024-06-04 19:03:53 -04:00
Cory Grinstead
70f92f19a6 feat(nodejs): table.search functionality (#1341)
closes https://github.com/lancedb/lancedb/issues/1256
2024-06-04 14:04:03 -05:00
Cory Grinstead
d9fb6457e1 fix(nodejs): better support for f16 and f64 (#1343)
closes https://github.com/lancedb/lancedb/issues/1292
closes https://github.com/lancedb/lancedb/issues/1293
2024-06-04 13:41:21 -05:00
Lei Xu
56b4fd2bd9 feat(rust): allow to create execution plan on queries (#1350) 2024-05-31 17:33:58 -07:00
paul n walsh
7c133ec416 feat(nodejs): table.toArrow function (#1282)
Addresses https://github.com/lancedb/lancedb/issues/1254.

---------

Co-authored-by: universalmind303 <cory.grinstead@gmail.com>
2024-05-31 13:24:21 -05:00
QianZhu
1dbb4cd1e2 fix: error msg when query vector dim is wrong (#1339)
- changed the error msg for table.search with wrong query vector dim 
- added missing fields for listIndices and indexStats to be consistent
with Python API - will make changes in node integ test
2024-05-31 10:18:06 -07:00
Paul Rinaldi
af65417d19 fix: update broken blog link on readme (#1310) 2024-05-31 10:04:56 -07:00
Cory Grinstead
01dd6c5e75 feat(rust): openai embedding function (#1275)
part of https://github.com/lancedb/lancedb/issues/994. 

Adds the ability to use the openai embedding functions.


the example can be run by the following

```sh
> EXPORT OPENAI_API_KEY="sk-..."
> cargo run --example openai --features=openai
```

which should output
```
Closest match: Winter Parka
```
2024-05-30 15:55:55 -05:00
Weston Pace
1e85b57c82 ci: don't update package locks if we are not releasing node (#1323)
This doesn't actually block a python-only release since this step runs
after the version bump has been pushed but it still would be nice for
the git job to finish successfully.
2024-05-30 04:42:06 -07:00
Ayush Chaurasia
16eff254ea feat: add support for new cohere models in cohere and bedrock embedding functions (#1335)
Fixes #1329

Will update docs on https://github.com/lancedb/lancedb/pull/1326
2024-05-30 10:20:03 +05:30
Lance Release
1b2463c5dd Updating package-lock.json 2024-05-30 01:00:43 +00:00
Lance Release
92f74f955f Bump version: 0.5.1-beta.0 → 0.5.1 2024-05-30 01:00:28 +00:00
Lance Release
53b5ea3f92 Bump version: 0.5.0 → 0.5.1-beta.0 2024-05-30 01:00:28 +00:00
95 changed files with 3950 additions and 674 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion] [tool.bumpversion]
current_version = "0.5.0" current_version = "0.5.2-final.1"
parse = """(?x) parse = """(?x)
(?P<major>0|[1-9]\\d*)\\. (?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\. (?P<minor>0|[1-9]\\d*)\\.

View File

@@ -46,6 +46,7 @@ runs:
with: with:
command: build command: build
working-directory: python working-directory: python
docker-options: "-e PIP_EXTRA_INDEX_URL=https://pypi.fury.io/lancedb/"
target: aarch64-unknown-linux-gnu target: aarch64-unknown-linux-gnu
manylinux: ${{ inputs.manylinux }} manylinux: ${{ inputs.manylinux }}
args: ${{ inputs.args }} args: ${{ inputs.args }}

View File

@@ -21,5 +21,6 @@ runs:
with: with:
command: build command: build
args: ${{ inputs.args }} args: ${{ inputs.args }}
docker-options: "-e PIP_EXTRA_INDEX_URL=https://pypi.fury.io/lancedb/"
working-directory: python working-directory: python
interpreter: 3.${{ inputs.python-minor-version }} interpreter: 3.${{ inputs.python-minor-version }}

View File

@@ -26,8 +26,9 @@ runs:
with: with:
command: build command: build
args: ${{ inputs.args }} args: ${{ inputs.args }}
docker-options: "-e PIP_EXTRA_INDEX_URL=https://pypi.fury.io/lancedb/"
working-directory: python working-directory: python
- uses: actions/upload-artifact@v3 - uses: actions/upload-artifact@v4
with: with:
name: windows-wheels name: windows-wheels
path: python\target\wheels path: python\target\wheels

View File

@@ -94,6 +94,6 @@ jobs:
branch: ${{ github.ref }} branch: ${{ github.ref }}
tags: true tags: true
- uses: ./.github/workflows/update_package_lock - uses: ./.github/workflows/update_package_lock
if: ${{ inputs.dry_run }} == "false" if: ${{ !inputs.dry_run && inputs.other }}
with: with:
github_token: ${{ secrets.GITHUB_TOKEN }} github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -3,7 +3,7 @@ name: NPM Publish
on: on:
push: push:
tags: tags:
- 'v*' - "v*"
jobs: jobs:
node: node:
@@ -111,12 +111,11 @@ jobs:
runner: ubuntu-latest runner: ubuntu-latest
- arch: aarch64 - arch: aarch64
# For successful fat LTO builds, we need a large runner to avoid OOM errors. # For successful fat LTO builds, we need a large runner to avoid OOM errors.
runner: buildjet-16vcpu-ubuntu-2204-arm runner: warp-ubuntu-latest-arm64-4x
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
# Buildjet aarch64 runners have only 1.5 GB RAM per core, vs 3.5 GB per core for # To avoid OOM errors on ARM, we create a swap file.
# x86_64 runners. To avoid OOM errors on ARM, we create a swap file.
- name: Configure aarch64 build - name: Configure aarch64 build
if: ${{ matrix.config.arch == 'aarch64' }} if: ${{ matrix.config.arch == 'aarch64' }}
run: | run: |
@@ -323,7 +322,7 @@ jobs:
- name: Publish to NPM - name: Publish to NPM
env: env:
NODE_AUTH_TOKEN: ${{ secrets.LANCEDB_NPM_REGISTRY_TOKEN }} NODE_AUTH_TOKEN: ${{ secrets.LANCEDB_NPM_REGISTRY_TOKEN }}
# By default, things are published to the latest tag. This is what is # By default, things are published to the latest tag. This is what is
# installed by default if the user does not specify a version. This is # installed by default if the user does not specify a version. This is
# good for stable releases, but for pre-releases, we want to publish to # good for stable releases, but for pre-releases, we want to publish to
# the "preview" tag so they can install with `npm install lancedb@preview`. # the "preview" tag so they can install with `npm install lancedb@preview`.
@@ -368,7 +367,7 @@ jobs:
- uses: ./.github/workflows/update_package_lock_nodejs - uses: ./.github/workflows/update_package_lock_nodejs
with: with:
github_token: ${{ secrets.GITHUB_TOKEN }} github_token: ${{ secrets.GITHUB_TOKEN }}
gh-release: gh-release:
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions: permissions:

View File

@@ -65,7 +65,7 @@ jobs:
workspaces: python workspaces: python
- name: Install - name: Install
run: | run: |
pip install -e .[tests,dev,embeddings] pip install --extra-index-url https://pypi.fury.io/lancedb/ -e .[tests,dev,embeddings]
pip install tantivy pip install tantivy
pip install mlx pip install mlx
- name: Doctest - name: Doctest
@@ -189,7 +189,7 @@ jobs:
- name: Install lancedb - name: Install lancedb
run: | run: |
pip install "pydantic<2" pip install "pydantic<2"
pip install -e .[tests] pip install --extra-index-url https://pypi.fury.io/lancedb/ -e .[tests]
pip install tantivy pip install tantivy
- name: Run tests - name: Run tests
run: pytest -m "not slow and not s3_test" -x -v --durations=30 python/tests run: pytest -m "not slow and not s3_test" -x -v --durations=30 python/tests

View File

@@ -15,7 +15,7 @@ runs:
- name: Install lancedb - name: Install lancedb
shell: bash shell: bash
run: | run: |
pip3 install $(ls target/wheels/lancedb-*.whl)[tests,dev] pip3 install --extra-index-url https://pypi.fury.io/lancedb/ $(ls target/wheels/lancedb-*.whl)[tests,dev]
- name: Setup localstack for integration tests - name: Setup localstack for integration tests
if: ${{ inputs.integration == 'true' }} if: ${{ inputs.integration == 'true' }}
shell: bash shell: bash

View File

@@ -14,7 +14,7 @@ repos:
hooks: hooks:
- id: local-biome-check - id: local-biome-check
name: biome check name: biome check
entry: npx @biomejs/biome check --config-path nodejs/biome.json nodejs/ entry: npx @biomejs/biome@1.7.3 check --config-path nodejs/biome.json nodejs/
language: system language: system
types: [text] types: [text]
files: "nodejs/.*" files: "nodejs/.*"

View File

@@ -1,5 +1,11 @@
[workspace] [workspace]
members = ["rust/ffi/node", "rust/lancedb", "nodejs", "python", "java/core/lancedb-jni"] members = [
"rust/ffi/node",
"rust/lancedb",
"nodejs",
"python",
"java/core/lancedb-jni",
]
# Python package needs to be built by maturin. # Python package needs to be built by maturin.
exclude = ["python"] exclude = ["python"]
resolver = "2" resolver = "2"
@@ -14,10 +20,11 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
categories = ["database-implementations"] categories = ["database-implementations"]
[workspace.dependencies] [workspace.dependencies]
lance = { "version" = "=0.11.1", "features" = ["dynamodb"] } lance = { "version" = "=0.13.0", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.11.1" } lance-index = { "version" = "=0.13.0" }
lance-linalg = { "version" = "=0.11.1" } lance-linalg = { "version" = "=0.13.0" }
lance-testing = { "version" = "=0.11.1" } lance-testing = { "version" = "=0.13.0" }
lance-datafusion = { "version" = "=0.13.0" }
# Note that this one does not include pyarrow # Note that this one does not include pyarrow
arrow = { version = "51.0", optional = false } arrow = { version = "51.0", optional = false }
arrow-array = "51.0" arrow-array = "51.0"
@@ -28,7 +35,8 @@ arrow-schema = "51.0"
arrow-arith = "51.0" arrow-arith = "51.0"
arrow-cast = "51.0" arrow-cast = "51.0"
async-trait = "0" async-trait = "0"
chrono = "0.4.35" chrono = "=0.4.39"
datafusion-physical-plan = "37.1"
half = { "version" = "=2.4.1", default-features = false, features = [ half = { "version" = "=2.4.1", default-features = false, features = [
"num-traits", "num-traits",
] } ] }

View File

@@ -83,5 +83,5 @@ result = table.search([100, 100]).limit(2).to_pandas()
``` ```
## Blogs, Tutorials & Videos ## Blogs, Tutorials & Videos
* 📈 <a href="https://blog.eto.ai/benchmarking-random-access-in-lance-ed690757a826">2000x better performance with Lance over Parquet</a> * 📈 <a href="https://blog.lancedb.com/benchmarking-random-access-in-lance/">2000x better performance with Lance over Parquet</a>
* 🤖 <a href="https://github.com/lancedb/lancedb/blob/main/docs/src/notebooks/youtube_transcript_search.ipynb">Build a question and answer bot with LanceDB</a> * 🤖 <a href="https://github.com/lancedb/lancedb/blob/main/docs/src/notebooks/youtube_transcript_search.ipynb">Build a question and answer bot with LanceDB</a>

View File

@@ -106,6 +106,9 @@ nav:
- Versioning & Reproducibility: notebooks/reproducibility.ipynb - Versioning & Reproducibility: notebooks/reproducibility.ipynb
- Configuring Storage: guides/storage.md - Configuring Storage: guides/storage.md
- Sync -> Async Migration Guide: migration.md - Sync -> Async Migration Guide: migration.md
- Tuning retrieval performance:
- Choosing right query type: guides/tuning_retrievers/1_query_types.md
- Reranking: guides/tuning_retrievers/2_reranking.md
- 🧬 Managing embeddings: - 🧬 Managing embeddings:
- Overview: embeddings/index.md - Overview: embeddings/index.md
- Embedding functions: embeddings/embedding_functions.md - Embedding functions: embeddings/embedding_functions.md
@@ -121,7 +124,9 @@ nav:
- LangChain: - LangChain:
- LangChain 🔗: integrations/langchain.md - LangChain 🔗: integrations/langchain.md
- LangChain JS/TS 🔗: https://js.langchain.com/docs/integrations/vectorstores/lancedb - LangChain JS/TS 🔗: https://js.langchain.com/docs/integrations/vectorstores/lancedb
- LlamaIndex 🦙: https://docs.llamaindex.ai/en/stable/examples/vector_stores/LanceDBIndexDemo/ - LlamaIndex 🦙:
- LlamaIndex docs: integrations/llamaIndex.md
- LlamaIndex demo: https://docs.llamaindex.ai/en/stable/examples/vector_stores/LanceDBIndexDemo/
- Pydantic: python/pydantic.md - Pydantic: python/pydantic.md
- Voxel51: integrations/voxel51.md - Voxel51: integrations/voxel51.md
- PromptTools: integrations/prompttools.md - PromptTools: integrations/prompttools.md
@@ -152,7 +157,7 @@ nav:
- Overview: cloud/index.md - Overview: cloud/index.md
- API reference: - API reference:
- 🐍 Python: python/saas-python.md - 🐍 Python: python/saas-python.md
- 👾 JavaScript: javascript/saas-modules.md - 👾 JavaScript: javascript/modules.md
- Quick start: basic.md - Quick start: basic.md
- Concepts: - Concepts:
@@ -181,6 +186,9 @@ nav:
- Versioning & Reproducibility: notebooks/reproducibility.ipynb - Versioning & Reproducibility: notebooks/reproducibility.ipynb
- Configuring Storage: guides/storage.md - Configuring Storage: guides/storage.md
- Sync -> Async Migration Guide: migration.md - Sync -> Async Migration Guide: migration.md
- Tuning retrieval performance:
- Choosing right query type: guides/tuning_retrievers/1_query_types.md
- Reranking: guides/tuning_retrievers/2_reranking.md
- Managing Embeddings: - Managing Embeddings:
- Overview: embeddings/index.md - Overview: embeddings/index.md
- Embedding functions: embeddings/embedding_functions.md - Embedding functions: embeddings/embedding_functions.md
@@ -219,7 +227,7 @@ nav:
- Overview: cloud/index.md - Overview: cloud/index.md
- API reference: - API reference:
- 🐍 Python: python/saas-python.md - 🐍 Python: python/saas-python.md
- 👾 JavaScript: javascript/saas-modules.md - 👾 JavaScript: javascript/modules.md
extra_css: extra_css:
- styles/global.css - styles/global.css

View File

@@ -180,6 +180,9 @@ table.
!!! info "Under the hood, LanceDB reads in the Apache Arrow data and persists it to disk using the [Lance format](https://www.github.com/lancedb/lance)." !!! info "Under the hood, LanceDB reads in the Apache Arrow data and persists it to disk using the [Lance format](https://www.github.com/lancedb/lance)."
!!! info "Automatic embedding generation with Embedding API"
When working with embedding models, it is recommended to use the LanceDB embedding API to automatically create vector representation of the data and queries in the background. See the [quickstart example](#using-the-embedding-api) or the embedding API [guide](./embeddings/)
### Create an empty table ### Create an empty table
Sometimes you may not have the data to insert into the table at creation time. Sometimes you may not have the data to insert into the table at creation time.
@@ -194,6 +197,9 @@ similar to a `CREATE TABLE` statement in SQL.
--8<-- "python/python/tests/docs/test_basic.py:create_empty_table_async" --8<-- "python/python/tests/docs/test_basic.py:create_empty_table_async"
``` ```
!!! note "You can define schema in Pydantic"
LanceDB comes with Pydantic support, which allows you to define the schema of your data using Pydantic models. This makes it easy to work with LanceDB tables and data. Learn more about all supported types in [tables guide](./guides/tables.md).
=== "Typescript" === "Typescript"
```typescript ```typescript
@@ -424,6 +430,19 @@ Use the `drop_table()` method on the database to remove a table.
}) })
``` ```
## Using the Embedding API
You can use the embedding API when working with embedding models. It automatically vectorizes the data at ingestion and query time and comes with built-in integrations with popular embedding models like Openai, Hugging Face, Sentence Transformers, CLIP and more.
=== "Python"
```python
--8<-- "python/python/tests/docs/test_embeddings_optional.py:imports"
--8<-- "python/python/tests/docs/test_embeddings_optional.py:openai_embeddings"
```
Learn about using the existing integrations and creating custom embedding functions in the [embedding API guide](./embeddings/).
## What's next ## What's next
This section covered the very basics of using LanceDB. If you're learning about vector databases for the first time, you may want to read the page on [indexing](concepts/index_ivfpq.md) to get familiar with the concepts. This section covered the very basics of using LanceDB. If you're learning about vector databases for the first time, you may want to read the page on [indexing](concepts/index_ivfpq.md) to get familiar with the concepts.

View File

@@ -216,7 +216,7 @@ Generate embeddings via the [ollama](https://github.com/ollama/ollama-python) py
|------------------------|----------------------------|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------| |------------------------|----------------------------|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------|
| `name` | `str` | `nomic-embed-text` | The name of the model. | | `name` | `str` | `nomic-embed-text` | The name of the model. |
| `host` | `str` | `http://localhost:11434` | The Ollama host to connect to. | | `host` | `str` | `http://localhost:11434` | The Ollama host to connect to. |
| `options` | `ollama.Options` or `dict` | `None` | Additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`. | | `options` | `ollama.Options` or `dict` | `None` | Additional model parameters listed in the documentation for the Modelfile such as `temperature`. |
| `keep_alive` | `float` or `str` | `"5m"` | Controls how long the model will stay loaded into memory following the request. | | `keep_alive` | `float` or `str` | `"5m"` | Controls how long the model will stay loaded into memory following the request. |
| `ollama_client_kwargs` | `dict` | `{}` | kwargs that can be past to the `ollama.Client`. | | `ollama_client_kwargs` | `dict` | `{}` | kwargs that can be past to the `ollama.Client`. |
@@ -365,6 +365,68 @@ tbl.add(df)
rs = tbl.search("hello").limit(1).to_pandas() rs = tbl.search("hello").limit(1).to_pandas()
``` ```
### Cohere Embeddings
Using cohere API requires cohere package, which can be installed using `pip install cohere`. Cohere embeddings are used to generate embeddings for text data. The embeddings can be used for various tasks like semantic search, clustering, and classification.
You also need to set the `COHERE_API_KEY` environment variable to use the Cohere API.
Supported models are:
```
* embed-english-v3.0
* embed-multilingual-v3.0
* embed-english-light-v3.0
* embed-multilingual-light-v3.0
* embed-english-v2.0
* embed-english-light-v2.0
* embed-multilingual-v2.0
```
Supported parameters (to be passed in `create` method) are:
| Parameter | Type | Default Value | Description |
|---|---|---|---|
| `name` | `str` | `"embed-english-v2.0"` | The model ID of the cohere model to use. Supported base models for Text Embeddings: embed-english-v3.0, embed-multilingual-v3.0, embed-english-light-v3.0, embed-multilingual-light-v3.0, embed-english-v2.0, embed-english-light-v2.0, embed-multilingual-v2.0 |
| `source_input_type` | `str` | `"search_document"` | The type of input data to be used for the source column. |
| `query_input_type` | `str` | `"search_query"` | The type of input data to be used for the query. |
Cohere supports following input types:
| Input Type | Description |
|-------------------------|---------------------------------------|
| "`search_document`" | Used for embeddings stored in a vector|
| | database for search use-cases. |
| "`search_query`" | Used for embeddings of search queries |
| | run against a vector DB |
| "`semantic_similarity`" | Specifies the given text will be used |
| | for Semantic Textual Similarity (STS) |
| "`classification`" | Used for embeddings passed through a |
| | text classifier. |
| "`clustering`" | Used for the embeddings run through a |
| | clustering algorithm |
Usage Example:
```python
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import EmbeddingFunctionRegistry
cohere = EmbeddingFunctionRegistry
.get_instance()
.get("cohere")
.create(name="embed-multilingual-v2.0")
class TextModel(LanceModel):
text: str = cohere.SourceField()
vector: Vector(cohere.ndims()) = cohere.VectorField()
data = [ { "text": "hello world" },
{ "text": "goodbye world" }]
db = lancedb.connect("~/.lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(data)
```
### AWS Bedrock Text Embedding Functions ### AWS Bedrock Text Embedding Functions
AWS Bedrock supports multiple base models for generating text embeddings. You need to setup the AWS credentials to use this embedding function. AWS Bedrock supports multiple base models for generating text embeddings. You need to setup the AWS credentials to use this embedding function.
You can do so by using `awscli` and also add your session_token: You can do so by using `awscli` and also add your session_token:

View File

@@ -2,6 +2,9 @@ Representing multi-modal data as vector embeddings is becoming a standard practi
For this purpose, LanceDB introduces an **embedding functions API**, that allow you simply set up once, during the configuration stage of your project. After this, the table remembers it, effectively making the embedding functions *disappear in the background* so you don't have to worry about manually passing callables, and instead, simply focus on the rest of your data engineering pipeline. For this purpose, LanceDB introduces an **embedding functions API**, that allow you simply set up once, during the configuration stage of your project. After this, the table remembers it, effectively making the embedding functions *disappear in the background* so you don't have to worry about manually passing callables, and instead, simply focus on the rest of your data engineering pipeline.
!!! Note "LanceDB cloud doesn't support embedding functions yet"
LanceDB Cloud does not support embedding functions yet. You need to generate embeddings before ingesting into the table or querying.
!!! warning !!! warning
Using the embedding function registry means that you don't have to explicitly generate the embeddings yourself. Using the embedding function registry means that you don't have to explicitly generate the embeddings yourself.
However, if your embedding function changes, you'll have to re-configure your table with the new embedding function However, if your embedding function changes, you'll have to re-configure your table with the new embedding function

View File

@@ -2,7 +2,6 @@
LanceDB provides support for full-text search via [Tantivy](https://github.com/quickwit-oss/tantivy) (currently Python only), allowing you to incorporate keyword-based search (based on BM25) in your retrieval solutions. Our goal is to push the FTS integration down to the Rust level in the future, so that it's available for Rust and JavaScript users as well. Follow along at [this Github issue](https://github.com/lancedb/lance/issues/1195) LanceDB provides support for full-text search via [Tantivy](https://github.com/quickwit-oss/tantivy) (currently Python only), allowing you to incorporate keyword-based search (based on BM25) in your retrieval solutions. Our goal is to push the FTS integration down to the Rust level in the future, so that it's available for Rust and JavaScript users as well. Follow along at [this Github issue](https://github.com/lancedb/lance/issues/1195)
A hybrid search solution combining vector and full-text search is also on the way.
## Installation ## Installation
@@ -55,6 +54,16 @@ This returns the result as a list of dictionaries as follows.
!!! note !!! note
LanceDB automatically searches on the existing FTS index if the input to the search is of type `str`. If you provide a vector as input, LanceDB will search the ANN index instead. LanceDB automatically searches on the existing FTS index if the input to the search is of type `str`. If you provide a vector as input, LanceDB will search the ANN index instead.
## Tokenization
By default the text is tokenized by splitting on punctuation and whitespaces and then removing tokens that are longer than 40 chars. For more language specific tokenization then provide the argument tokenizer_name with the 2 letter language code followed by "_stem". So for english it would be "en_stem".
```python
table.create_fts_index("text", tokenizer_name="en_stem")
```
The following [languages](https://docs.rs/tantivy/latest/tantivy/tokenizer/enum.Language.html) are currently supported.
## Index multiple columns ## Index multiple columns
If you have multiple string columns to index, there's no need to combine them manually -- simply pass them all as a list to `create_fts_index`: If you have multiple string columns to index, there's no need to combine them manually -- simply pass them all as a list to `create_fts_index`:
@@ -140,6 +149,7 @@ is treated as a phrase query.
In general, a query that's declared as a phrase query will be wrapped in double quotes during parsing, with nested In general, a query that's declared as a phrase query will be wrapped in double quotes during parsing, with nested
double quotes replaced by single quotes. double quotes replaced by single quotes.
## Configurations ## Configurations
By default, LanceDB configures a 1GB heap size limit for creating the index. You can By default, LanceDB configures a 1GB heap size limit for creating the index. You can

View File

@@ -452,6 +452,27 @@ After a table has been created, you can always add more data to it using the var
tbl.add(pydantic_model_items) tbl.add(pydantic_model_items)
``` ```
??? "Ingesting Pydantic models with LanceDB embedding API"
When using LanceDB's embedding API, you can add Pydantic models directly to the table. LanceDB will automatically convert the `vector` field to a vector before adding it to the table. You need to specify the default value of `vector` feild as None to allow LanceDB to automatically vectorize the data.
```python
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import get_registry
db = lancedb.connect("~/tmp")
embed_fcn = get_registry().get("huggingface").create(name="BAAI/bge-small-en-v1.5")
class Schema(LanceModel):
text: str = embed_fcn.SourceField()
vector: Vector(embed_fcn.ndims()) = embed_fcn.VectorField(default=None)
tbl = db.create_table("my_table", schema=Schema, mode="overwrite")
models = [Schema(text="hello"), Schema(text="world")]
tbl.add(models)
```
=== "JavaScript" === "JavaScript"
@@ -636,6 +657,31 @@ The `values` parameter is used to provide the new values for the columns as lite
When rows are updated, they are moved out of the index. The row will still show up in ANN queries, but the query will not be as fast as it would be if the row was in the index. If you update a large proportion of rows, consider rebuilding the index afterwards. When rows are updated, they are moved out of the index. The row will still show up in ANN queries, but the query will not be as fast as it would be if the row was in the index. If you update a large proportion of rows, consider rebuilding the index afterwards.
## Drop a table
Use the `drop_table()` method on the database to remove a table.
=== "Python"
```python
--8<-- "python/python/tests/docs/test_basic.py:drop_table"
--8<-- "python/python/tests/docs/test_basic.py:drop_table_async"
```
This permanently removes the table and is not recoverable, unlike deleting rows.
By default, if the table does not exist an exception is raised. To suppress this,
you can pass in `ignore_missing=True`.
=== "Javascript/Typescript"
```typescript
--8<-- "docs/src/basic_legacy.ts:drop_table"
```
This permanently removes the table and is not recoverable, unlike deleting rows.
If the table does not exist an exception is raised.
## Consistency ## Consistency
In LanceDB OSS, users can set the `read_consistency_interval` parameter on connections to achieve different levels of read consistency. This parameter determines how frequently the database synchronizes with the underlying storage system to check for updates made by other processes. If another process updates a table, the database will not see the changes until the next synchronization. In LanceDB OSS, users can set the `read_consistency_interval` parameter on connections to achieve different levels of read consistency. This parameter determines how frequently the database synchronizes with the underlying storage system to check for updates made by other processes. If another process updates a table, the database will not see the changes until the next synchronization.

View File

@@ -0,0 +1,128 @@
## Improving retriever performance
VectorDBs are used as retreivers in recommender or chatbot-based systems for retrieving relevant data based on user queries. For example, retriever is a critical component of Retrieval Augmented Generation (RAG) acrhitectures. In this section, we will discuss how to improve the performance of retrievers.
There are serveral ways to improve the performance of retrievers. Some of the common techniques are:
* Using different query types
* Using hybrid search
* Fine-tuning the embedding models
* Using different embedding models
Using different embedding models is something that's very specific to the use case and the data. So we will not discuss it here. In this section, we will discuss the first three techniques.
!!! note "Note"
We'll be using a simple metric called "hit-rate" for evaluating the performance of the retriever across this guide. Hit-rate is the percentage of queries for which the retriever returned the correct answer in the top-k results. For example, if the retriever returned the correct answer in the top-3 results for 70% of the queries, then the hit-rate@3 is 0.7.
## The dataset
We'll be using a QA dataset generated using a LLama2 review paper. The dataset contains 221 query, context and answer triplets. The queries and answers are generated using GPT-4 based on a given query. Full script used to generate the dataset can be found on this [repo](https://github.com/lancedb/ragged). It can be downloaded from [here](https://github.com/AyushExel/assets/blob/main/data_qa.csv)
### Using different query types
Let's setup the embeddings and the dataset first. We'll use the LanceDB's `huggingface` embeddings integration for this guide.
```python
import lancedb
import pandas as pd
from lancedb.embeddings import get_registry
from lancedb.pydantic import Vector, LanceModel
db = lancedb.connect("~/lancedb/query_types")
df = pd.read_csv("data_qa.csv")
embed_fcn = get_registry().get("huggingface").create(name="BAAI/bge-small-en-v1.")
class Schema(LanceModel):
context: str = embed_fcn.SourceField()
vector: Vector(embed_fcn.ndims()) = embed_fcn.VectorField()
table = db.create_table("qa", schema=Schema)
table.add(df[["context"]].to_dict(orient="records"))
queries = df["query"].tolist()
```
Now that we have the dataset and embeddings table set up, here's how you can run different query types on the dataset.
* <b> Vector Search: </b>
```python
table.search(quries[0], query_type="vector").limit(5).to_pandas()
```
By default, LanceDB uses vector search query type for searching and it automatically converts the input query to a vector before searching when using embedding API. So, the following statement is equivalent to the above statement.
```python
table.search(quries[0]).limit(5).to_pandas()
```
Vector or semantic search is useful when you want to find documents that are similar to the query in terms of meaning.
---
* <b> Full-text Search: </b>
FTS requires creating an index on the column you want to search on. `replace=True` will replace the existing index if it exists.
Once the index is created, you can search using the `fts` query type.
```python
table.create_fts_index("context", replace=True)
table.search(quries[0], query_type="fts").limit(5).to_pandas()
```
Full-text search is useful when you want to find documents that contain the query terms.
---
* <b> Hybrid Search: </b>
Hybrid search is a combination of vector and full-text search. Here's how you can run a hybrid search query on the dataset.
```python
table.search(quries[0], query_type="hybrid").limit(5).to_pandas()
```
Hybrid search requires a reranker to combine and rank the results from vector and full-text search. We'll cover reranking as a concept in the next section.
Hybrid search is useful when you want to combine the benefits of both vector and full-text search.
!!! note "Note"
By default, it uses `LinearCombinationReranker` that combines the scores from vector and full-text search using a weighted linear combination. It is the simplest reranker implementation available in LanceDB. You can also use other rerankers like `CrossEncoderReranker` or `CohereReranker` for reranking the results.
Learn more about rerankers [here](https://lancedb.github.io/lancedb/reranking/)
### Hit rate evaluation results
Now that we have seen how to run different query types on the dataset, let's evaluate the hit-rate of each query type on the dataset.
For brevity, the entire evaluation script is not shown here. You can find the complete evaluation and benchmarking utility scripts [here](https://github.com/lancedb/ragged).
Here are the hit-rate results for the dataset:
| Query Type | Hit-rate@5 |
| --- | --- |
| Vector Search | 0.640 |
| Full-text Search | 0.595 |
| Hybrid Search (w/ LinearCombinationReranker) | 0.645 |
**Choosing query type** is very specific to the use case and the data. This synthetic dataset has been generated to be semantically challenging, i.e, the queries don't have a lot of keywords in common with the context. So, vector search performs better than full-text search. However, in real-world scenarios, full-text search might perform better than vector search. Hybrid search is a good choice when you want to combine the benefits of both vector and full-text search.
### Evaluation results on other datasets
The hit-rate results can vary based on the dataset and the query type. Here are the hit-rate results for the other datasets using the same embedding function.
* <b> SQuAD Dataset: </b>
| Query Type | Hit-rate@5 |
| --- | --- |
| Vector Search | 0.822 |
| Full-text Search | 0.835 |
| Hybrid Search (w/ LinearCombinationReranker) | 0.8874 |
* <b> Uber10K sec filing Dataset: </b>
| Query Type | Hit-rate@5 |
| --- | --- |
| Vector Search | 0.608 |
| Full-text Search | 0.82 |
| Hybrid Search (w/ LinearCombinationReranker) | 0.80 |
In these standard datasets, FTS seems to perform much better than vector search because the queries have a lot of keywords in common with the context. So, in general choosing the query type is very specific to the use case and the data.

View File

@@ -0,0 +1,78 @@
Continuing from the previous example, we can now rerank the results using more complex rerankers.
## Reranking search results
You can rerank any search results using a reranker. The syntax for reranking is as follows:
```python
from lancedb.rerankers import LinearCombinationReranker
reranker = LinearCombinationReranker()
table.search(quries[0], query_type="hybrid").rerank(reranker=reranker).limit(5).to_pandas()
```
Based on the `query_type`, the `rerank()` function can accept other arguments as well. For example, hybrid search accepts a `normalize` param to determine the score normalization method.
!!! note "Note"
LanceDB provides a `Reranker` base class that can be extended to implement custom rerankers. Each reranker must implement the `rerank_hybrid` method. `rerank_vector` and `rerank_fts` methods are optional. For example, the `LinearCombinationReranker` only implements the `rerank_hybrid` method and so it can only be used for reranking hybrid search results.
## Choosing a Reranker
There are many rerankers available in LanceDB like `CrossEncoderReranker`, `CohereReranker`, and `ColBERT`. The choice of reranker depends on the dataset and the application. You can even implement you own custom reranker by extending the `Reranker` class. For more details about each available reranker and performance comparison, refer to the [rerankers](https://lancedb.github.io/lancedb/reranking/) documentation.
In this example, we'll use the `CohereReranker` to rerank the search results. It requires `cohere` to be installed and `COHERE_API_KEY` to be set in the environment. To get your API key, sign up on [Cohere](https://cohere.ai/).
```python
from lancedb.rerankers import CohereReranker
# use Cohere reranker v3
reranker = CohereReranker(model_name="rerank-english-v3.0") # default model is "rerank-english-v2.0"
```
### Reranking search results
Now we can rerank all query type results using the `CohereReranker`:
```python
# rerank hybrid search results
table.search(quries[0], query_type="hybrid").rerank(reranker=reranker).limit(5).to_pandas()
# rerank vector search results
table.search(quries[0], query_type="vector").rerank(reranker=reranker).limit(5).to_pandas()
# rerank fts search results
table.search(quries[0], query_type="fts").rerank(reranker=reranker).limit(5).to_pandas()
```
Each reranker can accept additional arguments. For example, `CohereReranker` accepts `top_k` and `batch_size` params to control the number of documents to rerank and the batch size for reranking respectively. Similarly, a custom reranker can accept any number of arguments based on the implementation. For example, a reranker can accept a `filter` that implements some custom logic to filter out documents before reranking.
## Results
Let us take a look at the same datasets from the previous sections, using the same embedding table but with Cohere reranker applied to all query types.
!!! note "Note"
When reranking fts or vector search results, the search results are over-fetched by a factor of 2 and then reranked. From the reranked set, `top_k` (5 in this case) results are taken. This is done because reranking will have no effect on the hit-rate if we only fetch the `top_k` results.
### Synthetic LLama2 paper dataset
| Query Type | Hit-rate@5 |
| --- | --- |
| Vector | 0.640 |
| FTS | 0.595 |
| Reranked vector | 0.677 |
| Reranked fts | 0.672 |
| Hybrid | 0.759 |
### SQuAD Dataset
### Uber10K sec filing Dataset
| Query Type | Hit-rate@5 |
| --- | --- |
| Vector | 0.608 |
| FTS | 0.824 |
| Reranked vector | 0.671 |
| Reranked fts | 0.843 |
| Hybrid | 0.849 |

View File

@@ -5,7 +5,9 @@ Hybrid Search is a broad (often misused) term. It can mean anything from combini
## The challenge of (re)ranking search results ## The challenge of (re)ranking search results
Once you have a group of the most relevant search results from multiple search sources, you'd likely standardize the score and rank them accordingly. This process can also be seen as another independent step-reranking. Once you have a group of the most relevant search results from multiple search sources, you'd likely standardize the score and rank them accordingly. This process can also be seen as another independent step-reranking.
There are two approaches for reranking search results from multiple sources. There are two approaches for reranking search results from multiple sources.
* <b>Score-based</b>: Calculate final relevance scores based on a weighted linear combination of individual search algorithm scores. Example-Weighted linear combination of semantic search & keyword-based search results. * <b>Score-based</b>: Calculate final relevance scores based on a weighted linear combination of individual search algorithm scores. Example-Weighted linear combination of semantic search & keyword-based search results.
* <b>Relevance-based</b>: Discards the existing scores and calculates the relevance of each search result-query pair. Example-Cross Encoder models * <b>Relevance-based</b>: Discards the existing scores and calculates the relevance of each search result-query pair. Example-Cross Encoder models
Even though there are many strategies for reranking search results, none works for all cases. Moreover, evaluating them itself is a challenge. Also, reranking can be dataset, application specific so it's hard to generalize. Even though there are many strategies for reranking search results, none works for all cases. Moreover, evaluating them itself is a challenge. Also, reranking can be dataset, application specific so it's hard to generalize.

View File

@@ -0,0 +1,139 @@
# Llama-Index
![Illustration](../assets/llama-index.jpg)
## Quick start
You would need to install the integration via `pip install llama-index-vector-stores-lancedb` in order to use it. You can run the below script to try it out :
```python
import logging
import sys
# Uncomment to see debug logs
# logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
# logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
from llama_index.core import SimpleDirectoryReader, Document, StorageContext
from llama_index.core import VectorStoreIndex
from llama_index.vector_stores.lancedb import LanceDBVectorStore
import textwrap
import openai
openai.api_key = "sk-..."
documents = SimpleDirectoryReader("./data/your-data-dir/").load_data()
print("Document ID:", documents[0].doc_id, "Document Hash:", documents[0].hash)
## For LanceDB cloud :
# vector_store = LanceDBVectorStore(
# uri="db://db_name", # your remote DB URI
# api_key="sk_..", # lancedb cloud api key
# region="your-region" # the region you configured
# ...
# )
vector_store = LanceDBVectorStore(
uri="./lancedb", mode="overwrite", query_type="vector"
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(
documents, storage_context=storage_context
)
lance_filter = "metadata.file_name = 'paul_graham_essay.txt' "
retriever = index.as_retriever(vector_store_kwargs={"where": lance_filter})
response = retriever.retrieve("What did the author do growing up?")
```
### Filtering
For metadata filtering, you can use a Lance SQL-like string filter as demonstrated in the example above. Additionally, you can also filter using the `MetadataFilters` class from LlamaIndex:
```python
from llama_index.core.vector_stores import (
MetadataFilters,
FilterOperator,
FilterCondition,
MetadataFilter,
)
query_filters = MetadataFilters(
filters=[
MetadataFilter(
key="creation_date", operator=FilterOperator.EQ, value="2024-05-23"
),
MetadataFilter(
key="file_size", value=75040, operator=FilterOperator.GT
),
],
condition=FilterCondition.AND,
)
```
### Hybrid Search
For complete documentation, refer [here](https://lancedb.github.io/lancedb/hybrid_search/hybrid_search/). This example uses the `colbert` reranker. Make sure to install necessary dependencies for the reranker you choose.
```python
from lancedb.rerankers import ColbertReranker
reranker = ColbertReranker()
vector_store._add_reranker(reranker)
query_engine = index.as_query_engine(
filters=query_filters,
vector_store_kwargs={
"query_type": "hybrid",
}
)
response = query_engine.query("How much did Viaweb charge per month?")
```
In the above snippet, you can change/specify query_type again when creating the engine/retriever.
## API reference
The exhaustive list of parameters for `LanceDBVectorStore` vector store are :
- `connection`: Optional, `lancedb.db.LanceDBConnection` connection object to use. If not provided, a new connection will be created.
- `uri`: Optional[str], the uri of your database. Defaults to `"/tmp/lancedb"`.
- `table_name` : Optional[str], Name of your table in the database. Defaults to `"vectors"`.
- `table`: Optional[Any], `lancedb.db.LanceTable` object to be passed. Defaults to `None`.
- `vector_column_name`: Optional[Any], Column name to use for vector's in the table. Defaults to `'vector'`.
- `doc_id_key`: Optional[str], Column name to use for document id's in the table. Defaults to `'doc_id'`.
- `text_key`: Optional[str], Column name to use for text in the table. Defaults to `'text'`.
- `api_key`: Optional[str], API key to use for LanceDB cloud database. Defaults to `None`.
- `region`: Optional[str], Region to use for LanceDB cloud database. Only for LanceDB Cloud, defaults to `None`.
- `nprobes` : Optional[int], Set the number of probes to use. Only applicable if ANN index is created on the table else its ignored. Defaults to `20`.
- `refine_factor` : Optional[int], Refine the results by reading extra elements and re-ranking them in memory. Defaults to `None`.
- `reranker`: Optional[Any], The reranker to use for LanceDB.
Defaults to `None`.
- `overfetch_factor`: Optional[int], The factor by which to fetch more results.
Defaults to `1`.
- `mode`: Optional[str], The mode to use for LanceDB.
Defaults to `"overwrite"`.
- `query_type`:Optional[str], The type of query to use for LanceDB.
Defaults to `"vector"`.
### Methods
- __from_table(cls, table: lancedb.db.LanceTable) -> `LanceDBVectorStore`__ : (class method) Creates instance from lancedb table.
- **_add_reranker(self, reranker: lancedb.rerankers.Reranker) -> `None`** : Add a reranker to an existing vector store.
- Usage :
```python
from lancedb.rerankers import ColbertReranker
reranker = ColbertReranker()
vector_store._add_reranker(reranker)
```
- **_table_exists(self, tbl_name: `Optional[str]` = `None`) -> `bool`** : Returns `True` if `tbl_name` exists in database.
- __create_index(
self, scalar: `Optional[bool]` = False, col_name: `Optional[str]` = None, num_partitions: `Optional[int]` = 256, num_sub_vectors: `Optional[int]` = 96, index_cache_size: `Optional[int]` = None, metric: `Optional[str]` = "L2",
) -> `None`__ : Creates a scalar(for non-vector cols) or a vector index on a table.
Make sure your vector column has enough data before creating an index on it.
- __add(self, nodes: `List[BaseNode]`, **add_kwargs: `Any`, ) -> `List[str]`__ :
adds Nodes to the table
- **delete(self, ref_doc_id: `str`) -> `None`**: Delete nodes using with node_ids.
- **delete_nodes(self, node_ids: `List[str]`) -> `None`** : Delete nodes using with node_ids.
- __query(
self,
query: `VectorStoreQuery`,
**kwargs: `Any`,
) -> `VectorStoreQueryResult`__:
Query index(`VectorStoreIndex`) for top k most similar nodes. Accepts llamaIndex `VectorStoreQuery` object.

View File

@@ -7,8 +7,7 @@ excluded_globs = [
"../src/fts.md", "../src/fts.md",
"../src/embedding.md", "../src/embedding.md",
"../src/examples/*.md", "../src/examples/*.md",
"../src/integrations/voxel51.md", "../src/integrations/*.md",
"../src/integrations/langchain.md",
"../src/guides/tables.md", "../src/guides/tables.md",
"../src/python/duckdb.md", "../src/python/duckdb.md",
"../src/embeddings/*.md", "../src/embeddings/*.md",
@@ -17,6 +16,7 @@ excluded_globs = [
"../src/basic.md", "../src/basic.md",
"../src/hybrid_search/hybrid_search.md", "../src/hybrid_search/hybrid_search.md",
"../src/reranking/*.md", "../src/reranking/*.md",
"../src/guides/tuning_retrievers/*.md",
] ]
python_prefix = "py" python_prefix = "py"

View File

@@ -175,8 +175,8 @@ impl JNIEnvExt for JNIEnv<'_> {
if obj.is_null() { if obj.is_null() {
return Ok(None); return Ok(None);
} }
let is_empty = self.call_method(obj, "isEmpty", "()Z", &[])?; let is_present = self.call_method(obj, "isPresent", "()Z", &[])?;
if is_empty.z()? { if !is_present.z()? {
// TODO(lu): put get java object into here cuz can only get java Object // TODO(lu): put get java object into here cuz can only get java Object
Ok(None) Ok(None)
} else { } else {

View File

@@ -1,12 +1,12 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.5.0", "version": "0.5.2",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "vectordb", "name": "vectordb",
"version": "0.5.0", "version": "0.5.2",
"cpu": [ "cpu": [
"x64", "x64",
"arm64" "arm64"

View File

@@ -1,12 +1,12 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.5.0", "version": "0.5.2-final.1",
"description": " Serverless, low-latency vector database for AI applications", "description": " Serverless, low-latency vector database for AI applications",
"main": "dist/index.js", "main": "dist/index.js",
"types": "dist/index.d.ts", "types": "dist/index.d.ts",
"scripts": { "scripts": {
"tsc": "tsc -b", "tsc": "tsc -b",
"build": "npm run tsc && cargo-cp-artifact --artifact cdylib lancedb-node index.node -- cargo build --message-format=json", "build": "npm run tsc && cargo-cp-artifact --artifact cdylib lancedb_node index.node -- cargo build --message-format=json",
"build-release": "npm run build -- --release", "build-release": "npm run build -- --release",
"test": "npm run tsc && mocha -recursive dist/test", "test": "npm run tsc && mocha -recursive dist/test",
"integration-test": "npm run tsc && mocha -recursive dist/integration_test", "integration-test": "npm run tsc && mocha -recursive dist/integration_test",

View File

@@ -695,15 +695,26 @@ export interface MergeInsertArgs {
whenNotMatchedBySourceDelete?: string | boolean whenNotMatchedBySourceDelete?: string | boolean
} }
export enum IndexStatus {
Pending = "pending",
Indexing = "indexing",
Done = "done",
Failed = "failed"
}
export interface VectorIndex { export interface VectorIndex {
columns: string[] columns: string[]
name: string name: string
uuid: string uuid: string
status: IndexStatus
} }
export interface IndexStats { export interface IndexStats {
numIndexedRows: number | null numIndexedRows: number | null
numUnindexedRows: number | null numUnindexedRows: number | null
indexType: string | null
distanceType: string | null
completedAt: string | null
} }
/** /**

View File

@@ -509,7 +509,8 @@ export class RemoteTable<T = number[]> implements Table<T> {
return (await results.body()).indexes?.map((index: any) => ({ return (await results.body()).indexes?.map((index: any) => ({
columns: index.columns, columns: index.columns,
name: index.index_name, name: index.index_name,
uuid: index.index_uuid uuid: index.index_uuid,
status: index.status
})) }))
} }
@@ -520,7 +521,10 @@ export class RemoteTable<T = number[]> implements Table<T> {
const body = await results.body() const body = await results.body()
return { return {
numIndexedRows: body?.num_indexed_rows, numIndexedRows: body?.num_indexed_rows,
numUnindexedRows: body?.num_unindexed_rows numUnindexedRows: body?.num_unindexed_rows,
indexType: body?.index_type,
distanceType: body?.distance_type,
completedAt: body?.completed_at
} }
} }

View File

@@ -12,8 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
import { Field, Float64, Schema } from "apache-arrow";
import * as tmp from "tmp"; import * as tmp from "tmp";
import { Connection, connect } from "../lancedb"; import { Connection, Table, connect } from "../lancedb";
describe("when connecting", () => { describe("when connecting", () => {
let tmpDir: tmp.DirResult; let tmpDir: tmp.DirResult;
@@ -56,6 +57,18 @@ describe("given a connection", () => {
expect(db.isOpen()).toBe(false); expect(db.isOpen()).toBe(false);
await expect(db.tableNames()).rejects.toThrow("Connection is closed"); await expect(db.tableNames()).rejects.toThrow("Connection is closed");
}); });
it("should be able to create a table from an object arg `createTable(options)`, or args `createTable(name, data, options)`", async () => {
let tbl = await db.createTable("test", [{ id: 1 }, { id: 2 }]);
await expect(tbl.countRows()).resolves.toBe(2);
tbl = await db.createTable({
name: "test",
data: [{ id: 3 }],
mode: "overwrite",
});
await expect(tbl.countRows()).resolves.toBe(1);
});
it("should fail if creating table twice, unless overwrite is true", async () => { it("should fail if creating table twice, unless overwrite is true", async () => {
let tbl = await db.createTable("test", [{ id: 1 }, { id: 2 }]); let tbl = await db.createTable("test", [{ id: 1 }, { id: 2 }]);
@@ -86,4 +99,39 @@ describe("given a connection", () => {
tables = await db.tableNames({ startAfter: "a" }); tables = await db.tableNames({ startAfter: "a" });
expect(tables).toEqual(["b", "c"]); expect(tables).toEqual(["b", "c"]);
}); });
it("should create tables in v2 mode", async () => {
const db = await connect(tmpDir.name);
const data = [...Array(10000).keys()].map((i) => ({ id: i }));
// Create in v1 mode
let table = await db.createTable("test", data);
const isV2 = async (table: Table) => {
const data = await table.query().toArrow({ maxBatchLength: 100000 });
console.log(data.batches.length);
return data.batches.length < 5;
};
await expect(isV2(table)).resolves.toBe(false);
// Create in v2 mode
table = await db.createTable("test_v2", data, { useLegacyFormat: false });
await expect(isV2(table)).resolves.toBe(true);
await table.add(data);
await expect(isV2(table)).resolves.toBe(true);
// Create empty in v2 mode
const schema = new Schema([new Field("id", new Float64(), true)]);
table = await db.createEmptyTable("test_v2_empty", schema, {
useLegacyFormat: false,
});
await table.add(data);
await expect(isV2(table)).resolves.toBe(true);
});
}); });

View File

@@ -0,0 +1,314 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import * as tmp from "tmp";
import { connect } from "../lancedb";
import {
Field,
FixedSizeList,
Float,
Float16,
Float32,
Float64,
Schema,
Utf8,
} from "../lancedb/arrow";
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding";
import { getRegistry, register } from "../lancedb/embedding/registry";
describe("embedding functions", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
});
afterEach(() => {
tmpDir.removeCallback();
getRegistry().reset();
});
it("should be able to create a table with an embedding function", async () => {
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const func = new MockEmbeddingFunction();
const db = await connect(tmpDir.name);
const table = await db.createTable(
"test",
[
{ id: 1, text: "hello" },
{ id: 2, text: "world" },
],
{
embeddingFunction: {
function: func,
sourceColumn: "text",
},
},
);
// biome-ignore lint/suspicious/noExplicitAny: test
const arr = (await table.query().toArray()) as any;
expect(arr[0].vector).toBeDefined();
// we round trip through JSON to make sure the vector properly gets converted to an array
// otherwise it'll be a TypedArray or Vector
const vector0 = JSON.parse(JSON.stringify(arr[0].vector));
expect(vector0).toEqual([1, 2, 3]);
});
it("should be able to create an empty table with an embedding function", async () => {
@register()
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const schema = new Schema([
new Field("text", new Utf8(), true),
new Field(
"vector",
new FixedSizeList(3, new Field("item", new Float32(), true)),
true,
),
]);
const func = new MockEmbeddingFunction();
const db = await connect(tmpDir.name);
const table = await db.createEmptyTable("test", schema, {
embeddingFunction: {
function: func,
sourceColumn: "text",
},
});
const outSchema = await table.schema();
expect(outSchema.metadata.get("embedding_functions")).toBeDefined();
await table.add([{ text: "hello world" }]);
// biome-ignore lint/suspicious/noExplicitAny: test
const arr = (await table.query().toArray()) as any;
expect(arr[0].vector).toBeDefined();
// we round trip through JSON to make sure the vector properly gets converted to an array
// otherwise it'll be a TypedArray or Vector
const vector0 = JSON.parse(JSON.stringify(arr[0].vector));
expect(vector0).toEqual([1, 2, 3]);
});
it("should error when appending to a table with an unregistered embedding function", async () => {
@register("mock")
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const func = getRegistry().get<MockEmbeddingFunction>("mock")!.create();
const schema = LanceSchema({
id: new Float64(),
text: func.sourceField(new Utf8()),
vector: func.vectorField(),
});
const db = await connect(tmpDir.name);
await db.createTable(
"test",
[
{ id: 1, text: "hello" },
{ id: 2, text: "world" },
],
{
schema,
},
);
getRegistry().reset();
const db2 = await connect(tmpDir.name);
const tbl = await db2.openTable("test");
expect(tbl.add([{ id: 3, text: "hello" }])).rejects.toThrow(
`Function "mock" not found in registry`,
);
});
test.each([new Float16(), new Float32(), new Float64()])(
"should be able to provide manual embeddings with multiple float datatype",
async (floatType) => {
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return floatType;
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const data = [{ text: "hello" }, { text: "hello world" }];
const schema = new Schema([
new Field("vector", new FixedSizeList(3, new Field("item", floatType))),
new Field("text", new Utf8()),
]);
const func = new MockEmbeddingFunction();
const name = "test";
const db = await connect(tmpDir.name);
const table = await db.createTable(name, data, {
schema,
embeddingFunction: {
sourceColumn: "text",
function: func,
},
});
const res = await table.query().toArray();
expect([...res[0].vector]).toEqual([1, 2, 3]);
},
);
test.each([new Float16(), new Float32(), new Float64()])(
"should be able to provide auto embeddings with multiple float datatypes",
async (floatType) => {
@register("test1")
class MockEmbeddingFunctionWithoutNDims extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
embeddingDataType(): Float {
return floatType;
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
@register("test")
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return floatType;
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const func = getRegistry().get<MockEmbeddingFunction>("test")!.create();
const func2 = getRegistry()
.get<MockEmbeddingFunctionWithoutNDims>("test1")!
.create();
const schema = LanceSchema({
text: func.sourceField(new Utf8()),
vector: func.vectorField(floatType),
});
const schema2 = LanceSchema({
text: func2.sourceField(new Utf8()),
vector: func2.vectorField({ datatype: floatType, dims: 3 }),
});
const schema3 = LanceSchema({
text: func2.sourceField(new Utf8()),
vector: func.vectorField({
datatype: new FixedSizeList(3, new Field("item", floatType, true)),
dims: 3,
}),
});
const expectedSchema = new Schema([
new Field("text", new Utf8(), true),
new Field(
"vector",
new FixedSizeList(3, new Field("item", floatType, true)),
true,
),
]);
const stringSchema = JSON.stringify(schema, null, 2);
const stringSchema2 = JSON.stringify(schema2, null, 2);
const stringSchema3 = JSON.stringify(schema3, null, 2);
const stringExpectedSchema = JSON.stringify(expectedSchema, null, 2);
expect(stringSchema).toEqual(stringExpectedSchema);
expect(stringSchema2).toEqual(stringExpectedSchema);
expect(stringSchema3).toEqual(stringExpectedSchema);
},
);
});

View File

@@ -21,19 +21,17 @@ import * as arrowOld from "apache-arrow-old";
import { Table, connect } from "../lancedb"; import { Table, connect } from "../lancedb";
import { import {
Table as ArrowTable,
Field, Field,
FixedSizeList, FixedSizeList,
Float,
Float32, Float32,
Float64, Float64,
Int32, Int32,
Int64, Int64,
Schema, Schema,
Utf8,
makeArrowTable, makeArrowTable,
} from "../lancedb/arrow"; } from "../lancedb/arrow";
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding"; import { EmbeddingFunction, LanceSchema, register } from "../lancedb/embedding";
import { getRegistry, register } from "../lancedb/embedding/registry";
import { Index } from "../lancedb/indices"; import { Index } from "../lancedb/indices";
// biome-ignore lint/suspicious/noExplicitAny: <explanation> // biome-ignore lint/suspicious/noExplicitAny: <explanation>
@@ -44,6 +42,7 @@ describe.each([arrow, arrowOld])("Given a table", (arrow: any) => {
const schema = new arrow.Schema([ const schema = new arrow.Schema([
new arrow.Field("id", new arrow.Float64(), true), new arrow.Field("id", new arrow.Float64(), true),
]); ]);
beforeEach(async () => { beforeEach(async () => {
tmpDir = tmp.dirSync({ unsafeCleanup: true }); tmpDir = tmp.dirSync({ unsafeCleanup: true });
const conn = await connect(tmpDir.name); const conn = await connect(tmpDir.name);
@@ -94,6 +93,177 @@ describe.each([arrow, arrowOld])("Given a table", (arrow: any) => {
expect(await table.countRows("id == 7")).toBe(1); expect(await table.countRows("id == 7")).toBe(1);
expect(await table.countRows("id == 10")).toBe(1); expect(await table.countRows("id == 10")).toBe(1);
}); });
// https://github.com/lancedb/lancedb/issues/1293
test.each([new arrow.Float16(), new arrow.Float32(), new arrow.Float64()])(
"can create empty table with non default float type: %s",
async (floatType) => {
const db = await connect(tmpDir.name);
const data = [
{ text: "hello", vector: Array(512).fill(1.0) },
{ text: "hello world", vector: Array(512).fill(1.0) },
];
const f64Schema = new arrow.Schema([
new arrow.Field("text", new arrow.Utf8(), true),
new arrow.Field(
"vector",
new arrow.FixedSizeList(512, new arrow.Field("item", floatType)),
true,
),
]);
const f64Table = await db.createEmptyTable("f64", f64Schema, {
mode: "overwrite",
});
try {
await f64Table.add(data);
const res = await f64Table.query().toArray();
expect(res.length).toBe(2);
} catch (e) {
expect(e).toBeUndefined();
}
},
);
it("should return the table as an instance of an arrow table", async () => {
const arrowTbl = await table.toArrow();
expect(arrowTbl).toBeInstanceOf(ArrowTable);
});
});
describe("merge insert", () => {
let tmpDir: tmp.DirResult;
let table: Table;
beforeEach(async () => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
const conn = await connect(tmpDir.name);
table = await conn.createTable("some_table", [
{ a: 1, b: "a" },
{ a: 2, b: "b" },
{ a: 3, b: "c" },
]);
});
afterEach(() => tmpDir.removeCallback());
test("upsert", async () => {
const newData = [
{ a: 2, b: "x" },
{ a: 3, b: "y" },
{ a: 4, b: "z" },
];
await table
.mergeInsert("a")
.whenMatchedUpdateAll()
.whenNotMatchedInsertAll()
.execute(newData);
const expected = [
{ a: 1, b: "a" },
{ a: 2, b: "x" },
{ a: 3, b: "y" },
{ a: 4, b: "z" },
];
expect(
JSON.parse(JSON.stringify((await table.toArrow()).toArray())),
).toEqual(expected);
});
test("conditional update", async () => {
const newData = [
{ a: 2, b: "x" },
{ a: 3, b: "y" },
{ a: 4, b: "z" },
];
await table
.mergeInsert("a")
.whenMatchedUpdateAll({ where: "target.b = 'b'" })
.execute(newData);
const expected = [
{ a: 1, b: "a" },
{ a: 2, b: "x" },
{ a: 3, b: "c" },
];
// round trip to arrow and back to json to avoid comparing arrow objects to js object
// biome-ignore lint/suspicious/noExplicitAny: test
let res: any[] = JSON.parse(
JSON.stringify((await table.toArrow()).toArray()),
);
res = res.sort((a, b) => a.a - b.a);
expect(res).toEqual(expected);
});
test("insert if not exists", async () => {
const newData = [
{ a: 2, b: "x" },
{ a: 3, b: "y" },
{ a: 4, b: "z" },
];
await table.mergeInsert("a").whenNotMatchedInsertAll().execute(newData);
const expected = [
{ a: 1, b: "a" },
{ a: 2, b: "b" },
{ a: 3, b: "c" },
{ a: 4, b: "z" },
];
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
let res: any[] = JSON.parse(
JSON.stringify((await table.toArrow()).toArray()),
);
res = res.sort((a, b) => a.a - b.a);
expect(res).toEqual(expected);
});
test("replace range", async () => {
const newData = [
{ a: 2, b: "x" },
{ a: 4, b: "z" },
];
await table
.mergeInsert("a")
.whenMatchedUpdateAll()
.whenNotMatchedInsertAll()
.whenNotMatchedBySourceDelete({ where: "a > 2" })
.execute(newData);
const expected = [
{ a: 1, b: "a" },
{ a: 2, b: "x" },
{ a: 4, b: "z" },
];
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
let res: any[] = JSON.parse(
JSON.stringify((await table.toArrow()).toArray()),
);
res = res.sort((a, b) => a.a - b.a);
expect(res).toEqual(expected);
});
test("replace range no condition", async () => {
const newData = [
{ a: 2, b: "x" },
{ a: 4, b: "z" },
];
await table
.mergeInsert("a")
.whenMatchedUpdateAll()
.whenNotMatchedInsertAll()
.whenNotMatchedBySourceDelete()
.execute(newData);
const expected = [
{ a: 2, b: "x" },
{ a: 4, b: "z" },
];
// biome-ignore lint/suspicious/noExplicitAny: test
let res: any[] = JSON.parse(
JSON.stringify((await table.toArrow()).toArray()),
);
res = res.sort((a, b) => a.a - b.a);
expect(res).toEqual(expected);
});
}); });
describe("When creating an index", () => { describe("When creating an index", () => {
@@ -135,6 +305,7 @@ describe("When creating an index", () => {
const indices = await tbl.listIndices(); const indices = await tbl.listIndices();
expect(indices.length).toBe(1); expect(indices.length).toBe(1);
expect(indices[0]).toEqual({ expect(indices[0]).toEqual({
name: "vec_idx",
indexType: "IvfPq", indexType: "IvfPq",
columns: ["vec"], columns: ["vec"],
}); });
@@ -191,6 +362,24 @@ describe("When creating an index", () => {
for await (const r of tbl.query().where("id > 1").select(["id"])) { for await (const r of tbl.query().where("id > 1").select(["id"])) {
expect(r.numRows).toBe(298); expect(r.numRows).toBe(298);
} }
// should also work with 'filter' alias
for await (const r of tbl.query().filter("id > 1").select(["id"])) {
expect(r.numRows).toBe(298);
}
});
test("should be able to get index stats", async () => {
await tbl.createIndex("id");
const stats = await tbl.indexStats("id_idx");
expect(stats).toBeDefined();
expect(stats?.numIndexedRows).toEqual(300);
expect(stats?.numUnindexedRows).toEqual(0);
});
test("when getting stats on non-existent index", async () => {
const stats = await tbl.indexStats("some non-existent index");
expect(stats).toBeUndefined();
}); });
// TODO: Move this test to the query API test (making sure we can reject queries // TODO: Move this test to the query API test (making sure we can reject queries
@@ -431,161 +620,6 @@ describe("when dealing with versioning", () => {
}); });
}); });
describe("embedding functions", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
});
afterEach(() => tmpDir.removeCallback());
it("should be able to create a table with an embedding function", async () => {
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const func = new MockEmbeddingFunction();
const db = await connect(tmpDir.name);
const table = await db.createTable(
"test",
[
{ id: 1, text: "hello" },
{ id: 2, text: "world" },
],
{
embeddingFunction: {
function: func,
sourceColumn: "text",
},
},
);
// biome-ignore lint/suspicious/noExplicitAny: test
const arr = (await table.query().toArray()) as any;
expect(arr[0].vector).toBeDefined();
// we round trip through JSON to make sure the vector properly gets converted to an array
// otherwise it'll be a TypedArray or Vector
const vector0 = JSON.parse(JSON.stringify(arr[0].vector));
expect(vector0).toEqual([1, 2, 3]);
});
it("should be able to create an empty table with an embedding function", async () => {
@register()
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const schema = new Schema([
new Field("text", new Utf8(), true),
new Field(
"vector",
new FixedSizeList(3, new Field("item", new Float32(), true)),
true,
),
]);
const func = new MockEmbeddingFunction();
const db = await connect(tmpDir.name);
const table = await db.createEmptyTable("test", schema, {
embeddingFunction: {
function: func,
sourceColumn: "text",
},
});
const outSchema = await table.schema();
expect(outSchema.metadata.get("embedding_functions")).toBeDefined();
await table.add([{ text: "hello world" }]);
// biome-ignore lint/suspicious/noExplicitAny: test
const arr = (await table.query().toArray()) as any;
expect(arr[0].vector).toBeDefined();
// we round trip through JSON to make sure the vector properly gets converted to an array
// otherwise it'll be a TypedArray or Vector
const vector0 = JSON.parse(JSON.stringify(arr[0].vector));
expect(vector0).toEqual([1, 2, 3]);
});
it("should error when appending to a table with an unregistered embedding function", async () => {
@register("mock")
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
}
async computeQueryEmbeddings(_data: string) {
return [1, 2, 3];
}
async computeSourceEmbeddings(data: string[]) {
return Array.from({ length: data.length }).fill([
1, 2, 3,
]) as number[][];
}
}
const func = getRegistry().get<MockEmbeddingFunction>("mock")!.create();
const schema = LanceSchema({
id: new arrow.Float64(),
text: func.sourceField(new Utf8()),
vector: func.vectorField(),
});
const db = await connect(tmpDir.name);
await db.createTable(
"test",
[
{ id: 1, text: "hello" },
{ id: 2, text: "world" },
],
{
schema,
},
);
getRegistry().reset();
const db2 = await connect(tmpDir.name);
const tbl = await db2.openTable("test");
expect(tbl.add([{ id: 3, text: "hello" }])).rejects.toThrow(
`Function "mock" not found in registry`,
);
});
});
describe("when optimizing a dataset", () => { describe("when optimizing a dataset", () => {
let tmpDir: tmp.DirResult; let tmpDir: tmp.DirResult;
let table: Table; let table: Table;
@@ -613,3 +647,99 @@ describe("when optimizing a dataset", () => {
expect(stats.prune.oldVersionsRemoved).toBe(3); expect(stats.prune.oldVersionsRemoved).toBe(3);
}); });
}); });
describe("table.search", () => {
let tmpDir: tmp.DirResult;
beforeEach(() => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
});
afterEach(() => tmpDir.removeCallback());
test("can search using a string", async () => {
@register()
class MockEmbeddingFunction extends EmbeddingFunction<string> {
toJSON(): object {
return {};
}
ndims() {
return 1;
}
embeddingDataType(): arrow.Float {
return new Float32();
}
// Hardcoded embeddings for the sake of testing
async computeQueryEmbeddings(_data: string) {
switch (_data) {
case "greetings":
return [0.1];
case "farewell":
return [0.2];
default:
return null as never;
}
}
// Hardcoded embeddings for the sake of testing
async computeSourceEmbeddings(data: string[]) {
return data.map((s) => {
switch (s) {
case "hello world":
return [0.1];
case "goodbye world":
return [0.2];
default:
return null as never;
}
});
}
}
const func = new MockEmbeddingFunction();
const schema = LanceSchema({
text: func.sourceField(new arrow.Utf8()),
vector: func.vectorField(),
});
const db = await connect(tmpDir.name);
const data = [{ text: "hello world" }, { text: "goodbye world" }];
const table = await db.createTable("test", data, { schema });
const results = await table.search("greetings").then((r) => r.toArray());
expect(results[0].text).toBe(data[0].text);
const results2 = await table.search("farewell").then((r) => r.toArray());
expect(results2[0].text).toBe(data[1].text);
});
test("rejects if no embedding function provided", async () => {
const db = await connect(tmpDir.name);
const data = [
{ text: "hello world", vector: [0.1, 0.2, 0.3] },
{ text: "goodbye world", vector: [0.4, 0.5, 0.6] },
];
const table = await db.createTable("test", data);
expect(table.search("hello")).rejects.toThrow(
"No embedding functions are defined in the table",
);
});
test.each([
[0.4, 0.5, 0.599], // number[]
Float32Array.of(0.4, 0.5, 0.599), // Float32Array
Float64Array.of(0.4, 0.5, 0.599), // Float64Array
])("can search using vectorlike datatypes", async (vectorlike) => {
const db = await connect(tmpDir.name);
const data = [
{ text: "hello world", vector: [0.1, 0.2, 0.3] },
{ text: "goodbye world", vector: [0.4, 0.5, 0.6] },
];
const table = await db.createTable("test", data);
// biome-ignore lint/suspicious/noExplicitAny: test
const results: any[] = await table.search(vectorlike).toArray();
expect(results.length).toBe(2);
expect(results[0].text).toBe(data[1].text);
});
});

View File

@@ -77,7 +77,7 @@
"noDuplicateObjectKeys": "error", "noDuplicateObjectKeys": "error",
"noDuplicateParameters": "error", "noDuplicateParameters": "error",
"noEmptyBlockStatements": "error", "noEmptyBlockStatements": "error",
"noExplicitAny": "error", "noExplicitAny": "warn",
"noExtraNonNullAssertion": "error", "noExtraNonNullAssertion": "error",
"noFallthroughSwitchClause": "error", "noFallthroughSwitchClause": "error",
"noFunctionAssign": "error", "noFunctionAssign": "error",

View File

@@ -31,7 +31,7 @@ import {
Schema, Schema,
Struct, Struct,
Utf8, Utf8,
type Vector, Vector,
makeBuilder, makeBuilder,
makeData, makeData,
type makeTable, type makeTable,
@@ -42,6 +42,8 @@ import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
import { sanitizeField, sanitizeSchema, sanitizeType } from "./sanitize"; import { sanitizeField, sanitizeSchema, sanitizeType } from "./sanitize";
export * from "apache-arrow"; export * from "apache-arrow";
export type IntoVector = Float32Array | Float64Array | number[];
export function isArrowTable(value: object): value is ArrowTable { export function isArrowTable(value: object): value is ArrowTable {
if (value instanceof ArrowTable) return true; if (value instanceof ArrowTable) return true;
return "schema" in value && "batches" in value; return "schema" in value && "batches" in value;
@@ -182,6 +184,7 @@ export class MakeArrowTableOptions {
vector: new VectorColumnOptions(), vector: new VectorColumnOptions(),
}; };
embeddings?: EmbeddingFunction<unknown>; embeddings?: EmbeddingFunction<unknown>;
embeddingFunction?: EmbeddingFunctionConfig;
/** /**
* If true then string columns will be encoded with dictionary encoding * If true then string columns will be encoded with dictionary encoding
@@ -306,7 +309,11 @@ export function makeArrowTable(
const opt = new MakeArrowTableOptions(options !== undefined ? options : {}); const opt = new MakeArrowTableOptions(options !== undefined ? options : {});
if (opt.schema !== undefined && opt.schema !== null) { if (opt.schema !== undefined && opt.schema !== null) {
opt.schema = sanitizeSchema(opt.schema); opt.schema = sanitizeSchema(opt.schema);
opt.schema = validateSchemaEmbeddings(opt.schema, data, opt.embeddings); opt.schema = validateSchemaEmbeddings(
opt.schema,
data,
options?.embeddingFunction,
);
} }
const columns: Record<string, Vector> = {}; const columns: Record<string, Vector> = {};
// TODO: sample dataset to find missing columns // TODO: sample dataset to find missing columns
@@ -545,7 +552,6 @@ async function applyEmbeddingsFromMetadata(
dtype, dtype,
); );
} }
const vector = makeVector(vectors, destType); const vector = makeVector(vectors, destType);
columns[destColumn] = vector; columns[destColumn] = vector;
} }
@@ -835,7 +841,7 @@ export function createEmptyTable(schema: Schema): ArrowTable {
function validateSchemaEmbeddings( function validateSchemaEmbeddings(
schema: Schema, schema: Schema,
data: Array<Record<string, unknown>>, data: Array<Record<string, unknown>>,
embeddings: EmbeddingFunction<unknown> | undefined, embeddings: EmbeddingFunctionConfig | undefined,
) { ) {
const fields = []; const fields = [];
const missingEmbeddingFields = []; const missingEmbeddingFields = [];

View File

@@ -12,38 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
import { Table as ArrowTable, Schema } from "./arrow"; import { Table as ArrowTable, Data, Schema } from "./arrow";
import { import { fromTableToBuffer, makeEmptyTable } from "./arrow";
fromTableToBuffer,
isArrowTable,
makeArrowTable,
makeEmptyTable,
} from "./arrow";
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry"; import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
import { ConnectionOptions, Connection as LanceDbConnection } from "./native"; import { Connection as LanceDbConnection } from "./native";
import { Table } from "./table"; import { LocalTable, Table } from "./table";
/**
* Connect to a LanceDB instance at the given URI.
*
* Accepted formats:
*
* - `/path/to/database` - local database
* - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud storage
* - `db://host:port` - remote database (LanceDB cloud)
* @param {string} uri - The uri of the database. If the database uri starts
* with `db://` then it connects to a remote database.
* @see {@link ConnectionOptions} for more details on the URI format.
*/
export async function connect(
uri: string,
opts?: Partial<ConnectionOptions>,
): Promise<Connection> {
opts = opts ?? {};
opts.storageOptions = cleanseStorageOptions(opts.storageOptions);
const nativeConn = await LanceDbConnection.new(uri, opts);
return new Connection(nativeConn);
}
export interface CreateTableOptions { export interface CreateTableOptions {
/** /**
@@ -71,6 +44,12 @@ export interface CreateTableOptions {
* The available options are described at https://lancedb.github.io/lancedb/guides/storage/ * The available options are described at https://lancedb.github.io/lancedb/guides/storage/
*/ */
storageOptions?: Record<string, string>; storageOptions?: Record<string, string>;
/**
* If true then data files will be written with the legacy format
*
* The default is true while the new format is in beta
*/
useLegacyFormat?: boolean;
schema?: Schema; schema?: Schema;
embeddingFunction?: EmbeddingFunctionConfig; embeddingFunction?: EmbeddingFunctionConfig;
} }
@@ -111,7 +90,6 @@ export interface TableNamesOptions {
/** An optional limit to the number of results to return. */ /** An optional limit to the number of results to return. */
limit?: number; limit?: number;
} }
/** /**
* A LanceDB Connection that allows you to open tables and create new ones. * A LanceDB Connection that allows you to open tables and create new ones.
* *
@@ -130,17 +108,15 @@ export interface TableNamesOptions {
* Any created tables are independent and will continue to work even if * Any created tables are independent and will continue to work even if
* the underlying connection has been closed. * the underlying connection has been closed.
*/ */
export class Connection { export abstract class Connection {
readonly inner: LanceDbConnection; [Symbol.for("nodejs.util.inspect.custom")](): string {
return this.display();
constructor(inner: LanceDbConnection) {
this.inner = inner;
} }
/** Return true if the connection has not been closed */ /**
isOpen(): boolean { * Return true if the connection has not been closed
return this.inner.isOpen(); */
} abstract isOpen(): boolean;
/** /**
* Close the connection, releasing any underlying resources. * Close the connection, releasing any underlying resources.
@@ -149,14 +125,12 @@ export class Connection {
* *
* Any attempt to use the connection after it is closed will result in an error. * Any attempt to use the connection after it is closed will result in an error.
*/ */
close(): void { abstract close(): void;
this.inner.close();
}
/** Return a brief description of the connection */ /**
display(): string { * Return a brief description of the connection
return this.inner.display(); */
} abstract display(): string;
/** /**
* List all the table names in this database. * List all the table names in this database.
@@ -164,15 +138,86 @@ export class Connection {
* Tables will be returned in lexicographical order. * Tables will be returned in lexicographical order.
* @param {Partial<TableNamesOptions>} options - options to control the * @param {Partial<TableNamesOptions>} options - options to control the
* paging / start point * paging / start point
*
*/ */
async tableNames(options?: Partial<TableNamesOptions>): Promise<string[]> { abstract tableNames(options?: Partial<TableNamesOptions>): Promise<string[]>;
return this.inner.tableNames(options?.startAfter, options?.limit);
}
/** /**
* Open a table in the database. * Open a table in the database.
* @param {string} name - The name of the table * @param {string} name - The name of the table
*/ */
abstract openTable(
name: string,
options?: Partial<OpenTableOptions>,
): Promise<Table>;
/**
* Creates a new Table and initialize it with new data.
* @param {object} options - The options object.
* @param {string} options.name - The name of the table.
* @param {Data} options.data - Non-empty Array of Records to be inserted into the table
*
*/
abstract createTable(
options: {
name: string;
data: Data;
} & Partial<CreateTableOptions>,
): Promise<Table>;
/**
* Creates a new Table and initialize it with new data.
* @param {string} name - The name of the table.
* @param {Record<string, unknown>[] | ArrowTable} data - Non-empty Array of Records
* to be inserted into the table
*/
abstract createTable(
name: string,
data: Record<string, unknown>[] | ArrowTable,
options?: Partial<CreateTableOptions>,
): Promise<Table>;
/**
* Creates a new empty Table
* @param {string} name - The name of the table.
* @param {Schema} schema - The schema of the table
*/
abstract createEmptyTable(
name: string,
schema: Schema,
options?: Partial<CreateTableOptions>,
): Promise<Table>;
/**
* Drop an existing table.
* @param {string} name The name of the table to drop.
*/
abstract dropTable(name: string): Promise<void>;
}
export class LocalConnection extends Connection {
readonly inner: LanceDbConnection;
constructor(inner: LanceDbConnection) {
super();
this.inner = inner;
}
isOpen(): boolean {
return this.inner.isOpen();
}
close(): void {
this.inner.close();
}
display(): string {
return this.inner.display();
}
async tableNames(options?: Partial<TableNamesOptions>): Promise<string[]> {
return this.inner.tableNames(options?.startAfter, options?.limit);
}
async openTable( async openTable(
name: string, name: string,
options?: Partial<OpenTableOptions>, options?: Partial<OpenTableOptions>,
@@ -183,54 +228,35 @@ export class Connection {
options?.indexCacheSize, options?.indexCacheSize,
); );
return new Table(innerTable); return new LocalTable(innerTable);
} }
/**
* Creates a new Table and initialize it with new data.
* @param {string} name - The name of the table.
* @param {Record<string, unknown>[] | ArrowTable} data - Non-empty Array of Records
* to be inserted into the table
*/
async createTable( async createTable(
name: string, nameOrOptions:
data: Record<string, unknown>[] | ArrowTable, | string
| ({ name: string; data: Data } & Partial<CreateTableOptions>),
data?: Record<string, unknown>[] | ArrowTable,
options?: Partial<CreateTableOptions>, options?: Partial<CreateTableOptions>,
): Promise<Table> { ): Promise<Table> {
let mode: string = options?.mode ?? "create"; if (typeof nameOrOptions !== "string" && "name" in nameOrOptions) {
const existOk = options?.existOk ?? false; const { name, data, ...options } = nameOrOptions;
return this.createTable(name, data, options);
if (mode === "create" && existOk) {
mode = "exist_ok";
} }
if (data === undefined) {
let table: ArrowTable; throw new Error("data is required");
if (isArrowTable(data)) {
table = data;
} else {
table = makeArrowTable(data, options);
} }
const { buf, mode } = await Table.parseTableData(data, options);
const buf = await fromTableToBuffer(
table,
options?.embeddingFunction,
options?.schema,
);
const innerTable = await this.inner.createTable( const innerTable = await this.inner.createTable(
name, nameOrOptions,
buf, buf,
mode, mode,
cleanseStorageOptions(options?.storageOptions), cleanseStorageOptions(options?.storageOptions),
options?.useLegacyFormat,
); );
return new Table(innerTable); return new LocalTable(innerTable);
} }
/**
* Creates a new empty Table
* @param {string} name - The name of the table.
* @param {Schema} schema - The schema of the table
*/
async createEmptyTable( async createEmptyTable(
name: string, name: string,
schema: Schema, schema: Schema,
@@ -256,14 +282,11 @@ export class Connection {
buf, buf,
mode, mode,
cleanseStorageOptions(options?.storageOptions), cleanseStorageOptions(options?.storageOptions),
options?.useLegacyFormat,
); );
return new Table(innerTable); return new LocalTable(innerTable);
} }
/**
* Drop an existing table.
* @param {string} name The name of the table to drop.
*/
async dropTable(name: string): Promise<void> { async dropTable(name: string): Promise<void> {
return this.inner.dropTable(name); return this.inner.dropTable(name);
} }
@@ -272,7 +295,7 @@ export class Connection {
/** /**
* Takes storage options and makes all the keys snake case. * Takes storage options and makes all the keys snake case.
*/ */
function cleanseStorageOptions( export function cleanseStorageOptions(
options?: Record<string, string>, options?: Record<string, string>,
): Record<string, string> | undefined { ): Record<string, string> | undefined {
if (options === undefined) { if (options === undefined) {

View File

@@ -19,6 +19,7 @@ import {
FixedSizeList, FixedSizeList,
Float, Float,
Float32, Float32,
type IntoVector,
isDataType, isDataType,
isFixedSizeList, isFixedSizeList,
isFloat, isFloat,
@@ -100,33 +101,55 @@ export abstract class EmbeddingFunction<
* @see {@link lancedb.LanceSchema} * @see {@link lancedb.LanceSchema}
*/ */
vectorField( vectorField(
options?: Partial<FieldOptions>, optionsOrDatatype?: Partial<FieldOptions> | DataType,
): [DataType, Map<string, EmbeddingFunction>] { ): [DataType, Map<string, EmbeddingFunction>] {
let dtype: DataType; let dtype: DataType | undefined;
const dims = this.ndims() ?? options?.dims; let vectorType: DataType;
if (!options?.datatype) { let dims: number | undefined = this.ndims();
if (dims === undefined) {
throw new Error("ndims is required for vector field"); // `func.vectorField(new Float32())`
} if (isDataType(optionsOrDatatype)) {
dtype = new FixedSizeList(dims, new Field("item", new Float32(), true)); dtype = optionsOrDatatype;
} else { } else {
if (isFixedSizeList(options.datatype)) { // `func.vectorField({
dtype = options.datatype; // datatype: new Float32(),
} else if (isFloat(options.datatype)) { // dims: 10
// })`
dims = dims ?? optionsOrDatatype?.dims;
dtype = optionsOrDatatype?.datatype;
}
if (dtype !== undefined) {
// `func.vectorField(new FixedSizeList(dims, new Field("item", new Float32(), true)))`
// or `func.vectorField({datatype: new FixedSizeList(dims, new Field("item", new Float32(), true))})`
if (isFixedSizeList(dtype)) {
vectorType = dtype;
// `func.vectorField(new Float32())`
// or `func.vectorField({datatype: new Float32()})`
} else if (isFloat(dtype)) {
// No `ndims` impl and no `{dims: n}` provided;
if (dims === undefined) { if (dims === undefined) {
throw new Error("ndims is required for vector field"); throw new Error("ndims is required for vector field");
} }
dtype = newVectorType(dims, options.datatype); vectorType = newVectorType(dims, dtype);
} else { } else {
throw new Error( throw new Error(
"Expected FixedSizeList or Float as datatype for vector field", "Expected FixedSizeList or Float as datatype for vector field",
); );
} }
} else {
if (dims === undefined) {
throw new Error("ndims is required for vector field");
}
vectorType = new FixedSizeList(
dims,
new Field("item", new Float32(), true),
);
} }
const metadata = new Map<string, EmbeddingFunction>(); const metadata = new Map<string, EmbeddingFunction>();
metadata.set("vector_column_for", this); metadata.set("vector_column_for", this);
return [dtype, metadata]; return [vectorType, metadata];
} }
/** The number of dimensions of the embeddings */ /** The number of dimensions of the embeddings */
@@ -147,9 +170,7 @@ export abstract class EmbeddingFunction<
/** /**
Compute the embeddings for a single query Compute the embeddings for a single query
*/ */
async computeQueryEmbeddings( async computeQueryEmbeddings(data: T): Promise<IntoVector> {
data: T,
): Promise<number[] | Float32Array | Float64Array> {
return this.computeSourceEmbeddings([data]).then( return this.computeSourceEmbeddings([data]).then(
(embeddings) => embeddings[0], (embeddings) => embeddings[0],
); );

View File

@@ -42,6 +42,7 @@ export class EmbeddingFunctionRegistry {
* Register an embedding function * Register an embedding function
* @param name The name of the function * @param name The name of the function
* @param func The function to register * @param func The function to register
* @throws Error if the function is already registered
*/ */
register<T extends EmbeddingFunctionFactory = EmbeddingFunctionFactory>( register<T extends EmbeddingFunctionFactory = EmbeddingFunctionFactory>(
this: EmbeddingFunctionRegistry, this: EmbeddingFunctionRegistry,
@@ -89,6 +90,9 @@ export class EmbeddingFunctionRegistry {
this.#functions.clear(); this.#functions.clear();
} }
/**
* @ignore
*/
parseFunctions( parseFunctions(
this: EmbeddingFunctionRegistry, this: EmbeddingFunctionRegistry,
metadata: Map<string, string>, metadata: Map<string, string>,

View File

@@ -12,25 +12,43 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
import {
Connection,
LocalConnection,
cleanseStorageOptions,
} from "./connection";
import {
ConnectionOptions,
Connection as LanceDbConnection,
} from "./native.js";
import { RemoteConnection, RemoteConnectionOptions } from "./remote";
export { export {
WriteOptions, WriteOptions,
WriteMode, WriteMode,
AddColumnsSql, AddColumnsSql,
ColumnAlteration, ColumnAlteration,
ConnectionOptions, ConnectionOptions,
IndexStatistics,
IndexMetadata,
IndexConfig,
} from "./native.js"; } from "./native.js";
export { export {
makeArrowTable, makeArrowTable,
MakeArrowTableOptions, MakeArrowTableOptions,
Data, Data,
VectorColumnOptions, VectorColumnOptions,
} from "./arrow"; } from "./arrow";
export { export {
connect,
Connection, Connection,
CreateTableOptions, CreateTableOptions,
TableNamesOptions, TableNamesOptions,
} from "./connection"; } from "./connection";
export { export {
ExecutableQuery, ExecutableQuery,
Query, Query,
@@ -38,6 +56,87 @@ export {
VectorQuery, VectorQuery,
RecordBatchIterator, RecordBatchIterator,
} from "./query"; } from "./query";
export { Index, IndexOptions, IvfPqOptions } from "./indices"; export { Index, IndexOptions, IvfPqOptions } from "./indices";
export { Table, AddDataOptions, IndexConfig, UpdateOptions } from "./table";
export { Table, AddDataOptions, UpdateOptions } from "./table";
export * as embedding from "./embedding"; export * as embedding from "./embedding";
/**
* Connect to a LanceDB instance at the given URI.
*
* Accepted formats:
*
* - `/path/to/database` - local database
* - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud storage
* - `db://host:port` - remote database (LanceDB cloud)
* @param {string} uri - The uri of the database. If the database uri starts
* with `db://` then it connects to a remote database.
* @see {@link ConnectionOptions} for more details on the URI format.
* @example
* ```ts
* const conn = await connect("/path/to/database");
* ```
* @example
* ```ts
* const conn = await connect(
* "s3://bucket/path/to/database",
* {storageOptions: {timeout: "60s"}
* });
* ```
*/
export async function connect(
uri: string,
opts?: Partial<ConnectionOptions | RemoteConnectionOptions>,
): Promise<Connection>;
/**
* Connect to a LanceDB instance at the given URI.
*
* Accepted formats:
*
* - `/path/to/database` - local database
* - `s3://bucket/path/to/database` or `gs://bucket/path/to/database` - database on cloud storage
* - `db://host:port` - remote database (LanceDB cloud)
* @param options - The options to use when connecting to the database
* @see {@link ConnectionOptions} for more details on the URI format.
* @example
* ```ts
* const conn = await connect({
* uri: "/path/to/database",
* storageOptions: {timeout: "60s"}
* });
* ```
*/
export async function connect(
opts: Partial<RemoteConnectionOptions | ConnectionOptions> & { uri: string },
): Promise<Connection>;
export async function connect(
uriOrOptions:
| string
| (Partial<RemoteConnectionOptions | ConnectionOptions> & { uri: string }),
opts: Partial<ConnectionOptions | RemoteConnectionOptions> = {},
): Promise<Connection> {
let uri: string | undefined;
if (typeof uriOrOptions !== "string") {
const { uri: uri_, ...options } = uriOrOptions;
uri = uri_;
opts = options;
} else {
uri = uriOrOptions;
}
if (!uri) {
throw new Error("uri is required");
}
if (uri?.startsWith("db://")) {
return new RemoteConnection(uri, opts as RemoteConnectionOptions);
}
opts = (opts as ConnectionOptions) ?? {};
(<ConnectionOptions>opts).storageOptions = cleanseStorageOptions(
(<ConnectionOptions>opts).storageOptions,
);
const nativeConn = await LanceDbConnection.new(uri, opts);
return new LocalConnection(nativeConn);
}

70
nodejs/lancedb/merge.ts Normal file
View File

@@ -0,0 +1,70 @@
import { Data, fromDataToBuffer } from "./arrow";
import { NativeMergeInsertBuilder } from "./native";
/** A builder used to create and run a merge insert operation */
export class MergeInsertBuilder {
#native: NativeMergeInsertBuilder;
/** Construct a MergeInsertBuilder. __Internal use only.__ */
constructor(native: NativeMergeInsertBuilder) {
this.#native = native;
}
/**
* Rows that exist in both the source table (new data) and
* the target table (old data) will be updated, replacing
* the old row with the corresponding matching row.
*
* If there are multiple matches then the behavior is undefined.
* Currently this causes multiple copies of the row to be created
* but that behavior is subject to change.
*
* An optional condition may be specified. If it is, then only
* matched rows that satisfy the condtion will be updated. Any
* rows that do not satisfy the condition will be left as they
* are. Failing to satisfy the condition does not cause a
* "matched row" to become a "not matched" row.
*
* The condition should be an SQL string. Use the prefix
* target. to refer to rows in the target table (old data)
* and the prefix source. to refer to rows in the source
* table (new data).
*
* For example, "target.last_update < source.last_update"
*/
whenMatchedUpdateAll(options?: { where: string }): MergeInsertBuilder {
return new MergeInsertBuilder(
this.#native.whenMatchedUpdateAll(options?.where),
);
}
/**
* Rows that exist only in the source table (new data) should
* be inserted into the target table.
*/
whenNotMatchedInsertAll(): MergeInsertBuilder {
return new MergeInsertBuilder(this.#native.whenNotMatchedInsertAll());
}
/**
* Rows that exist only in the target table (old data) will be
* deleted. An optional condition can be provided to limit what
* data is deleted.
*
* @param options.where - An optional condition to limit what data is deleted
*/
whenNotMatchedBySourceDelete(options?: {
where: string;
}): MergeInsertBuilder {
return new MergeInsertBuilder(
this.#native.whenNotMatchedBySourceDelete(options?.where),
);
}
/**
* Executes the merge insert operation
*
* Nothing is returned but the `Table` is updated
*/
async execute(data: Data): Promise<void> {
const buffer = await fromDataToBuffer(data);
await this.#native.execute(buffer);
}
}

View File

@@ -12,7 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
import { Table as ArrowTable, RecordBatch, tableFromIPC } from "./arrow"; import {
Table as ArrowTable,
type IntoVector,
RecordBatch,
tableFromIPC,
} from "./arrow";
import { type IvfPqOptions } from "./indices"; import { type IvfPqOptions } from "./indices";
import { import {
RecordBatchIterator as NativeBatchIterator, RecordBatchIterator as NativeBatchIterator,
@@ -50,6 +55,39 @@ export class RecordBatchIterator implements AsyncIterator<RecordBatch> {
} }
/* eslint-enable */ /* eslint-enable */
class RecordBatchIterable<
NativeQueryType extends NativeQuery | NativeVectorQuery,
> implements AsyncIterable<RecordBatch>
{
private inner: NativeQueryType;
private options?: QueryExecutionOptions;
constructor(inner: NativeQueryType, options?: QueryExecutionOptions) {
this.inner = inner;
this.options = options;
}
// biome-ignore lint/suspicious/noExplicitAny: skip
[Symbol.asyncIterator](): AsyncIterator<RecordBatch<any>, any, undefined> {
return new RecordBatchIterator(
this.inner.execute(this.options?.maxBatchLength),
);
}
}
/**
* Options that control the behavior of a particular query execution
*/
export interface QueryExecutionOptions {
/**
* The maximum number of rows to return in a single batch
*
* Batches may have fewer rows if the underlying data is stored
* in smaller chunks.
*/
maxBatchLength?: number;
}
/** Common methods supported by all query types */ /** Common methods supported by all query types */
export class QueryBase< export class QueryBase<
NativeQueryType extends NativeQuery | NativeVectorQuery, NativeQueryType extends NativeQuery | NativeVectorQuery,
@@ -76,6 +114,14 @@ export class QueryBase<
this.inner.onlyIf(predicate); this.inner.onlyIf(predicate);
return this as unknown as QueryType; return this as unknown as QueryType;
} }
/**
* A filter statement to be applied to this query.
* @alias where
* @deprecated Use `where` instead
*/
filter(predicate: string): QueryType {
return this.where(predicate);
}
/** /**
* Return only the specified columns. * Return only the specified columns.
@@ -108,9 +154,12 @@ export class QueryBase<
* object insertion order is easy to get wrong and `Map` is more foolproof. * object insertion order is easy to get wrong and `Map` is more foolproof.
*/ */
select( select(
columns: string[] | Map<string, string> | Record<string, string>, columns: string[] | Map<string, string> | Record<string, string> | string,
): QueryType { ): QueryType {
let columnTuples: [string, string][]; let columnTuples: [string, string][];
if (typeof columns === "string") {
columns = [columns];
}
if (Array.isArray(columns)) { if (Array.isArray(columns)) {
columnTuples = columns.map((c) => [c, c]); columnTuples = columns.map((c) => [c, c]);
} else if (columns instanceof Map) { } else if (columns instanceof Map) {
@@ -133,8 +182,10 @@ export class QueryBase<
return this as unknown as QueryType; return this as unknown as QueryType;
} }
protected nativeExecute(): Promise<NativeBatchIterator> { protected nativeExecute(
return this.inner.execute(); options?: Partial<QueryExecutionOptions>,
): Promise<NativeBatchIterator> {
return this.inner.execute(options?.maxBatchLength);
} }
/** /**
@@ -148,8 +199,10 @@ export class QueryBase<
* single query) * single query)
* *
*/ */
protected execute(): RecordBatchIterator { protected execute(
return new RecordBatchIterator(this.nativeExecute()); options?: Partial<QueryExecutionOptions>,
): RecordBatchIterator {
return new RecordBatchIterator(this.nativeExecute(options));
} }
// biome-ignore lint/suspicious/noExplicitAny: skip // biome-ignore lint/suspicious/noExplicitAny: skip
@@ -159,19 +212,18 @@ export class QueryBase<
} }
/** Collect the results as an Arrow @see {@link ArrowTable}. */ /** Collect the results as an Arrow @see {@link ArrowTable}. */
async toArrow(): Promise<ArrowTable> { async toArrow(options?: Partial<QueryExecutionOptions>): Promise<ArrowTable> {
const batches = []; const batches = [];
for await (const batch of this) { for await (const batch of new RecordBatchIterable(this.inner, options)) {
batches.push(batch); batches.push(batch);
} }
return new ArrowTable(batches); return new ArrowTable(batches);
} }
/** Collect the results as an array of objects. */ /** Collect the results as an array of objects. */
async toArray(): Promise<unknown[]> { // biome-ignore lint/suspicious/noExplicitAny: arrow.toArrow() returns any[]
const tbl = await this.toArrow(); async toArray(options?: Partial<QueryExecutionOptions>): Promise<any[]> {
const tbl = await this.toArrow(options);
// eslint-disable-next-line @typescript-eslint/no-unsafe-return
return tbl.toArray(); return tbl.toArray();
} }
} }
@@ -370,9 +422,8 @@ export class Query extends QueryBase<NativeQuery, Query> {
* Vector searches always have a `limit`. If `limit` has not been called then * Vector searches always have a `limit`. If `limit` has not been called then
* a default `limit` of 10 will be used. @see {@link Query#limit} * a default `limit` of 10 will be used. @see {@link Query#limit}
*/ */
nearestTo(vector: unknown): VectorQuery { nearestTo(vector: IntoVector): VectorQuery {
// biome-ignore lint/suspicious/noExplicitAny: skip const vectorQuery = this.inner.nearestTo(Float32Array.from(vector));
const vectorQuery = this.inner.nearestTo(Float32Array.from(vector as any));
return new VectorQuery(vectorQuery); return new VectorQuery(vectorQuery);
} }
} }

View File

@@ -0,0 +1,221 @@
// Copyright 2023 LanceDB Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import axios, {
AxiosError,
type AxiosResponse,
type ResponseType,
} from "axios";
import { Table as ArrowTable } from "../arrow";
import { tableFromIPC } from "../arrow";
import { VectorQuery } from "../query";
export class RestfulLanceDBClient {
#dbName: string;
#region: string;
#apiKey: string;
#hostOverride?: string;
#closed: boolean = false;
#connectionTimeout: number = 12 * 1000; // 12 seconds;
#readTimeout: number = 30 * 1000; // 30 seconds;
#session?: import("axios").AxiosInstance;
constructor(
dbName: string,
apiKey: string,
region: string,
hostOverride?: string,
connectionTimeout?: number,
readTimeout?: number,
) {
this.#dbName = dbName;
this.#apiKey = apiKey;
this.#region = region;
this.#hostOverride = hostOverride ?? this.#hostOverride;
this.#connectionTimeout = connectionTimeout ?? this.#connectionTimeout;
this.#readTimeout = readTimeout ?? this.#readTimeout;
}
// todo: cache the session.
get session(): import("axios").AxiosInstance {
if (this.#session !== undefined) {
return this.#session;
} else {
return axios.create({
baseURL: this.url,
headers: {
// biome-ignore lint/style/useNamingConvention: external api
Authorization: `Bearer ${this.#apiKey}`,
},
transformResponse: decodeErrorData,
timeout: this.#connectionTimeout,
});
}
}
get url(): string {
return (
this.#hostOverride ??
`https://${this.#dbName}.${this.#region}.api.lancedb.com`
);
}
get headers(): { [key: string]: string } {
const headers: { [key: string]: string } = {
"x-api-key": this.#apiKey,
"x-request-id": "na",
};
if (this.#region == "local") {
headers["Host"] = `${this.#dbName}.${this.#region}.api.lancedb.com`;
}
if (this.#hostOverride) {
headers["x-lancedb-database"] = this.#dbName;
}
return headers;
}
isOpen(): boolean {
return !this.#closed;
}
private checkNotClosed(): void {
if (this.#closed) {
throw new Error("Connection is closed");
}
}
close(): void {
this.#session = undefined;
this.#closed = true;
}
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
async get(uri: string, params?: Record<string, any>): Promise<any> {
this.checkNotClosed();
uri = new URL(uri, this.url).toString();
let response;
try {
response = await this.session.get(uri, {
headers: this.headers,
params,
});
} catch (e) {
if (e instanceof AxiosError) {
response = e.response;
} else {
throw e;
}
}
RestfulLanceDBClient.checkStatus(response!);
return response!.data;
}
// biome-ignore lint/suspicious/noExplicitAny: api response
async post(uri: string, body?: any): Promise<any>;
async post(
uri: string,
// biome-ignore lint/suspicious/noExplicitAny: api request
body: any,
additional: {
config?: { responseType: "arraybuffer" };
headers?: Record<string, string>;
params?: Record<string, string>;
},
): Promise<Buffer>;
async post(
uri: string,
// biome-ignore lint/suspicious/noExplicitAny: api request
body?: any,
additional?: {
config?: { responseType: ResponseType };
headers?: Record<string, string>;
params?: Record<string, string>;
},
// biome-ignore lint/suspicious/noExplicitAny: api response
): Promise<any> {
this.checkNotClosed();
uri = new URL(uri, this.url).toString();
additional = Object.assign(
{ config: { responseType: "json" } },
additional,
);
const headers = { ...this.headers, ...additional.headers };
if (!headers["Content-Type"]) {
headers["Content-Type"] = "application/json";
}
let response;
try {
response = await this.session.post(uri, body, {
headers,
responseType: additional!.config!.responseType,
params: new Map(Object.entries(additional.params ?? {})),
});
} catch (e) {
if (e instanceof AxiosError) {
response = e.response;
} else {
throw e;
}
}
RestfulLanceDBClient.checkStatus(response!);
if (additional!.config!.responseType === "arraybuffer") {
return response!.data;
} else {
return JSON.parse(response!.data);
}
}
async listTables(limit = 10, pageToken = ""): Promise<string[]> {
const json = await this.get("/v1/table", { limit, pageToken });
return json.tables;
}
async query(tableName: string, query: VectorQuery): Promise<ArrowTable> {
const tbl = await this.post(`/v1/table/${tableName}/query`, query, {
config: {
responseType: "arraybuffer",
},
});
return tableFromIPC(tbl);
}
static checkStatus(response: AxiosResponse): void {
if (response.status === 404) {
throw new Error(`Not found: ${response.data}`);
} else if (response.status >= 400 && response.status < 500) {
throw new Error(
`Bad Request: ${response.status}, error: ${response.data}`,
);
} else if (response.status >= 500 && response.status < 600) {
throw new Error(
`Internal Server Error: ${response.status}, error: ${response.data}`,
);
} else if (response.status !== 200) {
throw new Error(
`Unknown Error: ${response.status}, error: ${response.data}`,
);
}
}
}
function decodeErrorData(data: unknown) {
if (Buffer.isBuffer(data)) {
const decoded = data.toString("utf-8");
return decoded;
}
return data;
}

View File

@@ -0,0 +1,196 @@
import { Schema } from "apache-arrow";
import { Data, fromTableToStreamBuffer, makeEmptyTable } from "../arrow";
import {
Connection,
CreateTableOptions,
OpenTableOptions,
TableNamesOptions,
} from "../connection";
import { Table } from "../table";
import { TTLCache } from "../util";
import { RestfulLanceDBClient } from "./client";
import { RemoteTable } from "./table";
export interface RemoteConnectionOptions {
apiKey?: string;
region?: string;
hostOverride?: string;
connectionTimeout?: number;
readTimeout?: number;
}
export class RemoteConnection extends Connection {
#dbName: string;
#apiKey: string;
#region: string;
#client: RestfulLanceDBClient;
#tableCache = new TTLCache(300_000);
constructor(
url: string,
{
apiKey,
region,
hostOverride,
connectionTimeout,
readTimeout,
}: RemoteConnectionOptions,
) {
super();
apiKey = apiKey ?? process.env.LANCEDB_API_KEY;
region = region ?? process.env.LANCEDB_REGION;
if (!apiKey) {
throw new Error("apiKey is required when connecting to LanceDB Cloud");
}
if (!region) {
throw new Error("region is required when connecting to LanceDB Cloud");
}
const parsed = new URL(url);
if (parsed.protocol !== "db:") {
throw new Error(
`invalid protocol: ${parsed.protocol}, only accepts db://`,
);
}
this.#dbName = parsed.hostname;
this.#apiKey = apiKey;
this.#region = region;
this.#client = new RestfulLanceDBClient(
this.#dbName,
this.#apiKey,
this.#region,
hostOverride,
connectionTimeout,
readTimeout,
);
}
isOpen(): boolean {
return this.#client.isOpen();
}
close(): void {
return this.#client.close();
}
display(): string {
return `RemoteConnection(${this.#dbName})`;
}
async tableNames(options?: Partial<TableNamesOptions>): Promise<string[]> {
const response = await this.#client.get("/v1/table/", {
limit: options?.limit ?? 10,
// biome-ignore lint/style/useNamingConvention: <explanation>
page_token: options?.startAfter ?? "",
});
const body = await response.body();
for (const table of body.tables) {
this.#tableCache.set(table, true);
}
return body.tables;
}
async openTable(
name: string,
_options?: Partial<OpenTableOptions> | undefined,
): Promise<Table> {
if (this.#tableCache.get(name) === undefined) {
await this.#client.post(
`/v1/table/${encodeURIComponent(name)}/describe/`,
);
this.#tableCache.set(name, true);
}
return new RemoteTable(this.#client, name, this.#dbName);
}
async createTable(
nameOrOptions:
| string
| ({ name: string; data: Data } & Partial<CreateTableOptions>),
data?: Data,
options?: Partial<CreateTableOptions> | undefined,
): Promise<Table> {
if (typeof nameOrOptions !== "string" && "name" in nameOrOptions) {
const { name, data, ...options } = nameOrOptions;
return this.createTable(name, data, options);
}
if (data === undefined) {
throw new Error("data is required");
}
if (options?.mode) {
console.warn(
"option 'mode' is not supported in LanceDB Cloud",
"LanceDB Cloud only supports the default 'create' mode.",
"If the table already exists, an error will be thrown.",
);
}
if (options?.embeddingFunction) {
console.warn(
"embedding_functions is not yet supported on LanceDB Cloud.",
"Please vote https://github.com/lancedb/lancedb/issues/626 ",
"for this feature.",
);
}
const { buf } = await Table.parseTableData(
data,
options,
true /** streaming */,
);
await this.#client.post(
`/v1/table/${encodeURIComponent(nameOrOptions)}/create/`,
buf,
{
config: {
responseType: "arraybuffer",
},
headers: { "Content-Type": "application/vnd.apache.arrow.stream" },
},
);
this.#tableCache.set(nameOrOptions, true);
return new RemoteTable(this.#client, nameOrOptions, this.#dbName);
}
async createEmptyTable(
name: string,
schema: Schema,
options?: Partial<CreateTableOptions> | undefined,
): Promise<Table> {
if (options?.mode) {
console.warn(`mode is not supported on LanceDB Cloud`);
}
if (options?.embeddingFunction) {
console.warn(
"embeddingFunction is not yet supported on LanceDB Cloud.",
"Please vote https://github.com/lancedb/lancedb/issues/626 ",
"for this feature.",
);
}
const emptyTable = makeEmptyTable(schema);
const buf = await fromTableToStreamBuffer(emptyTable);
await this.#client.post(
`/v1/table/${encodeURIComponent(name)}/create/`,
buf,
{
config: {
responseType: "arraybuffer",
},
headers: { "Content-Type": "application/vnd.apache.arrow.stream" },
},
);
this.#tableCache.set(name, true);
return new RemoteTable(this.#client, name, this.#dbName);
}
async dropTable(name: string): Promise<void> {
await this.#client.post(`/v1/table/${encodeURIComponent(name)}/drop/`);
this.#tableCache.delete(name);
}
}

View File

@@ -0,0 +1,3 @@
export { RestfulLanceDBClient } from "./client";
export { type RemoteConnectionOptions, RemoteConnection } from "./connection";
export { RemoteTable } from "./table";

View File

@@ -0,0 +1,172 @@
// Copyright 2023 LanceDB Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { Table as ArrowTable } from "apache-arrow";
import { Data, IntoVector } from "../arrow";
import { IndexStatistics } from "..";
import { CreateTableOptions } from "../connection";
import { IndexOptions } from "../indices";
import { MergeInsertBuilder } from "../merge";
import { VectorQuery } from "../query";
import { AddDataOptions, Table, UpdateOptions } from "../table";
import { RestfulLanceDBClient } from "./client";
export class RemoteTable extends Table {
#client: RestfulLanceDBClient;
#name: string;
// Used in the display() method
#dbName: string;
get #tablePrefix() {
return `/v1/table/${encodeURIComponent(this.#name)}/`;
}
get name(): string {
return this.#name;
}
public constructor(
client: RestfulLanceDBClient,
tableName: string,
dbName: string,
) {
super();
this.#client = client;
this.#name = tableName;
this.#dbName = dbName;
}
isOpen(): boolean {
return !this.#client.isOpen();
}
close(): void {
this.#client.close();
}
display(): string {
return `RemoteTable(${this.#dbName}; ${this.#name})`;
}
async schema(): Promise<import("apache-arrow").Schema> {
const resp = await this.#client.post(`${this.#tablePrefix}/describe/`);
// TODO: parse this into a valid arrow schema
return resp.schema;
}
async add(data: Data, options?: Partial<AddDataOptions>): Promise<void> {
const { buf, mode } = await Table.parseTableData(
data,
options as CreateTableOptions,
true,
);
await this.#client.post(`${this.#tablePrefix}/insert/`, buf, {
params: {
mode,
},
headers: {
"Content-Type": "application/vnd.apache.arrow.stream",
},
});
}
async update(
updates: Map<string, string> | Record<string, string>,
options?: Partial<UpdateOptions>,
): Promise<void> {
await this.#client.post(`${this.#tablePrefix}/update/`, {
predicate: options?.where ?? null,
updates: Object.entries(updates).map(([key, value]) => [key, value]),
});
}
async countRows(filter?: unknown): Promise<number> {
const payload = { predicate: filter };
return await this.#client.post(`${this.#tablePrefix}/count_rows/`, payload);
}
async delete(predicate: unknown): Promise<void> {
const payload = { predicate };
await this.#client.post(`${this.#tablePrefix}/delete/`, payload);
}
async createIndex(
column: string,
options?: Partial<IndexOptions>,
): Promise<void> {
if (options !== undefined) {
console.warn("options are not yet supported on the LanceDB cloud");
}
const indexType = "vector";
const metric = "L2";
const data = {
column,
// biome-ignore lint/style/useNamingConvention: external API
index_type: indexType,
// biome-ignore lint/style/useNamingConvention: external API
metric_type: metric,
};
await this.#client.post(`${this.#tablePrefix}/create_index`, data);
}
query(): import("..").Query {
throw new Error("query() is not yet supported on the LanceDB cloud");
}
search(query: IntoVector): VectorQuery;
search(query: string): Promise<VectorQuery>;
search(_query: string | IntoVector): VectorQuery | Promise<VectorQuery> {
throw new Error("search() is not yet supported on the LanceDB cloud");
}
vectorSearch(_vector: unknown): import("..").VectorQuery {
throw new Error("vectorSearch() is not yet supported on the LanceDB cloud");
}
addColumns(_newColumnTransforms: unknown): Promise<void> {
throw new Error("addColumns() is not yet supported on the LanceDB cloud");
}
alterColumns(_columnAlterations: unknown): Promise<void> {
throw new Error("alterColumns() is not yet supported on the LanceDB cloud");
}
dropColumns(_columnNames: unknown): Promise<void> {
throw new Error("dropColumns() is not yet supported on the LanceDB cloud");
}
async version(): Promise<number> {
const resp = await this.#client.post(`${this.#tablePrefix}/describe/`);
return resp.version;
}
checkout(_version: unknown): Promise<void> {
throw new Error("checkout() is not yet supported on the LanceDB cloud");
}
checkoutLatest(): Promise<void> {
throw new Error(
"checkoutLatest() is not yet supported on the LanceDB cloud",
);
}
restore(): Promise<void> {
throw new Error("restore() is not yet supported on the LanceDB cloud");
}
optimize(_options?: unknown): Promise<import("../native").OptimizeStats> {
throw new Error("optimize() is not yet supported on the LanceDB cloud");
}
async listIndices(): Promise<import("../native").IndexConfig[]> {
return await this.#client.post(`${this.#tablePrefix}/index/list/`);
}
toArrow(): Promise<ArrowTable> {
throw new Error("toArrow() is not yet supported on the LanceDB cloud");
}
mergeInsert(_on: string | string[]): MergeInsertBuilder {
throw new Error("mergeInsert() is not yet supported on the LanceDB cloud");
}
async indexStats(_name: string): Promise<IndexStatistics | undefined> {
throw new Error("indexStats() is not yet supported on the LanceDB cloud");
}
}

View File

@@ -12,20 +12,33 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
import { Data, Schema, fromDataToBuffer, tableFromIPC } from "./arrow"; import {
Table as ArrowTable,
Data,
IntoVector,
Schema,
fromDataToBuffer,
fromTableToBuffer,
fromTableToStreamBuffer,
isArrowTable,
makeArrowTable,
tableFromIPC,
} from "./arrow";
import { CreateTableOptions } from "./connection";
import { getRegistry } from "./embedding/registry"; import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
import { IndexOptions } from "./indices"; import { IndexOptions } from "./indices";
import { MergeInsertBuilder } from "./merge";
import { import {
AddColumnsSql, AddColumnsSql,
ColumnAlteration, ColumnAlteration,
IndexConfig, IndexConfig,
IndexStatistics,
OptimizeStats, OptimizeStats,
Table as _NativeTable, Table as _NativeTable,
} from "./native"; } from "./native";
import { Query, VectorQuery } from "./query"; import { Query, VectorQuery } from "./query";
export { IndexConfig } from "./native";
/** /**
* Options for adding data to a table. * Options for adding data to a table.
*/ */
@@ -81,19 +94,15 @@ export interface OptimizeOptions {
* Closing a table is optional. It not closed, it will be closed when it is garbage * Closing a table is optional. It not closed, it will be closed when it is garbage
* collected. * collected.
*/ */
export class Table { export abstract class Table {
private readonly inner: _NativeTable; [Symbol.for("nodejs.util.inspect.custom")](): string {
return this.display();
/** Construct a Table. Internal use only. */
constructor(inner: _NativeTable) {
this.inner = inner;
} }
/** Returns the name of the table */
abstract get name(): string;
/** Return true if the table has not been closed */ /** Return true if the table has not been closed */
isOpen(): boolean { abstract isOpen(): boolean;
return this.inner.isOpen();
}
/** /**
* Close the table, releasing any underlying resources. * Close the table, releasing any underlying resources.
* *
@@ -101,39 +110,16 @@ export class Table {
* *
* Any attempt to use the table after it is closed will result in an error. * Any attempt to use the table after it is closed will result in an error.
*/ */
close(): void { abstract close(): void;
this.inner.close();
}
/** Return a brief description of the table */ /** Return a brief description of the table */
display(): string { abstract display(): string;
return this.inner.display();
}
/** Get the schema of the table. */ /** Get the schema of the table. */
async schema(): Promise<Schema> { abstract schema(): Promise<Schema>;
const schemaBuf = await this.inner.schema();
const tbl = tableFromIPC(schemaBuf);
return tbl.schema;
}
/** /**
* Insert records into this Table. * Insert records into this Table.
* @param {Data} data Records to be inserted into the Table * @param {Data} data Records to be inserted into the Table
*/ */
async add(data: Data, options?: Partial<AddDataOptions>): Promise<void> { abstract add(data: Data, options?: Partial<AddDataOptions>): Promise<void>;
const mode = options?.mode ?? "append";
const schema = await this.schema();
const registry = getRegistry();
const functions = registry.parseFunctions(schema.metadata);
const buffer = await fromDataToBuffer(
data,
functions.values().next().value,
);
await this.inner.add(buffer, mode);
}
/** /**
* Update existing records in the Table * Update existing records in the Table
* *
@@ -159,30 +145,14 @@ export class Table {
* @param {Partial<UpdateOptions>} options - additional options to control * @param {Partial<UpdateOptions>} options - additional options to control
* the update behavior * the update behavior
*/ */
async update( abstract update(
updates: Map<string, string> | Record<string, string>, updates: Map<string, string> | Record<string, string>,
options?: Partial<UpdateOptions>, options?: Partial<UpdateOptions>,
) { ): Promise<void>;
const onlyIf = options?.where;
let columns: [string, string][];
if (updates instanceof Map) {
columns = Array.from(updates.entries());
} else {
columns = Object.entries(updates);
}
await this.inner.update(onlyIf, columns);
}
/** Count the total number of rows in the dataset. */ /** Count the total number of rows in the dataset. */
async countRows(filter?: string): Promise<number> { abstract countRows(filter?: string): Promise<number>;
return await this.inner.countRows(filter);
}
/** Delete the rows that satisfy the predicate. */ /** Delete the rows that satisfy the predicate. */
async delete(predicate: string): Promise<void> { abstract delete(predicate: string): Promise<void>;
await this.inner.delete(predicate);
}
/** /**
* Create an index to speed up queries. * Create an index to speed up queries.
* *
@@ -190,6 +160,9 @@ export class Table {
* Indices on vector columns will speed up vector searches. * Indices on vector columns will speed up vector searches.
* Indices on scalar columns will speed up filtering (in both * Indices on scalar columns will speed up filtering (in both
* vector and non-vector searches) * vector and non-vector searches)
*
* @note We currently don't support custom named indexes,
* The index name will always be `${column}_idx`
* @example * @example
* // If the column has a vector (fixed size list) data type then * // If the column has a vector (fixed size list) data type then
* // an IvfPq vector index will be created. * // an IvfPq vector index will be created.
@@ -209,13 +182,10 @@ export class Table {
* // Or create a Scalar index * // Or create a Scalar index
* await table.createIndex("my_float_col"); * await table.createIndex("my_float_col");
*/ */
async createIndex(column: string, options?: Partial<IndexOptions>) { abstract createIndex(
// Bit of a hack to get around the fact that TS has no package-scope. column: string,
// biome-ignore lint/suspicious/noExplicitAny: skip options?: Partial<IndexOptions>,
const nativeIndex = (options?.config as any)?.inner; ): Promise<void>;
await this.inner.createIndex(nativeIndex, column, options?.replace);
}
/** /**
* Create a {@link Query} Builder. * Create a {@link Query} Builder.
* *
@@ -266,10 +236,20 @@ export class Table {
* } * }
* @returns {Query} A builder that can be used to parameterize the query * @returns {Query} A builder that can be used to parameterize the query
*/ */
query(): Query { abstract query(): Query;
return new Query(this.inner); /**
} * Create a search query to find the nearest neighbors
* of the given query vector
* @param {string} query - the query. This will be converted to a vector using the table's provided embedding function
* @rejects {Error} If no embedding functions are defined in the table
*/
abstract search(query: string): Promise<VectorQuery>;
/**
* Create a search query to find the nearest neighbors
* of the given query vector
* @param {IntoVector} query - the query vector
*/
abstract search(query: IntoVector): VectorQuery;
/** /**
* Search the table with a given query vector. * Search the table with a given query vector.
* *
@@ -277,11 +257,7 @@ export class Table {
* is the same thing as calling `nearestTo` on the builder returned * is the same thing as calling `nearestTo` on the builder returned
* by `query`. @see {@link Query#nearestTo} for more details. * by `query`. @see {@link Query#nearestTo} for more details.
*/ */
vectorSearch(vector: unknown): VectorQuery { abstract vectorSearch(vector: IntoVector): VectorQuery;
return this.query().nearestTo(vector);
}
// TODO: Support BatchUDF
/** /**
* Add new columns with defined values. * Add new columns with defined values.
* @param {AddColumnsSql[]} newColumnTransforms pairs of column names and * @param {AddColumnsSql[]} newColumnTransforms pairs of column names and
@@ -289,19 +265,14 @@ export class Table {
* expressions will be evaluated for each row in the table, and can * expressions will be evaluated for each row in the table, and can
* reference existing columns in the table. * reference existing columns in the table.
*/ */
async addColumns(newColumnTransforms: AddColumnsSql[]): Promise<void> { abstract addColumns(newColumnTransforms: AddColumnsSql[]): Promise<void>;
await this.inner.addColumns(newColumnTransforms);
}
/** /**
* Alter the name or nullability of columns. * Alter the name or nullability of columns.
* @param {ColumnAlteration[]} columnAlterations One or more alterations to * @param {ColumnAlteration[]} columnAlterations One or more alterations to
* apply to columns. * apply to columns.
*/ */
async alterColumns(columnAlterations: ColumnAlteration[]): Promise<void> { abstract alterColumns(columnAlterations: ColumnAlteration[]): Promise<void>;
await this.inner.alterColumns(columnAlterations);
}
/** /**
* Drop one or more columns from the dataset * Drop one or more columns from the dataset
* *
@@ -313,15 +284,10 @@ export class Table {
* be nested column references (e.g. "a.b.c") or top-level column names * be nested column references (e.g. "a.b.c") or top-level column names
* (e.g. "a"). * (e.g. "a").
*/ */
async dropColumns(columnNames: string[]): Promise<void> { abstract dropColumns(columnNames: string[]): Promise<void>;
await this.inner.dropColumns(columnNames);
}
/** Retrieve the version of the table */ /** Retrieve the version of the table */
async version(): Promise<number> {
return await this.inner.version();
}
abstract version(): Promise<number>;
/** /**
* Checks out a specific version of the table _This is an in-place operation._ * Checks out a specific version of the table _This is an in-place operation._
* *
@@ -347,19 +313,14 @@ export class Table {
* console.log(await table.version()); // 2 * console.log(await table.version()); // 2
* ``` * ```
*/ */
async checkout(version: number): Promise<void> { abstract checkout(version: number): Promise<void>;
await this.inner.checkout(version);
}
/** /**
* Checkout the latest version of the table. _This is an in-place operation._ * Checkout the latest version of the table. _This is an in-place operation._
* *
* The table will be set back into standard mode, and will track the latest * The table will be set back into standard mode, and will track the latest
* version of the table. * version of the table.
*/ */
async checkoutLatest(): Promise<void> { abstract checkoutLatest(): Promise<void>;
await this.inner.checkoutLatest();
}
/** /**
* Restore the table to the currently checked out version * Restore the table to the currently checked out version
@@ -373,10 +334,7 @@ export class Table {
* Once the operation concludes the table will no longer be in a checked * Once the operation concludes the table will no longer be in a checked
* out state and the read_consistency_interval, if any, will apply. * out state and the read_consistency_interval, if any, will apply.
*/ */
async restore(): Promise<void> { abstract restore(): Promise<void>;
await this.inner.restore();
}
/** /**
* Optimize the on-disk data and indices for better performance. * Optimize the on-disk data and indices for better performance.
* *
@@ -407,6 +365,200 @@ export class Table {
* you have added or modified 100,000 or more records or run more than 20 data * you have added or modified 100,000 or more records or run more than 20 data
* modification operations. * modification operations.
*/ */
abstract optimize(options?: Partial<OptimizeOptions>): Promise<OptimizeStats>;
/** List all indices that have been created with {@link Table.createIndex} */
abstract listIndices(): Promise<IndexConfig[]>;
/** Return the table as an arrow table */
abstract toArrow(): Promise<ArrowTable>;
abstract mergeInsert(on: string | string[]): MergeInsertBuilder;
/** List all the stats of a specified index
*
* @param {string} name The name of the index.
* @returns {IndexStatistics | undefined} The stats of the index. If the index does not exist, it will return undefined
*/
abstract indexStats(name: string): Promise<IndexStatistics | undefined>;
static async parseTableData(
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
data: Record<string, unknown>[] | ArrowTable<any>,
options?: Partial<CreateTableOptions>,
streaming = false,
) {
let mode: string = options?.mode ?? "create";
const existOk = options?.existOk ?? false;
if (mode === "create" && existOk) {
mode = "exist_ok";
}
let table: ArrowTable;
if (isArrowTable(data)) {
table = data;
} else {
table = makeArrowTable(data, options);
}
if (streaming) {
const buf = await fromTableToStreamBuffer(
table,
options?.embeddingFunction,
options?.schema,
);
return { buf, mode };
} else {
const buf = await fromTableToBuffer(
table,
options?.embeddingFunction,
options?.schema,
);
return { buf, mode };
}
}
}
export class LocalTable extends Table {
private readonly inner: _NativeTable;
constructor(inner: _NativeTable) {
super();
this.inner = inner;
}
get name(): string {
return this.inner.name;
}
isOpen(): boolean {
return this.inner.isOpen();
}
close(): void {
this.inner.close();
}
display(): string {
return this.inner.display();
}
private async getEmbeddingFunctions(): Promise<
Map<string, EmbeddingFunctionConfig>
> {
const schema = await this.schema();
const registry = getRegistry();
return registry.parseFunctions(schema.metadata);
}
/** Get the schema of the table. */
async schema(): Promise<Schema> {
const schemaBuf = await this.inner.schema();
const tbl = tableFromIPC(schemaBuf);
return tbl.schema;
}
async add(data: Data, options?: Partial<AddDataOptions>): Promise<void> {
const mode = options?.mode ?? "append";
const schema = await this.schema();
const registry = getRegistry();
const functions = registry.parseFunctions(schema.metadata);
const buffer = await fromDataToBuffer(
data,
functions.values().next().value,
schema,
);
await this.inner.add(buffer, mode);
}
async update(
updates: Map<string, string> | Record<string, string>,
options?: Partial<UpdateOptions>,
) {
const onlyIf = options?.where;
let columns: [string, string][];
if (updates instanceof Map) {
columns = Array.from(updates.entries());
} else {
columns = Object.entries(updates);
}
await this.inner.update(onlyIf, columns);
}
async countRows(filter?: string): Promise<number> {
return await this.inner.countRows(filter);
}
async delete(predicate: string): Promise<void> {
await this.inner.delete(predicate);
}
async createIndex(column: string, options?: Partial<IndexOptions>) {
// Bit of a hack to get around the fact that TS has no package-scope.
// biome-ignore lint/suspicious/noExplicitAny: skip
const nativeIndex = (options?.config as any)?.inner;
await this.inner.createIndex(nativeIndex, column, options?.replace);
}
query(): Query {
return new Query(this.inner);
}
search(query: string): Promise<VectorQuery>;
search(query: IntoVector): VectorQuery;
search(query: string | IntoVector): Promise<VectorQuery> | VectorQuery {
if (typeof query !== "string") {
return this.vectorSearch(query);
} else {
return this.getEmbeddingFunctions().then(async (functions) => {
// TODO: Support multiple embedding functions
const embeddingFunc: EmbeddingFunctionConfig | undefined = functions
.values()
.next().value;
if (!embeddingFunc) {
return Promise.reject(
new Error("No embedding functions are defined in the table"),
);
}
const embeddings =
await embeddingFunc.function.computeQueryEmbeddings(query);
return this.query().nearestTo(embeddings);
});
}
}
vectorSearch(vector: IntoVector): VectorQuery {
return this.query().nearestTo(vector);
}
// TODO: Support BatchUDF
async addColumns(newColumnTransforms: AddColumnsSql[]): Promise<void> {
await this.inner.addColumns(newColumnTransforms);
}
async alterColumns(columnAlterations: ColumnAlteration[]): Promise<void> {
await this.inner.alterColumns(columnAlterations);
}
async dropColumns(columnNames: string[]): Promise<void> {
await this.inner.dropColumns(columnNames);
}
async version(): Promise<number> {
return await this.inner.version();
}
async checkout(version: number): Promise<void> {
await this.inner.checkout(version);
}
async checkoutLatest(): Promise<void> {
await this.inner.checkoutLatest();
}
async restore(): Promise<void> {
await this.inner.restore();
}
async optimize(options?: Partial<OptimizeOptions>): Promise<OptimizeStats> { async optimize(options?: Partial<OptimizeOptions>): Promise<OptimizeStats> {
let cleanupOlderThanMs; let cleanupOlderThanMs;
if ( if (
@@ -419,8 +571,23 @@ export class Table {
return await this.inner.optimize(cleanupOlderThanMs); return await this.inner.optimize(cleanupOlderThanMs);
} }
/** List all indices that have been created with {@link Table.createIndex} */
async listIndices(): Promise<IndexConfig[]> { async listIndices(): Promise<IndexConfig[]> {
return await this.inner.listIndices(); return await this.inner.listIndices();
} }
async toArrow(): Promise<ArrowTable> {
return await this.query().toArrow();
}
async indexStats(name: string): Promise<IndexStatistics | undefined> {
const stats = await this.inner.indexStats(name);
if (stats === null) {
return undefined;
}
return stats;
}
mergeInsert(on: string | string[]): MergeInsertBuilder {
on = Array.isArray(on) ? on : [on];
return new MergeInsertBuilder(this.inner.mergeInsert(on));
}
} }

35
nodejs/lancedb/util.ts Normal file
View File

@@ -0,0 +1,35 @@
export class TTLCache {
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
private readonly cache: Map<string, { value: any; expires: number }>;
/**
* @param ttl Time to live in milliseconds
*/
constructor(private readonly ttl: number) {
this.cache = new Map();
}
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
get(key: string): any | undefined {
const entry = this.cache.get(key);
if (entry === undefined) {
return undefined;
}
if (entry.expires < Date.now()) {
this.cache.delete(key);
return undefined;
}
return entry.value;
}
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
set(key: string, value: any): void {
this.cache.set(key, { value, expires: Date.now() + this.ttl });
}
delete(key: string): void {
this.cache.delete(key);
}
}

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-darwin-arm64", "name": "@lancedb/lancedb-darwin-arm64",
"version": "0.5.0", "version": "0.5.2-final.1",
"os": ["darwin"], "os": ["darwin"],
"cpu": ["arm64"], "cpu": ["arm64"],
"main": "lancedb.darwin-arm64.node", "main": "lancedb.darwin-arm64.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-darwin-x64", "name": "@lancedb/lancedb-darwin-x64",
"version": "0.5.0", "version": "0.5.2-final.1",
"os": ["darwin"], "os": ["darwin"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.darwin-x64.node", "main": "lancedb.darwin-x64.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-arm64-gnu", "name": "@lancedb/lancedb-linux-arm64-gnu",
"version": "0.5.0", "version": "0.5.2-final.1",
"os": ["linux"], "os": ["linux"],
"cpu": ["arm64"], "cpu": ["arm64"],
"main": "lancedb.linux-arm64-gnu.node", "main": "lancedb.linux-arm64-gnu.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-linux-x64-gnu", "name": "@lancedb/lancedb-linux-x64-gnu",
"version": "0.5.0", "version": "0.5.2-final.1",
"os": ["linux"], "os": ["linux"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.linux-x64-gnu.node", "main": "lancedb.linux-x64-gnu.node",

View File

@@ -1,6 +1,6 @@
{ {
"name": "@lancedb/lancedb-win32-x64-msvc", "name": "@lancedb/lancedb-win32-x64-msvc",
"version": "0.5.0", "version": "0.5.2-final.1",
"os": ["win32"], "os": ["win32"],
"cpu": ["x64"], "cpu": ["x64"],
"main": "lancedb.win32-x64-msvc.node", "main": "lancedb.win32-x64-msvc.node",

View File

@@ -1,12 +1,12 @@
{ {
"name": "@lancedb/lancedb", "name": "@lancedb/lancedb",
"version": "0.5.0", "version": "0.5.2",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "@lancedb/lancedb", "name": "@lancedb/lancedb",
"version": "0.5.0", "version": "0.5.2",
"cpu": [ "cpu": [
"x64", "x64",
"arm64" "arm64"
@@ -19,6 +19,7 @@
], ],
"dependencies": { "dependencies": {
"apache-arrow": "^15.0.0", "apache-arrow": "^15.0.0",
"axios": "^1.7.2",
"openai": "^4.29.2", "openai": "^4.29.2",
"reflect-metadata": "^0.2.2" "reflect-metadata": "^0.2.2"
}, },
@@ -28,6 +29,7 @@
"@biomejs/biome": "^1.7.3", "@biomejs/biome": "^1.7.3",
"@jest/globals": "^29.7.0", "@jest/globals": "^29.7.0",
"@napi-rs/cli": "^2.18.0", "@napi-rs/cli": "^2.18.0",
"@types/axios": "^0.14.0",
"@types/jest": "^29.1.2", "@types/jest": "^29.1.2",
"@types/tmp": "^0.2.6", "@types/tmp": "^0.2.6",
"apache-arrow-old": "npm:apache-arrow@13.0.0", "apache-arrow-old": "npm:apache-arrow@13.0.0",
@@ -3123,6 +3125,16 @@
"tslib": "^2.4.0" "tslib": "^2.4.0"
} }
}, },
"node_modules/@types/axios": {
"version": "0.14.0",
"resolved": "https://registry.npmjs.org/@types/axios/-/axios-0.14.0.tgz",
"integrity": "sha512-KqQnQbdYE54D7oa/UmYVMZKq7CO4l8DEENzOKc4aBRwxCXSlJXGz83flFx5L7AWrOQnmuN3kVsRdt+GZPPjiVQ==",
"deprecated": "This is a stub types definition for axios (https://github.com/mzabriskie/axios). axios provides its own type definitions, so you don't need @types/axios installed!",
"dev": true,
"dependencies": {
"axios": "*"
}
},
"node_modules/@types/babel__core": { "node_modules/@types/babel__core": {
"version": "7.20.5", "version": "7.20.5",
"resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.5.tgz", "resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.5.tgz",
@@ -3497,6 +3509,16 @@
"resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz", "resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz",
"integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==" "integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q=="
}, },
"node_modules/axios": {
"version": "1.7.2",
"resolved": "https://registry.npmjs.org/axios/-/axios-1.7.2.tgz",
"integrity": "sha512-2A8QhOMrbomlDuiLeK9XibIBzuHeRcqqNOHp0Cyp5EoJ1IFDh+XZH3A6BkXtv0K4gFGCI0Y4BM7B1wOEi0Rmgw==",
"dependencies": {
"follow-redirects": "^1.15.6",
"form-data": "^4.0.0",
"proxy-from-env": "^1.1.0"
}
},
"node_modules/babel-jest": { "node_modules/babel-jest": {
"version": "29.7.0", "version": "29.7.0",
"resolved": "https://registry.npmjs.org/babel-jest/-/babel-jest-29.7.0.tgz", "resolved": "https://registry.npmjs.org/babel-jest/-/babel-jest-29.7.0.tgz",
@@ -4478,6 +4500,25 @@
"integrity": "sha512-36yxDn5H7OFZQla0/jFJmbIKTdZAQHngCedGxiMmpNfEZM0sdEeT+WczLQrjK6D7o2aiyLYDnkw0R3JK0Qv1RQ==", "integrity": "sha512-36yxDn5H7OFZQla0/jFJmbIKTdZAQHngCedGxiMmpNfEZM0sdEeT+WczLQrjK6D7o2aiyLYDnkw0R3JK0Qv1RQ==",
"dev": true "dev": true
}, },
"node_modules/follow-redirects": {
"version": "1.15.6",
"resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz",
"integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==",
"funding": [
{
"type": "individual",
"url": "https://github.com/sponsors/RubenVerborgh"
}
],
"engines": {
"node": ">=4.0"
},
"peerDependenciesMeta": {
"debug": {
"optional": true
}
}
},
"node_modules/form-data": { "node_modules/form-data": {
"version": "4.0.0", "version": "4.0.0",
"resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.0.tgz", "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.0.tgz",
@@ -6359,6 +6400,11 @@
"node": ">= 6" "node": ">= 6"
} }
}, },
"node_modules/proxy-from-env": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/proxy-from-env/-/proxy-from-env-1.1.0.tgz",
"integrity": "sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg=="
},
"node_modules/punycode": { "node_modules/punycode": {
"version": "2.3.1", "version": "2.3.1",
"resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz", "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz",

View File

@@ -1,6 +1,16 @@
{ {
"name": "@lancedb/lancedb", "name": "@lancedb/lancedb",
"version": "0.5.0", "description": "LanceDB: A serverless, low-latency vector database for AI applications",
"keywords": [
"database",
"lance",
"lancedb",
"search",
"vector",
"vector database",
"ann"
],
"version": "0.5.2-final.1",
"main": "dist/index.js", "main": "dist/index.js",
"exports": { "exports": {
".": "./dist/index.js", ".": "./dist/index.js",
@@ -38,7 +48,8 @@
"typedoc": "^0.25.7", "typedoc": "^0.25.7",
"typedoc-plugin-markdown": "^3.17.1", "typedoc-plugin-markdown": "^3.17.1",
"typescript": "^5.3.3", "typescript": "^5.3.3",
"typescript-eslint": "^7.1.0" "typescript-eslint": "^7.1.0",
"@types/axios": "^0.14.0"
}, },
"ava": { "ava": {
"timeout": "3m" "timeout": "3m"
@@ -66,6 +77,7 @@
}, },
"dependencies": { "dependencies": {
"apache-arrow": "^15.0.0", "apache-arrow": "^15.0.0",
"axios": "^1.7.2",
"openai": "^4.29.2", "openai": "^4.29.2",
"reflect-metadata": "^0.2.2" "reflect-metadata": "^0.2.2"
} }

View File

@@ -56,12 +56,6 @@ impl Connection {
#[napi(factory)] #[napi(factory)]
pub async fn new(uri: String, options: ConnectionOptions) -> napi::Result<Self> { pub async fn new(uri: String, options: ConnectionOptions) -> napi::Result<Self> {
let mut builder = ConnectBuilder::new(&uri); let mut builder = ConnectBuilder::new(&uri);
if let Some(api_key) = options.api_key {
builder = builder.api_key(&api_key);
}
if let Some(host_override) = options.host_override {
builder = builder.host_override(&host_override);
}
if let Some(interval) = options.read_consistency_interval { if let Some(interval) = options.read_consistency_interval {
builder = builder =
builder.read_consistency_interval(std::time::Duration::from_secs_f64(interval)); builder.read_consistency_interval(std::time::Duration::from_secs_f64(interval));
@@ -126,6 +120,7 @@ impl Connection {
buf: Buffer, buf: Buffer,
mode: String, mode: String,
storage_options: Option<HashMap<String, String>>, storage_options: Option<HashMap<String, String>>,
use_legacy_format: Option<bool>,
) -> napi::Result<Table> { ) -> napi::Result<Table> {
let batches = ipc_file_to_batches(buf.to_vec()) let batches = ipc_file_to_batches(buf.to_vec())
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?; .map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
@@ -136,6 +131,9 @@ impl Connection {
builder = builder.storage_option(key, value); builder = builder.storage_option(key, value);
} }
} }
if let Some(use_legacy_format) = use_legacy_format {
builder = builder.use_legacy_format(use_legacy_format);
}
let tbl = builder let tbl = builder
.execute() .execute()
.await .await
@@ -150,6 +148,7 @@ impl Connection {
schema_buf: Buffer, schema_buf: Buffer,
mode: String, mode: String,
storage_options: Option<HashMap<String, String>>, storage_options: Option<HashMap<String, String>>,
use_legacy_format: Option<bool>,
) -> napi::Result<Table> { ) -> napi::Result<Table> {
let schema = ipc_file_to_schema(schema_buf.to_vec()).map_err(|e| { let schema = ipc_file_to_schema(schema_buf.to_vec()).map_err(|e| {
napi::Error::from_reason(format!("Failed to marshal schema from JS to Rust: {}", e)) napi::Error::from_reason(format!("Failed to marshal schema from JS to Rust: {}", e))
@@ -164,6 +163,9 @@ impl Connection {
builder = builder.storage_option(key, value); builder = builder.storage_option(key, value);
} }
} }
if let Some(use_legacy_format) = use_legacy_format {
builder = builder.use_legacy_format(use_legacy_format);
}
let tbl = builder let tbl = builder
.execute() .execute()
.await .await

View File

@@ -20,6 +20,7 @@ mod connection;
mod error; mod error;
mod index; mod index;
mod iterator; mod iterator;
pub mod merge;
mod query; mod query;
mod table; mod table;
mod util; mod util;
@@ -27,8 +28,6 @@ mod util;
#[napi(object)] #[napi(object)]
#[derive(Debug)] #[derive(Debug)]
pub struct ConnectionOptions { pub struct ConnectionOptions {
pub api_key: Option<String>,
pub host_override: Option<String>,
/// (For LanceDB OSS only): The interval, in seconds, at which to check for /// (For LanceDB OSS only): The interval, in seconds, at which to check for
/// updates to the table from other processes. If None, then consistency is not /// updates to the table from other processes. If None, then consistency is not
/// checked. For performance reasons, this is the default. For strong /// checked. For performance reasons, this is the default. For strong
@@ -56,6 +55,7 @@ pub enum WriteMode {
/// Write options when creating a Table. /// Write options when creating a Table.
#[napi(object)] #[napi(object)]
pub struct WriteOptions { pub struct WriteOptions {
/// Write mode for writing to a table.
pub mode: Option<WriteMode>, pub mode: Option<WriteMode>,
} }

53
nodejs/src/merge.rs Normal file
View File

@@ -0,0 +1,53 @@
use lancedb::{arrow::IntoArrow, ipc::ipc_file_to_batches, table::merge::MergeInsertBuilder};
use napi::bindgen_prelude::*;
use napi_derive::napi;
#[napi]
#[derive(Clone)]
/// A builder used to create and run a merge insert operation
pub struct NativeMergeInsertBuilder {
pub(crate) inner: MergeInsertBuilder,
}
#[napi]
impl NativeMergeInsertBuilder {
#[napi]
pub fn when_matched_update_all(&self, condition: Option<String>) -> Self {
let mut this = self.clone();
this.inner.when_matched_update_all(condition);
this
}
#[napi]
pub fn when_not_matched_insert_all(&self) -> Self {
let mut this = self.clone();
this.inner.when_not_matched_insert_all();
this
}
#[napi]
pub fn when_not_matched_by_source_delete(&self, filter: Option<String>) -> Self {
let mut this = self.clone();
this.inner.when_not_matched_by_source_delete(filter);
this
}
#[napi]
pub async fn execute(&self, buf: Buffer) -> napi::Result<()> {
let data = ipc_file_to_batches(buf.to_vec())
.and_then(IntoArrow::into_arrow)
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
let this = self.clone();
this.inner
.execute(data)
.await
.map_err(|e| napi::Error::from_reason(format!("Failed to execute merge insert: {}", e)))
}
}
impl From<MergeInsertBuilder> for NativeMergeInsertBuilder {
fn from(inner: MergeInsertBuilder) -> Self {
Self { inner }
}
}

View File

@@ -15,6 +15,7 @@
use lancedb::query::ExecutableQuery; use lancedb::query::ExecutableQuery;
use lancedb::query::Query as LanceDbQuery; use lancedb::query::Query as LanceDbQuery;
use lancedb::query::QueryBase; use lancedb::query::QueryBase;
use lancedb::query::QueryExecutionOptions;
use lancedb::query::Select; use lancedb::query::Select;
use lancedb::query::VectorQuery as LanceDbVectorQuery; use lancedb::query::VectorQuery as LanceDbVectorQuery;
use napi::bindgen_prelude::*; use napi::bindgen_prelude::*;
@@ -62,10 +63,21 @@ impl Query {
} }
#[napi] #[napi]
pub async fn execute(&self) -> napi::Result<RecordBatchIterator> { pub async fn execute(
let inner_stream = self.inner.execute().await.map_err(|e| { &self,
napi::Error::from_reason(format!("Failed to execute query stream: {}", e)) max_batch_length: Option<u32>,
})?; ) -> napi::Result<RecordBatchIterator> {
let mut execution_opts = QueryExecutionOptions::default();
if let Some(max_batch_length) = max_batch_length {
execution_opts.max_batch_length = max_batch_length;
}
let inner_stream = self
.inner
.execute_with_options(execution_opts)
.await
.map_err(|e| {
napi::Error::from_reason(format!("Failed to execute query stream: {}", e))
})?;
Ok(RecordBatchIterator::new(inner_stream)) Ok(RecordBatchIterator::new(inner_stream))
} }
} }
@@ -125,10 +137,21 @@ impl VectorQuery {
} }
#[napi] #[napi]
pub async fn execute(&self) -> napi::Result<RecordBatchIterator> { pub async fn execute(
let inner_stream = self.inner.execute().await.map_err(|e| { &self,
napi::Error::from_reason(format!("Failed to execute query stream: {}", e)) max_batch_length: Option<u32>,
})?; ) -> napi::Result<RecordBatchIterator> {
let mut execution_opts = QueryExecutionOptions::default();
if let Some(max_batch_length) = max_batch_length {
execution_opts.max_batch_length = max_batch_length;
}
let inner_stream = self
.inner
.execute_with_options(execution_opts)
.await
.map_err(|e| {
napi::Error::from_reason(format!("Failed to execute query stream: {}", e))
})?;
Ok(RecordBatchIterator::new(inner_stream)) Ok(RecordBatchIterator::new(inner_stream))
} }
} }

View File

@@ -23,13 +23,14 @@ use napi_derive::napi;
use crate::error::NapiErrorExt; use crate::error::NapiErrorExt;
use crate::index::Index; use crate::index::Index;
use crate::merge::NativeMergeInsertBuilder;
use crate::query::{Query, VectorQuery}; use crate::query::{Query, VectorQuery};
#[napi] #[napi]
pub struct Table { pub struct Table {
// We keep a duplicate of the table name so we can use it for error // We keep a duplicate of the table name so we can use it for error
// messages even if the table has been closed // messages even if the table has been closed
name: String, pub name: String,
pub(crate) inner: Option<LanceDbTable>, pub(crate) inner: Option<LanceDbTable>,
} }
@@ -328,16 +329,31 @@ impl Table {
.map(IndexConfig::from) .map(IndexConfig::from)
.collect::<Vec<_>>()) .collect::<Vec<_>>())
} }
#[napi]
pub async fn index_stats(&self, index_name: String) -> napi::Result<Option<IndexStatistics>> {
let tbl = self.inner_ref()?.as_native().unwrap();
let stats = tbl.index_stats(&index_name).await.default_error()?;
Ok(stats.map(IndexStatistics::from))
}
#[napi]
pub fn merge_insert(&self, on: Vec<String>) -> napi::Result<NativeMergeInsertBuilder> {
let on: Vec<_> = on.iter().map(String::as_str).collect();
Ok(self.inner_ref()?.merge_insert(on.as_slice()).into())
}
} }
#[napi(object)] #[napi(object)]
/// A description of an index currently configured on a column /// A description of an index currently configured on a column
pub struct IndexConfig { pub struct IndexConfig {
/// The name of the index
pub name: String,
/// The type of the index /// The type of the index
pub index_type: String, pub index_type: String,
/// The columns in the index /// The columns in the index
/// ///
/// Currently this is always an array of size 1. In the future there may /// Currently this is always an array of size 1. In the future there may
/// be more columns to represent composite indices. /// be more columns to represent composite indices.
pub columns: Vec<String>, pub columns: Vec<String>,
} }
@@ -348,6 +364,7 @@ impl From<lancedb::index::IndexConfig> for IndexConfig {
Self { Self {
index_type, index_type,
columns: value.columns, columns: value.columns,
name: value.name,
} }
} }
} }
@@ -430,3 +447,40 @@ pub struct AddColumnsSql {
/// The expression can reference other columns in the table. /// The expression can reference other columns in the table.
pub value_sql: String, pub value_sql: String,
} }
#[napi(object)]
pub struct IndexStatistics {
/// The number of rows indexed by the index
pub num_indexed_rows: f64,
/// The number of rows not indexed
pub num_unindexed_rows: f64,
/// The type of the index
pub index_type: Option<String>,
/// The metadata for each index
pub indices: Vec<IndexMetadata>,
}
impl From<lancedb::index::IndexStatistics> for IndexStatistics {
fn from(value: lancedb::index::IndexStatistics) -> Self {
Self {
num_indexed_rows: value.num_indexed_rows as f64,
num_unindexed_rows: value.num_unindexed_rows as f64,
index_type: value.index_type.map(|t| format!("{:?}", t)),
indices: value.indices.into_iter().map(Into::into).collect(),
}
}
}
#[napi(object)]
pub struct IndexMetadata {
pub metric_type: Option<String>,
pub index_type: Option<String>,
}
impl From<lancedb::index::IndexMetadata> for IndexMetadata {
fn from(value: lancedb::index::IndexMetadata) -> Self {
Self {
metric_type: value.metric_type,
index_type: value.index_type,
}
}
}

View File

@@ -1,5 +1,5 @@
[tool.bumpversion] [tool.bumpversion]
current_version = "0.8.1" current_version = "0.9.0-beta.8"
parse = """(?x) parse = """(?x)
(?P<major>0|[1-9]\\d*)\\. (?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\. (?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb-python" name = "lancedb-python"
version = "0.8.1" version = "0.9.0-beta.8"
edition.workspace = true edition.workspace = true
description = "Python bindings for LanceDB" description = "Python bindings for LanceDB"
license.workspace = true license.workspace = true
@@ -19,6 +19,8 @@ lancedb = { path = "../rust/lancedb" }
env_logger = "0.10" env_logger = "0.10"
pyo3 = { version = "0.20", features = ["extension-module", "abi3-py38"] } pyo3 = { version = "0.20", features = ["extension-module", "abi3-py38"] }
pyo3-asyncio = { version = "0.20", features = ["attributes", "tokio-runtime"] } pyo3-asyncio = { version = "0.20", features = ["attributes", "tokio-runtime"] }
base64ct = "=1.6.0" # workaround for https://github.com/RustCrypto/formats/issues/1684
chrono = "=0.4.39"
# Prevent dynamic linking of lzma, which comes from datafusion # Prevent dynamic linking of lzma, which comes from datafusion
lzma-sys = { version = "*", features = ["static"] } lzma-sys = { version = "*", features = ["static"] }

View File

@@ -3,7 +3,7 @@ name = "lancedb"
# version in Cargo.toml # version in Cargo.toml
dependencies = [ dependencies = [
"deprecation", "deprecation",
"pylance==0.11.1", "pylance==0.13.0",
"ratelimiter~=1.0", "ratelimiter~=1.0",
"requests>=2.31.0", "requests>=2.31.0",
"retry>=0.9.2", "retry>=0.9.2",
@@ -13,6 +13,7 @@ dependencies = [
"packaging", "packaging",
"cachetools", "cachetools",
"overrides>=0.7", "overrides>=0.7",
"urllib3==1.26.19"
] ]
description = "lancedb" description = "lancedb"
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }] authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
@@ -57,15 +58,10 @@ tests = [
"duckdb", "duckdb",
"pytz", "pytz",
"polars>=0.19", "polars>=0.19",
"tantivy" "tantivy",
] ]
dev = ["ruff", "pre-commit"] dev = ["ruff", "pre-commit"]
docs = [ docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
"mkdocs",
"mkdocs-jupyter",
"mkdocs-material",
"mkdocstrings[python]",
]
clip = ["torch", "pillow", "open-clip"] clip = ["torch", "pillow", "open-clip"]
embeddings = [ embeddings = [
"openai>=1.6.1", "openai>=1.6.1",
@@ -100,5 +96,5 @@ addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py"
markers = [ markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')", "slow: marks tests as slow (deselect with '-m \"not slow\"')",
"asyncio", "asyncio",
"s3_test" "s3_test",
] ]

View File

@@ -35,6 +35,7 @@ def connect(
host_override: Optional[str] = None, host_override: Optional[str] = None,
read_consistency_interval: Optional[timedelta] = None, read_consistency_interval: Optional[timedelta] = None,
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None, request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
storage_options: Optional[Dict[str, str]] = None,
**kwargs, **kwargs,
) -> DBConnection: ) -> DBConnection:
"""Connect to a LanceDB database. """Connect to a LanceDB database.
@@ -70,6 +71,9 @@ def connect(
executor will be used for making requests. This is for LanceDB Cloud executor will be used for making requests. This is for LanceDB Cloud
only and is only used when making batch requests (i.e., passing in only and is only used when making batch requests (i.e., passing in
multiple queries to the search method at once). multiple queries to the search method at once).
storage_options: dict, optional
Additional options for the storage backend. See available options at
https://lancedb.github.io/lancedb/guides/storage/
Examples Examples
-------- --------
@@ -105,12 +109,16 @@ def connect(
region, region,
host_override, host_override,
request_thread_pool=request_thread_pool, request_thread_pool=request_thread_pool,
storage_options=storage_options,
**kwargs, **kwargs,
) )
if kwargs: if kwargs:
raise ValueError(f"Unknown keyword arguments: {kwargs}") raise ValueError(f"Unknown keyword arguments: {kwargs}")
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval) return LanceDBConnection(
uri,
read_consistency_interval=read_consistency_interval,
)
async def connect_async( async def connect_async(

View File

@@ -24,6 +24,7 @@ class Connection(object):
mode: str, mode: str,
data: pa.RecordBatchReader, data: pa.RecordBatchReader,
storage_options: Optional[Dict[str, str]] = None, storage_options: Optional[Dict[str, str]] = None,
use_legacy_format: Optional[bool] = None,
) -> Table: ... ) -> Table: ...
async def create_empty_table( async def create_empty_table(
self, self,
@@ -31,6 +32,7 @@ class Connection(object):
mode: str, mode: str,
schema: pa.Schema, schema: pa.Schema,
storage_options: Optional[Dict[str, str]] = None, storage_options: Optional[Dict[str, str]] = None,
use_legacy_format: Optional[bool] = None,
) -> Table: ... ) -> Table: ...
class Table: class Table:
@@ -72,7 +74,7 @@ class Query:
def select(self, columns: Tuple[str, str]): ... def select(self, columns: Tuple[str, str]): ...
def limit(self, limit: int): ... def limit(self, limit: int): ...
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ... def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
async def execute(self) -> RecordBatchStream: ... async def execute(self, max_batch_legnth: Optional[int]) -> RecordBatchStream: ...
class VectorQuery: class VectorQuery:
async def execute(self) -> RecordBatchStream: ... async def execute(self) -> RecordBatchStream: ...

View File

@@ -558,6 +558,8 @@ class AsyncConnection(object):
on_bad_vectors: Optional[str] = None, on_bad_vectors: Optional[str] = None,
fill_value: Optional[float] = None, fill_value: Optional[float] = None,
storage_options: Optional[Dict[str, str]] = None, storage_options: Optional[Dict[str, str]] = None,
*,
use_legacy_format: Optional[bool] = None,
) -> AsyncTable: ) -> AsyncTable:
"""Create an [AsyncTable][lancedb.table.AsyncTable] in the database. """Create an [AsyncTable][lancedb.table.AsyncTable] in the database.
@@ -600,6 +602,9 @@ class AsyncConnection(object):
connection will be inherited by the table, but can be overridden here. connection will be inherited by the table, but can be overridden here.
See available options at See available options at
https://lancedb.github.io/lancedb/guides/storage/ https://lancedb.github.io/lancedb/guides/storage/
use_legacy_format: bool, optional, default True
If True, use the legacy format for the table. If False, use the new format.
The default is True while the new format is in beta.
Returns Returns
@@ -761,7 +766,11 @@ class AsyncConnection(object):
if data is None: if data is None:
new_table = await self._inner.create_empty_table( new_table = await self._inner.create_empty_table(
name, mode, schema, storage_options=storage_options name,
mode,
schema,
storage_options=storage_options,
use_legacy_format=use_legacy_format,
) )
else: else:
data = data_to_reader(data, schema) data = data_to_reader(data, schema)
@@ -770,6 +779,7 @@ class AsyncConnection(object):
mode, mode,
data, data,
storage_options=storage_options, storage_options=storage_options,
use_legacy_format=use_legacy_format,
) )
return AsyncTable(new_table) return AsyncTable(new_table)

View File

@@ -153,7 +153,7 @@ class TextEmbeddingFunction(EmbeddingFunction):
@abstractmethod @abstractmethod
def generate_embeddings( def generate_embeddings(
self, texts: Union[List[str], np.ndarray] self, texts: Union[List[str], np.ndarray], *args, **kwargs
) -> List[np.array]: ) -> List[np.array]:
""" """
Generate the embeddings for the given texts Generate the embeddings for the given texts

View File

@@ -73,6 +73,8 @@ class BedRockText(TextEmbeddingFunction):
assumed_role: Union[str, None] = None assumed_role: Union[str, None] = None
profile_name: Union[str, None] = None profile_name: Union[str, None] = None
role_session_name: str = "lancedb-embeddings" role_session_name: str = "lancedb-embeddings"
source_input_type: str = "search_document"
query_input_type: str = "search_query"
if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat
@@ -87,21 +89,29 @@ class BedRockText(TextEmbeddingFunction):
# TODO: fix hardcoding # TODO: fix hardcoding
if self.name == "amazon.titan-embed-text-v1": if self.name == "amazon.titan-embed-text-v1":
return 1536 return 1536
elif self.name in {"cohere.embed-english-v3", "cohere.embed-multilingual-v3"}: elif self.name in [
"amazon.titan-embed-text-v2:0",
"cohere.embed-english-v3",
"cohere.embed-multilingual-v3",
]:
# TODO: "amazon.titan-embed-text-v2:0" model supports dynamic ndims
return 1024 return 1024
else: else:
raise ValueError(f"Unknown model name: {self.name}") raise ValueError(f"Model {self.name} not supported")
def compute_query_embeddings( def compute_query_embeddings(
self, query: str, *args, **kwargs self, query: str, *args, **kwargs
) -> List[List[float]]: ) -> List[List[float]]:
return self.compute_source_embeddings(query) return self.compute_source_embeddings(query, input_type=self.query_input_type)
def compute_source_embeddings( def compute_source_embeddings(
self, texts: TEXT, *args, **kwargs self, texts: TEXT, *args, **kwargs
) -> List[List[float]]: ) -> List[List[float]]:
texts = self.sanitize_input(texts) texts = self.sanitize_input(texts)
return self.generate_embeddings(texts) # assume source input type if not passed by `compute_query_embeddings`
kwargs["input_type"] = kwargs.get("input_type") or self.source_input_type
return self.generate_embeddings(texts, **kwargs)
def generate_embeddings( def generate_embeddings(
self, texts: Union[List[str], np.ndarray], *args, **kwargs self, texts: Union[List[str], np.ndarray], *args, **kwargs
@@ -121,11 +131,11 @@ class BedRockText(TextEmbeddingFunction):
""" """
results = [] results = []
for text in texts: for text in texts:
response = self._generate_embedding(text) response = self._generate_embedding(text, *args, **kwargs)
results.append(response) results.append(response)
return results return results
def _generate_embedding(self, text: str) -> List[float]: def _generate_embedding(self, text: str, *args, **kwargs) -> List[float]:
""" """
Get the embeddings for the given texts Get the embeddings for the given texts
@@ -141,14 +151,12 @@ class BedRockText(TextEmbeddingFunction):
""" """
# format input body for provider # format input body for provider
provider = self.name.split(".")[0] provider = self.name.split(".")[0]
_model_kwargs = {} input_body = {**kwargs}
input_body = {**_model_kwargs}
if provider == "cohere": if provider == "cohere":
if "input_type" not in input_body.keys():
input_body["input_type"] = "search_document"
input_body["texts"] = [text] input_body["texts"] = [text]
else: else:
# includes common provider == "amazon" # includes common provider == "amazon"
input_body.pop("input_type", None)
input_body["inputText"] = text input_body["inputText"] = text
body = json.dumps(input_body) body = json.dumps(input_body)

View File

@@ -19,7 +19,7 @@ import numpy as np
from ..util import attempt_import_or_raise from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction from .base import TextEmbeddingFunction
from .registry import register from .registry import register
from .utils import api_key_not_found_help from .utils import api_key_not_found_help, TEXT
@register("cohere") @register("cohere")
@@ -32,8 +32,36 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
Parameters Parameters
---------- ----------
name: str, default "embed-multilingual-v2.0" name: str, default "embed-multilingual-v2.0"
The name of the model to use. See the Cohere documentation for The name of the model to use. List of acceptable models:
a list of available models.
* embed-english-v3.0
* embed-multilingual-v3.0
* embed-english-light-v3.0
* embed-multilingual-light-v3.0
* embed-english-v2.0
* embed-english-light-v2.0
* embed-multilingual-v2.0
source_input_type: str, default "search_document"
The input type for the source column in the database
query_input_type: str, default "search_query"
The input type for the query column in the database
Cohere supports following input types:
| Input Type | Description |
|-------------------------|---------------------------------------|
| "`search_document`" | Used for embeddings stored in a vector|
| | database for search use-cases. |
| "`search_query`" | Used for embeddings of search queries |
| | run against a vector DB |
| "`semantic_similarity`" | Specifies the given text will be used |
| | for Semantic Textual Similarity (STS) |
| "`classification`" | Used for embeddings passed through a |
| | text classifier. |
| "`clustering`" | Used for the embeddings run through a |
| | clustering algorithm |
Examples Examples
-------- --------
@@ -61,14 +89,39 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
""" """
name: str = "embed-multilingual-v2.0" name: str = "embed-multilingual-v2.0"
source_input_type: str = "search_document"
query_input_type: str = "search_query"
client: ClassVar = None client: ClassVar = None
def ndims(self): def ndims(self):
# TODO: fix hardcoding # TODO: fix hardcoding
return 768 if self.name in [
"embed-english-v3.0",
"embed-multilingual-v3.0",
"embed-english-light-v2.0",
]:
return 1024
elif self.name in ["embed-english-light-v3.0", "embed-multilingual-light-v3.0"]:
return 384
elif self.name == "embed-english-v2.0":
return 4096
elif self.name == "embed-multilingual-v2.0":
return 768
else:
raise ValueError(f"Model {self.name} not supported")
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
return self.compute_source_embeddings(query, input_type=self.query_input_type)
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
texts = self.sanitize_input(texts)
input_type = (
kwargs.get("input_type") or self.source_input_type
) # assume source input type if not passed by `compute_query_embeddings`
return self.generate_embeddings(texts, input_type=input_type)
def generate_embeddings( def generate_embeddings(
self, texts: Union[List[str], np.ndarray] self, texts: Union[List[str], np.ndarray], *args, **kwargs
) -> List[np.array]: ) -> List[np.array]:
""" """
Get the embeddings for the given texts Get the embeddings for the given texts
@@ -78,9 +131,10 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
texts: list[str] or np.ndarray (of str) texts: list[str] or np.ndarray (of str)
The texts to embed The texts to embed
""" """
# TODO retry, rate limit, token limit
self._init_client() self._init_client()
rs = CohereEmbeddingFunction.client.embed(texts=texts, model=self.name) rs = CohereEmbeddingFunction.client.embed(
texts=texts, model=self.name, **kwargs
)
return [emb for emb in rs.embeddings] return [emb for emb in rs.embeddings]

View File

@@ -29,7 +29,10 @@ from .table import LanceTable
def create_index( def create_index(
index_path: str, text_fields: List[str], ordering_fields: List[str] = None index_path: str,
text_fields: List[str],
ordering_fields: List[str] = None,
tokenizer_name: str = "default",
) -> tantivy.Index: ) -> tantivy.Index:
""" """
Create a new Index (not populated) Create a new Index (not populated)
@@ -42,6 +45,8 @@ def create_index(
List of text fields to index List of text fields to index
ordering_fields: List[str] ordering_fields: List[str]
List of unsigned type fields to order by at search time List of unsigned type fields to order by at search time
tokenizer_name : str, default "default"
The tokenizer to use
Returns Returns
------- -------
@@ -56,7 +61,7 @@ def create_index(
schema_builder.add_integer_field("doc_id", stored=True) schema_builder.add_integer_field("doc_id", stored=True)
# data fields # data fields
for name in text_fields: for name in text_fields:
schema_builder.add_text_field(name, stored=True) schema_builder.add_text_field(name, stored=True, tokenizer_name=tokenizer_name)
if ordering_fields: if ordering_fields:
for name in ordering_fields: for name in ordering_fields:
schema_builder.add_unsigned_field(name, fast=True) schema_builder.add_unsigned_field(name, fast=True)

View File

@@ -117,6 +117,8 @@ class Query(pydantic.BaseModel):
with_row_id: bool = False with_row_id: bool = False
fast_search: bool = False
class LanceQueryBuilder(ABC): class LanceQueryBuilder(ABC):
"""An abstract query builder. Subclasses are defined for vector search, """An abstract query builder. Subclasses are defined for vector search,
@@ -125,12 +127,14 @@ class LanceQueryBuilder(ABC):
@classmethod @classmethod
def create( def create(
cls, cls,
table: "Table", table: "Table",
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]], query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
query_type: str, query_type: str,
vector_column_name: str, vector_column_name: str,
ordering_field_name: str = None, ordering_field_name: Optional[str] = None,
fts_columns: Union[str, List[str]] = [],
fast_search: bool = False,
) -> LanceQueryBuilder: ) -> LanceQueryBuilder:
""" """
Create a query builder based on the given query and query type. Create a query builder based on the given query and query type.
@@ -147,13 +151,18 @@ class LanceQueryBuilder(ABC):
If "auto", the query type is inferred based on the query. If "auto", the query type is inferred based on the query.
vector_column_name: str vector_column_name: str
The name of the vector column to use for vector search. The name of the vector column to use for vector search.
fast_search: bool
Skip flat search of unindexed data.
""" """
if query is None: # Check hybrid search first as it supports empty query pattern
return LanceEmptyQueryBuilder(table)
if query_type == "hybrid": if query_type == "hybrid":
# hybrid fts and vector query # hybrid fts and vector query
return LanceHybridQueryBuilder(table, query, vector_column_name) return LanceHybridQueryBuilder(
table, query, vector_column_name, fts_columns=fts_columns
)
if query is None:
return LanceEmptyQueryBuilder(table)
# remember the string query for reranking purpose # remember the string query for reranking purpose
str_query = query if isinstance(query, str) else None str_query = query if isinstance(query, str) else None
@@ -165,12 +174,17 @@ class LanceQueryBuilder(ABC):
) )
if query_type == "hybrid": if query_type == "hybrid":
return LanceHybridQueryBuilder(table, query, vector_column_name) return LanceHybridQueryBuilder(
table, query, vector_column_name, fts_columns=fts_columns
)
if isinstance(query, str): if isinstance(query, str):
# fts # fts
return LanceFtsQueryBuilder( return LanceFtsQueryBuilder(
table, query, ordering_field_name=ordering_field_name table,
query,
ordering_field_name=ordering_field_name,
fts_columns=fts_columns,
) )
if isinstance(query, list): if isinstance(query, list):
@@ -180,7 +194,9 @@ class LanceQueryBuilder(ABC):
else: else:
raise TypeError(f"Unsupported query type: {type(query)}") raise TypeError(f"Unsupported query type: {type(query)}")
return LanceVectorQueryBuilder(table, query, vector_column_name, str_query) return LanceVectorQueryBuilder(
table, query, vector_column_name, str_query, fast_search
)
@classmethod @classmethod
def _resolve_query(cls, table, query, query_type, vector_column_name): def _resolve_query(cls, table, query, query_type, vector_column_name):
@@ -196,8 +212,6 @@ class LanceQueryBuilder(ABC):
elif query_type == "auto": elif query_type == "auto":
if isinstance(query, (list, np.ndarray)): if isinstance(query, (list, np.ndarray)):
return query, "vector" return query, "vector"
if isinstance(query, tuple):
return query, "hybrid"
else: else:
conf = table.embedding_functions.get(vector_column_name) conf = table.embedding_functions.get(vector_column_name)
if conf is not None: if conf is not None:
@@ -224,9 +238,14 @@ class LanceQueryBuilder(ABC):
def __init__(self, table: "Table"): def __init__(self, table: "Table"):
self._table = table self._table = table
self._limit = 10 self._limit = 10
self._offset = 0
self._columns = None self._columns = None
self._where = None self._where = None
self._prefilter = False
self._with_row_id = False self._with_row_id = False
self._vector = None
self._text = None
self._ef = None
@deprecation.deprecated( @deprecation.deprecated(
deprecated_in="0.3.1", deprecated_in="0.3.1",
@@ -337,11 +356,13 @@ class LanceQueryBuilder(ABC):
---------- ----------
limit: int limit: int
The maximum number of results to return. The maximum number of results to return.
By default the query is limited to the first 10. The default query limit is 10 results.
Call this method and pass 0, a negative value, For ANN/KNN queries, you must specify a limit.
or None to remove the limit. Entering 0, a negative number, or None will reset
*WARNING* if you have a large dataset, removing the limit to the default value of 10.
the limit can potentially result in reading a *WARNING* if you have a large dataset, setting
the limit to a large number, e.g. the table size,
can potentially result in reading a
large amount of data into memory and cause large amount of data into memory and cause
out of memory issues. out of memory issues.
@@ -351,11 +372,33 @@ class LanceQueryBuilder(ABC):
The LanceQueryBuilder object. The LanceQueryBuilder object.
""" """
if limit is None or limit <= 0: if limit is None or limit <= 0:
self._limit = None if isinstance(self, LanceVectorQueryBuilder):
raise ValueError("Limit is required for ANN/KNN queries")
else:
self._limit = None
else: else:
self._limit = limit self._limit = limit
return self return self
def offset(self, offset: int) -> LanceQueryBuilder:
"""Set the offset for the results.
Parameters
----------
offset: int
The offset to start fetching results from.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
if offset is None or offset <= 0:
self._offset = 0
else:
self._offset = offset
return self
def select(self, columns: Union[list[str], dict[str, str]]) -> LanceQueryBuilder: def select(self, columns: Union[list[str], dict[str, str]]) -> LanceQueryBuilder:
"""Set the columns to return. """Set the columns to return.
@@ -417,6 +460,80 @@ class LanceQueryBuilder(ABC):
self._with_row_id = with_row_id self._with_row_id = with_row_id
return self return self
def explain_plan(self, verbose: Optional[bool] = False) -> str:
"""Return the execution plan for this query.
Examples
--------
>>> import lancedb
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", [{"vector": [99, 99]}])
>>> query = [100, 100]
>>> plan = table.search(query).explain_plan(True)
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
ProjectionExec: expr=[vector@0 as vector, _distance@2 as _distance]
GlobalLimitExec: skip=0, fetch=10
FilterExec: _distance@2 IS NOT NULL
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false]
KNNVectorDistance: metric=l2
LanceScan: uri=..., projection=[vector], row_id=true, row_addr=false, ordered=false
Parameters
----------
verbose : bool, default False
Use a verbose output format.
Returns
-------
plan : str
""" # noqa: E501
ds = self._table.to_lance()
return ds.scanner(
nearest={
"column": self._vector_column,
"q": self._query,
"k": self._limit,
"metric": self._metric,
"nprobes": self._nprobes,
"refine_factor": self._refine_factor,
},
prefilter=self._prefilter,
filter=self._str_query,
limit=self._limit,
with_row_id=self._with_row_id,
offset=self._offset,
).explain_plan(verbose)
def vector(self, vector: Union[np.ndarray, list]) -> LanceQueryBuilder:
"""Set the vector to search for.
Parameters
----------
vector: np.ndarray or list
The vector to search for.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
raise NotImplementedError
def text(self, text: str) -> LanceQueryBuilder:
"""Set the text to search for.
Parameters
----------
text: str
The text to search for.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
raise NotImplementedError
class LanceVectorQueryBuilder(LanceQueryBuilder): class LanceVectorQueryBuilder(LanceQueryBuilder):
""" """
@@ -440,11 +557,12 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
""" """
def __init__( def __init__(
self, self,
table: "Table", table: "Table",
query: Union[np.ndarray, list, "PIL.Image.Image"], query: Union[np.ndarray, list, "PIL.Image.Image"],
vector_column: str, vector_column: str,
str_query: Optional[str] = None, str_query: Optional[str] = None,
fast_search: bool = False,
): ):
super().__init__(table) super().__init__(table)
self._query = query self._query = query
@@ -455,13 +573,14 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._prefilter = False self._prefilter = False
self._reranker = None self._reranker = None
self._str_query = str_query self._str_query = str_query
self._fast_search = fast_search
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder: def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceVectorQueryBuilder:
"""Set the distance metric to use. """Set the distance metric to use.
Parameters Parameters
---------- ----------
metric: "L2" or "cosine" metric: "L2" or "cosine" or "dot"
The distance metric to use. By default "L2" is used. The distance metric to use. By default "L2" is used.
Returns Returns
@@ -469,7 +588,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
LanceVectorQueryBuilder LanceVectorQueryBuilder
The LanceQueryBuilder object. The LanceQueryBuilder object.
""" """
self._metric = metric self._metric = metric.lower()
return self return self
def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder: def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder:
@@ -494,6 +613,28 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._nprobes = nprobes self._nprobes = nprobes
return self return self
def ef(self, ef: int) -> LanceVectorQueryBuilder:
"""Set the number of candidates to consider during search.
Higher values will yield better recall (more likely to find vectors if
they exist) at the expense of latency.
This only applies to the HNSW-related index.
The default value is 1.5 * limit.
Parameters
----------
ef: int
The number of candidates to consider during search.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._ef = ef
return self
def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder: def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder:
"""Set the refine factor to use, increasing the number of vectors sampled. """Set the refine factor to use, increasing the number of vectors sampled.
@@ -554,15 +695,11 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
refine_factor=self._refine_factor, refine_factor=self._refine_factor,
vector_column=self._vector_column, vector_column=self._vector_column,
with_row_id=self._with_row_id, with_row_id=self._with_row_id,
offset=self._offset,
fast_search=self._fast_search,
ef=self._ef,
) )
result_set = self._table._execute_query(query, batch_size) result_set = self._table._execute_query(query, batch_size)
if self._reranker is not None:
rs_table = result_set.read_all()
result_set = self._reranker.rerank_vector(self._str_query, rs_table)
# convert result_set back to RecordBatchReader
result_set = pa.RecordBatchReader.from_batches(
result_set.schema, result_set.to_batches()
)
return result_set return result_set
@@ -591,7 +728,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
return self return self
def rerank( def rerank(
self, reranker: Reranker, query_string: Optional[str] = None self, reranker: Reranker, query_string: Optional[str] = None
) -> LanceVectorQueryBuilder: ) -> LanceVectorQueryBuilder:
"""Rerank the results using the specified reranker. """Rerank the results using the specified reranker.
@@ -756,12 +893,34 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
class LanceEmptyQueryBuilder(LanceQueryBuilder): class LanceEmptyQueryBuilder(LanceQueryBuilder):
def to_arrow(self) -> pa.Table: def to_arrow(self) -> pa.Table:
ds = self._table.to_lance() return self.to_batches().read_all()
return ds.to_table(
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
query = Query(
columns=self._columns, columns=self._columns,
filter=self._where, filter=self._where,
limit=self._limit, k=self._limit or 10,
with_row_id=self._with_row_id,
vector=[],
# not actually respected in remote query
offset=self._offset or 0,
) )
return self._table._execute_query(query)
def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder:
"""Rerank the results using the specified reranker.
Parameters
----------
reranker: Reranker
The reranker to use.
Returns
-------
LanceEmptyQueryBuilder
The LanceQueryBuilder object.
"""
raise NotImplementedError("Reranking is not yet supported.")
class LanceHybridQueryBuilder(LanceQueryBuilder): class LanceHybridQueryBuilder(LanceQueryBuilder):
@@ -1113,11 +1272,22 @@ class AsyncQueryBase(object):
self._inner.limit(limit) self._inner.limit(limit)
return self return self
async def to_batches(self) -> AsyncRecordBatchReader: async def to_batches(
self, *, max_batch_length: Optional[int] = None
) -> AsyncRecordBatchReader:
""" """
Execute the query and return the results as an Apache Arrow RecordBatchReader. Execute the query and return the results as an Apache Arrow RecordBatchReader.
Parameters
----------
max_batch_length: Optional[int]
The maximum number of selected records in a single RecordBatch object.
If not specified, a default batch length is used.
It is possible for batches to be smaller than the provided length if the
underlying data is stored in smaller chunks.
""" """
return AsyncRecordBatchReader(await self._inner.execute()) return AsyncRecordBatchReader(await self._inner.execute(max_batch_length))
async def to_arrow(self) -> pa.Table: async def to_arrow(self) -> pa.Table:
""" """

View File

@@ -55,11 +55,13 @@ class RestfulLanceDBClient:
region: str region: str
api_key: Credential api_key: Credential
host_override: Optional[str] = attrs.field(default=None) host_override: Optional[str] = attrs.field(default=None)
db_prefix: Optional[str] = attrs.field(default=None)
closed: bool = attrs.field(default=False, init=False) closed: bool = attrs.field(default=False, init=False)
connection_timeout: float = attrs.field(default=120.0, kw_only=True) connection_timeout: float = attrs.field(default=120.0, kw_only=True)
read_timeout: float = attrs.field(default=300.0, kw_only=True) read_timeout: float = attrs.field(default=300.0, kw_only=True)
storage_options: Optional[Dict[str, str]] = attrs.field(default=None, kw_only=True)
@functools.cached_property @functools.cached_property
def session(self) -> requests.Session: def session(self) -> requests.Session:
@@ -92,6 +94,18 @@ class RestfulLanceDBClient:
headers["Host"] = f"{self.db_name}.{self.region}.api.lancedb.com" headers["Host"] = f"{self.db_name}.{self.region}.api.lancedb.com"
if self.host_override: if self.host_override:
headers["x-lancedb-database"] = self.db_name headers["x-lancedb-database"] = self.db_name
if self.storage_options:
if self.storage_options.get("account_name") is not None:
headers["x-azure-storage-account-name"] = self.storage_options[
"account_name"
]
if self.storage_options.get("azure_storage_account_name") is not None:
headers["x-azure-storage-account-name"] = self.storage_options[
"azure_storage_account_name"
]
if self.db_prefix:
headers["x-lancedb-database-prefix"] = self.db_prefix
return headers return headers
@staticmethod @staticmethod
@@ -158,6 +172,7 @@ class RestfulLanceDBClient:
headers["content-type"] = content_type headers["content-type"] = content_type
if request_id is not None: if request_id is not None:
headers["x-request-id"] = request_id headers["x-request-id"] = request_id
with self.session.post( with self.session.post(
urljoin(self.url, uri), urljoin(self.url, uri),
headers=headers, headers=headers,
@@ -245,7 +260,6 @@ def retry_adapter(options: Dict[str, Any]) -> HTTPAdapter:
connect=connect_retries, connect=connect_retries,
read=read_retries, read=read_retries,
backoff_factor=backoff_factor, backoff_factor=backoff_factor,
backoff_jitter=backoff_jitter,
status_forcelist=statuses, status_forcelist=statuses,
allowed_methods=methods, allowed_methods=methods,
) )

View File

@@ -15,7 +15,7 @@ import inspect
import logging import logging
import uuid import uuid
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Iterable, List, Optional, Union from typing import Dict, Iterable, List, Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
from cachetools import TTLCache from cachetools import TTLCache
@@ -44,20 +44,25 @@ class RemoteDBConnection(DBConnection):
request_thread_pool: Optional[ThreadPoolExecutor] = None, request_thread_pool: Optional[ThreadPoolExecutor] = None,
connection_timeout: float = 120.0, connection_timeout: float = 120.0,
read_timeout: float = 300.0, read_timeout: float = 300.0,
storage_options: Optional[Dict[str, str]] = None,
): ):
"""Connect to a remote LanceDB database.""" """Connect to a remote LanceDB database."""
parsed = urlparse(db_url) parsed = urlparse(db_url)
if parsed.scheme != "db": if parsed.scheme != "db":
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://") raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
self.db_name = parsed.netloc self.db_name = parsed.netloc
prefix = parsed.path.lstrip("/")
self.db_prefix = None if not prefix else prefix
self.api_key = api_key self.api_key = api_key
self._client = RestfulLanceDBClient( self._client = RestfulLanceDBClient(
self.db_name, self.db_name,
region, region,
api_key, api_key,
host_override, host_override,
self.db_prefix,
connection_timeout=connection_timeout, connection_timeout=connection_timeout,
read_timeout=read_timeout, read_timeout=read_timeout,
storage_options=storage_options,
) )
self._request_thread_pool = request_thread_pool self._request_thread_pool = request_thread_pool
self._table_cache = TTLCache(maxsize=10000, ttl=300) self._table_cache = TTLCache(maxsize=10000, ttl=300)

View File

@@ -15,13 +15,14 @@ import logging
import uuid import uuid
from concurrent.futures import Future from concurrent.futures import Future
from functools import cached_property from functools import cached_property
from typing import Dict, Iterable, Optional, Union from typing import Dict, Iterable, Optional, Union, Literal
import pyarrow as pa import pyarrow as pa
from lance import json_to_schema from lance import json_to_schema
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
from lancedb.merge import LanceMergeInsertBuilder from lancedb.merge import LanceMergeInsertBuilder
from lancedb.query import LanceQueryBuilder
from ..query import LanceVectorQueryBuilder from ..query import LanceVectorQueryBuilder
from ..table import Query, Table, _sanitize_data from ..table import Query, Table, _sanitize_data
@@ -81,6 +82,7 @@ class RemoteTable(Table):
def create_scalar_index( def create_scalar_index(
self, self,
column: str, column: str,
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST", "scalar"] = "scalar",
): ):
"""Creates a scalar index """Creates a scalar index
Parameters Parameters
@@ -89,8 +91,6 @@ class RemoteTable(Table):
The column to be indexed. Must be a boolean, integer, float, The column to be indexed. Must be a boolean, integer, float,
or string column. or string column.
""" """
index_type = "scalar"
data = { data = {
"column": column, "column": column,
"index_type": index_type, "index_type": index_type,
@@ -228,10 +228,21 @@ class RemoteTable(Table):
content_type=ARROW_STREAM_CONTENT_TYPE, content_type=ARROW_STREAM_CONTENT_TYPE,
) )
def query(
self,
query: Union[VEC, str] = None,
query_type: str = "vector",
vector_column_name: Optional[str] = None,
fast_search: bool = False,
) -> LanceVectorQueryBuilder:
return self.search(query, query_type, vector_column_name, fast_search)
def search( def search(
self, self,
query: Union[VEC, str], query: Union[VEC, str] = None,
query_type: str = "vector",
vector_column_name: Optional[str] = None, vector_column_name: Optional[str] = None,
fast_search: bool = False,
) -> LanceVectorQueryBuilder: ) -> LanceVectorQueryBuilder:
"""Create a search query to find the nearest neighbors """Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search] of the given query vector. We currently support [vector search][search]
@@ -278,6 +289,11 @@ class RemoteTable(Table):
- If the table has multiple vector columns then the *vector_column_name* - If the table has multiple vector columns then the *vector_column_name*
needs to be specified. Otherwise, an error is raised. needs to be specified. Otherwise, an error is raised.
fast_search: bool, optional
Skip a flat search of unindexed data. This may improve
search performance but search results will not include unindexed data.
- *default False*.
Returns Returns
------- -------
LanceQueryBuilder LanceQueryBuilder
@@ -293,7 +309,14 @@ class RemoteTable(Table):
""" """
if vector_column_name is None: if vector_column_name is None:
vector_column_name = inf_vector_column_query(self.schema) vector_column_name = inf_vector_column_query(self.schema)
return LanceVectorQueryBuilder(self, query, vector_column_name)
return LanceQueryBuilder.create(
self,
query,
query_type,
vector_column_name=vector_column_name,
fast_search=fast_search,
)
def _execute_query( def _execute_query(
self, query: Query, batch_size: Optional[int] = None self, query: Query, batch_size: Optional[int] = None

View File

@@ -337,7 +337,6 @@ class Table(ABC):
For example, the following scan will be faster if the column ``my_col`` has For example, the following scan will be faster if the column ``my_col`` has
a scalar index: a scalar index:
.. code-block:: python
import lancedb import lancedb
@@ -348,8 +347,6 @@ class Table(ABC):
Scalar indices can also speed up scans containing a vector search and a Scalar indices can also speed up scans containing a vector search and a
prefilter: prefilter:
.. code-block::python
import lancedb import lancedb
db = lancedb.connect("/data/lance") db = lancedb.connect("/data/lance")
@@ -385,7 +382,6 @@ class Table(ABC):
Examples Examples
-------- --------
.. code-block:: python
import lance import lance
@@ -1175,6 +1171,7 @@ class LanceTable(Table):
*, *,
replace: bool = False, replace: bool = False,
writer_heap_size: Optional[int] = 1024 * 1024 * 1024, writer_heap_size: Optional[int] = 1024 * 1024 * 1024,
tokenizer_name: str = "default",
): ):
"""Create a full-text search index on the table. """Create a full-text search index on the table.
@@ -1193,6 +1190,10 @@ class LanceTable(Table):
ordering_field_names: ordering_field_names:
A list of unsigned type fields to index to optionally order A list of unsigned type fields to index to optionally order
results on at search time results on at search time
tokenizer_name: str, default "default"
The tokenizer to use for the index. Can be "raw", "default" or the 2 letter
language code followed by "_stem". So for english it would be "en_stem".
For available languages see: https://docs.rs/tantivy/latest/tantivy/tokenizer/enum.Language.html
""" """
from .fts import create_index, populate_index from .fts import create_index, populate_index
@@ -1218,6 +1219,7 @@ class LanceTable(Table):
self._get_fts_index_path(), self._get_fts_index_path(),
field_names, field_names,
ordering_fields=ordering_field_names, ordering_fields=ordering_field_names,
tokenizer_name=tokenizer_name,
) )
populate_index( populate_index(
index, index,

View File

@@ -0,0 +1,27 @@
import lancedb
# --8<-- [start:imports]
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import get_registry
# --8<-- [end:imports]
import pytest
@pytest.mark.slow
def test_embeddings_openai():
# --8<-- [start:openai_embeddings]
db = lancedb.connect("/tmp/db")
func = get_registry().get("openai").create(name="text-embedding-ada-002")
class Words(LanceModel):
text: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
table = db.create_table("words", schema=Words, mode="overwrite")
table.add([{"text": "hello world"}, {"text": "goodbye world"}])
query = "greetings"
actual = table.search(query).limit(1).to_pydantic(Words)[0]
print(actual.text)
# --8<-- [end:openai_embeddings]

View File

@@ -507,6 +507,52 @@ def test_empty_or_nonexistent_table(tmp_path):
assert test.schema == test2.schema assert test.schema == test2.schema
@pytest.mark.asyncio
async def test_create_in_v2_mode(tmp_path):
def make_data():
for i in range(10):
yield pa.record_batch([pa.array([x for x in range(1024)])], names=["x"])
def make_table():
return pa.table([pa.array([x for x in range(10 * 1024)])], names=["x"])
schema = pa.schema([pa.field("x", pa.int64())])
db = await lancedb.connect_async(tmp_path)
# Create table in v1 mode
tbl = await db.create_table("test", data=make_data(), schema=schema)
async def is_in_v2_mode(tbl):
batches = await tbl.query().to_batches(max_batch_length=1024 * 10)
num_batches = 0
async for batch in batches:
num_batches += 1
return num_batches < 10
assert not await is_in_v2_mode(tbl)
# Create table in v2 mode
tbl = await db.create_table(
"test_v2", data=make_data(), schema=schema, use_legacy_format=False
)
assert await is_in_v2_mode(tbl)
# Add data (should remain in v2 mode)
await tbl.add(make_table())
assert await is_in_v2_mode(tbl)
# Create empty table in v2 mode and add data
tbl = await db.create_table(
"test_empty_v2", data=None, schema=schema, use_legacy_format=False
)
await tbl.add(make_table())
assert await is_in_v2_mode(tbl)
def test_replace_index(tmp_path): def test_replace_index(tmp_path):
db = lancedb.connect(uri=tmp_path) db = lancedb.connect(uri=tmp_path)
table = db.create_table( table = db.create_table(

View File

@@ -66,6 +66,17 @@ def test_create_index(tmp_path):
assert os.path.exists(str(tmp_path / "index")) assert os.path.exists(str(tmp_path / "index"))
def test_create_index_with_stemming(tmp_path, table):
index = ldb.fts.create_index(
str(tmp_path / "index"), ["text"], tokenizer_name="en_stem"
)
assert isinstance(index, tantivy.Index)
assert os.path.exists(str(tmp_path / "index"))
# Check stemming by running tokenizer on non empty table
table.create_fts_index("text", tokenizer_name="en_stem")
def test_populate_index(tmp_path, table): def test_populate_index(tmp_path, table):
index = ldb.fts.create_index(str(tmp_path / "index"), ["text"]) index = ldb.fts.create_index(str(tmp_path / "index"), ["text"])
assert ldb.fts.populate_index(index, table, ["text"]) == len(table) assert ldb.fts.populate_index(index, table, ["text"]) == len(table)

View File

@@ -21,6 +21,7 @@ class FakeLanceDBClient:
pass pass
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
print(f"{query=}")
assert table_name == "test" assert table_name == "test"
t = pa.schema([]).empty_table() t = pa.schema([]).empty_table()
return VectorQueryResult(t) return VectorQueryResult(t)
@@ -39,3 +40,21 @@ def test_remote_db():
table = conn["test"] table = conn["test"]
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))]) table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
table.search([1.0, 2.0]).to_pandas() table.search([1.0, 2.0]).to_pandas()
def test_empty_query_with_filter():
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
setattr(conn, "_client", FakeLanceDBClient())
table = conn["test"]
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
print(table.query().select(["vector"]).where("foo == bar").to_arrow())
def test_fast_search_query_with_filter():
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
setattr(conn, "_client", FakeLanceDBClient())
table = conn["test"]
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
print(table.query([0, 0], fast_search=True).select(["vector"]).where("foo == bar").to_arrow())

View File

@@ -735,7 +735,7 @@ def test_create_scalar_index(db):
indices = table.to_lance().list_indices() indices = table.to_lance().list_indices()
assert len(indices) == 1 assert len(indices) == 1
scalar_index = indices[0] scalar_index = indices[0]
assert scalar_index["type"] == "Scalar" assert scalar_index["type"] == "BTree"
# Confirm that prefiltering still works with the scalar index column # Confirm that prefiltering still works with the scalar index column
results = table.search().where("x = 'c'").to_arrow() results = table.search().where("x = 'c'").to_arrow()

View File

@@ -91,6 +91,7 @@ impl Connection {
mode: &str, mode: &str,
data: &PyAny, data: &PyAny,
storage_options: Option<HashMap<String, String>>, storage_options: Option<HashMap<String, String>>,
use_legacy_format: Option<bool>,
) -> PyResult<&'a PyAny> { ) -> PyResult<&'a PyAny> {
let inner = self_.get_inner()?.clone(); let inner = self_.get_inner()?.clone();
@@ -103,6 +104,10 @@ impl Connection {
builder = builder.storage_options(storage_options); builder = builder.storage_options(storage_options);
} }
if let Some(use_legacy_format) = use_legacy_format {
builder = builder.use_legacy_format(use_legacy_format);
}
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
let table = builder.execute().await.infer_error()?; let table = builder.execute().await.infer_error()?;
Ok(Table::new(table)) Ok(Table::new(table))
@@ -115,6 +120,7 @@ impl Connection {
mode: &str, mode: &str,
schema: &PyAny, schema: &PyAny,
storage_options: Option<HashMap<String, String>>, storage_options: Option<HashMap<String, String>>,
use_legacy_format: Option<bool>,
) -> PyResult<&'a PyAny> { ) -> PyResult<&'a PyAny> {
let inner = self_.get_inner()?.clone(); let inner = self_.get_inner()?.clone();
@@ -128,6 +134,10 @@ impl Connection {
builder = builder.storage_options(storage_options); builder = builder.storage_options(storage_options);
} }
if let Some(use_legacy_format) = use_legacy_format {
builder = builder.use_legacy_format(use_legacy_format);
}
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
let table = builder.execute().await.infer_error()?; let table = builder.execute().await.infer_error()?;
Ok(Table::new(table)) Ok(Table::new(table))

View File

@@ -15,6 +15,7 @@
use arrow::array::make_array; use arrow::array::make_array;
use arrow::array::ArrayData; use arrow::array::ArrayData;
use arrow::pyarrow::FromPyArrow; use arrow::pyarrow::FromPyArrow;
use lancedb::query::QueryExecutionOptions;
use lancedb::query::{ use lancedb::query::{
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery, ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
}; };
@@ -61,10 +62,14 @@ impl Query {
Ok(VectorQuery { inner }) Ok(VectorQuery { inner })
} }
pub fn execute(self_: PyRef<'_, Self>) -> PyResult<&PyAny> { pub fn execute(self_: PyRef<'_, Self>, max_batch_length: Option<u32>) -> PyResult<&PyAny> {
let inner = self_.inner.clone(); let inner = self_.inner.clone();
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
let inner_stream = inner.execute().await.infer_error()?; let mut opts = QueryExecutionOptions::default();
if let Some(max_batch_length) = max_batch_length {
opts.max_batch_length = max_batch_length;
}
let inner_stream = inner.execute_with_options(opts).await.infer_error()?;
Ok(RecordBatchStream::new(inner_stream)) Ok(RecordBatchStream::new(inner_stream))
}) })
} }
@@ -115,10 +120,14 @@ impl VectorQuery {
self.inner = self.inner.clone().bypass_vector_index() self.inner = self.inner.clone().bypass_vector_index()
} }
pub fn execute(self_: PyRef<'_, Self>) -> PyResult<&PyAny> { pub fn execute(self_: PyRef<'_, Self>, max_batch_length: Option<u32>) -> PyResult<&PyAny> {
let inner = self_.inner.clone(); let inner = self_.inner.clone();
future_into_py(self_.py(), async move { future_into_py(self_.py(), async move {
let inner_stream = inner.execute().await.infer_error()?; let mut opts = QueryExecutionOptions::default();
if let Some(max_batch_length) = max_batch_length {
opts.max_batch_length = max_batch_length;
}
let inner_stream = inner.execute_with_options(opts).await.infer_error()?;
Ok(RecordBatchStream::new(inner_stream)) Ok(RecordBatchStream::new(inner_stream))
}) })
} }

2
rust-toolchain.toml Normal file
View File

@@ -0,0 +1,2 @@
[toolchain]
channel = "1.79.0"

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb-node" name = "lancedb-node"
version = "0.5.0" version = "0.5.2-final.1"
description = "Serverless, low-latency vector database for AI applications" description = "Serverless, low-latency vector database for AI applications"
license.workspace = true license.workspace = true
edition.workspace = true edition.workspace = true

View File

@@ -463,6 +463,7 @@ impl JsTable {
Ok(promise) Ok(promise)
} }
#[allow(deprecated)]
pub(crate) fn js_index_stats(mut cx: FunctionContext) -> JsResult<JsPromise> { pub(crate) fn js_index_stats(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?; let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
let rt = runtime(&mut cx)?; let rt = runtime(&mut cx)?;

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb" name = "lancedb"
version = "0.5.0" version = "0.5.2-final.1"
edition.workspace = true edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications" description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true license.workspace = true
@@ -19,11 +19,13 @@ arrow-ord = { workspace = true }
arrow-cast = { workspace = true } arrow-cast = { workspace = true }
arrow-ipc.workspace = true arrow-ipc.workspace = true
chrono = { workspace = true } chrono = { workspace = true }
datafusion-physical-plan.workspace = true
object_store = { workspace = true } object_store = { workspace = true }
snafu = { workspace = true } snafu = { workspace = true }
half = { workspace = true } half = { workspace = true }
lazy_static.workspace = true lazy_static.workspace = true
lance = { workspace = true } lance = { workspace = true }
lance-datafusion.workspace = true
lance-index = { workspace = true } lance-index = { workspace = true }
lance-linalg = { workspace = true } lance-linalg = { workspace = true }
lance-testing = { workspace = true } lance-testing = { workspace = true }
@@ -38,11 +40,12 @@ url.workspace = true
regex.workspace = true regex.workspace = true
serde = { version = "^1" } serde = { version = "^1" }
serde_json = { version = "1" } serde_json = { version = "1" }
async-openai = { version = "0.20.0", optional = true }
serde_with = { version = "3.8.1" } serde_with = { version = "3.8.1" }
# For remote feature # For remote feature
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true } reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
polars-arrow = { version = ">=0.37,<0.40.0", optional = true } polars-arrow = { version = ">=0.37,<0.40.0", optional = true }
polars = { version = ">=0.37,<0.40.0", optional = true} polars = { version = ">=0.37,<0.40.0", optional = true }
[dev-dependencies] [dev-dependencies]
tempfile = "3.5.0" tempfile = "3.5.0"
@@ -62,4 +65,10 @@ default = []
remote = ["dep:reqwest"] remote = ["dep:reqwest"]
fp16kernels = ["lance-linalg/fp16kernels"] fp16kernels = ["lance-linalg/fp16kernels"]
s3-test = [] s3-test = []
openai = ["dep:async-openai", "dep:reqwest"]
polars = ["dep:polars-arrow", "dep:polars"] polars = ["dep:polars-arrow", "dep:polars"]
[[example]]
name = "openai"
required-features = ["openai"]

View File

@@ -0,0 +1,82 @@
use std::{iter::once, sync::Arc};
use arrow_array::{Float64Array, Int32Array, RecordBatch, RecordBatchIterator, StringArray};
use arrow_schema::{DataType, Field, Schema};
use futures::StreamExt;
use lancedb::{
arrow::IntoArrow,
connect,
embeddings::{openai::OpenAIEmbeddingFunction, EmbeddingDefinition, EmbeddingFunction},
query::{ExecutableQuery, QueryBase},
Result,
};
#[tokio::main]
async fn main() -> Result<()> {
let tempdir = tempfile::tempdir().unwrap();
let tempdir = tempdir.path().to_str().unwrap();
let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY is not set");
let embedding = Arc::new(OpenAIEmbeddingFunction::new_with_model(
api_key,
"text-embedding-3-large",
)?);
let db = connect(tempdir).execute().await?;
db.embedding_registry()
.register("openai", embedding.clone())?;
let table = db
.create_table("vectors", make_data())
.add_embedding(EmbeddingDefinition::new(
"text",
"openai",
Some("embeddings"),
))?
.execute()
.await?;
// there is no equivalent to '.search(<query>)' yet
let query = Arc::new(StringArray::from_iter_values(once("something warm")));
let query_vector = embedding.compute_query_embeddings(query)?;
let mut results = table
.vector_search(query_vector)?
.limit(1)
.execute()
.await?;
let rb = results.next().await.unwrap()?;
let out = rb
.column_by_name("text")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let text = out.iter().next().unwrap().unwrap();
println!("Closest match: {}", text);
Ok(())
}
fn make_data() -> impl IntoArrow {
let schema = Schema::new(vec![
Field::new("id", DataType::Int32, true),
Field::new("text", DataType::Utf8, false),
Field::new("price", DataType::Float64, false),
]);
let id = Int32Array::from(vec![1, 2, 3, 4]);
let text = StringArray::from_iter_values(vec![
"Black T-Shirt",
"Leather Jacket",
"Winter Parka",
"Hooded Sweatshirt",
]);
let price = Float64Array::from(vec![10.0, 50.0, 100.0, 30.0]);
let schema = Arc::new(schema);
let rb = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(id), Arc::new(text), Arc::new(price)],
)
.unwrap();
Box::new(RecordBatchIterator::new(vec![Ok(rb)], schema))
}

View File

@@ -140,6 +140,7 @@ pub struct CreateTableBuilder<const HAS_DATA: bool, T: IntoArrow> {
pub(crate) write_options: WriteOptions, pub(crate) write_options: WriteOptions,
pub(crate) table_definition: Option<TableDefinition>, pub(crate) table_definition: Option<TableDefinition>,
pub(crate) embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>, pub(crate) embeddings: Vec<(EmbeddingDefinition, Arc<dyn EmbeddingFunction>)>,
pub(crate) use_legacy_format: bool,
} }
// Builder methods that only apply when we have initial data // Builder methods that only apply when we have initial data
@@ -153,6 +154,7 @@ impl<T: IntoArrow> CreateTableBuilder<true, T> {
write_options: WriteOptions::default(), write_options: WriteOptions::default(),
table_definition: None, table_definition: None,
embeddings: Vec::new(), embeddings: Vec::new(),
use_legacy_format: true,
} }
} }
@@ -184,6 +186,7 @@ impl<T: IntoArrow> CreateTableBuilder<true, T> {
mode: self.mode, mode: self.mode,
write_options: self.write_options, write_options: self.write_options,
embeddings: self.embeddings, embeddings: self.embeddings,
use_legacy_format: self.use_legacy_format,
}; };
Ok((data, builder)) Ok((data, builder))
} }
@@ -217,6 +220,7 @@ impl CreateTableBuilder<false, NoData> {
mode: CreateTableMode::default(), mode: CreateTableMode::default(),
write_options: WriteOptions::default(), write_options: WriteOptions::default(),
embeddings: Vec::new(), embeddings: Vec::new(),
use_legacy_format: false,
} }
} }
@@ -278,6 +282,20 @@ impl<const HAS_DATA: bool, T: IntoArrow> CreateTableBuilder<HAS_DATA, T> {
} }
self self
} }
/// Set to true to use the v1 format for data files
///
/// This is currently defaulted to true and can be set to false to opt-in
/// to the new format. This should only be used for experimentation and
/// evaluation. The new format is still in beta and may change in ways that
/// are not backwards compatible.
///
/// Once the new format is stable, the default will change to `false` for
/// several releases and then eventually this option will be removed.
pub fn use_legacy_format(mut self, use_legacy_format: bool) -> Self {
self.use_legacy_format = use_legacy_format;
self
}
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@@ -943,6 +961,7 @@ impl ConnectionInternal for Database {
if matches!(&options.mode, CreateTableMode::Overwrite) { if matches!(&options.mode, CreateTableMode::Overwrite) {
write_params.mode = WriteMode::Overwrite; write_params.mode = WriteMode::Overwrite;
} }
write_params.use_legacy_format = options.use_legacy_format;
match NativeTable::create( match NativeTable::create(
&table_uri, &table_uri,
@@ -1040,8 +1059,12 @@ impl ConnectionInternal for Database {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use arrow_schema::{DataType, Field, Schema}; use arrow_schema::{DataType, Field, Schema};
use futures::TryStreamExt;
use lance_testing::datagen::{BatchGenerator, IncrementingInt32};
use tempfile::tempdir; use tempfile::tempdir;
use crate::query::{ExecutableQuery, QueryExecutionOptions};
use super::*; use super::*;
#[tokio::test] #[tokio::test]
@@ -1146,6 +1169,58 @@ mod tests {
assert_eq!(tables, vec!["table1".to_owned()]); assert_eq!(tables, vec!["table1".to_owned()]);
} }
fn make_data() -> impl RecordBatchReader + Send + 'static {
let id = Box::new(IncrementingInt32::new().named("id".to_string()));
BatchGenerator::new().col(id).batches(10, 2000)
}
#[tokio::test]
async fn test_create_table_v2() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
let tbl = db
.create_table("v1_test", make_data())
.execute()
.await
.unwrap();
// In v1 the row group size will trump max_batch_length
let batches = tbl
.query()
.execute_with_options(QueryExecutionOptions {
max_batch_length: 50000,
})
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(batches.len(), 20);
let tbl = db
.create_table("v2_test", make_data())
.use_legacy_format(false)
.execute()
.await
.unwrap();
// In v2 the page size is much bigger than 50k so we should get a single batch
let batches = tbl
.query()
.execute_with_options(QueryExecutionOptions {
max_batch_length: 50000,
})
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(batches.len(), 1);
}
#[tokio::test] #[tokio::test]
async fn drop_table() { async fn drop_table() {
let tmp_dir = tempdir().unwrap(); let tmp_dir = tempdir().unwrap();

View File

@@ -11,6 +11,8 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#[cfg(feature = "openai")]
pub mod openai;
use lance::arrow::RecordBatchExt; use lance::arrow::RecordBatchExt;
use std::{ use std::{
@@ -51,8 +53,10 @@ pub trait EmbeddingFunction: std::fmt::Debug + Send + Sync {
/// The type of the output data /// The type of the output data
/// This should **always** match the output of the `embed` function /// This should **always** match the output of the `embed` function
fn dest_type(&self) -> Result<Cow<DataType>>; fn dest_type(&self) -> Result<Cow<DataType>>;
/// Embed the input /// Compute the embeddings for the source column in the database
fn embed(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>>; fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>>;
/// Compute the embeddings for a given user query
fn compute_query_embeddings(&self, input: Arc<dyn Array>) -> Result<Arc<dyn Array>>;
} }
/// Defines an embedding from input data into a lower-dimensional space /// Defines an embedding from input data into a lower-dimensional space
@@ -266,7 +270,7 @@ impl<R: RecordBatchReader> Iterator for WithEmbeddings<R> {
// todo: parallelize this // todo: parallelize this
for (fld, func) in self.embeddings.iter() { for (fld, func) in self.embeddings.iter() {
let src_column = batch.column_by_name(&fld.source_column).unwrap(); let src_column = batch.column_by_name(&fld.source_column).unwrap();
let embedding = match func.embed(src_column.clone()) { let embedding = match func.compute_source_embeddings(src_column.clone()) {
Ok(embedding) => embedding, Ok(embedding) => embedding,
Err(e) => { Err(e) => {
return Some(Err(arrow_schema::ArrowError::ComputeError(format!( return Some(Err(arrow_schema::ArrowError::ComputeError(format!(

View File

@@ -0,0 +1,257 @@
use std::{borrow::Cow, fmt::Formatter, str::FromStr, sync::Arc};
use arrow::array::{AsArray, Float32Builder};
use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array};
use arrow_data::ArrayData;
use arrow_schema::DataType;
use async_openai::{
config::OpenAIConfig,
types::{CreateEmbeddingRequest, Embedding, EmbeddingInput, EncodingFormat},
Client,
};
use tokio::{runtime::Handle, task};
use crate::{Error, Result};
use super::EmbeddingFunction;
#[derive(Debug)]
pub enum EmbeddingModel {
TextEmbeddingAda002,
TextEmbedding3Small,
TextEmbedding3Large,
}
impl EmbeddingModel {
fn ndims(&self) -> usize {
match self {
Self::TextEmbeddingAda002 => 1536,
Self::TextEmbedding3Small => 1536,
Self::TextEmbedding3Large => 3072,
}
}
}
impl FromStr for EmbeddingModel {
type Err = Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s {
"text-embedding-ada-002" => Ok(Self::TextEmbeddingAda002),
"text-embedding-3-small" => Ok(Self::TextEmbedding3Small),
"text-embedding-3-large" => Ok(Self::TextEmbedding3Large),
_ => Err(Error::InvalidInput {
message: "Invalid input. Available models are: 'text-embedding-3-small', 'text-embedding-ada-002', 'text-embedding-3-large' ".to_string()
}),
}
}
}
impl std::fmt::Display for EmbeddingModel {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
match self {
Self::TextEmbeddingAda002 => write!(f, "text-embedding-ada-002"),
Self::TextEmbedding3Small => write!(f, "text-embedding-3-small"),
Self::TextEmbedding3Large => write!(f, "text-embedding-3-large"),
}
}
}
impl TryFrom<&str> for EmbeddingModel {
type Error = Error;
fn try_from(value: &str) -> std::result::Result<Self, Self::Error> {
value.parse()
}
}
pub struct OpenAIEmbeddingFunction {
model: EmbeddingModel,
api_key: String,
api_base: Option<String>,
org_id: Option<String>,
}
impl std::fmt::Debug for OpenAIEmbeddingFunction {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
// let's be safe and not print the full API key
let creds_display = if self.api_key.len() > 6 {
format!(
"{}***{}",
&self.api_key[0..2],
&self.api_key[self.api_key.len() - 4..]
)
} else {
"[INVALID]".to_string()
};
f.debug_struct("OpenAI")
.field("model", &self.model)
.field("api_key", &creds_display)
.field("api_base", &self.api_base)
.field("org_id", &self.org_id)
.finish()
}
}
impl OpenAIEmbeddingFunction {
/// Create a new OpenAIEmbeddingFunction
pub fn new<A: Into<String>>(api_key: A) -> Self {
Self::new_impl(api_key.into(), EmbeddingModel::TextEmbeddingAda002)
}
pub fn new_with_model<A: Into<String>, M: TryInto<EmbeddingModel>>(
api_key: A,
model: M,
) -> crate::Result<Self>
where
M::Error: Into<crate::Error>,
{
Ok(Self::new_impl(
api_key.into(),
model.try_into().map_err(|e| e.into())?,
))
}
/// concrete implementation to reduce monomorphization
fn new_impl(api_key: String, model: EmbeddingModel) -> Self {
Self {
model,
api_key,
api_base: None,
org_id: None,
}
}
/// To use a API base url different from default "https://api.openai.com/v1"
pub fn api_base<S: Into<String>>(mut self, api_base: S) -> Self {
self.api_base = Some(api_base.into());
self
}
/// To use a different OpenAI organization id other than default
pub fn org_id<S: Into<String>>(mut self, org_id: S) -> Self {
self.org_id = Some(org_id.into());
self
}
}
impl EmbeddingFunction for OpenAIEmbeddingFunction {
fn name(&self) -> &str {
"openai"
}
fn source_type(&self) -> Result<Cow<DataType>> {
Ok(Cow::Owned(DataType::Utf8))
}
fn dest_type(&self) -> Result<Cow<DataType>> {
let n_dims = self.model.ndims();
Ok(Cow::Owned(DataType::new_fixed_size_list(
DataType::Float32,
n_dims as i32,
false,
)))
}
fn compute_source_embeddings(&self, source: ArrayRef) -> crate::Result<ArrayRef> {
let len = source.len();
let n_dims = self.model.ndims();
let inner = self.compute_inner(source)?;
let fsl = DataType::new_fixed_size_list(DataType::Float32, n_dims as i32, false);
// We can't use the FixedSizeListBuilder here because it always adds a null bitmap
// and we want to explicitly work with non-nullable arrays.
let array_data = ArrayData::builder(fsl)
.len(len)
.add_child_data(inner.into_data())
.build()?;
Ok(Arc::new(FixedSizeListArray::from(array_data)))
}
fn compute_query_embeddings(&self, input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
let arr = self.compute_inner(input)?;
Ok(Arc::new(arr))
}
}
impl OpenAIEmbeddingFunction {
fn compute_inner(&self, source: Arc<dyn Array>) -> Result<Float32Array> {
// OpenAI only supports non-nullable string arrays
if source.is_nullable() {
return Err(crate::Error::InvalidInput {
message: "Expected non-nullable data type".to_string(),
});
}
// OpenAI only supports string arrays
if !matches!(source.data_type(), DataType::Utf8 | DataType::LargeUtf8) {
return Err(crate::Error::InvalidInput {
message: "Expected Utf8 data type".to_string(),
});
};
let mut creds = OpenAIConfig::new().with_api_key(self.api_key.clone());
if let Some(api_base) = &self.api_base {
creds = creds.with_api_base(api_base.clone());
}
if let Some(org_id) = &self.org_id {
creds = creds.with_org_id(org_id.clone());
}
let input = match source.data_type() {
DataType::Utf8 => {
let array = source
.as_string::<i32>()
.into_iter()
.map(|s| {
s.expect("we already asserted that the array is non-nullable")
.to_string()
})
.collect::<Vec<String>>();
EmbeddingInput::StringArray(array)
}
DataType::LargeUtf8 => {
let array = source
.as_string::<i64>()
.into_iter()
.map(|s| {
s.expect("we already asserted that the array is non-nullable")
.to_string()
})
.collect::<Vec<String>>();
EmbeddingInput::StringArray(array)
}
_ => unreachable!("This should not happen. We already checked the data type."),
};
let client = Client::with_config(creds);
let embed = client.embeddings();
let req = CreateEmbeddingRequest {
model: self.model.to_string(),
input,
encoding_format: Some(EncodingFormat::Float),
user: None,
dimensions: None,
};
// TODO: request batching and retry logic
task::block_in_place(move || {
Handle::current().block_on(async {
let mut builder = Float32Builder::new();
let res = embed.create(req).await.map_err(|e| crate::Error::Runtime {
message: format!("OpenAI embed request failed: {e}"),
})?;
for Embedding { embedding, .. } in res.data.iter() {
builder.append_slice(embedding);
}
Ok(builder.finish())
})
})
}
}

View File

@@ -80,6 +80,8 @@ pub enum IndexType {
/// A description of an index currently configured on a column /// A description of an index currently configured on a column
pub struct IndexConfig { pub struct IndexConfig {
/// The name of the index
pub name: String,
/// The type of the index /// The type of the index
pub index_type: IndexType, pub index_type: IndexType,
/// The columns in the index /// The columns in the index

View File

@@ -84,7 +84,8 @@ pub fn convert_polars_arrow_array_to_arrow_rs_array(
arrow_datatype: arrow_schema::DataType, arrow_datatype: arrow_schema::DataType,
) -> std::result::Result<arrow_array::ArrayRef, arrow_schema::ArrowError> { ) -> std::result::Result<arrow_array::ArrayRef, arrow_schema::ArrowError> {
let polars_c_array = polars_arrow::ffi::export_array_to_c(polars_array); let polars_c_array = polars_arrow::ffi::export_array_to_c(polars_array);
let arrow_c_array = unsafe { mem::transmute(polars_c_array) }; // Safety: `polars_arrow::ffi::ArrowArray` has the same memory layout as `arrow::ffi::FFI_ArrowArray`.
let arrow_c_array: arrow_data::ffi::FFI_ArrowArray = unsafe { mem::transmute(polars_c_array) };
Ok(arrow_array::make_array(unsafe { Ok(arrow_array::make_array(unsafe {
arrow::ffi::from_ffi_and_data_type(arrow_c_array, arrow_datatype) arrow::ffi::from_ffi_and_data_type(arrow_c_array, arrow_datatype)
}?)) }?))
@@ -96,7 +97,8 @@ fn convert_arrow_rs_array_to_polars_arrow_array(
polars_arrow_dtype: polars::datatypes::ArrowDataType, polars_arrow_dtype: polars::datatypes::ArrowDataType,
) -> Result<Box<dyn polars_arrow::array::Array>> { ) -> Result<Box<dyn polars_arrow::array::Array>> {
let arrow_c_array = arrow::ffi::FFI_ArrowArray::new(&arrow_rs_array.to_data()); let arrow_c_array = arrow::ffi::FFI_ArrowArray::new(&arrow_rs_array.to_data());
let polars_c_array = unsafe { mem::transmute(arrow_c_array) }; // Safety: `polars_arrow::ffi::ArrowArray` has the same memory layout as `arrow::ffi::FFI_ArrowArray`.
let polars_c_array: polars_arrow::ffi::ArrowArray = unsafe { mem::transmute(arrow_c_array) };
Ok(unsafe { polars_arrow::ffi::import_array_from_c(polars_c_array, polars_arrow_dtype) }?) Ok(unsafe { polars_arrow::ffi::import_array_from_c(polars_c_array, polars_arrow_dtype) }?)
} }
@@ -104,7 +106,9 @@ fn convert_polars_arrow_field_to_arrow_rs_field(
polars_arrow_field: polars_arrow::datatypes::Field, polars_arrow_field: polars_arrow::datatypes::Field,
) -> Result<arrow_schema::Field> { ) -> Result<arrow_schema::Field> {
let polars_c_schema = polars_arrow::ffi::export_field_to_c(&polars_arrow_field); let polars_c_schema = polars_arrow::ffi::export_field_to_c(&polars_arrow_field);
let arrow_c_schema: arrow::ffi::FFI_ArrowSchema = unsafe { mem::transmute(polars_c_schema) }; // Safety: `polars_arrow::ffi::ArrowSchema` has the same memory layout as `arrow::ffi::FFI_ArrowSchema`.
let arrow_c_schema: arrow::ffi::FFI_ArrowSchema =
unsafe { mem::transmute::<_, _>(polars_c_schema) };
let arrow_rs_dtype = arrow_schema::DataType::try_from(&arrow_c_schema)?; let arrow_rs_dtype = arrow_schema::DataType::try_from(&arrow_c_schema)?;
Ok(arrow_schema::Field::new( Ok(arrow_schema::Field::new(
polars_arrow_field.name, polars_arrow_field.name,
@@ -118,6 +122,8 @@ fn convert_arrow_rs_field_to_polars_arrow_field(
) -> Result<polars_arrow::datatypes::Field> { ) -> Result<polars_arrow::datatypes::Field> {
let arrow_rs_dtype = arrow_rs_field.data_type(); let arrow_rs_dtype = arrow_rs_field.data_type();
let arrow_c_schema = arrow::ffi::FFI_ArrowSchema::try_from(arrow_rs_dtype)?; let arrow_c_schema = arrow::ffi::FFI_ArrowSchema::try_from(arrow_rs_dtype)?;
let polars_c_schema: polars_arrow::ffi::ArrowSchema = unsafe { mem::transmute(arrow_c_schema) }; // Safety: `polars_arrow::ffi::ArrowSchema` has the same memory layout as `arrow::ffi::FFI_ArrowSchema`.
let polars_c_schema: polars_arrow::ffi::ArrowSchema =
unsafe { mem::transmute::<_, _>(arrow_c_schema) };
Ok(unsafe { polars_arrow::ffi::import_field_from_c(&polars_c_schema) }?) Ok(unsafe { polars_arrow::ffi::import_field_from_c(&polars_c_schema) }?)
} }

View File

@@ -17,7 +17,10 @@ use std::sync::Arc;
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array}; use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
use arrow_schema::DataType; use arrow_schema::DataType;
use datafusion_physical_plan::ExecutionPlan;
use half::f16; use half::f16;
use lance::dataset::scanner::DatasetRecordBatchStream;
use lance_datafusion::exec::execute_plan;
use crate::arrow::SendableRecordBatchStream; use crate::arrow::SendableRecordBatchStream;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
@@ -425,6 +428,15 @@ impl Default for QueryExecutionOptions {
/// There are various kinds of queries but they all return results /// There are various kinds of queries but they all return results
/// in the same way. /// in the same way.
pub trait ExecutableQuery { pub trait ExecutableQuery {
/// Return the Datafusion [ExecutionPlan].
///
/// The caller can further optimize the plan or execute it.
///
fn create_plan(
&self,
options: QueryExecutionOptions,
) -> impl Future<Output = Result<Arc<dyn ExecutionPlan>>> + Send;
/// Execute the query with default options and return results /// Execute the query with default options and return results
/// ///
/// See [`ExecutableQuery::execute_with_options`] for more details. /// See [`ExecutableQuery::execute_with_options`] for more details.
@@ -545,6 +557,13 @@ impl HasQuery for Query {
} }
impl ExecutableQuery for Query { impl ExecutableQuery for Query {
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
self.parent
.clone()
.create_plan(&self.clone().into_vector(), options)
.await
}
async fn execute_with_options( async fn execute_with_options(
&self, &self,
options: QueryExecutionOptions, options: QueryExecutionOptions,
@@ -718,12 +737,19 @@ impl VectorQuery {
} }
impl ExecutableQuery for VectorQuery { impl ExecutableQuery for VectorQuery {
async fn create_plan(&self, options: QueryExecutionOptions) -> Result<Arc<dyn ExecutionPlan>> {
self.base.parent.clone().create_plan(self, options).await
}
async fn execute_with_options( async fn execute_with_options(
&self, &self,
options: QueryExecutionOptions, options: QueryExecutionOptions,
) -> Result<SendableRecordBatchStream> { ) -> Result<SendableRecordBatchStream> {
Ok(SendableRecordBatchStream::from( Ok(SendableRecordBatchStream::from(
self.base.parent.clone().vector_query(self, options).await?, DatasetRecordBatchStream::new(execute_plan(
self.create_plan(options).await?,
Default::default(),
)?),
)) ))
} }
} }
@@ -972,6 +998,30 @@ mod tests {
} }
} }
fn assert_plan_exists(plan: &Arc<dyn ExecutionPlan>, name: &str) -> bool {
if plan.name() == name {
return true;
}
plan.children()
.iter()
.any(|child| assert_plan_exists(child, name))
}
#[tokio::test]
async fn test_create_execute_plan() {
let tmp_dir = tempdir().unwrap();
let table = make_test_table(&tmp_dir).await;
let plan = table
.query()
.nearest_to(vec![0.1, 0.2, 0.3, 0.4])
.unwrap()
.create_plan(QueryExecutionOptions::default())
.await
.unwrap();
assert_plan_exists(&plan, "KNNFlatSearch");
assert_plan_exists(&plan, "ProjectionExec");
}
#[tokio::test] #[tokio::test]
async fn query_base_methods_on_vector_query() { async fn query_base_methods_on_vector_query() {
// Make sure VectorQuery can be used as a QueryBase // Make sure VectorQuery can be used as a QueryBase
@@ -989,5 +1039,18 @@ mod tests {
let first_batch = results.next().await.unwrap().unwrap(); let first_batch = results.next().await.unwrap().unwrap();
assert_eq!(first_batch.num_rows(), 1); assert_eq!(first_batch.num_rows(), 1);
assert!(results.next().await.is_none()); assert!(results.next().await.is_none());
// query with wrong vector dimension
let error_result = table
.vector_search(&[1.0, 2.0, 3.0])
.unwrap()
.limit(1)
.execute()
.await;
assert!(error_result
.err()
.unwrap()
.to_string()
.contains("No vector column found to match with the query vector dimension: 3"));
} }
} }

View File

@@ -1,6 +1,9 @@
use std::sync::Arc;
use arrow_array::RecordBatchReader; use arrow_array::RecordBatchReader;
use arrow_schema::SchemaRef; use arrow_schema::SchemaRef;
use async_trait::async_trait; use async_trait::async_trait;
use datafusion_physical_plan::ExecutionPlan;
use lance::dataset::{scanner::DatasetRecordBatchStream, ColumnAlteration, NewColumnTransform}; use lance::dataset::{scanner::DatasetRecordBatchStream, ColumnAlteration, NewColumnTransform};
use crate::{ use crate::{
@@ -71,6 +74,13 @@ impl TableInternal for RemoteTable {
) -> Result<()> { ) -> Result<()> {
todo!() todo!()
} }
async fn create_plan(
&self,
_query: &VectorQuery,
_options: QueryExecutionOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
unimplemented!()
}
async fn plain_query( async fn plain_query(
&self, &self,
_query: &Query, _query: &Query,
@@ -78,13 +88,6 @@ impl TableInternal for RemoteTable {
) -> Result<DatasetRecordBatchStream> { ) -> Result<DatasetRecordBatchStream> {
todo!() todo!()
} }
async fn vector_query(
&self,
_query: &VectorQuery,
_options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
todo!()
}
async fn update(&self, _update: UpdateBuilder) -> Result<()> { async fn update(&self, _update: UpdateBuilder) -> Result<()> {
todo!() todo!()
} }

View File

@@ -23,6 +23,7 @@ use arrow::datatypes::Float32Type;
use arrow_array::{RecordBatchIterator, RecordBatchReader}; use arrow_array::{RecordBatchIterator, RecordBatchReader};
use arrow_schema::{DataType, Field, Schema, SchemaRef}; use arrow_schema::{DataType, Field, Schema, SchemaRef};
use async_trait::async_trait; use async_trait::async_trait;
use datafusion_physical_plan::ExecutionPlan;
use lance::dataset::builder::DatasetBuilder; use lance::dataset::builder::DatasetBuilder;
use lance::dataset::cleanup::RemovalStats; use lance::dataset::cleanup::RemovalStats;
use lance::dataset::optimize::{compact_files, CompactionMetrics, IndexRemapperOptions}; use lance::dataset::optimize::{compact_files, CompactionMetrics, IndexRemapperOptions};
@@ -35,6 +36,7 @@ use lance::dataset::{
}; };
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource}; use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
use lance::io::WrappingObjectStore; use lance::io::WrappingObjectStore;
use lance_datafusion::exec::execute_plan;
use lance_index::vector::hnsw::builder::HnswBuildParams; use lance_index::vector::hnsw::builder::HnswBuildParams;
use lance_index::vector::ivf::IvfBuildParams; use lance_index::vector::ivf::IvfBuildParams;
use lance_index::vector::pq::PQBuildParams; use lance_index::vector::pq::PQBuildParams;
@@ -231,7 +233,8 @@ pub struct WriteOptions {
// pub on_bad_vectors: BadVectorHandling, // pub on_bad_vectors: BadVectorHandling,
/// Advanced parameters that can be used to customize table creation /// Advanced parameters that can be used to customize table creation
/// ///
/// If set, these will take precedence over any overlapping `OpenTableBuilder` options /// Overlapping `OpenTableBuilder` options (e.g. [AddDataBuilder::mode]) will take
/// precedence over their counterparts in `WriteOptions` (e.g. [WriteParams::mode]).
pub lance_write_params: Option<WriteParams>, pub lance_write_params: Option<WriteParams>,
} }
@@ -366,16 +369,16 @@ pub(crate) trait TableInternal: std::fmt::Display + std::fmt::Debug + Send + Syn
async fn schema(&self) -> Result<SchemaRef>; async fn schema(&self) -> Result<SchemaRef>;
/// Count the number of rows in this table. /// Count the number of rows in this table.
async fn count_rows(&self, filter: Option<String>) -> Result<usize>; async fn count_rows(&self, filter: Option<String>) -> Result<usize>;
async fn create_plan(
&self,
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<Arc<dyn ExecutionPlan>>;
async fn plain_query( async fn plain_query(
&self, &self,
query: &Query, query: &Query,
options: QueryExecutionOptions, options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream>; ) -> Result<DatasetRecordBatchStream>;
async fn vector_query(
&self,
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream>;
async fn add( async fn add(
&self, &self,
add: AddDataBuilder<NoData>, add: AddDataBuilder<NoData>,
@@ -1203,28 +1206,36 @@ impl NativeTable {
.await) .await)
} }
#[deprecated(since = "0.5.2", note = "Please use `index_stats` instead")]
pub async fn count_indexed_rows(&self, index_uuid: &str) -> Result<Option<usize>> { pub async fn count_indexed_rows(&self, index_uuid: &str) -> Result<Option<usize>> {
#[allow(deprecated)]
match self.load_index_stats(index_uuid).await? { match self.load_index_stats(index_uuid).await? {
Some(stats) => Ok(Some(stats.num_indexed_rows)), Some(stats) => Ok(Some(stats.num_indexed_rows)),
None => Ok(None), None => Ok(None),
} }
} }
#[deprecated(since = "0.5.2", note = "Please use `index_stats` instead")]
pub async fn count_unindexed_rows(&self, index_uuid: &str) -> Result<Option<usize>> { pub async fn count_unindexed_rows(&self, index_uuid: &str) -> Result<Option<usize>> {
#[allow(deprecated)]
match self.load_index_stats(index_uuid).await? { match self.load_index_stats(index_uuid).await? {
Some(stats) => Ok(Some(stats.num_unindexed_rows)), Some(stats) => Ok(Some(stats.num_unindexed_rows)),
None => Ok(None), None => Ok(None),
} }
} }
#[deprecated(since = "0.5.2", note = "Please use `index_stats` instead")]
pub async fn get_index_type(&self, index_uuid: &str) -> Result<Option<String>> { pub async fn get_index_type(&self, index_uuid: &str) -> Result<Option<String>> {
#[allow(deprecated)]
match self.load_index_stats(index_uuid).await? { match self.load_index_stats(index_uuid).await? {
Some(stats) => Ok(Some(stats.index_type.unwrap_or_default())), Some(stats) => Ok(Some(stats.index_type.unwrap_or_default())),
None => Ok(None), None => Ok(None),
} }
} }
#[deprecated(since = "0.5.2", note = "Please use `index_stats` instead")]
pub async fn get_distance_type(&self, index_uuid: &str) -> Result<Option<String>> { pub async fn get_distance_type(&self, index_uuid: &str) -> Result<Option<String>> {
#[allow(deprecated)]
match self.load_index_stats(index_uuid).await? { match self.load_index_stats(index_uuid).await? {
Some(stats) => Ok(Some( Some(stats) => Ok(Some(
stats stats
@@ -1237,16 +1248,8 @@ impl NativeTable {
} }
} }
pub async fn load_indices(&self) -> Result<Vec<VectorIndex>> { #[deprecated(since = "0.5.2", note = "Please use `index_stats` instead")]
let dataset = self.dataset.get().await?; pub async fn load_index_stats(&self, index_uuid: &str) -> Result<Option<IndexStatistics>> {
let (indices, mf) = futures::try_join!(dataset.load_indices(), dataset.latest_manifest())?;
Ok(indices
.iter()
.map(|i| VectorIndex::new_from_format(&mf, i))
.collect())
}
async fn load_index_stats(&self, index_uuid: &str) -> Result<Option<IndexStatistics>> {
let index = self let index = self
.load_indices() .load_indices()
.await? .await?
@@ -1265,6 +1268,35 @@ impl NativeTable {
Ok(Some(index_stats)) Ok(Some(index_stats))
} }
/// Get statistics about an index.
/// Returns an error if the index does not exist.
pub async fn index_stats<S: AsRef<str>>(
&self,
index_name: S,
) -> Result<Option<IndexStatistics>> {
self.dataset
.get()
.await?
.index_statistics(index_name.as_ref())
.await
.ok()
.map(|stats| {
serde_json::from_str(&stats).map_err(|e| Error::InvalidInput {
message: format!("error deserializing index statistics: {}", e),
})
})
.transpose()
}
pub async fn load_indices(&self) -> Result<Vec<VectorIndex>> {
let dataset = self.dataset.get().await?;
let (indices, mf) = futures::try_join!(dataset.load_indices(), dataset.latest_manifest())?;
Ok(indices
.iter()
.map(|i| VectorIndex::new_from_format(&mf, i))
.collect())
}
async fn create_ivf_pq_index( async fn create_ivf_pq_index(
&self, &self,
index: IvfPqIndexBuilder, index: IvfPqIndexBuilder,
@@ -1479,79 +1511,11 @@ impl NativeTable {
query: &VectorQuery, query: &VectorQuery,
options: QueryExecutionOptions, options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> { ) -> Result<DatasetRecordBatchStream> {
let ds_ref = self.dataset.get().await?; let plan = self.create_plan(query, options).await?;
let mut scanner: Scanner = ds_ref.scan(); Ok(DatasetRecordBatchStream::new(execute_plan(
plan,
if let Some(query_vector) = query.query_vector.as_ref() { Default::default(),
// If there is a vector query, default to limit=10 if unspecified )?))
let column = if let Some(col) = query.column.as_ref() {
col.clone()
} else {
// Infer a vector column with the same dimension of the query vector.
let arrow_schema = Schema::from(ds_ref.schema());
default_vector_column(&arrow_schema, Some(query_vector.len() as i32))?
};
let field = ds_ref.schema().field(&column).ok_or(Error::Schema {
message: format!("Column {} not found in dataset schema", column),
})?;
if let arrow_schema::DataType::FixedSizeList(f, dim) = field.data_type() {
if !f.data_type().is_floating() {
return Err(Error::InvalidInput {
message: format!(
"The data type of the vector column '{}' is not a floating point type",
column
),
});
}
if dim != query_vector.len() as i32 {
return Err(Error::InvalidInput {
message: format!(
"The dimension of the query vector does not match with the dimension of the vector column '{}':
query dim={}, expected vector dim={}",
column,
query_vector.len(),
dim,
),
});
}
}
let query_vector = query_vector.as_primitive::<Float32Type>();
scanner.nearest(
&column,
query_vector,
query.base.limit.unwrap_or(DEFAULT_TOP_K),
)?;
} else {
// If there is no vector query, it's ok to not have a limit
scanner.limit(query.base.limit.map(|limit| limit as i64), None)?;
}
scanner.nprobs(query.nprobes);
scanner.use_index(query.use_index);
scanner.prefilter(query.prefilter);
scanner.batch_size(options.max_batch_length as usize);
match &query.base.select {
Select::Columns(select) => {
scanner.project(select.as_slice())?;
}
Select::Dynamic(select_with_transform) => {
scanner.project_with_transform(select_with_transform.as_slice())?;
}
Select::All => { /* Do nothing */ }
}
if let Some(filter) = &query.base.filter {
scanner.filter(filter)?;
}
if let Some(refine_factor) = query.refine_factor {
scanner.refine(refine_factor);
}
if let Some(distance_type) = query.distance_type {
scanner.distance_metric(distance_type.into());
}
Ok(scanner.try_into_stream().await?)
} }
} }
@@ -1703,6 +1667,86 @@ impl TableInternal for NativeTable {
Ok(()) Ok(())
} }
async fn create_plan(
&self,
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
let ds_ref = self.dataset.get().await?;
let mut scanner: Scanner = ds_ref.scan();
if let Some(query_vector) = query.query_vector.as_ref() {
// If there is a vector query, default to limit=10 if unspecified
let column = if let Some(col) = query.column.as_ref() {
col.clone()
} else {
// Infer a vector column with the same dimension of the query vector.
let arrow_schema = Schema::from(ds_ref.schema());
default_vector_column(&arrow_schema, Some(query_vector.len() as i32))?
};
let field = ds_ref.schema().field(&column).ok_or(Error::Schema {
message: format!("Column {} not found in dataset schema", column),
})?;
if let arrow_schema::DataType::FixedSizeList(f, dim) = field.data_type() {
if !f.data_type().is_floating() {
return Err(Error::InvalidInput {
message: format!(
"The data type of the vector column '{}' is not a floating point type",
column
),
});
}
if dim != query_vector.len() as i32 {
return Err(Error::InvalidInput {
message: format!(
"The dimension of the query vector does not match with the dimension of the vector column '{}': \
query dim={}, expected vector dim={}",
column,
query_vector.len(),
dim,
),
});
}
}
let query_vector = query_vector.as_primitive::<Float32Type>();
scanner.nearest(
&column,
query_vector,
query.base.limit.unwrap_or(DEFAULT_TOP_K),
)?;
} else {
// If there is no vector query, it's ok to not have a limit
scanner.limit(query.base.limit.map(|limit| limit as i64), None)?;
}
scanner.nprobs(query.nprobes);
scanner.use_index(query.use_index);
scanner.prefilter(query.prefilter);
scanner.batch_size(options.max_batch_length as usize);
match &query.base.select {
Select::Columns(select) => {
scanner.project(select.as_slice())?;
}
Select::Dynamic(select_with_transform) => {
scanner.project_with_transform(select_with_transform.as_slice())?;
}
Select::All => { /* Do nothing */ }
}
if let Some(filter) = &query.base.filter {
scanner.filter(filter)?;
}
if let Some(refine_factor) = query.refine_factor {
scanner.refine(refine_factor);
}
if let Some(distance_type) = query.distance_type {
scanner.distance_metric(distance_type.into());
}
Ok(scanner.create_plan().await?)
}
async fn plain_query( async fn plain_query(
&self, &self,
query: &Query, query: &Query,
@@ -1712,14 +1756,6 @@ impl TableInternal for NativeTable {
.await .await
} }
async fn vector_query(
&self,
query: &VectorQuery,
options: QueryExecutionOptions,
) -> Result<DatasetRecordBatchStream> {
self.generic_query(query, options).await
}
async fn merge_insert( async fn merge_insert(
&self, &self,
params: MergeInsertBuilder, params: MergeInsertBuilder,
@@ -1853,12 +1889,20 @@ impl TableInternal for NativeTable {
} }
columns.push(field.name.clone()); columns.push(field.name.clone());
} }
Ok(IndexConfig { index_type: if is_vector { crate::index::IndexType::IvfPq } else { crate::index::IndexType::BTree }, columns }) let index_type = if is_vector {
crate::index::IndexType::IvfPq
} else {
crate::index::IndexType::BTree
};
let name = idx.name.clone();
Ok(IndexConfig { index_type, columns, name })
}).collect::<Result<Vec<_>>>() }).collect::<Result<Vec<_>>>()
} }
} }
#[cfg(test)] #[cfg(test)]
#[allow(deprecated)]
mod tests { mod tests {
use std::iter; use std::iter;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
@@ -2550,8 +2594,7 @@ mod tests {
.unwrap() .unwrap()
.get_index_type(index_uuid) .get_index_type(index_uuid)
.await .await
.unwrap() .unwrap(),
.map(|index_type| index_type.to_string()),
Some("IVF".to_string()) Some("IVF".to_string())
); );
assert_eq!( assert_eq!(

View File

@@ -66,6 +66,19 @@ impl DatasetRef {
Ok(()) Ok(())
} }
fn is_latest(&self) -> bool {
matches!(self, Self::Latest { .. })
}
async fn need_reload(&self) -> Result<bool> {
Ok(match self {
Self::Latest { dataset, .. } => {
dataset.latest_version_id().await? != dataset.version().version
}
Self::TimeTravel { dataset, version } => dataset.version().version != *version,
})
}
async fn as_latest(&mut self, read_consistency_interval: Option<Duration>) -> Result<()> { async fn as_latest(&mut self, read_consistency_interval: Option<Duration>) -> Result<()> {
match self { match self {
Self::Latest { .. } => Ok(()), Self::Latest { .. } => Ok(()),
@@ -129,7 +142,7 @@ impl DatasetConsistencyWrapper {
Self(Arc::new(RwLock::new(DatasetRef::Latest { Self(Arc::new(RwLock::new(DatasetRef::Latest {
dataset, dataset,
read_consistency_interval, read_consistency_interval,
last_consistency_check: None, last_consistency_check: Some(Instant::now()),
}))) })))
} }
@@ -163,11 +176,16 @@ impl DatasetConsistencyWrapper {
/// Convert into a wrapper in latest version mode /// Convert into a wrapper in latest version mode
pub async fn as_latest(&self, read_consistency_interval: Option<Duration>) -> Result<()> { pub async fn as_latest(&self, read_consistency_interval: Option<Duration>) -> Result<()> {
self.0 if self.0.read().await.is_latest() {
.write() return Ok(());
.await }
.as_latest(read_consistency_interval)
.await let mut write_guard = self.0.write().await;
if write_guard.is_latest() {
return Ok(());
}
write_guard.as_latest(read_consistency_interval).await
} }
pub async fn as_time_travel(&self, target_version: u64) -> Result<()> { pub async fn as_time_travel(&self, target_version: u64) -> Result<()> {
@@ -183,7 +201,18 @@ impl DatasetConsistencyWrapper {
} }
pub async fn reload(&self) -> Result<()> { pub async fn reload(&self) -> Result<()> {
self.0.write().await.reload().await if !self.0.read().await.need_reload().await? {
return Ok(());
}
let mut write_guard = self.0.write().await;
// on lock escalation -- check if someone else has already reloaded
if !write_guard.need_reload().await? {
return Ok(());
}
// actually need reloading
write_guard.reload().await
} }
/// Returns the version, if in time travel mode, or None otherwise /// Returns the version, if in time travel mode, or None otherwise

View File

@@ -23,6 +23,7 @@ use super::TableInternal;
/// A builder used to create and run a merge insert operation /// A builder used to create and run a merge insert operation
/// ///
/// See [`super::Table::merge_insert`] for more context /// See [`super::Table::merge_insert`] for more context
#[derive(Debug, Clone)]
pub struct MergeInsertBuilder { pub struct MergeInsertBuilder {
table: Arc<dyn TableInternal>, table: Arc<dyn TableInternal>,
pub(super) on: Vec<String>, pub(super) on: Vec<String>,

View File

@@ -101,7 +101,7 @@ pub fn validate_table_name(name: &str) -> Result<()> {
Ok(()) Ok(())
} }
/// Find one default column to create index. /// Find one default column to create index or perform vector query.
pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result<String> { pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result<String> {
// Try to find one fixed size list array column. // Try to find one fixed size list array column.
let candidates = schema let candidates = schema
@@ -118,14 +118,17 @@ pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
if candidates.is_empty() { if candidates.is_empty() {
Err(Error::Schema { Err(Error::InvalidInput {
message: "No vector column found to create index".to_string(), message: format!(
"No vector column found to match with the query vector dimension: {}",
dim.unwrap_or_default()
),
}) })
} else if candidates.len() != 1 { } else if candidates.len() != 1 {
Err(Error::Schema { Err(Error::Schema {
message: format!( message: format!(
"More than one vector columns found, \ "More than one vector columns found, \
please specify which column to create index: {:?}", please specify which column to create index or query: {:?}",
candidates candidates
), ),
}) })

View File

@@ -302,7 +302,7 @@ impl EmbeddingFunction for MockEmbed {
fn dest_type(&self) -> Result<Cow<DataType>> { fn dest_type(&self) -> Result<Cow<DataType>> {
Ok(Cow::Borrowed(&self.dest_type)) Ok(Cow::Borrowed(&self.dest_type))
} }
fn embed(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> { fn compute_source_embeddings(&self, source: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
// We can't use the FixedSizeListBuilder here because it always adds a null bitmap // We can't use the FixedSizeListBuilder here because it always adds a null bitmap
// and we want to explicitly work with non-nullable arrays. // and we want to explicitly work with non-nullable arrays.
let len = source.len(); let len = source.len();
@@ -317,4 +317,9 @@ impl EmbeddingFunction for MockEmbed {
Ok(Arc::new(arr)) Ok(Arc::new(arr))
} }
#[allow(unused_variables)]
fn compute_query_embeddings(&self, input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
unimplemented!()
}
} }