mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 21:39:57 +00:00
Compare commits
25 Commits
v0.3.9
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a0608044a1 | ||
|
|
2e4ea7d2bc | ||
|
|
57e5695a54 | ||
|
|
ce58ea7c38 | ||
|
|
57207eff4a | ||
|
|
2d78bff120 | ||
|
|
7c09b9b9a9 | ||
|
|
bd0034a157 | ||
|
|
144b3b5d83 | ||
|
|
b6f0a31686 | ||
|
|
9ec526f73f | ||
|
|
600bfd7237 | ||
|
|
d087e7891d | ||
|
|
098e397cf0 | ||
|
|
63ee8fa6a1 | ||
|
|
693091db29 | ||
|
|
dca4533dbe | ||
|
|
f6bbe199dc | ||
|
|
366e522c2b | ||
|
|
244b6919cc | ||
|
|
aca785ff98 | ||
|
|
bbdebf2c38 | ||
|
|
1336cce0dc | ||
|
|
6c83b6a513 | ||
|
|
6bec4bec51 |
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.3.9
|
current_version = 0.3.11
|
||||||
commit = True
|
commit = True
|
||||||
message = Bump version: {current_version} → {new_version}
|
message = Bump version: {current_version} → {new_version}
|
||||||
tag = True
|
tag = True
|
||||||
|
|||||||
20
.github/workflows/npm-publish.yml
vendored
20
.github/workflows/npm-publish.yml
vendored
@@ -38,13 +38,17 @@ jobs:
|
|||||||
node/vectordb-*.tgz
|
node/vectordb-*.tgz
|
||||||
|
|
||||||
node-macos:
|
node-macos:
|
||||||
runs-on: macos-13
|
strategy:
|
||||||
|
matrix:
|
||||||
|
config:
|
||||||
|
- arch: x86_64-apple-darwin
|
||||||
|
runner: macos-13
|
||||||
|
- arch: aarch64-apple-darwin
|
||||||
|
# xlarge is implicitly arm64.
|
||||||
|
runner: macos-13-xlarge
|
||||||
|
runs-on: ${{ matrix.config.runner }}
|
||||||
# Only runs on tags that matches the make-release action
|
# Only runs on tags that matches the make-release action
|
||||||
if: startsWith(github.ref, 'refs/tags/v')
|
if: startsWith(github.ref, 'refs/tags/v')
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
target: [x86_64-apple-darwin, aarch64-apple-darwin]
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
@@ -54,11 +58,8 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
cd node
|
cd node
|
||||||
npm ci
|
npm ci
|
||||||
- name: Install rustup target
|
|
||||||
if: ${{ matrix.target == 'aarch64-apple-darwin' }}
|
|
||||||
run: rustup target add aarch64-apple-darwin
|
|
||||||
- name: Build MacOS native node modules
|
- name: Build MacOS native node modules
|
||||||
run: bash ci/build_macos_artifacts.sh ${{ matrix.target }}
|
run: bash ci/build_macos_artifacts.sh ${{ matrix.config.arch }}
|
||||||
- name: Upload Darwin Artifacts
|
- name: Upload Darwin Artifacts
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v3
|
||||||
with:
|
with:
|
||||||
@@ -66,6 +67,7 @@ jobs:
|
|||||||
path: |
|
path: |
|
||||||
node/dist/lancedb-vectordb-darwin*.tgz
|
node/dist/lancedb-vectordb-darwin*.tgz
|
||||||
|
|
||||||
|
|
||||||
node-linux:
|
node-linux:
|
||||||
name: node-linux (${{ matrix.config.arch}}-unknown-linux-gnu
|
name: node-linux (${{ matrix.config.arch}}-unknown-linux-gnu
|
||||||
runs-on: ${{ matrix.config.runner }}
|
runs-on: ${{ matrix.config.runner }}
|
||||||
|
|||||||
6
.github/workflows/python.yml
vendored
6
.github/workflows/python.yml
vendored
@@ -91,11 +91,7 @@ jobs:
|
|||||||
pip install "pydantic<2"
|
pip install "pydantic<2"
|
||||||
pip install -e .[tests]
|
pip install -e .[tests]
|
||||||
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
|
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
|
||||||
pip install pytest pytest-mock black isort
|
pip install pytest pytest-mock
|
||||||
- name: Black
|
|
||||||
run: black --check --diff --no-color --quiet .
|
|
||||||
- name: isort
|
|
||||||
run: isort --check --diff --quiet .
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: pytest -m "not slow" -x -v --durations=30 tests
|
run: pytest -m "not slow" -x -v --durations=30 tests
|
||||||
- name: doctest
|
- name: doctest
|
||||||
|
|||||||
@@ -5,10 +5,10 @@ exclude = ["python"]
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
lance = { "version" = "=0.8.17", "features" = ["dynamodb"] }
|
lance = { "version" = "=0.8.20", "features" = ["dynamodb"] }
|
||||||
lance-index = { "version" = "=0.8.17" }
|
lance-index = { "version" = "=0.8.20" }
|
||||||
lance-linalg = { "version" = "=0.8.17" }
|
lance-linalg = { "version" = "=0.8.20" }
|
||||||
lance-testing = { "version" = "=0.8.17" }
|
lance-testing = { "version" = "=0.8.20" }
|
||||||
# Note that this one does not include pyarrow
|
# Note that this one does not include pyarrow
|
||||||
arrow = { version = "47.0.0", optional = false }
|
arrow = { version = "47.0.0", optional = false }
|
||||||
arrow-array = "47.0"
|
arrow-array = "47.0"
|
||||||
|
|||||||
@@ -5,10 +5,11 @@
|
|||||||
|
|
||||||
**Developer-friendly, serverless vector database for AI applications**
|
**Developer-friendly, serverless vector database for AI applications**
|
||||||
|
|
||||||
<a href="https://lancedb.github.io/lancedb/">Documentation</a> •
|
<a href='https://github.com/lancedb/vectordb-recipes/tree/main' target="_blank"><img alt='LanceDB' src='https://img.shields.io/badge/VectorDB_Recipes-100000?style=for-the-badge&logo=LanceDB&logoColor=white&labelColor=645cfb&color=645cfb'/></a>
|
||||||
<a href="https://blog.lancedb.com/">Blog</a> •
|
<a href='https://lancedb.github.io/lancedb/' target="_blank"><img alt='lancdb' src='https://img.shields.io/badge/DOCS-100000?style=for-the-badge&logo=lancdb&logoColor=white&labelColor=645cfb&color=645cfb'/></a>
|
||||||
<a href="https://discord.gg/zMM32dvNtd">Discord</a> •
|
[](https://blog.lancedb.com/)
|
||||||
<a href="https://twitter.com/lancedb">Twitter</a>
|
[](https://discord.gg/zMM32dvNtd)
|
||||||
|
[](https://twitter.com/lancedb)
|
||||||
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# Builds the macOS artifacts (node binaries).
|
# Builds the macOS artifacts (node binaries).
|
||||||
# Usage: ./ci/build_macos_artifacts.sh [target]
|
# Usage: ./ci/build_macos_artifacts.sh [target]
|
||||||
# Targets supported: x86_64-apple-darwin aarch64-apple-darwin
|
# Targets supported: x86_64-apple-darwin aarch64-apple-darwin
|
||||||
|
set -e
|
||||||
|
|
||||||
prebuild_rust() {
|
prebuild_rust() {
|
||||||
# Building here for the sake of easier debugging.
|
# Building here for the sake of easier debugging.
|
||||||
|
|||||||
@@ -80,7 +80,6 @@ nav:
|
|||||||
- Ingest Embedding Functions: embeddings/embedding_functions.md
|
- Ingest Embedding Functions: embeddings/embedding_functions.md
|
||||||
- Available Functions: embeddings/default_embedding_functions.md
|
- Available Functions: embeddings/default_embedding_functions.md
|
||||||
- Create Custom Embedding Functions: embeddings/api.md
|
- Create Custom Embedding Functions: embeddings/api.md
|
||||||
- Example - Calculate CLIP Embeddings with Roboflow Inference: examples/image_embeddings_roboflow.md
|
|
||||||
- Example - Multi-lingual semantic search: notebooks/multi_lingual_example.ipynb
|
- Example - Multi-lingual semantic search: notebooks/multi_lingual_example.ipynb
|
||||||
- Example - MultiModal CLIP Embeddings: notebooks/DisappearingEmbeddingFunction.ipynb
|
- Example - MultiModal CLIP Embeddings: notebooks/DisappearingEmbeddingFunction.ipynb
|
||||||
- 🔍 Python full-text search: fts.md
|
- 🔍 Python full-text search: fts.md
|
||||||
@@ -99,6 +98,7 @@ nav:
|
|||||||
- YouTube Transcript Search: notebooks/youtube_transcript_search.ipynb
|
- YouTube Transcript Search: notebooks/youtube_transcript_search.ipynb
|
||||||
- Documentation QA Bot using LangChain: notebooks/code_qa_bot.ipynb
|
- Documentation QA Bot using LangChain: notebooks/code_qa_bot.ipynb
|
||||||
- Multimodal search using CLIP: notebooks/multimodal_search.ipynb
|
- Multimodal search using CLIP: notebooks/multimodal_search.ipynb
|
||||||
|
- Example - Calculate CLIP Embeddings with Roboflow Inference: examples/image_embeddings_roboflow.md
|
||||||
- Serverless QA Bot with S3 and Lambda: examples/serverless_lancedb_with_s3_and_lambda.md
|
- Serverless QA Bot with S3 and Lambda: examples/serverless_lancedb_with_s3_and_lambda.md
|
||||||
- Serverless QA Bot with Modal: examples/serverless_qa_bot_with_modal_and_langchain.md
|
- Serverless QA Bot with Modal: examples/serverless_qa_bot_with_modal_and_langchain.md
|
||||||
- 🌐 Javascript examples:
|
- 🌐 Javascript examples:
|
||||||
@@ -146,7 +146,8 @@ nav:
|
|||||||
- Serverless Chatbot from any website: examples/serverless_website_chatbot.md
|
- Serverless Chatbot from any website: examples/serverless_website_chatbot.md
|
||||||
- TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md
|
- TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md
|
||||||
- API references:
|
- API references:
|
||||||
- Python API: python/python.md
|
- OSS Python API: python/python.md
|
||||||
|
- SaaS Python API: python/saas-python.md
|
||||||
- Javascript API: javascript/modules.md
|
- Javascript API: javascript/modules.md
|
||||||
- LanceDB Cloud↗: https://noteforms.com/forms/lancedb-mailing-list-cloud-kty1o5?notionforms=1&utm_source=notionforms
|
- LanceDB Cloud↗: https://noteforms.com/forms/lancedb-mailing-list-cloud-kty1o5?notionforms=1&utm_source=notionforms
|
||||||
|
|
||||||
|
|||||||
18
docs/src/python/saas-python.md
Normal file
18
docs/src/python/saas-python.md
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
# LanceDB Python API Reference
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install lancedb
|
||||||
|
```
|
||||||
|
|
||||||
|
## Connection
|
||||||
|
|
||||||
|
::: lancedb.connect
|
||||||
|
|
||||||
|
::: lancedb.remote.db.RemoteDBConnection
|
||||||
|
|
||||||
|
## Table
|
||||||
|
|
||||||
|
::: lancedb.remote.table.RemoteTable
|
||||||
|
|
||||||
86
node/package-lock.json
generated
86
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.3.8",
|
"version": "0.3.11",
|
||||||
"lockfileVersion": 2,
|
"lockfileVersion": 2,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.3.8",
|
"version": "0.3.11",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64",
|
"x64",
|
||||||
"arm64"
|
"arm64"
|
||||||
@@ -53,11 +53,11 @@
|
|||||||
"uuid": "^9.0.0"
|
"uuid": "^9.0.0"
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
"optionalDependencies": {
|
||||||
"@lancedb/vectordb-darwin-arm64": "0.3.8",
|
"@lancedb/vectordb-darwin-arm64": "0.3.11",
|
||||||
"@lancedb/vectordb-darwin-x64": "0.3.8",
|
"@lancedb/vectordb-darwin-x64": "0.3.11",
|
||||||
"@lancedb/vectordb-linux-arm64-gnu": "0.3.8",
|
"@lancedb/vectordb-linux-arm64-gnu": "0.3.11",
|
||||||
"@lancedb/vectordb-linux-x64-gnu": "0.3.8",
|
"@lancedb/vectordb-linux-x64-gnu": "0.3.11",
|
||||||
"@lancedb/vectordb-win32-x64-msvc": "0.3.8"
|
"@lancedb/vectordb-win32-x64-msvc": "0.3.11"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/@apache-arrow/ts": {
|
"node_modules/@apache-arrow/ts": {
|
||||||
@@ -316,54 +316,6 @@
|
|||||||
"@jridgewell/sourcemap-codec": "^1.4.10"
|
"@jridgewell/sourcemap-codec": "^1.4.10"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-darwin-x64": {
|
|
||||||
"version": "0.3.8",
|
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.8.tgz",
|
|
||||||
"integrity": "sha512-PCJwJ2oV0yTq0XJryMMjLad14i9s6xRolFZ1M4EZtgN16X/n/m0xTZjU8Y95Fj28tPFMgd4Pmgtc/TWuEBxW8A==",
|
|
||||||
"cpu": [
|
|
||||||
"x64"
|
|
||||||
],
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"darwin"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
|
||||||
"version": "0.3.8",
|
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.8.tgz",
|
|
||||||
"integrity": "sha512-P+ZvI9+g8MDjPvz5+HPNFCCPNAvCDpCfIvKqEiTGEm2Sk5I0meIxRX4VGEnLGcQZmF1LUnVzhKV9+Rkiqd4JIQ==",
|
|
||||||
"cpu": [
|
|
||||||
"arm64"
|
|
||||||
],
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"linux"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
|
||||||
"version": "0.3.8",
|
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.8.tgz",
|
|
||||||
"integrity": "sha512-E7opS6JuNpyvej0JZ+DJxplnnFp543dlPW0hNxoxsndflo9NeeAa1AIsNQSCIABWlfsQbGxXPYrvsOKHbzAIdw==",
|
|
||||||
"cpu": [
|
|
||||||
"x64"
|
|
||||||
],
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"linux"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
|
||||||
"version": "0.3.8",
|
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.8.tgz",
|
|
||||||
"integrity": "sha512-RAL8U46UE12ksO3VAnnLlfxDd4wxZpJNFYtXjkacKL4ud9PkAJC4FBJpD7EFP9c7LEY3IlJPtvAp5Ax9LGWFeA==",
|
|
||||||
"cpu": [
|
|
||||||
"x64"
|
|
||||||
],
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"win32"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"node_modules/@neon-rs/cli": {
|
"node_modules/@neon-rs/cli": {
|
||||||
"version": "0.0.160",
|
"version": "0.0.160",
|
||||||
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",
|
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",
|
||||||
@@ -4856,30 +4808,6 @@
|
|||||||
"@jridgewell/sourcemap-codec": "^1.4.10"
|
"@jridgewell/sourcemap-codec": "^1.4.10"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"@lancedb/vectordb-darwin-x64": {
|
|
||||||
"version": "0.3.8",
|
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.8.tgz",
|
|
||||||
"integrity": "sha512-PCJwJ2oV0yTq0XJryMMjLad14i9s6xRolFZ1M4EZtgN16X/n/m0xTZjU8Y95Fj28tPFMgd4Pmgtc/TWuEBxW8A==",
|
|
||||||
"optional": true
|
|
||||||
},
|
|
||||||
"@lancedb/vectordb-linux-arm64-gnu": {
|
|
||||||
"version": "0.3.8",
|
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.8.tgz",
|
|
||||||
"integrity": "sha512-P+ZvI9+g8MDjPvz5+HPNFCCPNAvCDpCfIvKqEiTGEm2Sk5I0meIxRX4VGEnLGcQZmF1LUnVzhKV9+Rkiqd4JIQ==",
|
|
||||||
"optional": true
|
|
||||||
},
|
|
||||||
"@lancedb/vectordb-linux-x64-gnu": {
|
|
||||||
"version": "0.3.8",
|
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.8.tgz",
|
|
||||||
"integrity": "sha512-E7opS6JuNpyvej0JZ+DJxplnnFp543dlPW0hNxoxsndflo9NeeAa1AIsNQSCIABWlfsQbGxXPYrvsOKHbzAIdw==",
|
|
||||||
"optional": true
|
|
||||||
},
|
|
||||||
"@lancedb/vectordb-win32-x64-msvc": {
|
|
||||||
"version": "0.3.8",
|
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.8.tgz",
|
|
||||||
"integrity": "sha512-RAL8U46UE12ksO3VAnnLlfxDd4wxZpJNFYtXjkacKL4ud9PkAJC4FBJpD7EFP9c7LEY3IlJPtvAp5Ax9LGWFeA==",
|
|
||||||
"optional": true
|
|
||||||
},
|
|
||||||
"@neon-rs/cli": {
|
"@neon-rs/cli": {
|
||||||
"version": "0.0.160",
|
"version": "0.0.160",
|
||||||
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",
|
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.3.9",
|
"version": "0.3.11",
|
||||||
"description": " Serverless, low-latency vector database for AI applications",
|
"description": " Serverless, low-latency vector database for AI applications",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"types": "dist/index.d.ts",
|
"types": "dist/index.d.ts",
|
||||||
@@ -81,10 +81,10 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
"optionalDependencies": {
|
||||||
"@lancedb/vectordb-darwin-arm64": "0.3.9",
|
"@lancedb/vectordb-darwin-arm64": "0.3.11",
|
||||||
"@lancedb/vectordb-darwin-x64": "0.3.9",
|
"@lancedb/vectordb-darwin-x64": "0.3.11",
|
||||||
"@lancedb/vectordb-linux-arm64-gnu": "0.3.9",
|
"@lancedb/vectordb-linux-arm64-gnu": "0.3.11",
|
||||||
"@lancedb/vectordb-linux-x64-gnu": "0.3.9",
|
"@lancedb/vectordb-linux-x64-gnu": "0.3.11",
|
||||||
"@lancedb/vectordb-win32-x64-msvc": "0.3.9"
|
"@lancedb/vectordb-win32-x64-msvc": "0.3.11"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,9 +21,10 @@ import type { EmbeddingFunction } from './embedding/embedding_function'
|
|||||||
import { RemoteConnection } from './remote'
|
import { RemoteConnection } from './remote'
|
||||||
import { Query } from './query'
|
import { Query } from './query'
|
||||||
import { isEmbeddingFunction } from './embedding/embedding_function'
|
import { isEmbeddingFunction } from './embedding/embedding_function'
|
||||||
|
import { type Literal, toSQL } from './util'
|
||||||
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-var-requires
|
// eslint-disable-next-line @typescript-eslint/no-var-requires
|
||||||
const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, tableCreate, tableAdd, tableCreateVectorIndex, tableCountRows, tableDelete, tableCleanupOldVersions, tableCompactFiles, tableListIndices, tableIndexStats } = require('../native.js')
|
const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, tableCreate, tableAdd, tableCreateVectorIndex, tableCountRows, tableDelete, tableUpdate, tableCleanupOldVersions, tableCompactFiles, tableListIndices, tableIndexStats } = require('../native.js')
|
||||||
|
|
||||||
export { Query }
|
export { Query }
|
||||||
export type { EmbeddingFunction }
|
export type { EmbeddingFunction }
|
||||||
@@ -261,6 +262,39 @@ export interface Table<T = number[]> {
|
|||||||
*/
|
*/
|
||||||
delete: (filter: string) => Promise<void>
|
delete: (filter: string) => Promise<void>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update rows in this table.
|
||||||
|
*
|
||||||
|
* This can be used to update a single row, many rows, all rows, or
|
||||||
|
* sometimes no rows (if your predicate matches nothing).
|
||||||
|
*
|
||||||
|
* @param args see {@link UpdateArgs} and {@link UpdateSqlArgs} for more details
|
||||||
|
*
|
||||||
|
* @examples
|
||||||
|
*
|
||||||
|
* ```ts
|
||||||
|
* const con = await lancedb.connect("./.lancedb")
|
||||||
|
* const data = [
|
||||||
|
* {id: 1, vector: [3, 3], name: 'Ye'},
|
||||||
|
* {id: 2, vector: [4, 4], name: 'Mike'},
|
||||||
|
* ];
|
||||||
|
* const tbl = await con.createTable("my_table", data)
|
||||||
|
*
|
||||||
|
* await tbl.update({
|
||||||
|
* filter: "id = 2",
|
||||||
|
* updates: { vector: [2, 2], name: "Michael" },
|
||||||
|
* })
|
||||||
|
*
|
||||||
|
* let results = await tbl.search([1, 1]).execute();
|
||||||
|
* // Returns [
|
||||||
|
* // {id: 2, vector: [2, 2], name: 'Michael'}
|
||||||
|
* // {id: 1, vector: [3, 3], name: 'Ye'}
|
||||||
|
* // ]
|
||||||
|
* ```
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
update: (args: UpdateArgs | UpdateSqlArgs) => Promise<void>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* List the indicies on this table.
|
* List the indicies on this table.
|
||||||
*/
|
*/
|
||||||
@@ -272,6 +306,34 @@ export interface Table<T = number[]> {
|
|||||||
indexStats: (indexUuid: string) => Promise<IndexStats>
|
indexStats: (indexUuid: string) => Promise<IndexStats>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface UpdateArgs {
|
||||||
|
/**
|
||||||
|
* A filter in the same format used by a sql WHERE clause. The filter may be empty,
|
||||||
|
* in which case all rows will be updated.
|
||||||
|
*/
|
||||||
|
where?: string
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A key-value map of updates. The keys are the column names, and the values are the
|
||||||
|
* new values to set
|
||||||
|
*/
|
||||||
|
values: Record<string, Literal>
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface UpdateSqlArgs {
|
||||||
|
/**
|
||||||
|
* A filter in the same format used by a sql WHERE clause. The filter may be empty,
|
||||||
|
* in which case all rows will be updated.
|
||||||
|
*/
|
||||||
|
where?: string
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A key-value map of updates. The keys are the column names, and the values are the
|
||||||
|
* new values to set as SQL expressions.
|
||||||
|
*/
|
||||||
|
valuesSql: Record<string, string>
|
||||||
|
}
|
||||||
|
|
||||||
export interface VectorIndex {
|
export interface VectorIndex {
|
||||||
columns: string[]
|
columns: string[]
|
||||||
name: string
|
name: string
|
||||||
@@ -426,6 +488,16 @@ export class LocalTable<T = number[]> implements Table<T> {
|
|||||||
return new Query(query, this._tbl, this._embeddings)
|
return new Query(query, this._tbl, this._embeddings)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a filter query to find all rows matching the specified criteria
|
||||||
|
* @param value The filter criteria (like SQL where clause syntax)
|
||||||
|
*/
|
||||||
|
filter (value: string): Query<T> {
|
||||||
|
return new Query(undefined, this._tbl, this._embeddings).filter(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
where = this.filter
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Insert records into this Table.
|
* Insert records into this Table.
|
||||||
*
|
*
|
||||||
@@ -481,6 +553,31 @@ export class LocalTable<T = number[]> implements Table<T> {
|
|||||||
return tableDelete.call(this._tbl, filter).then((newTable: any) => { this._tbl = newTable })
|
return tableDelete.call(this._tbl, filter).then((newTable: any) => { this._tbl = newTable })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update rows in this table.
|
||||||
|
*
|
||||||
|
* @param args see {@link UpdateArgs} and {@link UpdateSqlArgs} for more details
|
||||||
|
*
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
|
async update (args: UpdateArgs | UpdateSqlArgs): Promise<void> {
|
||||||
|
let filter: string | null
|
||||||
|
let updates: Record<string, string>
|
||||||
|
|
||||||
|
if ('valuesSql' in args) {
|
||||||
|
filter = args.where ?? null
|
||||||
|
updates = args.valuesSql
|
||||||
|
} else {
|
||||||
|
filter = args.where ?? null
|
||||||
|
updates = {}
|
||||||
|
for (const [key, value] of Object.entries(args.values)) {
|
||||||
|
updates[key] = toSQL(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tableUpdate.call(this._tbl, filter, updates).then((newTable: any) => { this._tbl = newTable })
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Clean up old versions of the table, freeing disk space.
|
* Clean up old versions of the table, freeing disk space.
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -23,10 +23,10 @@ const { tableSearch } = require('../native.js')
|
|||||||
* A builder for nearest neighbor queries for LanceDB.
|
* A builder for nearest neighbor queries for LanceDB.
|
||||||
*/
|
*/
|
||||||
export class Query<T = number[]> {
|
export class Query<T = number[]> {
|
||||||
private readonly _query: T
|
private readonly _query?: T
|
||||||
private readonly _tbl?: any
|
private readonly _tbl?: any
|
||||||
private _queryVector?: number[]
|
private _queryVector?: number[]
|
||||||
private _limit: number
|
private _limit?: number
|
||||||
private _refineFactor?: number
|
private _refineFactor?: number
|
||||||
private _nprobes: number
|
private _nprobes: number
|
||||||
private _select?: string[]
|
private _select?: string[]
|
||||||
@@ -35,10 +35,10 @@ export class Query<T = number[]> {
|
|||||||
private _prefilter: boolean
|
private _prefilter: boolean
|
||||||
protected readonly _embeddings?: EmbeddingFunction<T>
|
protected readonly _embeddings?: EmbeddingFunction<T>
|
||||||
|
|
||||||
constructor (query: T, tbl?: any, embeddings?: EmbeddingFunction<T>) {
|
constructor (query?: T, tbl?: any, embeddings?: EmbeddingFunction<T>) {
|
||||||
this._tbl = tbl
|
this._tbl = tbl
|
||||||
this._query = query
|
this._query = query
|
||||||
this._limit = 10
|
this._limit = undefined
|
||||||
this._nprobes = 20
|
this._nprobes = 20
|
||||||
this._refineFactor = undefined
|
this._refineFactor = undefined
|
||||||
this._select = undefined
|
this._select = undefined
|
||||||
@@ -113,10 +113,12 @@ export class Query<T = number[]> {
|
|||||||
* Execute the query and return the results as an Array of Objects
|
* Execute the query and return the results as an Array of Objects
|
||||||
*/
|
*/
|
||||||
async execute<T = Record<string, unknown>> (): Promise<T[]> {
|
async execute<T = Record<string, unknown>> (): Promise<T[]> {
|
||||||
if (this._embeddings !== undefined) {
|
if (this._query !== undefined) {
|
||||||
this._queryVector = (await this._embeddings.embed([this._query]))[0]
|
if (this._embeddings !== undefined) {
|
||||||
} else {
|
this._queryVector = (await this._embeddings.embed([this._query]))[0]
|
||||||
this._queryVector = this._query as number[]
|
} else {
|
||||||
|
this._queryVector = this._query as number[]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const isElectron = this.isElectron()
|
const isElectron = this.isElectron()
|
||||||
|
|||||||
@@ -16,7 +16,8 @@ import {
|
|||||||
type EmbeddingFunction, type Table, type VectorIndexParams, type Connection,
|
type EmbeddingFunction, type Table, type VectorIndexParams, type Connection,
|
||||||
type ConnectionOptions, type CreateTableOptions, type VectorIndex,
|
type ConnectionOptions, type CreateTableOptions, type VectorIndex,
|
||||||
type WriteOptions,
|
type WriteOptions,
|
||||||
type IndexStats
|
type IndexStats,
|
||||||
|
type UpdateArgs, type UpdateSqlArgs
|
||||||
} from '../index'
|
} from '../index'
|
||||||
import { Query } from '../query'
|
import { Query } from '../query'
|
||||||
|
|
||||||
@@ -24,6 +25,7 @@ import { Vector, Table as ArrowTable } from 'apache-arrow'
|
|||||||
import { HttpLancedbClient } from './client'
|
import { HttpLancedbClient } from './client'
|
||||||
import { isEmbeddingFunction } from '../embedding/embedding_function'
|
import { isEmbeddingFunction } from '../embedding/embedding_function'
|
||||||
import { createEmptyTable, fromRecordsToStreamBuffer, fromTableToStreamBuffer } from '../arrow'
|
import { createEmptyTable, fromRecordsToStreamBuffer, fromTableToStreamBuffer } from '../arrow'
|
||||||
|
import { toSQL } from '../util'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Remote connection.
|
* Remote connection.
|
||||||
@@ -246,6 +248,26 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
|||||||
await this._client.post(`/v1/table/${this._name}/delete/`, { predicate: filter })
|
await this._client.post(`/v1/table/${this._name}/delete/`, { predicate: filter })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async update (args: UpdateArgs | UpdateSqlArgs): Promise<void> {
|
||||||
|
let filter: string | null
|
||||||
|
let updates: Record<string, string>
|
||||||
|
|
||||||
|
if ('valuesSql' in args) {
|
||||||
|
filter = args.where ?? null
|
||||||
|
updates = args.valuesSql
|
||||||
|
} else {
|
||||||
|
filter = args.where ?? null
|
||||||
|
updates = {}
|
||||||
|
for (const [key, value] of Object.entries(args.values)) {
|
||||||
|
updates[key] = toSQL(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
await this._client.post(`/v1/table/${this._name}/update/`, {
|
||||||
|
predicate: filter,
|
||||||
|
updates: Object.entries(updates).map(([key, value]) => [key, value])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
async listIndices (): Promise<VectorIndex[]> {
|
async listIndices (): Promise<VectorIndex[]> {
|
||||||
const results = await this._client.post(`/v1/table/${this._name}/index/list/`)
|
const results = await this._client.post(`/v1/table/${this._name}/index/list/`)
|
||||||
return results.data.indexes?.map((index: any) => ({
|
return results.data.indexes?.map((index: any) => ({
|
||||||
|
|||||||
@@ -78,12 +78,31 @@ describe('LanceDB client', function () {
|
|||||||
})
|
})
|
||||||
|
|
||||||
it('limits # of results', async function () {
|
it('limits # of results', async function () {
|
||||||
const uri = await createTestDB()
|
const uri = await createTestDB(2, 100)
|
||||||
const con = await lancedb.connect(uri)
|
const con = await lancedb.connect(uri)
|
||||||
const table = await con.openTable('vectors')
|
const table = await con.openTable('vectors')
|
||||||
const results = await table.search([0.1, 0.3]).limit(1).execute()
|
let results = await table.search([0.1, 0.3]).limit(1).execute()
|
||||||
assert.equal(results.length, 1)
|
assert.equal(results.length, 1)
|
||||||
assert.equal(results[0].id, 1)
|
assert.equal(results[0].id, 1)
|
||||||
|
|
||||||
|
// there is a default limit if unspecified
|
||||||
|
results = await table.search([0.1, 0.3]).execute()
|
||||||
|
assert.equal(results.length, 10)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('uses a filter / where clause without vector search', async function () {
|
||||||
|
// eslint-disable-next-line @typescript-eslint/explicit-function-return-type
|
||||||
|
const assertResults = (results: Array<Record<string, unknown>>) => {
|
||||||
|
assert.equal(results.length, 50)
|
||||||
|
}
|
||||||
|
|
||||||
|
const uri = await createTestDB(2, 100)
|
||||||
|
const con = await lancedb.connect(uri)
|
||||||
|
const table = (await con.openTable('vectors')) as LocalTable
|
||||||
|
let results = await table.filter('id % 2 = 0').execute()
|
||||||
|
assertResults(results)
|
||||||
|
results = await table.where('id % 2 = 0').execute()
|
||||||
|
assertResults(results)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('uses a filter / where clause', async function () {
|
it('uses a filter / where clause', async function () {
|
||||||
@@ -260,6 +279,46 @@ describe('LanceDB client', function () {
|
|||||||
assert.equal(await table.countRows(), 2)
|
assert.equal(await table.countRows(), 2)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('can update records in the table', async function () {
|
||||||
|
const uri = await createTestDB()
|
||||||
|
const con = await lancedb.connect(uri)
|
||||||
|
|
||||||
|
const table = await con.openTable('vectors')
|
||||||
|
assert.equal(await table.countRows(), 2)
|
||||||
|
|
||||||
|
await table.update({ where: 'price = 10', valuesSql: { price: '100' } })
|
||||||
|
const results = await table.search([0.1, 0.2]).execute()
|
||||||
|
assert.equal(results[0].price, 100)
|
||||||
|
assert.equal(results[1].price, 11)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('can update the records using a literal value', async function () {
|
||||||
|
const uri = await createTestDB()
|
||||||
|
const con = await lancedb.connect(uri)
|
||||||
|
|
||||||
|
const table = await con.openTable('vectors')
|
||||||
|
assert.equal(await table.countRows(), 2)
|
||||||
|
|
||||||
|
await table.update({ where: 'price = 10', values: { price: 100 } })
|
||||||
|
const results = await table.search([0.1, 0.2]).execute()
|
||||||
|
assert.equal(results[0].price, 100)
|
||||||
|
assert.equal(results[1].price, 11)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('can update every record in the table', async function () {
|
||||||
|
const uri = await createTestDB()
|
||||||
|
const con = await lancedb.connect(uri)
|
||||||
|
|
||||||
|
const table = await con.openTable('vectors')
|
||||||
|
assert.equal(await table.countRows(), 2)
|
||||||
|
|
||||||
|
await table.update({ valuesSql: { price: '100' } })
|
||||||
|
const results = await table.search([0.1, 0.2]).execute()
|
||||||
|
|
||||||
|
assert.equal(results[0].price, 100)
|
||||||
|
assert.equal(results[1].price, 100)
|
||||||
|
})
|
||||||
|
|
||||||
it('can delete records from a table', async function () {
|
it('can delete records from a table', async function () {
|
||||||
const uri = await createTestDB()
|
const uri = await createTestDB()
|
||||||
const con = await lancedb.connect(uri)
|
const con = await lancedb.connect(uri)
|
||||||
@@ -542,7 +601,7 @@ describe('Compact and cleanup', function () {
|
|||||||
|
|
||||||
// should have no effect, but this validates the arguments are parsed.
|
// should have no effect, but this validates the arguments are parsed.
|
||||||
await table.compactFiles({
|
await table.compactFiles({
|
||||||
targetRowsPerFragment: 1024 * 10,
|
targetRowsPerFragment: 102410,
|
||||||
maxRowsPerGroup: 1024,
|
maxRowsPerGroup: 1024,
|
||||||
materializeDeletions: true,
|
materializeDeletions: true,
|
||||||
materializeDeletionsThreshold: 0.5,
|
materializeDeletionsThreshold: 0.5,
|
||||||
|
|||||||
45
node/src/test/util.ts
Normal file
45
node/src/test/util.ts
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
// Copyright 2023 LanceDB Developers.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
import { toSQL } from '../util'
|
||||||
|
import * as chai from 'chai'
|
||||||
|
|
||||||
|
const expect = chai.expect
|
||||||
|
|
||||||
|
describe('toSQL', function () {
|
||||||
|
it('should turn string to SQL expression', function () {
|
||||||
|
expect(toSQL('foo')).to.equal("'foo'")
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should turn number to SQL expression', function () {
|
||||||
|
expect(toSQL(123)).to.equal('123')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should turn boolean to SQL expression', function () {
|
||||||
|
expect(toSQL(true)).to.equal('TRUE')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should turn null to SQL expression', function () {
|
||||||
|
expect(toSQL(null)).to.equal('NULL')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should turn Date to SQL expression', function () {
|
||||||
|
const date = new Date('05 October 2011 14:48 UTC')
|
||||||
|
expect(toSQL(date)).to.equal("'2011-10-05T14:48:00.000Z'")
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should turn array to SQL expression', function () {
|
||||||
|
expect(toSQL(['foo', 'bar', true, 1])).to.equal("['foo', 'bar', TRUE, 1]")
|
||||||
|
})
|
||||||
|
})
|
||||||
44
node/src/util.ts
Normal file
44
node/src/util.ts
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
// Copyright 2023 LanceDB Developers.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
export type Literal = string | number | boolean | null | Date | Literal[]
|
||||||
|
|
||||||
|
export function toSQL (value: Literal): string {
|
||||||
|
if (typeof value === 'string') {
|
||||||
|
return `'${value}'`
|
||||||
|
}
|
||||||
|
|
||||||
|
if (typeof value === 'number') {
|
||||||
|
return value.toString()
|
||||||
|
}
|
||||||
|
|
||||||
|
if (typeof value === 'boolean') {
|
||||||
|
return value ? 'TRUE' : 'FALSE'
|
||||||
|
}
|
||||||
|
|
||||||
|
if (value === null) {
|
||||||
|
return 'NULL'
|
||||||
|
}
|
||||||
|
|
||||||
|
if (value instanceof Date) {
|
||||||
|
return `'${value.toISOString()}'`
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Array.isArray(value)) {
|
||||||
|
return `[${value.map(toSQL).join(', ')}]`
|
||||||
|
}
|
||||||
|
|
||||||
|
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||||
|
throw new Error(`Unsupported value type: ${typeof value} value: (${value})`)
|
||||||
|
}
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.3.4
|
current_version = 0.3.6
|
||||||
commit = True
|
commit = True
|
||||||
message = [python] Bump version: {current_version} → {new_version}
|
message = [python] Bump version: {current_version} → {new_version}
|
||||||
tag = True
|
tag = True
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ def connect(
|
|||||||
uri: URI,
|
uri: URI,
|
||||||
*,
|
*,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
region: str = "us-west-2",
|
region: str = "us-east-1",
|
||||||
host_override: Optional[str] = None,
|
host_override: Optional[str] = None,
|
||||||
) -> DBConnection:
|
) -> DBConnection:
|
||||||
"""Connect to a LanceDB database.
|
"""Connect to a LanceDB database.
|
||||||
@@ -39,7 +39,7 @@ def connect(
|
|||||||
api_key: str, optional
|
api_key: str, optional
|
||||||
If presented, connect to LanceDB cloud.
|
If presented, connect to LanceDB cloud.
|
||||||
Otherwise, connect to a database on file system or cloud storage.
|
Otherwise, connect to a database on file system or cloud storage.
|
||||||
region: str, default "us-west-2"
|
region: str, default "us-east-1"
|
||||||
The region to use for LanceDB Cloud.
|
The region to use for LanceDB Cloud.
|
||||||
host_override: str, optional
|
host_override: str, optional
|
||||||
The override url for LanceDB Cloud.
|
The override url for LanceDB Cloud.
|
||||||
|
|||||||
@@ -348,3 +348,20 @@ def get_extras(field_info: pydantic.fields.FieldInfo, key: str) -> Any:
|
|||||||
if PYDANTIC_VERSION.major >= 2:
|
if PYDANTIC_VERSION.major >= 2:
|
||||||
return (field_info.json_schema_extra or {}).get(key)
|
return (field_info.json_schema_extra or {}).get(key)
|
||||||
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
|
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
|
||||||
|
|
||||||
|
|
||||||
|
if PYDANTIC_VERSION.major < 2:
|
||||||
|
|
||||||
|
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert a Pydantic model to a dictionary.
|
||||||
|
"""
|
||||||
|
return model.dict()
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert a Pydantic model to a dictionary.
|
||||||
|
"""
|
||||||
|
return model.model_dump()
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ import attrs
|
|||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from lancedb.common import VECTOR_COLUMN_NAME
|
||||||
|
|
||||||
__all__ = ["LanceDBClient", "VectorQuery", "VectorQueryResult"]
|
__all__ = ["LanceDBClient", "VectorQuery", "VectorQueryResult"]
|
||||||
|
|
||||||
|
|
||||||
@@ -43,6 +45,8 @@ class VectorQuery(BaseModel):
|
|||||||
|
|
||||||
refine_factor: Optional[int] = None
|
refine_factor: Optional[int] = None
|
||||||
|
|
||||||
|
vector_column: str = VECTOR_COLUMN_NAME
|
||||||
|
|
||||||
|
|
||||||
@attrs.define
|
@attrs.define
|
||||||
class VectorQueryResult:
|
class VectorQueryResult:
|
||||||
|
|||||||
@@ -56,16 +56,20 @@ class RemoteDBConnection(DBConnection):
|
|||||||
self._loop = asyncio.get_event_loop()
|
self._loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"RemoveConnect(name={self.db_name})"
|
return f"RemoteConnect(name={self.db_name})"
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def table_names(self, page_token: Optional[str] = None, limit=10) -> Iterable[str]:
|
def table_names(
|
||||||
|
self, page_token: Optional[str] = None, limit: int = 10
|
||||||
|
) -> Iterable[str]:
|
||||||
"""List the names of all tables in the database.
|
"""List the names of all tables in the database.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
page_token: str
|
page_token: str
|
||||||
The last token to start the new page.
|
The last token to start the new page.
|
||||||
|
limit: int, default 10
|
||||||
|
The maximum number of tables to return for each page.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@@ -120,6 +124,97 @@ class RemoteDBConnection(DBConnection):
|
|||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||||
) -> Table:
|
) -> Table:
|
||||||
|
"""Create a [Table][lancedb.table.Table] in the database.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name: str
|
||||||
|
The name of the table.
|
||||||
|
data: The data to initialize the table, *optional*
|
||||||
|
User must provide at least one of `data` or `schema`.
|
||||||
|
Acceptable types are:
|
||||||
|
|
||||||
|
- dict or list-of-dict
|
||||||
|
|
||||||
|
- pandas.DataFrame
|
||||||
|
|
||||||
|
- pyarrow.Table or pyarrow.RecordBatch
|
||||||
|
schema: The schema of the table, *optional*
|
||||||
|
Acceptable types are:
|
||||||
|
|
||||||
|
- pyarrow.Schema
|
||||||
|
|
||||||
|
- [LanceModel][lancedb.pydantic.LanceModel]
|
||||||
|
on_bad_vectors: str, default "error"
|
||||||
|
What to do if any of the vectors are not the same size or contains NaNs.
|
||||||
|
One of "error", "drop", "fill".
|
||||||
|
fill_value: float
|
||||||
|
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
LanceTable
|
||||||
|
A reference to the newly created table.
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
|
||||||
|
The vector index won't be created by default.
|
||||||
|
To create the index, call the `create_index` method on the table.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
|
||||||
|
Can create with list of tuples or dictionaries:
|
||||||
|
|
||||||
|
>>> import lancedb
|
||||||
|
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
|
||||||
|
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
||||||
|
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
||||||
|
>>> db.create_table("my_table", data) # doctest: +SKIP
|
||||||
|
LanceTable(my_table)
|
||||||
|
|
||||||
|
You can also pass a pandas DataFrame:
|
||||||
|
|
||||||
|
>>> import pandas as pd
|
||||||
|
>>> data = pd.DataFrame({
|
||||||
|
... "vector": [[1.1, 1.2], [0.2, 1.8]],
|
||||||
|
... "lat": [45.5, 40.1],
|
||||||
|
... "long": [-122.7, -74.1]
|
||||||
|
... })
|
||||||
|
>>> db.create_table("table2", data) # doctest: +SKIP
|
||||||
|
LanceTable(table2)
|
||||||
|
|
||||||
|
>>> custom_schema = pa.schema([
|
||||||
|
... pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||||
|
... pa.field("lat", pa.float32()),
|
||||||
|
... pa.field("long", pa.float32())
|
||||||
|
... ])
|
||||||
|
>>> db.create_table("table3", data, schema = custom_schema) # doctest: +SKIP
|
||||||
|
LanceTable(table3)
|
||||||
|
|
||||||
|
It is also possible to create an table from `[Iterable[pa.RecordBatch]]`:
|
||||||
|
|
||||||
|
>>> import pyarrow as pa
|
||||||
|
>>> def make_batches():
|
||||||
|
... for i in range(5):
|
||||||
|
... yield pa.RecordBatch.from_arrays(
|
||||||
|
... [
|
||||||
|
... pa.array([[3.1, 4.1], [5.9, 26.5]],
|
||||||
|
... pa.list_(pa.float32(), 2)),
|
||||||
|
... pa.array(["foo", "bar"]),
|
||||||
|
... pa.array([10.0, 20.0]),
|
||||||
|
... ],
|
||||||
|
... ["vector", "item", "price"],
|
||||||
|
... )
|
||||||
|
>>> schema=pa.schema([
|
||||||
|
... pa.field("vector", pa.list_(pa.float32(), 2)),
|
||||||
|
... pa.field("item", pa.utf8()),
|
||||||
|
... pa.field("price", pa.float32()),
|
||||||
|
... ])
|
||||||
|
>>> db.create_table("table4", make_batches(), schema=schema) # doctest: +SKIP
|
||||||
|
LanceTable(table4)
|
||||||
|
|
||||||
|
"""
|
||||||
if data is None and schema is None:
|
if data is None and schema is None:
|
||||||
raise ValueError("Either data or schema must be provided.")
|
raise ValueError("Either data or schema must be provided.")
|
||||||
if embedding_functions is not None:
|
if embedding_functions is not None:
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from lance import json_to_schema
|
from lance import json_to_schema
|
||||||
@@ -22,6 +22,7 @@ from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
|||||||
|
|
||||||
from ..query import LanceVectorQueryBuilder
|
from ..query import LanceVectorQueryBuilder
|
||||||
from ..table import Query, Table, _sanitize_data
|
from ..table import Query, Table, _sanitize_data
|
||||||
|
from ..util import value_to_sql
|
||||||
from .arrow import to_ipc_binary
|
from .arrow import to_ipc_binary
|
||||||
from .client import ARROW_STREAM_CONTENT_TYPE
|
from .client import ARROW_STREAM_CONTENT_TYPE
|
||||||
from .db import RemoteDBConnection
|
from .db import RemoteDBConnection
|
||||||
@@ -37,7 +38,10 @@ class RemoteTable(Table):
|
|||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def schema(self) -> pa.Schema:
|
def schema(self) -> pa.Schema:
|
||||||
"""Return the schema of the table."""
|
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
|
||||||
|
of this Table
|
||||||
|
|
||||||
|
"""
|
||||||
resp = self._conn._loop.run_until_complete(
|
resp = self._conn._loop.run_until_complete(
|
||||||
self._conn._client.post(f"/v1/table/{self._name}/describe/")
|
self._conn._client.post(f"/v1/table/{self._name}/describe/")
|
||||||
)
|
)
|
||||||
@@ -53,24 +57,17 @@ class RemoteTable(Table):
|
|||||||
return resp["version"]
|
return resp["version"]
|
||||||
|
|
||||||
def to_arrow(self) -> pa.Table:
|
def to_arrow(self) -> pa.Table:
|
||||||
"""Return the table as an Arrow table."""
|
"""to_arrow() is not supported on the LanceDB cloud"""
|
||||||
raise NotImplementedError("to_arrow() is not supported on the LanceDB cloud")
|
raise NotImplementedError("to_arrow() is not supported on the LanceDB cloud")
|
||||||
|
|
||||||
def to_pandas(self):
|
def to_pandas(self):
|
||||||
"""Return the table as a Pandas DataFrame.
|
"""to_pandas() is not supported on the LanceDB cloud"""
|
||||||
|
|
||||||
Intercept `to_arrow()` for better error message.
|
|
||||||
"""
|
|
||||||
return NotImplementedError("to_pandas() is not supported on the LanceDB cloud")
|
return NotImplementedError("to_pandas() is not supported on the LanceDB cloud")
|
||||||
|
|
||||||
def create_index(
|
def create_index(
|
||||||
self,
|
self,
|
||||||
metric="L2",
|
metric="L2",
|
||||||
num_partitions=256,
|
|
||||||
num_sub_vectors=96,
|
|
||||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||||
replace: bool = True,
|
|
||||||
accelerator: Optional[str] = None,
|
|
||||||
index_cache_size: Optional[int] = None,
|
index_cache_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
"""Create an index on the table.
|
"""Create an index on the table.
|
||||||
@@ -81,39 +78,28 @@ class RemoteTable(Table):
|
|||||||
----------
|
----------
|
||||||
metric : str
|
metric : str
|
||||||
The metric to use for the index. Default is "L2".
|
The metric to use for the index. Default is "L2".
|
||||||
num_partitions : int
|
|
||||||
The number of partitions to use for the index. Default is 256.
|
|
||||||
num_sub_vectors : int
|
|
||||||
The number of sub-vectors to use for the index. Default is 96.
|
|
||||||
vector_column_name : str
|
vector_column_name : str
|
||||||
The name of the vector column. Default is "vector".
|
The name of the vector column. Default is "vector".
|
||||||
replace : bool
|
|
||||||
Whether to replace the existing index. Default is True.
|
|
||||||
accelerator : str, optional
|
|
||||||
If set, use the given accelerator to create the index.
|
|
||||||
Default is None. Currently not supported.
|
|
||||||
index_cache_size : int, optional
|
|
||||||
The size of the index cache in number of entries. Default value is 256.
|
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
import lancedb
|
>>> import lancedb
|
||||||
import uuid
|
>>> import uuid
|
||||||
from lancedb.schema import vector
|
>>> from lancedb.schema import vector
|
||||||
conn = lancedb.connect("db://...", api_key="...", region="...")
|
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
|
||||||
table_name = uuid.uuid4().hex
|
>>> table_name = uuid.uuid4().hex
|
||||||
schema = pa.schema(
|
>>> schema = pa.schema(
|
||||||
[
|
... [
|
||||||
pa.field("id", pa.uint32(), False),
|
... pa.field("id", pa.uint32(), False),
|
||||||
pa.field("vector", vector(128), False),
|
... pa.field("vector", vector(128), False),
|
||||||
pa.field("s", pa.string(), False),
|
... pa.field("s", pa.string(), False),
|
||||||
]
|
... ]
|
||||||
)
|
... )
|
||||||
table = conn.create_table(
|
>>> table = db.create_table( # doctest: +SKIP
|
||||||
table_name,
|
... table_name, # doctest: +SKIP
|
||||||
schema=schema,
|
... schema=schema, # doctest: +SKIP
|
||||||
)
|
... )
|
||||||
table.create_index()
|
>>> table.create_index("L2", "vector") # doctest: +SKIP
|
||||||
"""
|
"""
|
||||||
index_type = "vector"
|
index_type = "vector"
|
||||||
|
|
||||||
@@ -135,6 +121,28 @@ class RemoteTable(Table):
|
|||||||
on_bad_vectors: str = "error",
|
on_bad_vectors: str = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
) -> int:
|
) -> int:
|
||||||
|
"""Add more data to the [Table](Table). It has the same API signature as the OSS version.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data: DATA
|
||||||
|
The data to insert into the table. Acceptable types are:
|
||||||
|
|
||||||
|
- dict or list-of-dict
|
||||||
|
|
||||||
|
- pandas.DataFrame
|
||||||
|
|
||||||
|
- pyarrow.Table or pyarrow.RecordBatch
|
||||||
|
mode: str
|
||||||
|
The mode to use when writing the data. Valid values are
|
||||||
|
"append" and "overwrite".
|
||||||
|
on_bad_vectors: str, default "error"
|
||||||
|
What to do if any of the vectors are not the same size or contains NaNs.
|
||||||
|
One of "error", "drop", "fill".
|
||||||
|
fill_value: float, default 0.
|
||||||
|
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||||
|
|
||||||
|
"""
|
||||||
data = _sanitize_data(
|
data = _sanitize_data(
|
||||||
data,
|
data,
|
||||||
self.schema,
|
self.schema,
|
||||||
@@ -158,6 +166,58 @@ class RemoteTable(Table):
|
|||||||
def search(
|
def search(
|
||||||
self, query: Union[VEC, str], vector_column_name: str = VECTOR_COLUMN_NAME
|
self, query: Union[VEC, str], vector_column_name: str = VECTOR_COLUMN_NAME
|
||||||
) -> LanceVectorQueryBuilder:
|
) -> LanceVectorQueryBuilder:
|
||||||
|
"""Create a search query to find the nearest neighbors
|
||||||
|
of the given query vector. We currently support [vector search][search]
|
||||||
|
|
||||||
|
All query options are defined in [Query][lancedb.query.Query].
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> import lancedb
|
||||||
|
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
|
||||||
|
>>> data = [
|
||||||
|
... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]},
|
||||||
|
... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]},
|
||||||
|
... {"original_width": 3000, "caption": "test", "vector": [0.3, 6.2, 2.6]}
|
||||||
|
... ]
|
||||||
|
>>> table = db.create_table("my_table", data) # doctest: +SKIP
|
||||||
|
>>> query = [0.4, 1.4, 2.4]
|
||||||
|
>>> (table.search(query, vector_column_name="vector") # doctest: +SKIP
|
||||||
|
... .where("original_width > 1000", prefilter=True) # doctest: +SKIP
|
||||||
|
... .select(["caption", "original_width"]) # doctest: +SKIP
|
||||||
|
... .limit(2) # doctest: +SKIP
|
||||||
|
... .to_pandas()) # doctest: +SKIP
|
||||||
|
caption original_width vector _distance # doctest: +SKIP
|
||||||
|
0 foo 2000 [0.5, 3.4, 1.3] 5.220000 # doctest: +SKIP
|
||||||
|
1 test 3000 [0.3, 6.2, 2.6] 23.089996 # doctest: +SKIP
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
query: list/np.ndarray/str/PIL.Image.Image, default None
|
||||||
|
The targetted vector to search for.
|
||||||
|
|
||||||
|
- *default None*.
|
||||||
|
Acceptable types are: list, np.ndarray, PIL.Image.Image
|
||||||
|
|
||||||
|
- If None then the select/where/limit clauses are applied to filter
|
||||||
|
the table
|
||||||
|
vector_column_name: str
|
||||||
|
The name of the vector column to search.
|
||||||
|
*default "vector"*
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
LanceQueryBuilder
|
||||||
|
A query builder object representing the query.
|
||||||
|
Once executed, the query returns
|
||||||
|
|
||||||
|
- selected columns
|
||||||
|
|
||||||
|
- the vector
|
||||||
|
|
||||||
|
- and also the "_distance" column which is the distance between the query
|
||||||
|
vector and the returned vector.
|
||||||
|
"""
|
||||||
return LanceVectorQueryBuilder(self, query, vector_column_name)
|
return LanceVectorQueryBuilder(self, query, vector_column_name)
|
||||||
|
|
||||||
def _execute_query(self, query: Query) -> pa.Table:
|
def _execute_query(self, query: Query) -> pa.Table:
|
||||||
@@ -165,8 +225,114 @@ class RemoteTable(Table):
|
|||||||
return self._conn._loop.run_until_complete(result).to_arrow()
|
return self._conn._loop.run_until_complete(result).to_arrow()
|
||||||
|
|
||||||
def delete(self, predicate: str):
|
def delete(self, predicate: str):
|
||||||
"""Delete rows from the table."""
|
"""Delete rows from the table.
|
||||||
|
|
||||||
|
This can be used to delete a single row, many rows, all rows, or
|
||||||
|
sometimes no rows (if your predicate matches nothing).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
predicate: str
|
||||||
|
The SQL where clause to use when deleting rows.
|
||||||
|
|
||||||
|
- For example, 'x = 2' or 'x IN (1, 2, 3)'.
|
||||||
|
|
||||||
|
The filter must not be empty, or it will error.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> import lancedb
|
||||||
|
>>> data = [
|
||||||
|
... {"x": 1, "vector": [1, 2]},
|
||||||
|
... {"x": 2, "vector": [3, 4]},
|
||||||
|
... {"x": 3, "vector": [5, 6]}
|
||||||
|
... ]
|
||||||
|
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
|
||||||
|
>>> table = db.create_table("my_table", data) # doctest: +SKIP
|
||||||
|
>>> table.search([10,10]).to_pandas() # doctest: +SKIP
|
||||||
|
x vector _distance # doctest: +SKIP
|
||||||
|
0 3 [5.0, 6.0] 41.0 # doctest: +SKIP
|
||||||
|
1 2 [3.0, 4.0] 85.0 # doctest: +SKIP
|
||||||
|
2 1 [1.0, 2.0] 145.0 # doctest: +SKIP
|
||||||
|
>>> table.delete("x = 2") # doctest: +SKIP
|
||||||
|
>>> table.search([10,10]).to_pandas() # doctest: +SKIP
|
||||||
|
x vector _distance # doctest: +SKIP
|
||||||
|
0 3 [5.0, 6.0] 41.0 # doctest: +SKIP
|
||||||
|
1 1 [1.0, 2.0] 145.0 # doctest: +SKIP
|
||||||
|
|
||||||
|
If you have a list of values to delete, you can combine them into a
|
||||||
|
stringified list and use the `IN` operator:
|
||||||
|
|
||||||
|
>>> to_remove = [1, 3] # doctest: +SKIP
|
||||||
|
>>> to_remove = ", ".join([str(v) for v in to_remove]) # doctest: +SKIP
|
||||||
|
>>> table.delete(f"x IN ({to_remove})") # doctest: +SKIP
|
||||||
|
>>> table.search([10,10]).to_pandas() # doctest: +SKIP
|
||||||
|
x vector _distance # doctest: +SKIP
|
||||||
|
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP
|
||||||
|
"""
|
||||||
payload = {"predicate": predicate}
|
payload = {"predicate": predicate}
|
||||||
self._conn._loop.run_until_complete(
|
self._conn._loop.run_until_complete(
|
||||||
self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload)
|
self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
where: Optional[str] = None,
|
||||||
|
values: Optional[dict] = None,
|
||||||
|
*,
|
||||||
|
values_sql: Optional[Dict[str, str]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This can be used to update zero to all rows depending on how many
|
||||||
|
rows match the where clause.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
where: str, optional
|
||||||
|
The SQL where clause to use when updating rows. For example, 'x = 2'
|
||||||
|
or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
|
||||||
|
values: dict, optional
|
||||||
|
The values to update. The keys are the column names and the values
|
||||||
|
are the values to set.
|
||||||
|
values_sql: dict, optional
|
||||||
|
The values to update, expressed as SQL expression strings. These can
|
||||||
|
reference existing columns. For example, {"x": "x + 1"} will increment
|
||||||
|
the x column by 1.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> import lancedb
|
||||||
|
>>> data = [
|
||||||
|
... {"x": 1, "vector": [1, 2]},
|
||||||
|
... {"x": 2, "vector": [3, 4]},
|
||||||
|
... {"x": 3, "vector": [5, 6]}
|
||||||
|
... ]
|
||||||
|
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
|
||||||
|
>>> table = db.create_table("my_table", data) # doctest: +SKIP
|
||||||
|
>>> table.to_pandas() # doctest: +SKIP
|
||||||
|
x vector # doctest: +SKIP
|
||||||
|
0 1 [1.0, 2.0] # doctest: +SKIP
|
||||||
|
1 2 [3.0, 4.0] # doctest: +SKIP
|
||||||
|
2 3 [5.0, 6.0] # doctest: +SKIP
|
||||||
|
>>> table.update(where="x = 2", values={"vector": [10, 10]}) # doctest: +SKIP
|
||||||
|
>>> table.to_pandas() # doctest: +SKIP
|
||||||
|
x vector # doctest: +SKIP
|
||||||
|
0 1 [1.0, 2.0] # doctest: +SKIP
|
||||||
|
1 3 [5.0, 6.0] # doctest: +SKIP
|
||||||
|
2 2 [10.0, 10.0] # doctest: +SKIP
|
||||||
|
|
||||||
|
"""
|
||||||
|
if values is not None and values_sql is not None:
|
||||||
|
raise ValueError("Only one of values or values_sql can be provided")
|
||||||
|
if values is None and values_sql is None:
|
||||||
|
raise ValueError("Either values or values_sql must be provided")
|
||||||
|
|
||||||
|
if values is not None:
|
||||||
|
updates = [[k, value_to_sql(v)] for k, v in values.items()]
|
||||||
|
else:
|
||||||
|
updates = [[k, v] for k, v in values_sql.items()]
|
||||||
|
|
||||||
|
payload = {"predicate": where, "updates": updates}
|
||||||
|
self._conn._loop.run_until_complete(
|
||||||
|
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
|
||||||
|
)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import inspect
|
|||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
|
||||||
|
|
||||||
import lance
|
import lance
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -28,9 +28,9 @@ from lance.vector import vec_to_table
|
|||||||
|
|
||||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||||
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||||
from .pydantic import LanceModel
|
from .pydantic import LanceModel, model_to_dict
|
||||||
from .query import LanceQueryBuilder, Query
|
from .query import LanceQueryBuilder, Query
|
||||||
from .util import fs_from_uri, safe_import_pandas
|
from .util import fs_from_uri, safe_import_pandas, value_to_sql
|
||||||
from .utils.events import register_event
|
from .utils.events import register_event
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -53,8 +53,10 @@ def _sanitize_data(
|
|||||||
# convert to list of dict if data is a bunch of LanceModels
|
# convert to list of dict if data is a bunch of LanceModels
|
||||||
if isinstance(data[0], LanceModel):
|
if isinstance(data[0], LanceModel):
|
||||||
schema = data[0].__class__.to_arrow_schema()
|
schema = data[0].__class__.to_arrow_schema()
|
||||||
data = [dict(d) for d in data]
|
data = [model_to_dict(d) for d in data]
|
||||||
data = pa.Table.from_pylist(data)
|
data = pa.Table.from_pylist(data, schema=schema)
|
||||||
|
else:
|
||||||
|
data = pa.Table.from_pylist(data)
|
||||||
elif isinstance(data, dict):
|
elif isinstance(data, dict):
|
||||||
data = vec_to_table(data)
|
data = vec_to_table(data)
|
||||||
elif pd is not None and isinstance(data, pd.DataFrame):
|
elif pd is not None and isinstance(data, pd.DataFrame):
|
||||||
@@ -785,7 +787,7 @@ class LanceTable(Table):
|
|||||||
and also the "_distance" column which is the distance between the query
|
and also the "_distance" column which is the distance between the query
|
||||||
vector and the returned vector.
|
vector and the returned vector.
|
||||||
"""
|
"""
|
||||||
register_event("search")
|
register_event("search_table")
|
||||||
return LanceQueryBuilder.create(
|
return LanceQueryBuilder.create(
|
||||||
self, query, query_type, vector_column_name=vector_column_name
|
self, query, query_type, vector_column_name=vector_column_name
|
||||||
)
|
)
|
||||||
@@ -906,35 +908,42 @@ class LanceTable(Table):
|
|||||||
f"Table {name} does not exist."
|
f"Table {name} does not exist."
|
||||||
f"Please first call db.create_table({name}, data)"
|
f"Please first call db.create_table({name}, data)"
|
||||||
)
|
)
|
||||||
|
register_event("open_table")
|
||||||
|
|
||||||
return tbl
|
return tbl
|
||||||
|
|
||||||
def delete(self, where: str):
|
def delete(self, where: str):
|
||||||
self._dataset.delete(where)
|
self._dataset.delete(where)
|
||||||
|
|
||||||
def update(self, where: str, values: dict):
|
def update(
|
||||||
|
self,
|
||||||
|
where: Optional[str] = None,
|
||||||
|
values: Optional[dict] = None,
|
||||||
|
*,
|
||||||
|
values_sql: Optional[Dict[str, str]] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
EXPERIMENTAL: Update rows in the table (not threadsafe).
|
|
||||||
|
|
||||||
This can be used to update zero to all rows depending on how many
|
This can be used to update zero to all rows depending on how many
|
||||||
rows match the where clause.
|
rows match the where clause.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
where: str
|
where: str, optional
|
||||||
The SQL where clause to use when updating rows. For example, 'x = 2'
|
The SQL where clause to use when updating rows. For example, 'x = 2'
|
||||||
or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
|
or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
|
||||||
values: dict
|
values: dict, optional
|
||||||
The values to update. The keys are the column names and the values
|
The values to update. The keys are the column names and the values
|
||||||
are the values to set.
|
are the values to set.
|
||||||
|
values_sql: dict, optional
|
||||||
|
The values to update, expressed as SQL expression strings. These can
|
||||||
|
reference existing columns. For example, {"x": "x + 1"} will increment
|
||||||
|
the x column by 1.
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
>>> import lancedb
|
>>> import lancedb
|
||||||
>>> data = [
|
>>> import pandas as pd
|
||||||
... {"x": 1, "vector": [1, 2]},
|
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
|
||||||
... {"x": 2, "vector": [3, 4]},
|
|
||||||
... {"x": 3, "vector": [5, 6]}
|
|
||||||
... ]
|
|
||||||
>>> db = lancedb.connect("./.lancedb")
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
>>> table = db.create_table("my_table", data)
|
>>> table = db.create_table("my_table", data)
|
||||||
>>> table.to_pandas()
|
>>> table.to_pandas()
|
||||||
@@ -950,18 +959,15 @@ class LanceTable(Table):
|
|||||||
2 2 [10.0, 10.0]
|
2 2 [10.0, 10.0]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
orig_data = self._dataset.to_table(filter=where).combine_chunks()
|
if values is not None and values_sql is not None:
|
||||||
if len(orig_data) == 0:
|
raise ValueError("Only one of values or values_sql can be provided")
|
||||||
return
|
if values is None and values_sql is None:
|
||||||
for col, val in values.items():
|
raise ValueError("Either values or values_sql must be provided")
|
||||||
i = orig_data.column_names.index(col)
|
|
||||||
if i < 0:
|
if values is not None:
|
||||||
raise ValueError(f"Column {col} does not exist")
|
values_sql = {k: value_to_sql(v) for k, v in values.items()}
|
||||||
orig_data = orig_data.set_column(
|
|
||||||
i, col, pa.array([val] * len(orig_data), type=orig_data[col].type)
|
self.to_lance().update(values_sql, where)
|
||||||
)
|
|
||||||
self.delete(where)
|
|
||||||
self.add(orig_data, mode="append")
|
|
||||||
self._reset_dataset()
|
self._reset_dataset()
|
||||||
register_event("update")
|
register_event("update")
|
||||||
|
|
||||||
|
|||||||
@@ -12,9 +12,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from datetime import date, datetime
|
||||||
|
from functools import singledispatch
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pyarrow.fs as pa_fs
|
import pyarrow.fs as pa_fs
|
||||||
|
|
||||||
|
|
||||||
@@ -88,3 +91,53 @@ def safe_import_pandas():
|
|||||||
return pd
|
return pd
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@singledispatch
|
||||||
|
def value_to_sql(value):
|
||||||
|
raise NotImplementedError("SQL conversion is not implemented for this type")
|
||||||
|
|
||||||
|
|
||||||
|
@value_to_sql.register(str)
|
||||||
|
def _(value: str):
|
||||||
|
return f"'{value}'"
|
||||||
|
|
||||||
|
|
||||||
|
@value_to_sql.register(int)
|
||||||
|
def _(value: int):
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
|
@value_to_sql.register(float)
|
||||||
|
def _(value: float):
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
|
@value_to_sql.register(bool)
|
||||||
|
def _(value: bool):
|
||||||
|
return str(value).upper()
|
||||||
|
|
||||||
|
|
||||||
|
@value_to_sql.register(type(None))
|
||||||
|
def _(value: type(None)):
|
||||||
|
return "NULL"
|
||||||
|
|
||||||
|
|
||||||
|
@value_to_sql.register(datetime)
|
||||||
|
def _(value: datetime):
|
||||||
|
return f"'{value.isoformat()}'"
|
||||||
|
|
||||||
|
|
||||||
|
@value_to_sql.register(date)
|
||||||
|
def _(value: date):
|
||||||
|
return f"'{value.isoformat()}'"
|
||||||
|
|
||||||
|
|
||||||
|
@value_to_sql.register(list)
|
||||||
|
def _(value: list):
|
||||||
|
return "[" + ", ".join(map(value_to_sql, value)) + "]"
|
||||||
|
|
||||||
|
|
||||||
|
@value_to_sql.register(np.ndarray)
|
||||||
|
def _(value: np.ndarray):
|
||||||
|
return value_to_sql(value.tolist())
|
||||||
|
|||||||
@@ -64,8 +64,10 @@ class _Events:
|
|||||||
Initializes the Events object with default values for events, rate_limit, and metadata.
|
Initializes the Events object with default values for events, rate_limit, and metadata.
|
||||||
"""
|
"""
|
||||||
self.events = [] # events list
|
self.events = [] # events list
|
||||||
self.max_events = 25 # max events to store in memory
|
self.throttled_event_names = ["search_table"]
|
||||||
self.rate_limit = 60.0 # rate limit (seconds)
|
self.throttled_events = set()
|
||||||
|
self.max_events = 5 # max events to store in memory
|
||||||
|
self.rate_limit = 60.0 * 5 # rate limit (seconds)
|
||||||
self.time = 0.0
|
self.time = 0.0
|
||||||
|
|
||||||
if is_git_dir():
|
if is_git_dir():
|
||||||
@@ -112,18 +114,21 @@ class _Events:
|
|||||||
return
|
return
|
||||||
if (
|
if (
|
||||||
len(self.events) < self.max_events
|
len(self.events) < self.max_events
|
||||||
): # Events list limited to 25 events (drop any events past this)
|
): # Events list limited to self.max_events (drop any events past this)
|
||||||
params.update(self.metadata)
|
params.update(self.metadata)
|
||||||
self.events.append(
|
event = {
|
||||||
{
|
"event": event_name,
|
||||||
"event": event_name,
|
"properties": params,
|
||||||
"properties": params,
|
"timestamp": datetime.datetime.now(
|
||||||
"timestamp": datetime.datetime.now(
|
tz=datetime.timezone.utc
|
||||||
tz=datetime.timezone.utc
|
).isoformat(),
|
||||||
).isoformat(),
|
"distinct_id": CONFIG["uuid"],
|
||||||
"distinct_id": CONFIG["uuid"],
|
}
|
||||||
}
|
if event_name not in self.throttled_event_names:
|
||||||
)
|
self.events.append(event)
|
||||||
|
elif event_name not in self.throttled_events:
|
||||||
|
self.throttled_events.add(event_name)
|
||||||
|
self.events.append(event)
|
||||||
|
|
||||||
# Check rate limit
|
# Check rate limit
|
||||||
t = time.time()
|
t = time.time()
|
||||||
@@ -135,7 +140,6 @@ class _Events:
|
|||||||
"distinct_id": CONFIG["uuid"], # posthog needs this to accepts the event
|
"distinct_id": CONFIG["uuid"], # posthog needs this to accepts the event
|
||||||
"batch": self.events,
|
"batch": self.events,
|
||||||
}
|
}
|
||||||
|
|
||||||
# POST equivalent to requests.post(self.url, json=data).
|
# POST equivalent to requests.post(self.url, json=data).
|
||||||
# threaded request is used to avoid blocking, retries are disabled, and verbose is disabled
|
# threaded request is used to avoid blocking, retries are disabled, and verbose is disabled
|
||||||
# to avoid any possible disruption in the console.
|
# to avoid any possible disruption in the console.
|
||||||
@@ -150,6 +154,7 @@ class _Events:
|
|||||||
|
|
||||||
# Flush & Reset
|
# Flush & Reset
|
||||||
self.events = []
|
self.events = []
|
||||||
|
self.throttled_events = set()
|
||||||
self.time = t
|
self.time = t
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.3.4"
|
version = "0.3.6"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"deprecation",
|
"deprecation",
|
||||||
"pylance==0.8.17",
|
"pylance==0.8.21",
|
||||||
"ratelimiter~=1.0",
|
"ratelimiter~=1.0",
|
||||||
"retry>=0.9.2",
|
"retry>=0.9.2",
|
||||||
"tqdm>=4.1.0",
|
"tqdm>=4.27.0",
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
"pydantic>=1.10",
|
"pydantic>=1.10",
|
||||||
"attrs>=21.3.0",
|
"attrs>=21.3.0",
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
from datetime import timedelta
|
from datetime import date, datetime, timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
from unittest.mock import PropertyMock, patch
|
from unittest.mock import PropertyMock, patch
|
||||||
@@ -21,6 +21,7 @@ import lance
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
from pydantic import BaseModel
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from lancedb.conftest import MockTextEmbeddingFunction
|
from lancedb.conftest import MockTextEmbeddingFunction
|
||||||
@@ -141,14 +142,32 @@ def test_add(db):
|
|||||||
|
|
||||||
|
|
||||||
def test_add_pydantic_model(db):
|
def test_add_pydantic_model(db):
|
||||||
class TestModel(LanceModel):
|
# https://github.com/lancedb/lancedb/issues/562
|
||||||
vector: Vector(16)
|
|
||||||
li: List[int]
|
|
||||||
|
|
||||||
data = TestModel(vector=list(range(16)), li=[1, 2, 3])
|
class Document(BaseModel):
|
||||||
table = LanceTable.create(db, "test", data=[data])
|
content: str
|
||||||
assert len(table) == 1
|
source: str
|
||||||
assert table.schema == TestModel.to_arrow_schema()
|
|
||||||
|
class LanceSchema(LanceModel):
|
||||||
|
id: str
|
||||||
|
vector: Vector(2)
|
||||||
|
li: List[int]
|
||||||
|
payload: Document
|
||||||
|
|
||||||
|
tbl = LanceTable.create(db, "mytable", schema=LanceSchema, mode="overwrite")
|
||||||
|
assert tbl.schema == LanceSchema.to_arrow_schema()
|
||||||
|
|
||||||
|
# add works
|
||||||
|
expected = LanceSchema(
|
||||||
|
id="id",
|
||||||
|
vector=[0.0, 0.0],
|
||||||
|
li=[1, 2, 3],
|
||||||
|
payload=Document(content="foo", source="bar"),
|
||||||
|
)
|
||||||
|
tbl.add([expected])
|
||||||
|
|
||||||
|
result = tbl.search([0.0, 0.0]).limit(1).to_pydantic(LanceSchema)[0]
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
def _add(table, schema):
|
def _add(table, schema):
|
||||||
@@ -348,14 +367,79 @@ def test_update(db):
|
|||||||
assert len(table) == 2
|
assert len(table) == 2
|
||||||
assert len(table.list_versions()) == 2
|
assert len(table.list_versions()) == 2
|
||||||
table.update(where="id=0", values={"vector": [1.1, 1.1]})
|
table.update(where="id=0", values={"vector": [1.1, 1.1]})
|
||||||
assert len(table.list_versions()) == 4
|
assert len(table.list_versions()) == 3
|
||||||
assert table.version == 4
|
assert table.version == 3
|
||||||
assert len(table) == 2
|
assert len(table) == 2
|
||||||
v = table.to_arrow()["vector"].combine_chunks()
|
v = table.to_arrow()["vector"].combine_chunks()
|
||||||
v = v.values.to_numpy().reshape(2, 2)
|
v = v.values.to_numpy().reshape(2, 2)
|
||||||
assert np.allclose(v, np.array([[1.2, 1.9], [1.1, 1.1]]))
|
assert np.allclose(v, np.array([[1.2, 1.9], [1.1, 1.1]]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_types(db):
|
||||||
|
table = LanceTable.create(
|
||||||
|
db,
|
||||||
|
"my_table",
|
||||||
|
data=[
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"str": "foo",
|
||||||
|
"float": 1.1,
|
||||||
|
"timestamp": datetime(2021, 1, 1),
|
||||||
|
"date": date(2021, 1, 1),
|
||||||
|
"vector1": [1.0, 0.0],
|
||||||
|
"vector2": [1.0, 1.0],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# Update with SQL
|
||||||
|
table.update(
|
||||||
|
values_sql=dict(
|
||||||
|
id="1",
|
||||||
|
str="'bar'",
|
||||||
|
float="2.2",
|
||||||
|
timestamp="TIMESTAMP '2021-01-02 00:00:00'",
|
||||||
|
date="DATE '2021-01-02'",
|
||||||
|
vector1="[2.0, 2.0]",
|
||||||
|
vector2="[3.0, 3.0]",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
actual = table.to_arrow().to_pylist()[0]
|
||||||
|
expected = dict(
|
||||||
|
id=1,
|
||||||
|
str="bar",
|
||||||
|
float=2.2,
|
||||||
|
timestamp=datetime(2021, 1, 2),
|
||||||
|
date=date(2021, 1, 2),
|
||||||
|
vector1=[2.0, 2.0],
|
||||||
|
vector2=[3.0, 3.0],
|
||||||
|
)
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
# Update with values
|
||||||
|
table.update(
|
||||||
|
values=dict(
|
||||||
|
id=2,
|
||||||
|
str="baz",
|
||||||
|
float=3.3,
|
||||||
|
timestamp=datetime(2021, 1, 3),
|
||||||
|
date=date(2021, 1, 3),
|
||||||
|
vector1=[3.0, 3.0],
|
||||||
|
vector2=np.array([4.0, 4.0]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
actual = table.to_arrow().to_pylist()[0]
|
||||||
|
expected = dict(
|
||||||
|
id=2,
|
||||||
|
str="baz",
|
||||||
|
float=3.3,
|
||||||
|
timestamp=datetime(2021, 1, 3),
|
||||||
|
date=date(2021, 1, 3),
|
||||||
|
vector1=[3.0, 3.0],
|
||||||
|
vector2=[4.0, 4.0],
|
||||||
|
)
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
def test_create_with_embedding_function(db):
|
def test_create_with_embedding_function(db):
|
||||||
class MyTable(LanceModel):
|
class MyTable(LanceModel):
|
||||||
text: str
|
text: str
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "vectordb-node"
|
name = "vectordb-node"
|
||||||
version = "0.3.9"
|
version = "0.3.11"
|
||||||
description = "Serverless, low-latency vector database for AI applications"
|
description = "Serverless, low-latency vector database for AI applications"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
edition = "2018"
|
edition = "2018"
|
||||||
|
|||||||
@@ -237,6 +237,7 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> {
|
|||||||
cx.export_function("tableAdd", JsTable::js_add)?;
|
cx.export_function("tableAdd", JsTable::js_add)?;
|
||||||
cx.export_function("tableCountRows", JsTable::js_count_rows)?;
|
cx.export_function("tableCountRows", JsTable::js_count_rows)?;
|
||||||
cx.export_function("tableDelete", JsTable::js_delete)?;
|
cx.export_function("tableDelete", JsTable::js_delete)?;
|
||||||
|
cx.export_function("tableUpdate", JsTable::js_update)?;
|
||||||
cx.export_function("tableCleanupOldVersions", JsTable::js_cleanup)?;
|
cx.export_function("tableCleanupOldVersions", JsTable::js_cleanup)?;
|
||||||
cx.export_function("tableCompactFiles", JsTable::js_compact)?;
|
cx.export_function("tableCompactFiles", JsTable::js_compact)?;
|
||||||
cx.export_function("tableListIndices", JsTable::js_list_indices)?;
|
cx.export_function("tableListIndices", JsTable::js_list_indices)?;
|
||||||
|
|||||||
@@ -23,8 +23,14 @@ impl JsQuery {
|
|||||||
let query_obj = cx.argument::<JsObject>(0)?;
|
let query_obj = cx.argument::<JsObject>(0)?;
|
||||||
|
|
||||||
let limit = query_obj
|
let limit = query_obj
|
||||||
.get::<JsNumber, _, _>(&mut cx, "_limit")?
|
.get_opt::<JsNumber, _, _>(&mut cx, "_limit")?
|
||||||
.value(&mut cx);
|
.map(|value| {
|
||||||
|
let limit = value.value(&mut cx) as u64;
|
||||||
|
if limit <= 0 {
|
||||||
|
panic!("Limit must be a positive integer");
|
||||||
|
}
|
||||||
|
limit
|
||||||
|
});
|
||||||
let select = query_obj
|
let select = query_obj
|
||||||
.get_opt::<JsArray, _, _>(&mut cx, "_select")?
|
.get_opt::<JsArray, _, _>(&mut cx, "_select")?
|
||||||
.map(|arr| {
|
.map(|arr| {
|
||||||
@@ -48,7 +54,9 @@ impl JsQuery {
|
|||||||
.map(|s| s.value(&mut cx))
|
.map(|s| s.value(&mut cx))
|
||||||
.map(|s| MetricType::try_from(s.as_str()).unwrap());
|
.map(|s| MetricType::try_from(s.as_str()).unwrap());
|
||||||
|
|
||||||
let prefilter = query_obj.get::<JsBoolean, _, _>(&mut cx, "_prefilter")?.value(&mut cx);
|
let prefilter = query_obj
|
||||||
|
.get::<JsBoolean, _, _>(&mut cx, "_prefilter")?
|
||||||
|
.value(&mut cx);
|
||||||
|
|
||||||
let is_electron = cx
|
let is_electron = cx
|
||||||
.argument::<JsBoolean>(1)
|
.argument::<JsBoolean>(1)
|
||||||
@@ -59,20 +67,23 @@ impl JsQuery {
|
|||||||
|
|
||||||
let (deferred, promise) = cx.promise();
|
let (deferred, promise) = cx.promise();
|
||||||
let channel = cx.channel();
|
let channel = cx.channel();
|
||||||
let query_vector = query_obj.get::<JsArray, _, _>(&mut cx, "_queryVector")?;
|
let query_vector = query_obj.get_opt::<JsArray, _, _>(&mut cx, "_queryVector")?;
|
||||||
let query = convert::js_array_to_vec(query_vector.deref(), &mut cx);
|
|
||||||
let table = js_table.table.clone();
|
let table = js_table.table.clone();
|
||||||
|
let query = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx));
|
||||||
|
|
||||||
rt.spawn(async move {
|
rt.spawn(async move {
|
||||||
let builder = table
|
let mut builder = table
|
||||||
.search(Float32Array::from(query))
|
.search(query.map(|q| Float32Array::from(q)))
|
||||||
.limit(limit as usize)
|
|
||||||
.refine_factor(refine_factor)
|
.refine_factor(refine_factor)
|
||||||
.nprobes(nprobes)
|
.nprobes(nprobes)
|
||||||
.filter(filter)
|
.filter(filter)
|
||||||
.metric_type(metric_type)
|
.metric_type(metric_type)
|
||||||
.select(select)
|
.select(select)
|
||||||
.prefilter(prefilter);
|
.prefilter(prefilter);
|
||||||
|
if let Some(limit) = limit {
|
||||||
|
builder = builder.limit(limit as usize);
|
||||||
|
};
|
||||||
|
|
||||||
let record_batch_stream = builder.execute();
|
let record_batch_stream = builder.execute();
|
||||||
let results = record_batch_stream
|
let results = record_batch_stream
|
||||||
.and_then(|stream| {
|
.and_then(|stream| {
|
||||||
|
|||||||
@@ -165,6 +165,69 @@ impl JsTable {
|
|||||||
Ok(promise)
|
Ok(promise)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn js_update(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
|
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||||
|
let mut table = js_table.table.clone();
|
||||||
|
|
||||||
|
let rt = runtime(&mut cx)?;
|
||||||
|
let (deferred, promise) = cx.promise();
|
||||||
|
let channel = cx.channel();
|
||||||
|
|
||||||
|
// create a vector of updates from the passed map
|
||||||
|
let updates_arg = cx.argument::<JsObject>(1)?;
|
||||||
|
let properties = updates_arg.get_own_property_names(&mut cx)?;
|
||||||
|
let mut updates: Vec<(String, String)> =
|
||||||
|
Vec::with_capacity(properties.len(&mut cx) as usize);
|
||||||
|
|
||||||
|
let len_properties = properties.len(&mut cx);
|
||||||
|
for i in 0..len_properties {
|
||||||
|
let property = properties
|
||||||
|
.get_value(&mut cx, i)?
|
||||||
|
.downcast_or_throw::<JsString, _>(&mut cx)?;
|
||||||
|
|
||||||
|
let value = updates_arg
|
||||||
|
.get_value(&mut cx, property.clone())?
|
||||||
|
.downcast_or_throw::<JsString, _>(&mut cx)?;
|
||||||
|
|
||||||
|
let property = property.value(&mut cx);
|
||||||
|
let value = value.value(&mut cx);
|
||||||
|
updates.push((property, value));
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the filter/predicate if the user passed one
|
||||||
|
let predicate = cx.argument_opt(0);
|
||||||
|
let predicate = predicate.unwrap().downcast::<JsString, _>(&mut cx);
|
||||||
|
let predicate = match predicate {
|
||||||
|
Ok(_) => {
|
||||||
|
let val = predicate.map(|s| s.value(&mut cx)).unwrap();
|
||||||
|
Some(val)
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
// if the predicate is not string, check it's null otherwise an invalid
|
||||||
|
// type was passed
|
||||||
|
cx.argument::<JsNull>(0)?;
|
||||||
|
None
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
rt.spawn(async move {
|
||||||
|
let updates_arg = updates
|
||||||
|
.iter()
|
||||||
|
.map(|(k, v)| (k.as_str(), v.as_str()))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let predicate = predicate.as_ref().map(|s| s.as_str());
|
||||||
|
|
||||||
|
let update_result = table.update(predicate, updates_arg).await;
|
||||||
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
|
update_result.or_throw(&mut cx)?;
|
||||||
|
Ok(cx.boxed(JsTable::from(table)))
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(promise)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn js_cleanup(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub(crate) fn js_cleanup(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||||
let rt = runtime(&mut cx)?;
|
let rt = runtime(&mut cx)?;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "vectordb"
|
name = "vectordb"
|
||||||
version = "0.3.9"
|
version = "0.3.11"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
|
|||||||
@@ -359,7 +359,9 @@ mod test {
|
|||||||
assert_eq!(t.count_rows().await.unwrap(), 100);
|
assert_eq!(t.count_rows().await.unwrap(), 100);
|
||||||
|
|
||||||
let q = t
|
let q = t
|
||||||
.search(PrimitiveArray::from_iter_values(vec![0.1, 0.1, 0.1, 0.1]))
|
.search(Some(PrimitiveArray::from_iter_values(vec![
|
||||||
|
0.1, 0.1, 0.1, 0.1,
|
||||||
|
])))
|
||||||
.limit(10)
|
.limit(10)
|
||||||
.execute()
|
.execute()
|
||||||
.await
|
.await
|
||||||
|
|||||||
@@ -24,8 +24,9 @@ use crate::error::Result;
|
|||||||
/// A builder for nearest neighbor queries for LanceDB.
|
/// A builder for nearest neighbor queries for LanceDB.
|
||||||
pub struct Query {
|
pub struct Query {
|
||||||
pub dataset: Arc<Dataset>,
|
pub dataset: Arc<Dataset>,
|
||||||
pub query_vector: Float32Array,
|
pub query_vector: Option<Float32Array>,
|
||||||
pub limit: usize,
|
pub column: String,
|
||||||
|
pub limit: Option<usize>,
|
||||||
pub filter: Option<String>,
|
pub filter: Option<String>,
|
||||||
pub select: Option<Vec<String>>,
|
pub select: Option<Vec<String>>,
|
||||||
pub nprobes: usize,
|
pub nprobes: usize,
|
||||||
@@ -46,11 +47,12 @@ impl Query {
|
|||||||
/// # Returns
|
/// # Returns
|
||||||
///
|
///
|
||||||
/// * A [Query] object.
|
/// * A [Query] object.
|
||||||
pub(crate) fn new(dataset: Arc<Dataset>, vector: Float32Array) -> Self {
|
pub(crate) fn new(dataset: Arc<Dataset>, vector: Option<Float32Array>) -> Self {
|
||||||
Query {
|
Query {
|
||||||
dataset,
|
dataset,
|
||||||
query_vector: vector,
|
query_vector: vector,
|
||||||
limit: 10,
|
column: crate::table::VECTOR_COLUMN_NAME.to_string(),
|
||||||
|
limit: None,
|
||||||
nprobes: 20,
|
nprobes: 20,
|
||||||
refine_factor: None,
|
refine_factor: None,
|
||||||
metric_type: None,
|
metric_type: None,
|
||||||
@@ -69,11 +71,13 @@ impl Query {
|
|||||||
pub async fn execute(&self) -> Result<DatasetRecordBatchStream> {
|
pub async fn execute(&self) -> Result<DatasetRecordBatchStream> {
|
||||||
let mut scanner: Scanner = self.dataset.scan();
|
let mut scanner: Scanner = self.dataset.scan();
|
||||||
|
|
||||||
scanner.nearest(
|
if let Some(query) = self.query_vector.as_ref() {
|
||||||
crate::table::VECTOR_COLUMN_NAME,
|
// If there is a vector query, default to limit=10 if unspecified
|
||||||
&self.query_vector,
|
scanner.nearest(&self.column, query, self.limit.unwrap_or(10))?;
|
||||||
self.limit,
|
} else {
|
||||||
)?;
|
// If there is no vector query, it's ok to not have a limit
|
||||||
|
scanner.limit(self.limit.map(|limit| limit as i64), None)?;
|
||||||
|
}
|
||||||
scanner.nprobs(self.nprobes);
|
scanner.nprobs(self.nprobes);
|
||||||
scanner.use_index(self.use_index);
|
scanner.use_index(self.use_index);
|
||||||
scanner.prefilter(self.prefilter);
|
scanner.prefilter(self.prefilter);
|
||||||
@@ -85,13 +89,23 @@ impl Query {
|
|||||||
Ok(scanner.try_into_stream().await?)
|
Ok(scanner.try_into_stream().await?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set the column to query
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `column` - The column name
|
||||||
|
pub fn column(mut self, column: &str) -> Query {
|
||||||
|
self.column = column.into();
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// Set the maximum number of results to return.
|
/// Set the maximum number of results to return.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// * `limit` - The maximum number of results to return.
|
/// * `limit` - The maximum number of results to return.
|
||||||
pub fn limit(mut self, limit: usize) -> Query {
|
pub fn limit(mut self, limit: usize) -> Query {
|
||||||
self.limit = limit;
|
self.limit = Some(limit);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,7 +115,7 @@ impl Query {
|
|||||||
///
|
///
|
||||||
/// * `vector` - The vector that will be used for search.
|
/// * `vector` - The vector that will be used for search.
|
||||||
pub fn query_vector(mut self, query_vector: Float32Array) -> Query {
|
pub fn query_vector(mut self, query_vector: Float32Array) -> Query {
|
||||||
self.query_vector = query_vector;
|
self.query_vector = Some(query_vector);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -174,7 +188,10 @@ mod tests {
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use arrow_array::{Float32Array, RecordBatch, RecordBatchIterator, RecordBatchReader};
|
use arrow_array::{
|
||||||
|
cast::AsArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
|
||||||
|
RecordBatchReader,
|
||||||
|
};
|
||||||
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
|
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use lance::dataset::Dataset;
|
use lance::dataset::Dataset;
|
||||||
@@ -187,7 +204,7 @@ mod tests {
|
|||||||
let batches = make_test_batches();
|
let batches = make_test_batches();
|
||||||
let ds = Dataset::write(batches, "memory://foo", None).await.unwrap();
|
let ds = Dataset::write(batches, "memory://foo", None).await.unwrap();
|
||||||
|
|
||||||
let vector = Float32Array::from_iter_values([0.1, 0.2]);
|
let vector = Some(Float32Array::from_iter_values([0.1, 0.2]));
|
||||||
let query = Query::new(Arc::new(ds), vector.clone());
|
let query = Query::new(Arc::new(ds), vector.clone());
|
||||||
assert_eq!(query.query_vector, vector);
|
assert_eq!(query.query_vector, vector);
|
||||||
|
|
||||||
@@ -201,8 +218,8 @@ mod tests {
|
|||||||
.metric_type(Some(MetricType::Cosine))
|
.metric_type(Some(MetricType::Cosine))
|
||||||
.refine_factor(Some(999));
|
.refine_factor(Some(999));
|
||||||
|
|
||||||
assert_eq!(query.query_vector, new_vector);
|
assert_eq!(query.query_vector.unwrap(), new_vector);
|
||||||
assert_eq!(query.limit, 100);
|
assert_eq!(query.limit.unwrap(), 100);
|
||||||
assert_eq!(query.nprobes, 1000);
|
assert_eq!(query.nprobes, 1000);
|
||||||
assert_eq!(query.use_index, true);
|
assert_eq!(query.use_index, true);
|
||||||
assert_eq!(query.metric_type, Some(MetricType::Cosine));
|
assert_eq!(query.metric_type, Some(MetricType::Cosine));
|
||||||
@@ -214,7 +231,7 @@ mod tests {
|
|||||||
let batches = make_non_empty_batches();
|
let batches = make_non_empty_batches();
|
||||||
let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap());
|
let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap());
|
||||||
|
|
||||||
let vector = Float32Array::from_iter_values([0.1; 4]);
|
let vector = Some(Float32Array::from_iter_values([0.1; 4]));
|
||||||
|
|
||||||
let query = Query::new(ds.clone(), vector.clone());
|
let query = Query::new(ds.clone(), vector.clone());
|
||||||
let result = query
|
let result = query
|
||||||
@@ -244,6 +261,27 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_execute_no_vector() {
|
||||||
|
// test that it's ok to not specify a query vector (just filter / limit)
|
||||||
|
let batches = make_non_empty_batches();
|
||||||
|
let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap());
|
||||||
|
|
||||||
|
let query = Query::new(ds.clone(), None);
|
||||||
|
let result = query
|
||||||
|
.filter(Some("id % 2 == 0".to_string()))
|
||||||
|
.execute()
|
||||||
|
.await;
|
||||||
|
let mut stream = result.expect("should have result");
|
||||||
|
// should only have one batch
|
||||||
|
while let Some(batch) = stream.next().await {
|
||||||
|
let b = batch.expect("should be Ok");
|
||||||
|
// cast arr into Int32Array
|
||||||
|
let arr: &Int32Array = b["id"].as_primitive();
|
||||||
|
assert!(arr.iter().all(|x| x.unwrap() % 2 == 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn make_non_empty_batches() -> impl RecordBatchReader + Send + 'static {
|
fn make_non_empty_batches() -> impl RecordBatchReader + Send + 'static {
|
||||||
let vec = Box::new(RandomVector::new().named("vector".to_string()));
|
let vec = Box::new(RandomVector::new().named("vector".to_string()));
|
||||||
let id = Box::new(IncrementingInt32::new().named("id".to_string()));
|
let id = Box::new(IncrementingInt32::new().named("id".to_string()));
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ use lance::dataset::cleanup::RemovalStats;
|
|||||||
use lance::dataset::optimize::{
|
use lance::dataset::optimize::{
|
||||||
compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions,
|
compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions,
|
||||||
};
|
};
|
||||||
use lance::dataset::{Dataset, WriteParams};
|
use lance::dataset::{Dataset, UpdateBuilder, WriteParams};
|
||||||
use lance::index::DatasetIndexExt;
|
use lance::index::DatasetIndexExt;
|
||||||
use lance::io::object_store::WrappingObjectStore;
|
use lance::io::object_store::WrappingObjectStore;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
@@ -308,10 +308,14 @@ impl Table {
|
|||||||
/// # Returns
|
/// # Returns
|
||||||
///
|
///
|
||||||
/// * A [Query] object.
|
/// * A [Query] object.
|
||||||
pub fn search(&self, query_vector: Float32Array) -> Query {
|
pub fn search(&self, query_vector: Option<Float32Array>) -> Query {
|
||||||
Query::new(self.dataset.clone(), query_vector)
|
Query::new(self.dataset.clone(), query_vector)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn filter(&self, expr: String) -> Query {
|
||||||
|
Query::new(self.dataset.clone(), None).filter(Some(expr))
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the number of rows in this Table
|
/// Returns the number of rows in this Table
|
||||||
pub async fn count_rows(&self) -> Result<usize> {
|
pub async fn count_rows(&self) -> Result<usize> {
|
||||||
Ok(self.dataset.count_rows().await?)
|
Ok(self.dataset.count_rows().await?)
|
||||||
@@ -338,6 +342,27 @@ impl Table {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn update(
|
||||||
|
&mut self,
|
||||||
|
predicate: Option<&str>,
|
||||||
|
updates: Vec<(&str, &str)>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let mut builder = UpdateBuilder::new(self.dataset.clone());
|
||||||
|
if let Some(predicate) = predicate {
|
||||||
|
builder = builder.update_where(predicate)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (column, value) in updates {
|
||||||
|
builder = builder.set(column, value)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let operation = builder.build()?;
|
||||||
|
let new_ds = operation.execute().await?;
|
||||||
|
self.dataset = new_ds;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Remove old versions of the dataset from disk.
|
/// Remove old versions of the dataset from disk.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
@@ -413,11 +438,14 @@ mod tests {
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use arrow_array::{
|
use arrow_array::{
|
||||||
Array, FixedSizeListArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
|
Array, BooleanArray, Date32Array, FixedSizeListArray, Float32Array, Float64Array,
|
||||||
RecordBatchReader,
|
Int32Array, Int64Array, LargeStringArray, RecordBatch, RecordBatchIterator,
|
||||||
|
RecordBatchReader, StringArray, TimestampMillisecondArray, TimestampNanosecondArray,
|
||||||
|
UInt32Array,
|
||||||
};
|
};
|
||||||
use arrow_data::ArrayDataBuilder;
|
use arrow_data::ArrayDataBuilder;
|
||||||
use arrow_schema::{DataType, Field, Schema};
|
use arrow_schema::{DataType, Field, Schema, TimeUnit};
|
||||||
|
use futures::TryStreamExt;
|
||||||
use lance::dataset::{Dataset, WriteMode};
|
use lance::dataset::{Dataset, WriteMode};
|
||||||
use lance::index::vector::pq::PQBuildParams;
|
use lance::index::vector::pq::PQBuildParams;
|
||||||
use lance::io::object_store::{ObjectStoreParams, WrappingObjectStore};
|
use lance::io::object_store::{ObjectStoreParams, WrappingObjectStore};
|
||||||
@@ -540,6 +568,272 @@ mod tests {
|
|||||||
assert_eq!(table.name, "test");
|
assert_eq!(table.name, "test");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_update_with_predicate() {
|
||||||
|
let tmp_dir = tempdir().unwrap();
|
||||||
|
let dataset_path = tmp_dir.path().join("test.lance");
|
||||||
|
let uri = dataset_path.to_str().unwrap();
|
||||||
|
|
||||||
|
let schema = Arc::new(Schema::new(vec![
|
||||||
|
Field::new("id", DataType::Int32, false),
|
||||||
|
Field::new("name", DataType::Utf8, false),
|
||||||
|
]));
|
||||||
|
|
||||||
|
let record_batch_iter = RecordBatchIterator::new(
|
||||||
|
vec![RecordBatch::try_new(
|
||||||
|
schema.clone(),
|
||||||
|
vec![
|
||||||
|
Arc::new(Int32Array::from_iter_values(0..10)),
|
||||||
|
Arc::new(StringArray::from_iter_values(vec![
|
||||||
|
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
|
||||||
|
])),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.unwrap()]
|
||||||
|
.into_iter()
|
||||||
|
.map(Ok),
|
||||||
|
schema.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
|
Dataset::write(record_batch_iter, uri, None).await.unwrap();
|
||||||
|
let mut table = Table::open(uri).await.unwrap();
|
||||||
|
|
||||||
|
table
|
||||||
|
.update(Some("id > 5"), vec![("name", "'foo'")])
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let ds_after = Dataset::open(uri).await.unwrap();
|
||||||
|
let mut batches = ds_after
|
||||||
|
.scan()
|
||||||
|
.project(&["id", "name"])
|
||||||
|
.unwrap()
|
||||||
|
.try_into_stream()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.try_collect::<Vec<_>>()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
while let Some(batch) = batches.pop() {
|
||||||
|
let ids = batch
|
||||||
|
.column(0)
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<Int32Array>()
|
||||||
|
.unwrap()
|
||||||
|
.iter()
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let names = batch
|
||||||
|
.column(1)
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<StringArray>()
|
||||||
|
.unwrap()
|
||||||
|
.iter()
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
for (i, name) in names.iter().enumerate() {
|
||||||
|
let id = ids[i].unwrap();
|
||||||
|
let name = name.unwrap();
|
||||||
|
if id > 5 {
|
||||||
|
assert_eq!(name, "foo");
|
||||||
|
} else {
|
||||||
|
assert_eq!(name, &format!("{}", (b'a' + id as u8) as char));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_update_all_types() {
|
||||||
|
let tmp_dir = tempdir().unwrap();
|
||||||
|
let dataset_path = tmp_dir.path().join("test.lance");
|
||||||
|
let uri = dataset_path.to_str().unwrap();
|
||||||
|
|
||||||
|
let schema = Arc::new(Schema::new(vec![
|
||||||
|
Field::new("int32", DataType::Int32, false),
|
||||||
|
Field::new("int64", DataType::Int64, false),
|
||||||
|
Field::new("uint32", DataType::UInt32, false),
|
||||||
|
Field::new("string", DataType::Utf8, false),
|
||||||
|
Field::new("large_string", DataType::LargeUtf8, false),
|
||||||
|
Field::new("float32", DataType::Float32, false),
|
||||||
|
Field::new("float64", DataType::Float64, false),
|
||||||
|
Field::new("bool", DataType::Boolean, false),
|
||||||
|
Field::new("date32", DataType::Date32, false),
|
||||||
|
Field::new(
|
||||||
|
"timestamp_ns",
|
||||||
|
DataType::Timestamp(TimeUnit::Nanosecond, None),
|
||||||
|
false,
|
||||||
|
),
|
||||||
|
Field::new(
|
||||||
|
"timestamp_ms",
|
||||||
|
DataType::Timestamp(TimeUnit::Millisecond, None),
|
||||||
|
false,
|
||||||
|
),
|
||||||
|
Field::new(
|
||||||
|
"vec_f32",
|
||||||
|
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
|
||||||
|
false,
|
||||||
|
),
|
||||||
|
Field::new(
|
||||||
|
"vec_f64",
|
||||||
|
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), 2),
|
||||||
|
false,
|
||||||
|
),
|
||||||
|
]));
|
||||||
|
|
||||||
|
let record_batch_iter = RecordBatchIterator::new(
|
||||||
|
vec![RecordBatch::try_new(
|
||||||
|
schema.clone(),
|
||||||
|
vec![
|
||||||
|
Arc::new(Int32Array::from_iter_values(0..10)),
|
||||||
|
Arc::new(Int64Array::from_iter_values(0..10)),
|
||||||
|
Arc::new(UInt32Array::from_iter_values(0..10)),
|
||||||
|
Arc::new(StringArray::from_iter_values(vec![
|
||||||
|
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
|
||||||
|
])),
|
||||||
|
Arc::new(LargeStringArray::from_iter_values(vec![
|
||||||
|
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
|
||||||
|
])),
|
||||||
|
Arc::new(Float32Array::from_iter_values(
|
||||||
|
(0..10).into_iter().map(|i| i as f32),
|
||||||
|
)),
|
||||||
|
Arc::new(Float64Array::from_iter_values(
|
||||||
|
(0..10).into_iter().map(|i| i as f64),
|
||||||
|
)),
|
||||||
|
Arc::new(Into::<BooleanArray>::into(vec![
|
||||||
|
true, false, true, false, true, false, true, false, true, false,
|
||||||
|
])),
|
||||||
|
Arc::new(Date32Array::from_iter_values(0..10)),
|
||||||
|
Arc::new(TimestampNanosecondArray::from_iter_values(0..10)),
|
||||||
|
Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
|
||||||
|
Arc::new(
|
||||||
|
create_fixed_size_list(
|
||||||
|
Float32Array::from_iter_values((0..20).into_iter().map(|i| i as f32)),
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
),
|
||||||
|
Arc::new(
|
||||||
|
create_fixed_size_list(
|
||||||
|
Float64Array::from_iter_values((0..20).into_iter().map(|i| i as f64)),
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.unwrap()]
|
||||||
|
.into_iter()
|
||||||
|
.map(Ok),
|
||||||
|
schema.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
|
Dataset::write(record_batch_iter, uri, None).await.unwrap();
|
||||||
|
let mut table = Table::open(uri).await.unwrap();
|
||||||
|
|
||||||
|
// check it can do update for each type
|
||||||
|
let updates: Vec<(&str, &str)> = vec![
|
||||||
|
("string", "'foo'"),
|
||||||
|
("large_string", "'large_foo'"),
|
||||||
|
("int32", "1"),
|
||||||
|
("int64", "1"),
|
||||||
|
("uint32", "1"),
|
||||||
|
("float32", "1.0"),
|
||||||
|
("float64", "1.0"),
|
||||||
|
("bool", "true"),
|
||||||
|
("date32", "1"),
|
||||||
|
("timestamp_ns", "1"),
|
||||||
|
("timestamp_ms", "1"),
|
||||||
|
("vec_f32", "[1.0, 1.0]"),
|
||||||
|
("vec_f64", "[1.0, 1.0]"),
|
||||||
|
];
|
||||||
|
|
||||||
|
// for (column, value) in test_cases {
|
||||||
|
table.update(None, updates).await.unwrap();
|
||||||
|
|
||||||
|
let ds_after = Dataset::open(uri).await.unwrap();
|
||||||
|
let mut batches = ds_after
|
||||||
|
.scan()
|
||||||
|
.project(&[
|
||||||
|
"string",
|
||||||
|
"large_string",
|
||||||
|
"int32",
|
||||||
|
"int64",
|
||||||
|
"uint32",
|
||||||
|
"float32",
|
||||||
|
"float64",
|
||||||
|
"bool",
|
||||||
|
"date32",
|
||||||
|
"timestamp_ns",
|
||||||
|
"timestamp_ms",
|
||||||
|
"vec_f32",
|
||||||
|
"vec_f64",
|
||||||
|
])
|
||||||
|
.unwrap()
|
||||||
|
.try_into_stream()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.try_collect::<Vec<_>>()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let batch = batches.pop().unwrap();
|
||||||
|
|
||||||
|
macro_rules! assert_column {
|
||||||
|
($column:expr, $array_type:ty, $expected:expr) => {
|
||||||
|
let array = $column
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<$array_type>()
|
||||||
|
.unwrap()
|
||||||
|
.iter()
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
for v in array {
|
||||||
|
assert_eq!(v, Some($expected));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_column!(batch.column(0), StringArray, "foo");
|
||||||
|
assert_column!(batch.column(1), LargeStringArray, "large_foo");
|
||||||
|
assert_column!(batch.column(2), Int32Array, 1);
|
||||||
|
assert_column!(batch.column(3), Int64Array, 1);
|
||||||
|
assert_column!(batch.column(4), UInt32Array, 1);
|
||||||
|
assert_column!(batch.column(5), Float32Array, 1.0);
|
||||||
|
assert_column!(batch.column(6), Float64Array, 1.0);
|
||||||
|
assert_column!(batch.column(7), BooleanArray, true);
|
||||||
|
assert_column!(batch.column(8), Date32Array, 1);
|
||||||
|
assert_column!(batch.column(9), TimestampNanosecondArray, 1);
|
||||||
|
assert_column!(batch.column(10), TimestampMillisecondArray, 1);
|
||||||
|
|
||||||
|
let array = batch
|
||||||
|
.column(11)
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<FixedSizeListArray>()
|
||||||
|
.unwrap()
|
||||||
|
.iter()
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
for v in array {
|
||||||
|
let v = v.unwrap();
|
||||||
|
let f32array = v.as_any().downcast_ref::<Float32Array>().unwrap();
|
||||||
|
for v in f32array {
|
||||||
|
assert_eq!(v, Some(1.0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let array = batch
|
||||||
|
.column(12)
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<FixedSizeListArray>()
|
||||||
|
.unwrap()
|
||||||
|
.iter()
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
for v in array {
|
||||||
|
let v = v.unwrap();
|
||||||
|
let f64array = v.as_any().downcast_ref::<Float64Array>().unwrap();
|
||||||
|
for v in f64array {
|
||||||
|
assert_eq!(v, Some(1.0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_search() {
|
async fn test_search() {
|
||||||
let tmp_dir = tempdir().unwrap();
|
let tmp_dir = tempdir().unwrap();
|
||||||
@@ -554,8 +848,8 @@ mod tests {
|
|||||||
let table = Table::open(uri).await.unwrap();
|
let table = Table::open(uri).await.unwrap();
|
||||||
|
|
||||||
let vector = Float32Array::from_iter_values([0.1, 0.2]);
|
let vector = Float32Array::from_iter_values([0.1, 0.2]);
|
||||||
let query = table.search(vector.clone());
|
let query = table.search(Some(vector.clone()));
|
||||||
assert_eq!(vector, query.query_vector);
|
assert_eq!(vector, query.query_vector.unwrap());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default, Debug)]
|
#[derive(Default, Debug)]
|
||||||
|
|||||||
Reference in New Issue
Block a user