Compare commits

..

2 Commits

Author SHA1 Message Date
qzhu
8e25e0c7f0 reformatted 2023-12-07 12:08:05 -08:00
qzhu
5f989e86d2 SaaS python SDK doc 2023-12-07 12:01:03 -08:00
51 changed files with 360 additions and 1832 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 0.4.0 current_version = 0.3.9
commit = True commit = True
message = Bump version: {current_version} → {new_version} message = Bump version: {current_version} → {new_version}
tag = True tag = True

View File

@@ -1,33 +0,0 @@
name: Bug Report - Node / Typescript
description: File a bug report
title: "bug(node): "
labels: [bug, typescript]
body:
- type: markdown
attributes:
value: |
Thanks for taking the time to fill out this bug report!
- type: input
id: version
attributes:
label: LanceDB version
description: What version of LanceDB are you using? `npm list | grep vectordb`.
placeholder: v0.3.2
validations:
required: false
- type: textarea
id: what-happened
attributes:
label: What happened?
description: Also tell us, what did you expect to happen?
validations:
required: true
- type: textarea
id: reproduction
attributes:
label: Are there known steps to reproduce?
description: |
Let us know how to reproduce the bug and we may be able to fix it more
quickly. This is not required, but it is helpful.
validations:
required: false

View File

@@ -1,33 +0,0 @@
name: Bug Report - Python
description: File a bug report
title: "bug(python): "
labels: [bug, python]
body:
- type: markdown
attributes:
value: |
Thanks for taking the time to fill out this bug report!
- type: input
id: version
attributes:
label: LanceDB version
description: What version of LanceDB are you using? `python -c "import lancedb; print(lancedb.__version__)"`.
placeholder: v0.3.2
validations:
required: false
- type: textarea
id: what-happened
attributes:
label: What happened?
description: Also tell us, what did you expect to happen?
validations:
required: true
- type: textarea
id: reproduction
attributes:
label: Are there known steps to reproduce?
description: |
Let us know how to reproduce the bug and we may be able to fix it more
quickly. This is not required, but it is helpful.
validations:
required: false

View File

@@ -1,5 +0,0 @@
blank_issues_enabled: true
contact_links:
- name: Discord Community Support
url: https://discord.com/invite/zMM32dvNtd
about: Please ask and answer questions here.

View File

@@ -1,23 +0,0 @@
name: 'Documentation improvement'
description: Report an issue with the documentation.
labels: [documentation]
body:
- type: textarea
id: description
attributes:
label: Description
description: >
Describe the issue with the documentation and how it can be fixed or improved.
validations:
required: true
- type: input
id: link
attributes:
label: Link
description: >
Provide a link to the existing documentation, if applicable.
placeholder: ex. https://lancedb.github.io/lancedb/guides/tables/...
validations:
required: false

View File

@@ -1,31 +0,0 @@
name: Feature suggestion
description: Suggestion a new feature for LanceDB
title: "Feature: "
labels: [enhancement]
body:
- type: markdown
attributes:
value: |
Share a new idea for a feature or improvement. Be sure to search existing
issues first to avoid duplicates.
- type: dropdown
id: sdk
attributes:
label: SDK
description: Which SDK are you using? This helps us prioritize.
options:
- Python
- Node
- Rust
default: 0
validations:
required: false
- type: textarea
id: description
attributes:
label: Description
description: |
Describe the feature and why it would be useful. If applicable, consider
providing a code example of what it might be like to use the feature.
validations:
required: true

View File

@@ -38,17 +38,13 @@ jobs:
node/vectordb-*.tgz node/vectordb-*.tgz
node-macos: node-macos:
strategy: runs-on: macos-13
matrix:
config:
- arch: x86_64-apple-darwin
runner: macos-13
- arch: aarch64-apple-darwin
# xlarge is implicitly arm64.
runner: macos-13-xlarge
runs-on: ${{ matrix.config.runner }}
# Only runs on tags that matches the make-release action # Only runs on tags that matches the make-release action
if: startsWith(github.ref, 'refs/tags/v') if: startsWith(github.ref, 'refs/tags/v')
strategy:
fail-fast: false
matrix:
target: [x86_64-apple-darwin, aarch64-apple-darwin]
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3
@@ -58,8 +54,11 @@ jobs:
run: | run: |
cd node cd node
npm ci npm ci
- name: Install rustup target
if: ${{ matrix.target == 'aarch64-apple-darwin' }}
run: rustup target add aarch64-apple-darwin
- name: Build MacOS native node modules - name: Build MacOS native node modules
run: bash ci/build_macos_artifacts.sh ${{ matrix.config.arch }} run: bash ci/build_macos_artifacts.sh ${{ matrix.target }}
- name: Upload Darwin Artifacts - name: Upload Darwin Artifacts
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v3
with: with:
@@ -67,7 +66,6 @@ jobs:
path: | path: |
node/dist/lancedb-vectordb-darwin*.tgz node/dist/lancedb-vectordb-darwin*.tgz
node-linux: node-linux:
name: node-linux (${{ matrix.config.arch}}-unknown-linux-gnu name: node-linux (${{ matrix.config.arch}}-unknown-linux-gnu
runs-on: ${{ matrix.config.runner }} runs-on: ${{ matrix.config.runner }}

View File

@@ -44,19 +44,12 @@ jobs:
run: pytest -m "not slow" -x -v --durations=30 tests run: pytest -m "not slow" -x -v --durations=30 tests
- name: doctest - name: doctest
run: pytest --doctest-modules lancedb run: pytest --doctest-modules lancedb
platform: mac:
name: "Platform: ${{ matrix.config.name }}"
timeout-minutes: 30 timeout-minutes: 30
strategy: strategy:
matrix: matrix:
config: mac-runner: [ "macos-13", "macos-13-xlarge" ]
- name: x86 Mac runs-on: "${{ matrix.mac-runner }}"
runner: macos-13
- name: Arm Mac
runner: macos-13-xlarge
- name: x86 Windows
runner: windows-latest
runs-on: "${{ matrix.config.runner }}"
defaults: defaults:
run: run:
shell: bash shell: bash
@@ -98,7 +91,11 @@ jobs:
pip install "pydantic<2" pip install "pydantic<2"
pip install -e .[tests] pip install -e .[tests]
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985 pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
pip install pytest pytest-mock pip install pytest pytest-mock black isort
- name: Black
run: black --check --diff --no-color --quiet .
- name: isort
run: isort --check --diff --quiet .
- name: Run tests - name: Run tests
run: pytest -m "not slow" -x -v --durations=30 tests run: pytest -m "not slow" -x -v --durations=30 tests
- name: doctest - name: doctest

View File

@@ -24,29 +24,6 @@ env:
RUST_BACKTRACE: "1" RUST_BACKTRACE: "1"
jobs: jobs:
lint:
timeout-minutes: 30
runs-on: ubuntu-22.04
defaults:
run:
shell: bash
working-directory: rust
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
lfs: true
- uses: Swatinem/rust-cache@v2
with:
workspaces: rust
- name: Install dependencies
run: |
sudo apt update
sudo apt install -y protobuf-compiler libssl-dev
- name: Run format
run: cargo fmt --all -- --check
- name: Run clippy
run: cargo clippy --all --all-features -- -D warnings
linux: linux:
timeout-minutes: 30 timeout-minutes: 30
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04

View File

@@ -5,24 +5,24 @@ exclude = ["python"]
resolver = "2" resolver = "2"
[workspace.dependencies] [workspace.dependencies]
lance = { "version" = "=0.9.1", "features" = ["dynamodb"] } lance = { "version" = "=0.8.17", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.9.1" } lance-index = { "version" = "=0.8.17" }
lance-linalg = { "version" = "=0.9.1" } lance-linalg = { "version" = "=0.8.17" }
lance-testing = { "version" = "=0.9.1" } lance-testing = { "version" = "=0.8.17" }
# Note that this one does not include pyarrow # Note that this one does not include pyarrow
arrow = { version = "49.0.0", optional = false } arrow = { version = "47.0.0", optional = false }
arrow-array = "49.0" arrow-array = "47.0"
arrow-data = "49.0" arrow-data = "47.0"
arrow-ipc = "49.0" arrow-ipc = "47.0"
arrow-ord = "49.0" arrow-ord = "47.0"
arrow-schema = "49.0" arrow-schema = "47.0"
arrow-arith = "49.0" arrow-arith = "47.0"
arrow-cast = "49.0" arrow-cast = "47.0"
chrono = "0.4.23" chrono = "0.4.23"
half = { "version" = "=2.3.1", default-features = false, features = [ half = { "version" = "=2.3.1", default-features = false, features = [
"num-traits", "num-traits",
] } ] }
log = "0.4" log = "0.4"
object_store = "0.8.0" object_store = "0.7.1"
snafu = "0.7.4" snafu = "0.7.4"
url = "2" url = "2"

View File

@@ -5,11 +5,10 @@
**Developer-friendly, serverless vector database for AI applications** **Developer-friendly, serverless vector database for AI applications**
<a href='https://github.com/lancedb/vectordb-recipes/tree/main' target="_blank"><img alt='LanceDB' src='https://img.shields.io/badge/VectorDB_Recipes-100000?style=for-the-badge&logo=LanceDB&logoColor=white&labelColor=645cfb&color=645cfb'/></a> <a href="https://lancedb.github.io/lancedb/">Documentation</a>
<a href='https://lancedb.github.io/lancedb/' target="_blank"><img alt='lancdb' src='https://img.shields.io/badge/DOCS-100000?style=for-the-badge&logo=lancdb&logoColor=white&labelColor=645cfb&color=645cfb'/></a> <a href="https://blog.lancedb.com/">Blog</a>
[![Medium](https://img.shields.io/badge/Medium-12100E?style=for-the-badge&logo=medium&logoColor=white)](https://blog.lancedb.com/) <a href="https://discord.gg/zMM32dvNtd">Discord</a>
[![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/zMM32dvNtd) <a href="https://twitter.com/lancedb">Twitter</a>
[![Twitter](https://img.shields.io/badge/Twitter-%231DA1F2.svg?style=for-the-badge&logo=Twitter&logoColor=white)](https://twitter.com/lancedb)
</p> </p>

View File

@@ -80,6 +80,7 @@ nav:
- Ingest Embedding Functions: embeddings/embedding_functions.md - Ingest Embedding Functions: embeddings/embedding_functions.md
- Available Functions: embeddings/default_embedding_functions.md - Available Functions: embeddings/default_embedding_functions.md
- Create Custom Embedding Functions: embeddings/api.md - Create Custom Embedding Functions: embeddings/api.md
- Example - Calculate CLIP Embeddings with Roboflow Inference: examples/image_embeddings_roboflow.md
- Example - Multi-lingual semantic search: notebooks/multi_lingual_example.ipynb - Example - Multi-lingual semantic search: notebooks/multi_lingual_example.ipynb
- Example - MultiModal CLIP Embeddings: notebooks/DisappearingEmbeddingFunction.ipynb - Example - MultiModal CLIP Embeddings: notebooks/DisappearingEmbeddingFunction.ipynb
- 🔍 Python full-text search: fts.md - 🔍 Python full-text search: fts.md
@@ -98,7 +99,6 @@ nav:
- YouTube Transcript Search: notebooks/youtube_transcript_search.ipynb - YouTube Transcript Search: notebooks/youtube_transcript_search.ipynb
- Documentation QA Bot using LangChain: notebooks/code_qa_bot.ipynb - Documentation QA Bot using LangChain: notebooks/code_qa_bot.ipynb
- Multimodal search using CLIP: notebooks/multimodal_search.ipynb - Multimodal search using CLIP: notebooks/multimodal_search.ipynb
- Example - Calculate CLIP Embeddings with Roboflow Inference: examples/image_embeddings_roboflow.md
- Serverless QA Bot with S3 and Lambda: examples/serverless_lancedb_with_s3_and_lambda.md - Serverless QA Bot with S3 and Lambda: examples/serverless_lancedb_with_s3_and_lambda.md
- Serverless QA Bot with Modal: examples/serverless_qa_bot_with_modal_and_langchain.md - Serverless QA Bot with Modal: examples/serverless_qa_bot_with_modal_and_langchain.md
- 🌐 Javascript examples: - 🌐 Javascript examples:

View File

@@ -2,4 +2,3 @@ mkdocs==1.4.2
mkdocs-jupyter==0.24.1 mkdocs-jupyter==0.24.1
mkdocs-material==9.1.3 mkdocs-material==9.1.3
mkdocstrings[python]==0.20.0 mkdocstrings[python]==0.20.0
pydantic

View File

@@ -64,26 +64,18 @@ We'll cover the basics of using LanceDB on your local machine in this section.
tbl = db.create_table("table_from_df", data=df) tbl = db.create_table("table_from_df", data=df)
``` ```
!!! warning
If the table already exists, LanceDB will raise an error by default.
If you want to overwrite the table, you can pass in `mode="overwrite"`
to the `createTable` function.
=== "Javascript" === "Javascript"
```javascript ```javascript
const tb = await db.createTable( const tb = await db.createTable("my_table",
"myTable", data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}])
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}])
``` ```
!!! warning !!! warning
If the table already exists, LanceDB will raise an error by default.
If you want to overwrite the table, you can pass in `"overwrite"`
to the `createTable` function like this: `await con.createTable(tableName, data, { writeMode: WriteMode.Overwrite })`
If the table already exists, LanceDB will raise an error by default.
If you want to overwrite the table, you can pass in `mode="overwrite"`
to the `createTable` function.
??? info "Under the hood, LanceDB is converting the input data into an Apache Arrow table and persisting it to disk in [Lance format](https://www.github.com/lancedb/lance)." ??? info "Under the hood, LanceDB is converting the input data into an Apache Arrow table and persisting it to disk in [Lance format](https://www.github.com/lancedb/lance)."
@@ -116,7 +108,7 @@ Once created, you can open a table using the following code:
=== "Javascript" === "Javascript"
```javascript ```javascript
const tbl = await db.openTable("myTable"); const tbl = await db.openTable("my_table");
``` ```
If you forget the name of your table, you can always get a listing of all table names: If you forget the name of your table, you can always get a listing of all table names:
@@ -202,17 +194,10 @@ Use the `drop_table()` method on the database to remove a table.
db.drop_table("my_table") db.drop_table("my_table")
``` ```
This permanently removes the table and is not recoverable, unlike deleting rows. This permanently removes the table and is not recoverable, unlike deleting rows.
By default, if the table does not exist an exception is raised. To suppress this, By default, if the table does not exist an exception is raised. To suppress this,
you can pass in `ignore_missing=True`. you can pass in `ignore_missing=True`.
=== "JavaScript"
```javascript
await db.dropTable('myTable')
```
This permanently removes the table and is not recoverable, unlike deleting rows.
If the table does not exist an exception is raised.
## What's next ## What's next

View File

@@ -201,8 +201,8 @@ A Table is a collection of Records in a LanceDB Database. You can follow along o
```javascript ```javascript
data data
const tb = await db.createTable("my_table", const tb = await db.createTable("my_table",
[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}]) {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}])
``` ```
!!! info "Note" !!! info "Note"

View File

@@ -119,100 +119,3 @@ This is why it is often called **Approximate Nearest Neighbors (ANN)** search, w
always returns 100% recall. always returns 100% recall.
See [ANN Index](ann_indexes.md) for more details. See [ANN Index](ann_indexes.md) for more details.
### Output formats
LanceDB returns results in many different formats commonly used in python.
Let's create a LanceDB table with a nested schema:
```python
from datetime import datetime
import lancedb
from lancedb.pydantic import LanceModel, Vector
import numpy as np
from pydantic import BaseModel
uri = "data/sample-lancedb-nested"
class Metadata(BaseModel):
source: str
timestamp: datetime
class Document(BaseModel):
content: str
meta: Metadata
class LanceSchema(LanceModel):
id: str
vector: Vector(1536)
payload: Document
# Let's add 100 sample rows to our dataset
data = [LanceSchema(
id=f"id{i}",
vector=np.random.randn(1536),
payload=Document(
content=f"document{i}", meta=Metadata(source=f"source{i%10}", timestamp=datetime.now())
),
) for i in range(100)]
tbl = db.create_table("documents", data=data)
```
#### As a pyarrow table
Using `to_arrow()` we can get the results back as a pyarrow Table.
This result table has the same columns as the LanceDB table, with
the addition of an `_distance` column for vector search or a `score`
column for full text search.
```python
tbl.search(np.random.randn(1536)).to_arrow()
```
#### As a pandas dataframe
You can also get the results as a pandas dataframe.
```python
tbl.search(np.random.randn(1536)).to_pandas()
```
While other formats like Arrow/Pydantic/Python dicts have a natural
way to handle nested schemas, pandas can only store nested data as a
python dict column, which makes it difficult to support nested references.
So for convenience, you can also tell LanceDB to flatten a nested schema
when creating the pandas dataframe.
```python
tbl.search(np.random.randn(1536)).to_pandas(flatten=True)
```
If your table has a deeply nested struct, you can control how many levels
of nesting to flatten by passing in a positive integer.
```python
tbl.search(np.random.randn(1536)).to_pandas(flatten=1)
```
#### As a list of python dicts
You can of course return results as a list of python dicts.
```python
tbl.search(np.random.randn(1536)).to_list()
```
#### As a list of pydantic models
We can add data using pydantic models, and we can certainly
retrieve results as pydantic models
```python
tbl.search(np.random.randn(1536)).to_pydantic(LanceSchema)
```
Note that in this case the extra `_distance` field is discarded since
it's not part of the LanceSchema.

View File

@@ -22,7 +22,7 @@ import numpy as np
uri = "data/sample-lancedb" uri = "data/sample-lancedb"
db = lancedb.connect(uri) db = lancedb.connect(uri)
data = [{"vector": row, "item": f"item {i}", "id": i} data = [{"vector": row, "item": f"item {i}"}
for i, row in enumerate(np.random.random((10_000, 2)).astype('int'))] for i, row in enumerate(np.random.random((10_000, 2)).astype('int'))]
tbl = db.create_table("my_vectors", data=data) tbl = db.create_table("my_vectors", data=data)
@@ -35,25 +35,33 @@ const db = await vectordb.connect('data/sample-lancedb')
let data = [] let data = []
for (let i = 0; i < 10_000; i++) { for (let i = 0; i < 10_000; i++) {
data.push({vector: Array(1536).fill(i), id: i, item: `item ${i}`, strId: `${i}`}) data.push({vector: Array(1536).fill(i), id: `${i}`, content: "", longId: `${i}`},)
} }
const tbl = await db.createTable('myVectors', data) const tbl = await db.createTable('my_vectors', data)
``` ```
--> -->
=== "Python" === "Python"
```python ```python
tbl.search([100, 102]) \ tbl.search([100, 102]) \
.where("(item IN ('item 0', 'item 2')) AND (id > 10)") \ .where("""(
.to_arrow() (label IN [10, 20])
``` AND
(note.email IS NOT NULL)
) OR NOT note.created
""")
```
=== "Javascript" === "Javascript"
```javascript ```javascript
await tbl.search(Array(1536).fill(0)) tbl.search([100, 102])
.where("(item IN ('item 0', 'item 2')) AND (id > 10)") .where(`(
.execute() (label IN [10, 20])
AND
(note.email IS NOT NULL)
) OR NOT note.created
`)
``` ```
@@ -110,22 +118,3 @@ The mapping from SQL types to Arrow types is:
[^1]: See precision mapping in previous table. [^1]: See precision mapping in previous table.
## Filtering without Vector Search
You can also filter your data without search.
=== "Python"
```python
tbl.search().where("id=10").limit(10).to_arrow()
```
=== "JavaScript"
```javascript
await tbl.where('id=10').limit(10).execute()
```
!!! warning
If your table is large, this could potentially return a very large
amount of data. Please be sure to use a `limit` clause unless
you're sure you want to return the whole result set.

80
node/package-lock.json generated
View File

@@ -1,12 +1,12 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.4.0", "version": "0.3.9",
"lockfileVersion": 2, "lockfileVersion": 2,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "vectordb", "name": "vectordb",
"version": "0.4.0", "version": "0.3.9",
"cpu": [ "cpu": [
"x64", "x64",
"arm64" "arm64"
@@ -53,11 +53,11 @@
"uuid": "^9.0.0" "uuid": "^9.0.0"
}, },
"optionalDependencies": { "optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.4.0", "@lancedb/vectordb-darwin-arm64": "0.3.9",
"@lancedb/vectordb-darwin-x64": "0.4.0", "@lancedb/vectordb-darwin-x64": "0.3.9",
"@lancedb/vectordb-linux-arm64-gnu": "0.4.0", "@lancedb/vectordb-linux-arm64-gnu": "0.3.9",
"@lancedb/vectordb-linux-x64-gnu": "0.4.0", "@lancedb/vectordb-linux-x64-gnu": "0.3.9",
"@lancedb/vectordb-win32-x64-msvc": "0.4.0" "@lancedb/vectordb-win32-x64-msvc": "0.3.9"
} }
}, },
"node_modules/@apache-arrow/ts": { "node_modules/@apache-arrow/ts": {
@@ -316,22 +316,10 @@
"@jridgewell/sourcemap-codec": "^1.4.10" "@jridgewell/sourcemap-codec": "^1.4.10"
} }
}, },
"node_modules/@lancedb/vectordb-darwin-arm64": {
"version": "0.4.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.4.0.tgz",
"integrity": "sha512-cP6zGtBWXEcJHCI4uLNIP5ILtRvexvwmL8Uri1dnHG8dT8g12Ykug3BHO6Wt6wp/xASd2jJRIF/VAJsN9IeP1A==",
"cpu": [
"arm64"
],
"optional": true,
"os": [
"darwin"
]
},
"node_modules/@lancedb/vectordb-darwin-x64": { "node_modules/@lancedb/vectordb-darwin-x64": {
"version": "0.4.0", "version": "0.3.9",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.4.0.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.9.tgz",
"integrity": "sha512-ig0gV5ol1sFe2lb1HOatK0rizyj9I91WbnH79i7OdUl3nAQIcWm70CnxrPLtx0DS2NTGh2kFJbYCWcaUlu6YfA==", "integrity": "sha512-4xXQoPheyIl1P5kRoKmZtaAHFrYdL9pw5yq+r6ewIx0TCemN4LSvzSUTqM5nZl3QPU8FeL0CGD8Gt2gMU0HQ2A==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
@@ -341,9 +329,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-linux-arm64-gnu": { "node_modules/@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.4.0", "version": "0.3.9",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.4.0.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.9.tgz",
"integrity": "sha512-gMXIDT2kriAPDwWIRKXdaTCNdOeFGEok1S9Y30AOruHXddW1vCIo4JNJIYbBqHnwAeI4wI3ae6GRCFaf1UxO3g==", "integrity": "sha512-WIxCZKnLeSlz0PGURtKSX6hJ4CYE2o5P+IFmmuWOWB1uNapQu6zOpea6rNxcRFHUA0IJdO02lVxVfn2hDX4SMg==",
"cpu": [ "cpu": [
"arm64" "arm64"
], ],
@@ -353,9 +341,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-linux-x64-gnu": { "node_modules/@lancedb/vectordb-linux-x64-gnu": {
"version": "0.4.0", "version": "0.3.9",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.4.0.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.9.tgz",
"integrity": "sha512-ZQ3lDrDSz1IKdx/mS9Lz08agFO+OD5oSFrrcFNCoT1+H93eS1mCLdmCoEARu3jKbx0tMs38l5J9yXZ2QmJye3w==", "integrity": "sha512-bQbcV9adKzYbJLNzDjk9OYsMnT2IjmieLfb4IQ1hj5IUoWfbg80Bd0+gZUnrmrhG6fe56TIriFZYQR9i7TSE9Q==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
@@ -365,9 +353,9 @@
] ]
}, },
"node_modules/@lancedb/vectordb-win32-x64-msvc": { "node_modules/@lancedb/vectordb-win32-x64-msvc": {
"version": "0.4.0", "version": "0.3.9",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.4.0.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.9.tgz",
"integrity": "sha512-toNcNwBRE1sdsSf5hr7W8QiqZ33csc/knVEek4CyvYkZHJGh4Z6WI+DJUIASo5wzUez4TX7qUPpRPL9HuaPMCg==", "integrity": "sha512-7EXI7P1QvAfgJNPWWBMDOkoJ696gSBAClcyEJNYg0JV21jVFZRwJVI3bZXflesWduFi/mTuzPkFFA68us1u19A==",
"cpu": [ "cpu": [
"x64" "x64"
], ],
@@ -4868,34 +4856,28 @@
"@jridgewell/sourcemap-codec": "^1.4.10" "@jridgewell/sourcemap-codec": "^1.4.10"
} }
}, },
"@lancedb/vectordb-darwin-arm64": {
"version": "0.4.0",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.4.0.tgz",
"integrity": "sha512-cP6zGtBWXEcJHCI4uLNIP5ILtRvexvwmL8Uri1dnHG8dT8g12Ykug3BHO6Wt6wp/xASd2jJRIF/VAJsN9IeP1A==",
"optional": true
},
"@lancedb/vectordb-darwin-x64": { "@lancedb/vectordb-darwin-x64": {
"version": "0.4.0", "version": "0.3.9",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.4.0.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.9.tgz",
"integrity": "sha512-ig0gV5ol1sFe2lb1HOatK0rizyj9I91WbnH79i7OdUl3nAQIcWm70CnxrPLtx0DS2NTGh2kFJbYCWcaUlu6YfA==", "integrity": "sha512-4xXQoPheyIl1P5kRoKmZtaAHFrYdL9pw5yq+r6ewIx0TCemN4LSvzSUTqM5nZl3QPU8FeL0CGD8Gt2gMU0HQ2A==",
"optional": true "optional": true
}, },
"@lancedb/vectordb-linux-arm64-gnu": { "@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.4.0", "version": "0.3.9",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.4.0.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.9.tgz",
"integrity": "sha512-gMXIDT2kriAPDwWIRKXdaTCNdOeFGEok1S9Y30AOruHXddW1vCIo4JNJIYbBqHnwAeI4wI3ae6GRCFaf1UxO3g==", "integrity": "sha512-WIxCZKnLeSlz0PGURtKSX6hJ4CYE2o5P+IFmmuWOWB1uNapQu6zOpea6rNxcRFHUA0IJdO02lVxVfn2hDX4SMg==",
"optional": true "optional": true
}, },
"@lancedb/vectordb-linux-x64-gnu": { "@lancedb/vectordb-linux-x64-gnu": {
"version": "0.4.0", "version": "0.3.9",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.4.0.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.9.tgz",
"integrity": "sha512-ZQ3lDrDSz1IKdx/mS9Lz08agFO+OD5oSFrrcFNCoT1+H93eS1mCLdmCoEARu3jKbx0tMs38l5J9yXZ2QmJye3w==", "integrity": "sha512-bQbcV9adKzYbJLNzDjk9OYsMnT2IjmieLfb4IQ1hj5IUoWfbg80Bd0+gZUnrmrhG6fe56TIriFZYQR9i7TSE9Q==",
"optional": true "optional": true
}, },
"@lancedb/vectordb-win32-x64-msvc": { "@lancedb/vectordb-win32-x64-msvc": {
"version": "0.4.0", "version": "0.3.9",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.4.0.tgz", "resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.9.tgz",
"integrity": "sha512-toNcNwBRE1sdsSf5hr7W8QiqZ33csc/knVEek4CyvYkZHJGh4Z6WI+DJUIASo5wzUez4TX7qUPpRPL9HuaPMCg==", "integrity": "sha512-7EXI7P1QvAfgJNPWWBMDOkoJ696gSBAClcyEJNYg0JV21jVFZRwJVI3bZXflesWduFi/mTuzPkFFA68us1u19A==",
"optional": true "optional": true
}, },
"@neon-rs/cli": { "@neon-rs/cli": {

View File

@@ -1,6 +1,6 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.4.0", "version": "0.3.9",
"description": " Serverless, low-latency vector database for AI applications", "description": " Serverless, low-latency vector database for AI applications",
"main": "dist/index.js", "main": "dist/index.js",
"types": "dist/index.d.ts", "types": "dist/index.d.ts",
@@ -81,10 +81,10 @@
} }
}, },
"optionalDependencies": { "optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.4.0", "@lancedb/vectordb-darwin-arm64": "0.3.9",
"@lancedb/vectordb-darwin-x64": "0.4.0", "@lancedb/vectordb-darwin-x64": "0.3.9",
"@lancedb/vectordb-linux-arm64-gnu": "0.4.0", "@lancedb/vectordb-linux-arm64-gnu": "0.3.9",
"@lancedb/vectordb-linux-x64-gnu": "0.4.0", "@lancedb/vectordb-linux-x64-gnu": "0.3.9",
"@lancedb/vectordb-win32-x64-msvc": "0.4.0" "@lancedb/vectordb-win32-x64-msvc": "0.3.9"
} }
} }

View File

@@ -21,10 +21,9 @@ import type { EmbeddingFunction } from './embedding/embedding_function'
import { RemoteConnection } from './remote' import { RemoteConnection } from './remote'
import { Query } from './query' import { Query } from './query'
import { isEmbeddingFunction } from './embedding/embedding_function' import { isEmbeddingFunction } from './embedding/embedding_function'
import { type Literal, toSQL } from './util'
// eslint-disable-next-line @typescript-eslint/no-var-requires // eslint-disable-next-line @typescript-eslint/no-var-requires
const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, tableCreate, tableAdd, tableCreateVectorIndex, tableCountRows, tableDelete, tableUpdate, tableCleanupOldVersions, tableCompactFiles, tableListIndices, tableIndexStats } = require('../native.js') const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, tableCreate, tableAdd, tableCreateVectorIndex, tableCountRows, tableDelete, tableCleanupOldVersions, tableCompactFiles, tableListIndices, tableIndexStats } = require('../native.js')
export { Query } export { Query }
export type { EmbeddingFunction } export type { EmbeddingFunction }
@@ -262,39 +261,6 @@ export interface Table<T = number[]> {
*/ */
delete: (filter: string) => Promise<void> delete: (filter: string) => Promise<void>
/**
* Update rows in this table.
*
* This can be used to update a single row, many rows, all rows, or
* sometimes no rows (if your predicate matches nothing).
*
* @param args see {@link UpdateArgs} and {@link UpdateSqlArgs} for more details
*
* @examples
*
* ```ts
* const con = await lancedb.connect("./.lancedb")
* const data = [
* {id: 1, vector: [3, 3], name: 'Ye'},
* {id: 2, vector: [4, 4], name: 'Mike'},
* ];
* const tbl = await con.createTable("my_table", data)
*
* await tbl.update({
* filter: "id = 2",
* updates: { vector: [2, 2], name: "Michael" },
* })
*
* let results = await tbl.search([1, 1]).execute();
* // Returns [
* // {id: 2, vector: [2, 2], name: 'Michael'}
* // {id: 1, vector: [3, 3], name: 'Ye'}
* // ]
* ```
*
*/
update: (args: UpdateArgs | UpdateSqlArgs) => Promise<void>
/** /**
* List the indicies on this table. * List the indicies on this table.
*/ */
@@ -306,34 +272,6 @@ export interface Table<T = number[]> {
indexStats: (indexUuid: string) => Promise<IndexStats> indexStats: (indexUuid: string) => Promise<IndexStats>
} }
export interface UpdateArgs {
/**
* A filter in the same format used by a sql WHERE clause. The filter may be empty,
* in which case all rows will be updated.
*/
where?: string
/**
* A key-value map of updates. The keys are the column names, and the values are the
* new values to set
*/
values: Record<string, Literal>
}
export interface UpdateSqlArgs {
/**
* A filter in the same format used by a sql WHERE clause. The filter may be empty,
* in which case all rows will be updated.
*/
where?: string
/**
* A key-value map of updates. The keys are the column names, and the values are the
* new values to set as SQL expressions.
*/
valuesSql: Record<string, string>
}
export interface VectorIndex { export interface VectorIndex {
columns: string[] columns: string[]
name: string name: string
@@ -488,16 +426,6 @@ export class LocalTable<T = number[]> implements Table<T> {
return new Query(query, this._tbl, this._embeddings) return new Query(query, this._tbl, this._embeddings)
} }
/**
* Creates a filter query to find all rows matching the specified criteria
* @param value The filter criteria (like SQL where clause syntax)
*/
filter (value: string): Query<T> {
return new Query(undefined, this._tbl, this._embeddings).filter(value)
}
where = this.filter
/** /**
* Insert records into this Table. * Insert records into this Table.
* *
@@ -553,31 +481,6 @@ export class LocalTable<T = number[]> implements Table<T> {
return tableDelete.call(this._tbl, filter).then((newTable: any) => { this._tbl = newTable }) return tableDelete.call(this._tbl, filter).then((newTable: any) => { this._tbl = newTable })
} }
/**
* Update rows in this table.
*
* @param args see {@link UpdateArgs} and {@link UpdateSqlArgs} for more details
*
* @returns
*/
async update (args: UpdateArgs | UpdateSqlArgs): Promise<void> {
let filter: string | null
let updates: Record<string, string>
if ('valuesSql' in args) {
filter = args.where ?? null
updates = args.valuesSql
} else {
filter = args.where ?? null
updates = {}
for (const [key, value] of Object.entries(args.values)) {
updates[key] = toSQL(value)
}
}
return tableUpdate.call(this._tbl, filter, updates).then((newTable: any) => { this._tbl = newTable })
}
/** /**
* Clean up old versions of the table, freeing disk space. * Clean up old versions of the table, freeing disk space.
* *
@@ -744,11 +647,6 @@ export interface IvfPQIndexConfig {
*/ */
replace?: boolean replace?: boolean
/**
* Cache size of the index
*/
index_cache_size?: number
type: 'ivf_pq' type: 'ivf_pq'
} }

View File

@@ -23,10 +23,10 @@ const { tableSearch } = require('../native.js')
* A builder for nearest neighbor queries for LanceDB. * A builder for nearest neighbor queries for LanceDB.
*/ */
export class Query<T = number[]> { export class Query<T = number[]> {
private readonly _query?: T private readonly _query: T
private readonly _tbl?: any private readonly _tbl?: any
private _queryVector?: number[] private _queryVector?: number[]
private _limit?: number private _limit: number
private _refineFactor?: number private _refineFactor?: number
private _nprobes: number private _nprobes: number
private _select?: string[] private _select?: string[]
@@ -35,10 +35,10 @@ export class Query<T = number[]> {
private _prefilter: boolean private _prefilter: boolean
protected readonly _embeddings?: EmbeddingFunction<T> protected readonly _embeddings?: EmbeddingFunction<T>
constructor (query?: T, tbl?: any, embeddings?: EmbeddingFunction<T>) { constructor (query: T, tbl?: any, embeddings?: EmbeddingFunction<T>) {
this._tbl = tbl this._tbl = tbl
this._query = query this._query = query
this._limit = undefined this._limit = 10
this._nprobes = 20 this._nprobes = 20
this._refineFactor = undefined this._refineFactor = undefined
this._select = undefined this._select = undefined
@@ -113,12 +113,10 @@ export class Query<T = number[]> {
* Execute the query and return the results as an Array of Objects * Execute the query and return the results as an Array of Objects
*/ */
async execute<T = Record<string, unknown>> (): Promise<T[]> { async execute<T = Record<string, unknown>> (): Promise<T[]> {
if (this._query !== undefined) { if (this._embeddings !== undefined) {
if (this._embeddings !== undefined) { this._queryVector = (await this._embeddings.embed([this._query]))[0]
this._queryVector = (await this._embeddings.embed([this._query]))[0] } else {
} else { this._queryVector = this._query as number[]
this._queryVector = this._query as number[]
}
} }
const isElectron = this.isElectron() const isElectron = this.isElectron()

View File

@@ -16,8 +16,7 @@ import {
type EmbeddingFunction, type Table, type VectorIndexParams, type Connection, type EmbeddingFunction, type Table, type VectorIndexParams, type Connection,
type ConnectionOptions, type CreateTableOptions, type VectorIndex, type ConnectionOptions, type CreateTableOptions, type VectorIndex,
type WriteOptions, type WriteOptions,
type IndexStats, type IndexStats
type UpdateArgs, type UpdateSqlArgs
} from '../index' } from '../index'
import { Query } from '../query' import { Query } from '../query'
@@ -25,7 +24,6 @@ import { Vector, Table as ArrowTable } from 'apache-arrow'
import { HttpLancedbClient } from './client' import { HttpLancedbClient } from './client'
import { isEmbeddingFunction } from '../embedding/embedding_function' import { isEmbeddingFunction } from '../embedding/embedding_function'
import { createEmptyTable, fromRecordsToStreamBuffer, fromTableToStreamBuffer } from '../arrow' import { createEmptyTable, fromRecordsToStreamBuffer, fromTableToStreamBuffer } from '../arrow'
import { toSQL } from '../util'
/** /**
* Remote connection. * Remote connection.
@@ -57,8 +55,8 @@ export class RemoteConnection implements Connection {
return 'db://' + this._client.uri return 'db://' + this._client.uri
} }
async tableNames (pageToken: string = '', limit: number = 10): Promise<string[]> { async tableNames (): Promise<string[]> {
const response = await this._client.get('/v1/table/', { limit, page_token: pageToken }) const response = await this._client.get('/v1/table/')
return response.data.tables return response.data.tables
} }
@@ -195,17 +193,6 @@ export class RemoteTable<T = number[]> implements Table<T> {
return this._name return this._name
} }
get schema (): Promise<any> {
return this._client.post(`/v1/table/${this._name}/describe/`).then(res => {
if (res.status !== 200) {
throw new Error(`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`)
}
return res.data?.schema
})
}
search (query: T): Query<T> { search (query: T): Query<T> {
return new RemoteQuery(query, this._client, this._name)//, this._embeddings_new) return new RemoteQuery(query, this._client, this._name)//, this._embeddings_new)
} }
@@ -246,41 +233,8 @@ export class RemoteTable<T = number[]> implements Table<T> {
return data.length return data.length
} }
async createIndex (indexParams: VectorIndexParams): Promise<void> { async createIndex (indexParams: VectorIndexParams): Promise<any> {
const unsupportedParams = [ throw new Error('Not implemented')
'index_name',
'num_partitions',
'max_iters',
'use_opq',
'num_sub_vectors',
'num_bits',
'max_opq_iters',
'replace'
]
for (const param of unsupportedParams) {
// eslint-disable-next-line @typescript-eslint/strict-boolean-expressions
if (indexParams[param as keyof VectorIndexParams]) {
throw new Error(`${param} is not supported for remote connections`)
}
}
const column = indexParams.column ?? 'vector'
const indexType = 'vector' // only vector index is supported for remote connections
const metricType = indexParams.metric_type ?? 'L2'
const indexCacheSize = indexParams ?? null
const data = {
column,
index_type: indexType,
metric_type: metricType,
index_cache_size: indexCacheSize
}
const res = await this._client.post(`/v1/table/${this._name}/create_index/`, data)
if (res.status !== 200) {
throw new Error(`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`)
}
} }
async countRows (): Promise<number> { async countRows (): Promise<number> {
@@ -292,26 +246,6 @@ export class RemoteTable<T = number[]> implements Table<T> {
await this._client.post(`/v1/table/${this._name}/delete/`, { predicate: filter }) await this._client.post(`/v1/table/${this._name}/delete/`, { predicate: filter })
} }
async update (args: UpdateArgs | UpdateSqlArgs): Promise<void> {
let filter: string | null
let updates: Record<string, string>
if ('valuesSql' in args) {
filter = args.where ?? null
updates = args.valuesSql
} else {
filter = args.where ?? null
updates = {}
for (const [key, value] of Object.entries(args.values)) {
updates[key] = toSQL(value)
}
}
await this._client.post(`/v1/table/${this._name}/update/`, {
predicate: filter,
updates: Object.entries(updates).map(([key, value]) => [key, value])
})
}
async listIndices (): Promise<VectorIndex[]> { async listIndices (): Promise<VectorIndex[]> {
const results = await this._client.post(`/v1/table/${this._name}/index/list/`) const results = await this._client.post(`/v1/table/${this._name}/index/list/`)
return results.data.indexes?.map((index: any) => ({ return results.data.indexes?.map((index: any) => ({

View File

@@ -78,31 +78,12 @@ describe('LanceDB client', function () {
}) })
it('limits # of results', async function () { it('limits # of results', async function () {
const uri = await createTestDB(2, 100) const uri = await createTestDB()
const con = await lancedb.connect(uri) const con = await lancedb.connect(uri)
const table = await con.openTable('vectors') const table = await con.openTable('vectors')
let results = await table.search([0.1, 0.3]).limit(1).execute() const results = await table.search([0.1, 0.3]).limit(1).execute()
assert.equal(results.length, 1) assert.equal(results.length, 1)
assert.equal(results[0].id, 1) assert.equal(results[0].id, 1)
// there is a default limit if unspecified
results = await table.search([0.1, 0.3]).execute()
assert.equal(results.length, 10)
})
it('uses a filter / where clause without vector search', async function () {
// eslint-disable-next-line @typescript-eslint/explicit-function-return-type
const assertResults = (results: Array<Record<string, unknown>>) => {
assert.equal(results.length, 50)
}
const uri = await createTestDB(2, 100)
const con = await lancedb.connect(uri)
const table = (await con.openTable('vectors')) as LocalTable
let results = await table.filter('id % 2 = 0').execute()
assertResults(results)
results = await table.where('id % 2 = 0').execute()
assertResults(results)
}) })
it('uses a filter / where clause', async function () { it('uses a filter / where clause', async function () {
@@ -279,46 +260,6 @@ describe('LanceDB client', function () {
assert.equal(await table.countRows(), 2) assert.equal(await table.countRows(), 2)
}) })
it('can update records in the table', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
assert.equal(await table.countRows(), 2)
await table.update({ where: 'price = 10', valuesSql: { price: '100' } })
const results = await table.search([0.1, 0.2]).execute()
assert.equal(results[0].price, 100)
assert.equal(results[1].price, 11)
})
it('can update the records using a literal value', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
assert.equal(await table.countRows(), 2)
await table.update({ where: 'price = 10', values: { price: 100 } })
const results = await table.search([0.1, 0.2]).execute()
assert.equal(results[0].price, 100)
assert.equal(results[1].price, 11)
})
it('can update every record in the table', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
assert.equal(await table.countRows(), 2)
await table.update({ valuesSql: { price: '100' } })
const results = await table.search([0.1, 0.2]).execute()
assert.equal(results[0].price, 100)
assert.equal(results[1].price, 100)
})
it('can delete records from a table', async function () { it('can delete records from a table', async function () {
const uri = await createTestDB() const uri = await createTestDB()
const con = await lancedb.connect(uri) const con = await lancedb.connect(uri)
@@ -601,7 +542,7 @@ describe('Compact and cleanup', function () {
// should have no effect, but this validates the arguments are parsed. // should have no effect, but this validates the arguments are parsed.
await table.compactFiles({ await table.compactFiles({
targetRowsPerFragment: 102410, targetRowsPerFragment: 1024 * 10,
maxRowsPerGroup: 1024, maxRowsPerGroup: 1024,
materializeDeletions: true, materializeDeletions: true,
materializeDeletionsThreshold: 0.5, materializeDeletionsThreshold: 0.5,

View File

@@ -1,45 +0,0 @@
// Copyright 2023 LanceDB Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { toSQL } from '../util'
import * as chai from 'chai'
const expect = chai.expect
describe('toSQL', function () {
it('should turn string to SQL expression', function () {
expect(toSQL('foo')).to.equal("'foo'")
})
it('should turn number to SQL expression', function () {
expect(toSQL(123)).to.equal('123')
})
it('should turn boolean to SQL expression', function () {
expect(toSQL(true)).to.equal('TRUE')
})
it('should turn null to SQL expression', function () {
expect(toSQL(null)).to.equal('NULL')
})
it('should turn Date to SQL expression', function () {
const date = new Date('05 October 2011 14:48 UTC')
expect(toSQL(date)).to.equal("'2011-10-05T14:48:00.000Z'")
})
it('should turn array to SQL expression', function () {
expect(toSQL(['foo', 'bar', true, 1])).to.equal("['foo', 'bar', TRUE, 1]")
})
})

View File

@@ -1,44 +0,0 @@
// Copyright 2023 LanceDB Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
export type Literal = string | number | boolean | null | Date | Literal[]
export function toSQL (value: Literal): string {
if (typeof value === 'string') {
return `'${value}'`
}
if (typeof value === 'number') {
return value.toString()
}
if (typeof value === 'boolean') {
return value ? 'TRUE' : 'FALSE'
}
if (value === null) {
return 'NULL'
}
if (value instanceof Date) {
return `'${value.toISOString()}'`
}
if (Array.isArray(value)) {
return `[${value.map(toSQL).join(', ')}]`
}
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
throw new Error(`Unsupported value type: ${typeof value} value: (${value})`)
}

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 0.4.0 current_version = 0.3.4
commit = True commit = True
message = [python] Bump version: {current_version} → {new_version} message = [python] Bump version: {current_version} → {new_version}
tag = True tag = True

View File

@@ -23,7 +23,7 @@ from overrides import EnforceOverrides, override
from pyarrow import fs from pyarrow import fs
from .table import LanceTable, Table from .table import LanceTable, Table
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri from .util import fs_from_uri, get_uri_location, get_uri_scheme
if TYPE_CHECKING: if TYPE_CHECKING:
from .common import DATA, URI from .common import DATA, URI
@@ -288,13 +288,14 @@ class LanceDBConnection(DBConnection):
A list of table names. A list of table names.
""" """
try: try:
filesystem = fs_from_uri(self.uri)[0] filesystem, path = fs_from_uri(self.uri)
except pa.ArrowInvalid: except pa.ArrowInvalid:
raise NotImplementedError("Unsupported scheme: " + self.uri) raise NotImplementedError("Unsupported scheme: " + self.uri)
try: try:
loc = get_uri_location(self.uri) paths = filesystem.get_file_info(
paths = filesystem.get_file_info(fs.FileSelector(loc)) fs.FileSelector(get_uri_location(self.uri))
)
except FileNotFoundError: except FileNotFoundError:
# It is ok if the file does not exist since it will be created # It is ok if the file does not exist since it will be created
paths = [] paths = []
@@ -372,7 +373,7 @@ class LanceDBConnection(DBConnection):
""" """
try: try:
filesystem, path = fs_from_uri(self.uri) filesystem, path = fs_from_uri(self.uri)
table_path = join_uri(path, name + ".lance") table_path = os.path.join(path, name + ".lance")
filesystem.delete_dir(table_path) filesystem.delete_dir(table_path)
except FileNotFoundError: except FileNotFoundError:
if not ignore_missing: if not ignore_missing:

View File

@@ -75,14 +75,8 @@ def populate_index(index: tantivy.Index, table: LanceTable, fields: List[str]) -
The number of rows indexed The number of rows indexed
""" """
# first check the fields exist and are string or large string type # first check the fields exist and are string or large string type
nested = []
for name in fields: for name in fields:
try: f = table.schema.field(name) # raises KeyError if not found
f = table.schema.field(name) # raises KeyError if not found
except KeyError:
f = resolve_path(table.schema, name)
nested.append(name)
if not pa.types.is_string(f.type) and not pa.types.is_large_string(f.type): if not pa.types.is_string(f.type) and not pa.types.is_large_string(f.type):
raise TypeError(f"Field {name} is not a string type") raise TypeError(f"Field {name} is not a string type")
@@ -91,16 +85,7 @@ def populate_index(index: tantivy.Index, table: LanceTable, fields: List[str]) -
# write data into index # write data into index
dataset = table.to_lance() dataset = table.to_lance()
row_id = 0 row_id = 0
max_nested_level = 0
if len(nested) > 0:
max_nested_level = max([len(name.split(".")) for name in nested])
for b in dataset.to_batches(columns=fields): for b in dataset.to_batches(columns=fields):
if max_nested_level > 0:
b = pa.Table.from_batches([b])
for _ in range(max_nested_level - 1):
b = b.flatten()
for i in range(b.num_rows): for i in range(b.num_rows):
doc = tantivy.Document() doc = tantivy.Document()
doc.add_integer("doc_id", row_id) doc.add_integer("doc_id", row_id)
@@ -113,30 +98,6 @@ def populate_index(index: tantivy.Index, table: LanceTable, fields: List[str]) -
return row_id return row_id
def resolve_path(schema, field_name: str) -> pa.Field:
"""
Resolve a nested field path to a list of field names
Parameters
----------
field_name : str
The field name to resolve
Returns
-------
List[str]
The resolved path
"""
path = field_name.split(".")
field = schema.field(path.pop(0))
for segment in path:
if pa.types.is_struct(field.type):
field = field.type.field(segment)
else:
raise KeyError(f"field {field_name} not found in schema {schema}")
return field
def search_index( def search_index(
index: tantivy.Index, query: str, limit: int = 10 index: tantivy.Index, query: str, limit: int = 10
) -> Tuple[Tuple[int], Tuple[float]]: ) -> Tuple[Tuple[int], Tuple[float]]:

View File

@@ -348,20 +348,3 @@ def get_extras(field_info: pydantic.fields.FieldInfo, key: str) -> Any:
if PYDANTIC_VERSION.major >= 2: if PYDANTIC_VERSION.major >= 2:
return (field_info.json_schema_extra or {}).get(key) return (field_info.json_schema_extra or {}).get(key)
return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key) return (field_info.field_info.extra or {}).get("json_schema_extra", {}).get(key)
if PYDANTIC_VERSION.major < 2:
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
"""
Convert a Pydantic model to a dictionary.
"""
return model.dict()
else:
def model_to_dict(model: pydantic.BaseModel) -> Dict[str, Any]:
"""
Convert a Pydantic model to a dictionary.
"""
return model.model_dump()

View File

@@ -185,40 +185,14 @@ class LanceQueryBuilder(ABC):
""" """
return self.to_pandas() return self.to_pandas()
def to_pandas(self, flatten: Optional[Union[int, bool]] = None) -> "pd.DataFrame": def to_pandas(self) -> "pd.DataFrame":
""" """
Execute the query and return the results as a pandas DataFrame. Execute the query and return the results as a pandas DataFrame.
In addition to the selected columns, LanceDB also returns a vector In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query and also the "_distance" column which is the distance between the query
vector and the returned vector. vector and the returned vector.
Parameters
----------
flatten: Optional[Union[int, bool]]
If flatten is True, flatten all nested columns.
If flatten is an integer, flatten the nested columns up to the
specified depth.
If unspecified, do not flatten the nested columns.
""" """
tbl = self.to_arrow() return self.to_arrow().to_pandas()
if flatten is True:
while True:
tbl = tbl.flatten()
has_struct = False
# loop through all columns to check if there is any struct column
if any(pa.types.is_struct(col.type) for col in tbl.schema):
continue
else:
break
elif isinstance(flatten, int):
if flatten <= 0:
raise ValueError(
"Please specify a positive integer for flatten or the boolean value `True`"
)
while flatten > 0:
tbl = tbl.flatten()
flatten -= 1
return tbl.to_pandas()
@abstractmethod @abstractmethod
def to_arrow(self) -> pa.Table: def to_arrow(self) -> pa.Table:

View File

@@ -18,8 +18,6 @@ import attrs
import pyarrow as pa import pyarrow as pa
from pydantic import BaseModel from pydantic import BaseModel
from lancedb.common import VECTOR_COLUMN_NAME
__all__ = ["LanceDBClient", "VectorQuery", "VectorQueryResult"] __all__ = ["LanceDBClient", "VectorQuery", "VectorQueryResult"]
@@ -45,8 +43,6 @@ class VectorQuery(BaseModel):
refine_factor: Optional[int] = None refine_factor: Optional[int] = None
vector_column: str = VECTOR_COLUMN_NAME
@attrs.define @attrs.define
class VectorQueryResult: class VectorQueryResult:

View File

@@ -56,7 +56,7 @@ class RemoteDBConnection(DBConnection):
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
def __repr__(self) -> str: def __repr__(self) -> str:
return f"RemoteConnect(name={self.db_name})" return f"RemoveConnect(name={self.db_name})"
@override @override
def table_names( def table_names(
@@ -167,10 +167,10 @@ class RemoteDBConnection(DBConnection):
Can create with list of tuples or dictionaries: Can create with list of tuples or dictionaries:
>>> import lancedb >>> import lancedb
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP >>> db = lancedb.connect("db://test-project-8f45eb")
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7}, >>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}] ... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
>>> db.create_table("my_table", data) # doctest: +SKIP >>> db.create_table("my_table", data)
LanceTable(my_table) LanceTable(my_table)
You can also pass a pandas DataFrame: You can also pass a pandas DataFrame:
@@ -181,7 +181,7 @@ class RemoteDBConnection(DBConnection):
... "lat": [45.5, 40.1], ... "lat": [45.5, 40.1],
... "long": [-122.7, -74.1] ... "long": [-122.7, -74.1]
... }) ... })
>>> db.create_table("table2", data) # doctest: +SKIP >>> db.create_table("table2", data)
LanceTable(table2) LanceTable(table2)
>>> custom_schema = pa.schema([ >>> custom_schema = pa.schema([
@@ -189,7 +189,7 @@ class RemoteDBConnection(DBConnection):
... pa.field("lat", pa.float32()), ... pa.field("lat", pa.float32()),
... pa.field("long", pa.float32()) ... pa.field("long", pa.float32())
... ]) ... ])
>>> db.create_table("table3", data, schema = custom_schema) # doctest: +SKIP >>> db.create_table("table3", data, schema = custom_schema)
LanceTable(table3) LanceTable(table3)
It is also possible to create an table from `[Iterable[pa.RecordBatch]]`: It is also possible to create an table from `[Iterable[pa.RecordBatch]]`:
@@ -211,7 +211,7 @@ class RemoteDBConnection(DBConnection):
... pa.field("item", pa.utf8()), ... pa.field("item", pa.utf8()),
... pa.field("price", pa.float32()), ... pa.field("price", pa.float32()),
... ]) ... ])
>>> db.create_table("table4", make_batches(), schema=schema) # doctest: +SKIP >>> db.create_table("table4", make_batches(), schema=schema)
LanceTable(table4) LanceTable(table4)
""" """

View File

@@ -13,7 +13,7 @@
import uuid import uuid
from functools import cached_property from functools import cached_property
from typing import Dict, Optional, Union from typing import Optional, Union
import pyarrow as pa import pyarrow as pa
from lance import json_to_schema from lance import json_to_schema
@@ -22,7 +22,6 @@ from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
from ..query import LanceVectorQueryBuilder from ..query import LanceVectorQueryBuilder
from ..table import Query, Table, _sanitize_data from ..table import Query, Table, _sanitize_data
from ..util import value_to_sql
from .arrow import to_ipc_binary from .arrow import to_ipc_binary
from .client import ARROW_STREAM_CONTENT_TYPE from .client import ARROW_STREAM_CONTENT_TYPE
from .db import RemoteDBConnection from .db import RemoteDBConnection
@@ -86,7 +85,7 @@ class RemoteTable(Table):
>>> import lancedb >>> import lancedb
>>> import uuid >>> import uuid
>>> from lancedb.schema import vector >>> from lancedb.schema import vector
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP >>> conn = lancedb.connect("db://...", api_key="...", region="...")
>>> table_name = uuid.uuid4().hex >>> table_name = uuid.uuid4().hex
>>> schema = pa.schema( >>> schema = pa.schema(
... [ ... [
@@ -95,11 +94,11 @@ class RemoteTable(Table):
... pa.field("s", pa.string(), False), ... pa.field("s", pa.string(), False),
... ] ... ]
... ) ... )
>>> table = db.create_table( # doctest: +SKIP >>> table = conn.create_table(
... table_name, # doctest: +SKIP >>> table_name,
... schema=schema, # doctest: +SKIP >>> schema=schema,
... ) >>> )
>>> table.create_index("L2", "vector") # doctest: +SKIP >>> table.create_index("L2", "vector")
""" """
index_type = "vector" index_type = "vector"
@@ -174,22 +173,22 @@ class RemoteTable(Table):
Examples Examples
-------- --------
>>> import lancedb >>> import lancedb
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP >>> db = lancedb.connect("db://...", api_key="...", region="...")
>>> data = [ >>> data = [
... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]}, ... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]},
... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]}, ... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]},
... {"original_width": 3000, "caption": "test", "vector": [0.3, 6.2, 2.6]} ... {"original_width": 3000, "caption": "test", "vector": [0.3, 6.2, 2.6]}
... ] ... ]
>>> table = db.create_table("my_table", data) # doctest: +SKIP >>> table = db.create_table("my_table", data)
>>> query = [0.4, 1.4, 2.4] >>> query = [0.4, 1.4, 2.4]
>>> (table.search(query, vector_column_name="vector") # doctest: +SKIP >>> (table.search(query, vector_column_name="vector")
... .where("original_width > 1000", prefilter=True) # doctest: +SKIP ... .where("original_width > 1000", prefilter=True)
... .select(["caption", "original_width"]) # doctest: +SKIP ... .select(["caption", "original_width"])
... .limit(2) # doctest: +SKIP ... .limit(2)
... .to_pandas()) # doctest: +SKIP ... .to_pandas())
caption original_width vector _distance # doctest: +SKIP caption original_width vector _distance
0 foo 2000 [0.5, 3.4, 1.3] 5.220000 # doctest: +SKIP 0 foo 2000 [0.5, 3.4, 1.3] 5.220000
1 test 3000 [0.3, 6.2, 2.6] 23.089996 # doctest: +SKIP 1 test 3000 [0.3, 6.2, 2.6] 23.089996
Parameters Parameters
---------- ----------
@@ -247,92 +246,32 @@ class RemoteTable(Table):
... {"x": 2, "vector": [3, 4]}, ... {"x": 2, "vector": [3, 4]},
... {"x": 3, "vector": [5, 6]} ... {"x": 3, "vector": [5, 6]}
... ] ... ]
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP >>> db = lancedb.connect("db://...", api_key="...", region="...")
>>> table = db.create_table("my_table", data) # doctest: +SKIP >>> table = db.create_table("my_table", data)
>>> table.search([10,10]).to_pandas() # doctest: +SKIP >>> table.search([10,10]).to_pandas()
x vector _distance # doctest: +SKIP x vector _distance
0 3 [5.0, 6.0] 41.0 # doctest: +SKIP 0 3 [5.0, 6.0] 41.0
1 2 [3.0, 4.0] 85.0 # doctest: +SKIP 1 2 [3.0, 4.0] 85.0
2 1 [1.0, 2.0] 145.0 # doctest: +SKIP 2 1 [1.0, 2.0] 145.0
>>> table.delete("x = 2") # doctest: +SKIP >>> table.delete("x = 2")
>>> table.search([10,10]).to_pandas() # doctest: +SKIP >>> table.search([10,10]).to_pandas()
x vector _distance # doctest: +SKIP x vector _distance
0 3 [5.0, 6.0] 41.0 # doctest: +SKIP 0 3 [5.0, 6.0] 41.0
1 1 [1.0, 2.0] 145.0 # doctest: +SKIP 1 1 [1.0, 2.0] 145.0
If you have a list of values to delete, you can combine them into a If you have a list of values to delete, you can combine them into a
stringified list and use the `IN` operator: stringified list and use the `IN` operator:
>>> to_remove = [1, 3] # doctest: +SKIP >>> to_remove = [1, 3]
>>> to_remove = ", ".join([str(v) for v in to_remove]) # doctest: +SKIP >>> to_remove = ", ".join([str(v) for v in to_remove])
>>> table.delete(f"x IN ({to_remove})") # doctest: +SKIP >>> to_remove
>>> table.search([10,10]).to_pandas() # doctest: +SKIP '1, 3'
x vector _distance # doctest: +SKIP >>> table.delete(f"x IN ({to_remove})")
0 2 [3.0, 4.0] 85.0 # doctest: +SKIP >>> table.search([10,10]).to_pandas()
x vector _distance
0 2 [3.0, 4.0] 85.0
""" """
payload = {"predicate": predicate} payload = {"predicate": predicate}
self._conn._loop.run_until_complete( self._conn._loop.run_until_complete(
self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload) self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload)
) )
def update(
self,
where: Optional[str] = None,
values: Optional[dict] = None,
*,
values_sql: Optional[Dict[str, str]] = None,
):
"""
This can be used to update zero to all rows depending on how many
rows match the where clause.
Parameters
----------
where: str, optional
The SQL where clause to use when updating rows. For example, 'x = 2'
or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
values: dict, optional
The values to update. The keys are the column names and the values
are the values to set.
values_sql: dict, optional
The values to update, expressed as SQL expression strings. These can
reference existing columns. For example, {"x": "x + 1"} will increment
the x column by 1.
Examples
--------
>>> import lancedb
>>> data = [
... {"x": 1, "vector": [1, 2]},
... {"x": 2, "vector": [3, 4]},
... {"x": 3, "vector": [5, 6]}
... ]
>>> db = lancedb.connect("db://...", api_key="...", region="...") # doctest: +SKIP
>>> table = db.create_table("my_table", data) # doctest: +SKIP
>>> table.to_pandas() # doctest: +SKIP
x vector # doctest: +SKIP
0 1 [1.0, 2.0] # doctest: +SKIP
1 2 [3.0, 4.0] # doctest: +SKIP
2 3 [5.0, 6.0] # doctest: +SKIP
>>> table.update(where="x = 2", values={"vector": [10, 10]}) # doctest: +SKIP
>>> table.to_pandas() # doctest: +SKIP
x vector # doctest: +SKIP
0 1 [1.0, 2.0] # doctest: +SKIP
1 3 [5.0, 6.0] # doctest: +SKIP
2 2 [10.0, 10.0] # doctest: +SKIP
"""
if values is not None and values_sql is not None:
raise ValueError("Only one of values or values_sql can be provided")
if values is None and values_sql is None:
raise ValueError("Either values or values_sql must be provided")
if values is not None:
updates = [[k, value_to_sql(v)] for k, v in values.items()]
else:
updates = [[k, v] for k, v in values_sql.items()]
payload = {"predicate": where, "updates": updates}
self._conn._loop.run_until_complete(
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
)

View File

@@ -17,21 +17,20 @@ import inspect
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import cached_property from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union
import lance import lance
import numpy as np import numpy as np
import pyarrow as pa import pyarrow as pa
import pyarrow.compute as pc import pyarrow.compute as pc
import pyarrow.fs as pa_fs
from lance import LanceDataset from lance import LanceDataset
from lance.vector import vec_to_table from lance.vector import vec_to_table
from .common import DATA, VEC, VECTOR_COLUMN_NAME from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from .pydantic import LanceModel, model_to_dict from .pydantic import LanceModel
from .query import LanceQueryBuilder, Query from .query import LanceQueryBuilder, Query
from .util import fs_from_uri, safe_import_pandas, value_to_sql, join_uri from .util import fs_from_uri, safe_import_pandas
from .utils.events import register_event from .utils.events import register_event
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -54,10 +53,8 @@ def _sanitize_data(
# convert to list of dict if data is a bunch of LanceModels # convert to list of dict if data is a bunch of LanceModels
if isinstance(data[0], LanceModel): if isinstance(data[0], LanceModel):
schema = data[0].__class__.to_arrow_schema() schema = data[0].__class__.to_arrow_schema()
data = [model_to_dict(d) for d in data] data = [dict(d) for d in data]
data = pa.Table.from_pylist(data, schema=schema) data = pa.Table.from_pylist(data)
else:
data = pa.Table.from_pylist(data)
elif isinstance(data, dict): elif isinstance(data, dict):
data = vec_to_table(data) data = vec_to_table(data)
elif pd is not None and isinstance(data, pd.DataFrame): elif pd is not None and isinstance(data, pd.DataFrame):
@@ -397,6 +394,14 @@ class LanceTable(Table):
self.name = name self.name = name
self._version = version self._version = version
def _reset_dataset(self, version=None):
try:
if "_dataset" in self.__dict__:
del self.__dict__["_dataset"]
self._version = version
except AttributeError:
pass
@property @property
def schema(self) -> pa.Schema: def schema(self) -> pa.Schema:
"""Return the schema of the table. """Return the schema of the table.
@@ -405,16 +410,16 @@ class LanceTable(Table):
------- -------
pa.Schema pa.Schema
A PyArrow schema object.""" A PyArrow schema object."""
return self.to_lance().schema return self._dataset.schema
def list_versions(self): def list_versions(self):
"""List all versions of the table""" """List all versions of the table"""
return self.to_lance().versions() return self._dataset.versions()
@property @property
def version(self) -> int: def version(self) -> int:
"""Get the current version of the table""" """Get the current version of the table"""
return self.to_lance().version return self._dataset.version
def checkout(self, version: int): def checkout(self, version: int):
"""Checkout a version of the table. This is an in-place operation. """Checkout a version of the table. This is an in-place operation.
@@ -447,12 +452,14 @@ class LanceTable(Table):
vector type vector type
0 [1.1, 0.9] vector 0 [1.1, 0.9] vector
""" """
max_ver = max([v["version"] for v in self.to_lance().versions()]) max_ver = max([v["version"] for v in self._dataset.versions()])
if version < 1 or version > max_ver: if version < 1 or version > max_ver:
raise ValueError(f"Invalid version {version}") raise ValueError(f"Invalid version {version}")
self._reset_dataset(version=version)
try: try:
self.to_lance().checkout(version) # Accessing the property updates the cached value
_ = self._dataset
except Exception as e: except Exception as e:
if "not found" in str(e): if "not found" in str(e):
raise ValueError( raise ValueError(
@@ -495,7 +502,7 @@ class LanceTable(Table):
>>> len(table.list_versions()) >>> len(table.list_versions())
4 4
""" """
max_ver = max([v["version"] for v in self.to_lance().versions()]) max_ver = max([v["version"] for v in self._dataset.versions()])
if version is None: if version is None:
version = self.version version = self.version
elif version < 1 or version > max_ver: elif version < 1 or version > max_ver:
@@ -507,10 +514,11 @@ class LanceTable(Table):
# no-op if restoring the latest version # no-op if restoring the latest version
return return
self.to_lance().restore() self._dataset.restore()
self._reset_dataset()
def __len__(self): def __len__(self):
return self.to_lance().count_rows() return self._dataset.count_rows()
def __repr__(self) -> str: def __repr__(self) -> str:
return f"LanceTable({self.name})" return f"LanceTable({self.name})"
@@ -520,7 +528,7 @@ class LanceTable(Table):
def head(self, n=5) -> pa.Table: def head(self, n=5) -> pa.Table:
"""Return the first n rows of the table.""" """Return the first n rows of the table."""
return self.to_lance().head(n) return self._dataset.head(n)
def to_pandas(self) -> "pd.DataFrame": def to_pandas(self) -> "pd.DataFrame":
"""Return the table as a pandas DataFrame. """Return the table as a pandas DataFrame.
@@ -537,11 +545,11 @@ class LanceTable(Table):
Returns Returns
------- -------
pa.Table""" pa.Table"""
return self.to_lance().to_table() return self._dataset.to_table()
@property @property
def _dataset_uri(self) -> str: def _dataset_uri(self) -> str:
return join_uri(self._conn.uri, f"{self.name}.lance") return os.path.join(self._conn.uri, f"{self.name}.lance")
def create_index( def create_index(
self, self,
@@ -564,11 +572,10 @@ class LanceTable(Table):
accelerator=accelerator, accelerator=accelerator,
index_cache_size=index_cache_size, index_cache_size=index_cache_size,
) )
self._reset_dataset()
register_event("create_index") register_event("create_index")
def create_fts_index( def create_fts_index(self, field_names: Union[str, List[str]]):
self, field_names: Union[str, List[str]], *, replace: bool = False
):
"""Create a full-text search index on the table. """Create a full-text search index on the table.
Warning - this API is highly experimental and is highly likely to change Warning - this API is highly experimental and is highly likely to change
@@ -578,35 +585,17 @@ class LanceTable(Table):
---------- ----------
field_names: str or list of str field_names: str or list of str
The name(s) of the field to index. The name(s) of the field to index.
replace: bool, default False
If True, replace the existing index if it exists. Note that this is
not yet an atomic operation; the index will be temporarily
unavailable while the new index is being created.
""" """
from .fts import create_index, populate_index from .fts import create_index, populate_index
if isinstance(field_names, str): if isinstance(field_names, str):
field_names = [field_names] field_names = [field_names]
fs, path = fs_from_uri(self._get_fts_index_path())
index_exists = fs.get_file_info(path).type != pa_fs.FileType.NotFound
if index_exists:
if not replace:
raise ValueError(
f"Index already exists. Use replace=True to overwrite."
)
try:
fs.delete_dir(path)
except FileNotFoundError as e:
if "Cannot get information for path" in str(e):
pass
index = create_index(self._get_fts_index_path(), field_names) index = create_index(self._get_fts_index_path(), field_names)
populate_index(index, self, field_names) populate_index(index, self, field_names)
register_event("create_fts_index") register_event("create_fts_index")
def _get_fts_index_path(self): def _get_fts_index_path(self):
return join_uri(self._dataset_uri, "_indices", "tantivy") return os.path.join(self._dataset_uri, "_indices", "tantivy")
@cached_property @cached_property
def _dataset(self) -> LanceDataset: def _dataset(self) -> LanceDataset:
@@ -654,7 +643,8 @@ class LanceTable(Table):
on_bad_vectors=on_bad_vectors, on_bad_vectors=on_bad_vectors,
fill_value=fill_value, fill_value=fill_value,
) )
self.to_lance().write(data, mode=mode) lance.write_dataset(data, self._dataset_uri, schema=self.schema, mode=mode)
self._reset_dataset()
register_event("add") register_event("add")
def merge( def merge(
@@ -715,9 +705,10 @@ class LanceTable(Table):
other_table = other_table.to_lance() other_table = other_table.to_lance()
if isinstance(other_table, LanceDataset): if isinstance(other_table, LanceDataset):
other_table = other_table.to_table() other_table = other_table.to_table()
self.to_lance().merge( self._dataset.merge(
other_table, left_on=left_on, right_on=right_on, schema=schema other_table, left_on=left_on, right_on=right_on, schema=schema
) )
self._reset_dataset()
register_event("merge") register_event("merge")
@cached_property @cached_property
@@ -794,7 +785,7 @@ class LanceTable(Table):
and also the "_distance" column which is the distance between the query and also the "_distance" column which is the distance between the query
vector and the returned vector. vector and the returned vector.
""" """
register_event("search_table") register_event("search")
return LanceQueryBuilder.create( return LanceQueryBuilder.create(
self, query, query_type, vector_column_name=vector_column_name self, query, query_type, vector_column_name=vector_column_name
) )
@@ -915,42 +906,35 @@ class LanceTable(Table):
f"Table {name} does not exist." f"Table {name} does not exist."
f"Please first call db.create_table({name}, data)" f"Please first call db.create_table({name}, data)"
) )
register_event("open_table")
return tbl return tbl
def delete(self, where: str): def delete(self, where: str):
self.to_lance().delete(where) self._dataset.delete(where)
def update( def update(self, where: str, values: dict):
self,
where: Optional[str] = None,
values: Optional[dict] = None,
*,
values_sql: Optional[Dict[str, str]] = None,
):
""" """
EXPERIMENTAL: Update rows in the table (not threadsafe).
This can be used to update zero to all rows depending on how many This can be used to update zero to all rows depending on how many
rows match the where clause. rows match the where clause.
Parameters Parameters
---------- ----------
where: str, optional where: str
The SQL where clause to use when updating rows. For example, 'x = 2' The SQL where clause to use when updating rows. For example, 'x = 2'
or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error. or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
values: dict, optional values: dict
The values to update. The keys are the column names and the values The values to update. The keys are the column names and the values
are the values to set. are the values to set.
values_sql: dict, optional
The values to update, expressed as SQL expression strings. These can
reference existing columns. For example, {"x": "x + 1"} will increment
the x column by 1.
Examples Examples
-------- --------
>>> import lancedb >>> import lancedb
>>> import pandas as pd >>> data = [
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]}) ... {"x": 1, "vector": [1, 2]},
... {"x": 2, "vector": [3, 4]},
... {"x": 3, "vector": [5, 6]}
... ]
>>> db = lancedb.connect("./.lancedb") >>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data) >>> table = db.create_table("my_table", data)
>>> table.to_pandas() >>> table.to_pandas()
@@ -966,15 +950,19 @@ class LanceTable(Table):
2 2 [10.0, 10.0] 2 2 [10.0, 10.0]
""" """
if values is not None and values_sql is not None: orig_data = self._dataset.to_table(filter=where).combine_chunks()
raise ValueError("Only one of values or values_sql can be provided") if len(orig_data) == 0:
if values is None and values_sql is None: return
raise ValueError("Either values or values_sql must be provided") for col, val in values.items():
i = orig_data.column_names.index(col)
if values is not None: if i < 0:
values_sql = {k: value_to_sql(v) for k, v in values.items()} raise ValueError(f"Column {col} does not exist")
orig_data = orig_data.set_column(
self.to_lance().update(values_sql, where) i, col, pa.array([val] * len(orig_data), type=orig_data[col].type)
)
self.delete(where)
self.add(orig_data, mode="append")
self._reset_dataset()
register_event("update") register_event("update")
def _execute_query(self, query: Query) -> pa.Table: def _execute_query(self, query: Query) -> pa.Table:

View File

@@ -12,13 +12,9 @@
# limitations under the License. # limitations under the License.
import os import os
from datetime import date, datetime from typing import Tuple
from functools import singledispatch
import pathlib
from typing import Tuple, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import numpy as np
import pyarrow.fs as pa_fs import pyarrow.fs as pa_fs
@@ -63,12 +59,6 @@ def get_uri_location(uri: str) -> str:
str: Location part of the URL, without scheme str: Location part of the URL, without scheme
""" """
parsed = urlparse(uri) parsed = urlparse(uri)
if len(parsed.scheme) == 1:
# Windows drive names are parsed as the scheme
# e.g. "c:\path" -> ParseResult(scheme="c", netloc="", path="/path", ...)
# So we add special handling here for schemes that are a single character
return uri
if not parsed.netloc: if not parsed.netloc:
return parsed.path return parsed.path
else: else:
@@ -91,29 +81,6 @@ def fs_from_uri(uri: str) -> Tuple[pa_fs.FileSystem, str]:
return pa_fs.FileSystem.from_uri(uri) return pa_fs.FileSystem.from_uri(uri)
def join_uri(base: Union[str, pathlib.Path], *parts: str) -> str:
"""
Join a URI with multiple parts, handles both local and remote paths
Parameters
----------
base : str
The base URI
parts : str
The parts to join to the base URI, each separated by the
appropriate path separator for the URI scheme and OS
"""
if isinstance(base, pathlib.Path):
return base.joinpath(*parts)
base = str(base)
if get_uri_scheme(base) == "file":
# using pathlib for local paths make this windows compatible
# `get_uri_scheme` returns `file` for windows drive names (e.g. `c:\path`)
return str(pathlib.Path(base, *parts))
# for remote paths, just use os.path.join
return "/".join([p.rstrip("/") for p in [base, *parts]])
def safe_import_pandas(): def safe_import_pandas():
try: try:
import pandas as pd import pandas as pd
@@ -121,53 +88,3 @@ def safe_import_pandas():
return pd return pd
except ImportError: except ImportError:
return None return None
@singledispatch
def value_to_sql(value):
raise NotImplementedError("SQL conversion is not implemented for this type")
@value_to_sql.register(str)
def _(value: str):
return f"'{value}'"
@value_to_sql.register(int)
def _(value: int):
return str(value)
@value_to_sql.register(float)
def _(value: float):
return str(value)
@value_to_sql.register(bool)
def _(value: bool):
return str(value).upper()
@value_to_sql.register(type(None))
def _(value: type(None)):
return "NULL"
@value_to_sql.register(datetime)
def _(value: datetime):
return f"'{value.isoformat()}'"
@value_to_sql.register(date)
def _(value: date):
return f"'{value.isoformat()}'"
@value_to_sql.register(list)
def _(value: list):
return "[" + ", ".join(map(value_to_sql, value)) + "]"
@value_to_sql.register(np.ndarray)
def _(value: np.ndarray):
return value_to_sql(value.tolist())

View File

@@ -64,10 +64,8 @@ class _Events:
Initializes the Events object with default values for events, rate_limit, and metadata. Initializes the Events object with default values for events, rate_limit, and metadata.
""" """
self.events = [] # events list self.events = [] # events list
self.throttled_event_names = ["search_table"] self.max_events = 25 # max events to store in memory
self.throttled_events = set() self.rate_limit = 60.0 # rate limit (seconds)
self.max_events = 5 # max events to store in memory
self.rate_limit = 60.0 * 5 # rate limit (seconds)
self.time = 0.0 self.time = 0.0
if is_git_dir(): if is_git_dir():
@@ -114,21 +112,18 @@ class _Events:
return return
if ( if (
len(self.events) < self.max_events len(self.events) < self.max_events
): # Events list limited to self.max_events (drop any events past this) ): # Events list limited to 25 events (drop any events past this)
params.update(self.metadata) params.update(self.metadata)
event = { self.events.append(
"event": event_name, {
"properties": params, "event": event_name,
"timestamp": datetime.datetime.now( "properties": params,
tz=datetime.timezone.utc "timestamp": datetime.datetime.now(
).isoformat(), tz=datetime.timezone.utc
"distinct_id": CONFIG["uuid"], ).isoformat(),
} "distinct_id": CONFIG["uuid"],
if event_name not in self.throttled_event_names: }
self.events.append(event) )
elif event_name not in self.throttled_events:
self.throttled_events.add(event_name)
self.events.append(event)
# Check rate limit # Check rate limit
t = time.time() t = time.time()
@@ -140,6 +135,7 @@ class _Events:
"distinct_id": CONFIG["uuid"], # posthog needs this to accepts the event "distinct_id": CONFIG["uuid"], # posthog needs this to accepts the event
"batch": self.events, "batch": self.events,
} }
# POST equivalent to requests.post(self.url, json=data). # POST equivalent to requests.post(self.url, json=data).
# threaded request is used to avoid blocking, retries are disabled, and verbose is disabled # threaded request is used to avoid blocking, retries are disabled, and verbose is disabled
# to avoid any possible disruption in the console. # to avoid any possible disruption in the console.
@@ -154,7 +150,6 @@ class _Events:
# Flush & Reset # Flush & Reset
self.events = [] self.events = []
self.throttled_events = set()
self.time = t self.time = t

View File

@@ -1,12 +1,12 @@
[project] [project]
name = "lancedb" name = "lancedb"
version = "0.4.0" version = "0.3.4"
dependencies = [ dependencies = [
"deprecation", "deprecation",
"pylance==0.9.1", "pylance==0.8.17",
"ratelimiter~=1.0", "ratelimiter~=1.0",
"retry>=0.9.2", "retry>=0.9.2",
"tqdm>=4.27.0", "tqdm>=4.1.0",
"aiohttp", "aiohttp",
"pydantic>=1.10", "pydantic>=1.10",
"attrs>=21.3.0", "attrs>=21.3.0",

View File

@@ -43,15 +43,7 @@ def table(tmp_path) -> ldb.table.LanceTable:
for _ in range(100) for _ in range(100)
] ]
table = db.create_table( table = db.create_table(
"test", "test", data=pd.DataFrame({"vector": vectors, "text": text, "text2": text})
data=pd.DataFrame(
{
"vector": vectors,
"text": text,
"text2": text,
"nested": [{"text": t} for t in text],
}
),
) )
return table return table
@@ -83,24 +75,6 @@ def test_create_index_from_table(tmp_path, table):
assert len(df) == 10 assert len(df) == 10
assert "text" in df.columns assert "text" in df.columns
# Check whether it can be updated
table.add(
[
{
"vector": np.random.randn(128),
"text": "gorilla",
"text2": "gorilla",
"nested": {"text": "gorilla"},
}
]
)
table.create_fts_index("text", replace=True)
assert len(table.search("gorilla").limit(1).to_pandas()) == 1
with pytest.raises(ValueError, match="already exists"):
table.create_fts_index("text")
def test_create_index_multiple_columns(tmp_path, table): def test_create_index_multiple_columns(tmp_path, table):
table.create_fts_index(["text", "text2"]) table.create_fts_index(["text", "text2"])
@@ -115,9 +89,3 @@ def test_empty_rs(tmp_path, table, mocker):
mocker.patch("lancedb.fts.search_index", return_value=([], [])) mocker.patch("lancedb.fts.search_index", return_value=([], []))
df = table.search("puppy").limit(10).to_pandas() df = table.search("puppy").limit(10).to_pandas()
assert len(df) == 0 assert len(df) == 0
def test_nested_schema(tmp_path, table):
table.create_fts_index("nested.text")
rs = table.search("puppy").limit(10).to_list()
assert len(rs) == 10

View File

@@ -12,7 +12,7 @@
# limitations under the License. # limitations under the License.
import functools import functools
from datetime import date, datetime, timedelta from datetime import timedelta
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from unittest.mock import PropertyMock, patch from unittest.mock import PropertyMock, patch
@@ -22,7 +22,6 @@ import numpy as np
import pandas as pd import pandas as pd
import pyarrow as pa import pyarrow as pa
import pytest import pytest
from pydantic import BaseModel
from lancedb.conftest import MockTextEmbeddingFunction from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.db import LanceDBConnection from lancedb.db import LanceDBConnection
@@ -142,44 +141,14 @@ def test_add(db):
def test_add_pydantic_model(db): def test_add_pydantic_model(db):
# https://github.com/lancedb/lancedb/issues/562 class TestModel(LanceModel):
vector: Vector(16)
class Metadata(BaseModel):
source: str
timestamp: datetime
class Document(BaseModel):
content: str
meta: Metadata
class LanceSchema(LanceModel):
id: str
vector: Vector(2)
li: List[int] li: List[int]
payload: Document
tbl = LanceTable.create(db, "mytable", schema=LanceSchema, mode="overwrite") data = TestModel(vector=list(range(16)), li=[1, 2, 3])
assert tbl.schema == LanceSchema.to_arrow_schema() table = LanceTable.create(db, "test", data=[data])
assert len(table) == 1
# add works assert table.schema == TestModel.to_arrow_schema()
expected = LanceSchema(
id="id",
vector=[0.0, 0.0],
li=[1, 2, 3],
payload=Document(
content="foo", meta=Metadata(source="bar", timestamp=datetime.now())
),
)
tbl.add([expected])
result = tbl.search([0.0, 0.0]).limit(1).to_pydantic(LanceSchema)[0]
assert result == expected
flattened = tbl.search([0.0, 0.0]).limit(1).to_pandas(flatten=1)
assert len(flattened.columns) == 6 # _distance is automatically added
really_flattened = tbl.search([0.0, 0.0]).limit(1).to_pandas(flatten=True)
assert len(really_flattened.columns) == 7
def _add(table, schema): def _add(table, schema):
@@ -226,38 +195,39 @@ def test_versioning(db):
def test_create_index_method(): def test_create_index_method():
with patch.object( with patch.object(LanceTable, "_reset_dataset", return_value=None):
LanceTable, "_dataset", new_callable=PropertyMock with patch.object(
) as mock_dataset: LanceTable, "_dataset", new_callable=PropertyMock
# Setup mock responses ) as mock_dataset:
mock_dataset.return_value.create_index.return_value = None # Setup mock responses
mock_dataset.return_value.create_index.return_value = None
# Create a LanceTable object # Create a LanceTable object
connection = LanceDBConnection(uri="mock.uri") connection = LanceDBConnection(uri="mock.uri")
table = LanceTable(connection, "test_table") table = LanceTable(connection, "test_table")
# Call the create_index method # Call the create_index method
table.create_index( table.create_index(
metric="L2", metric="L2",
num_partitions=256, num_partitions=256,
num_sub_vectors=96, num_sub_vectors=96,
vector_column_name="vector", vector_column_name="vector",
replace=True, replace=True,
index_cache_size=256, index_cache_size=256,
) )
# Check that the _dataset.create_index method was called # Check that the _dataset.create_index method was called
# with the right parameters # with the right parameters
mock_dataset.return_value.create_index.assert_called_once_with( mock_dataset.return_value.create_index.assert_called_once_with(
column="vector", column="vector",
index_type="IVF_PQ", index_type="IVF_PQ",
metric="L2", metric="L2",
num_partitions=256, num_partitions=256,
num_sub_vectors=96, num_sub_vectors=96,
replace=True, replace=True,
accelerator=None, accelerator=None,
index_cache_size=256, index_cache_size=256,
) )
def test_add_with_nans(db): def test_add_with_nans(db):
@@ -378,79 +348,14 @@ def test_update(db):
assert len(table) == 2 assert len(table) == 2
assert len(table.list_versions()) == 2 assert len(table.list_versions()) == 2
table.update(where="id=0", values={"vector": [1.1, 1.1]}) table.update(where="id=0", values={"vector": [1.1, 1.1]})
assert len(table.list_versions()) == 3 assert len(table.list_versions()) == 4
assert table.version == 3 assert table.version == 4
assert len(table) == 2 assert len(table) == 2
v = table.to_arrow()["vector"].combine_chunks() v = table.to_arrow()["vector"].combine_chunks()
v = v.values.to_numpy().reshape(2, 2) v = v.values.to_numpy().reshape(2, 2)
assert np.allclose(v, np.array([[1.2, 1.9], [1.1, 1.1]])) assert np.allclose(v, np.array([[1.2, 1.9], [1.1, 1.1]]))
def test_update_types(db):
table = LanceTable.create(
db,
"my_table",
data=[
{
"id": 0,
"str": "foo",
"float": 1.1,
"timestamp": datetime(2021, 1, 1),
"date": date(2021, 1, 1),
"vector1": [1.0, 0.0],
"vector2": [1.0, 1.0],
}
],
)
# Update with SQL
table.update(
values_sql=dict(
id="1",
str="'bar'",
float="2.2",
timestamp="TIMESTAMP '2021-01-02 00:00:00'",
date="DATE '2021-01-02'",
vector1="[2.0, 2.0]",
vector2="[3.0, 3.0]",
)
)
actual = table.to_arrow().to_pylist()[0]
expected = dict(
id=1,
str="bar",
float=2.2,
timestamp=datetime(2021, 1, 2),
date=date(2021, 1, 2),
vector1=[2.0, 2.0],
vector2=[3.0, 3.0],
)
assert actual == expected
# Update with values
table.update(
values=dict(
id=2,
str="baz",
float=3.3,
timestamp=datetime(2021, 1, 3),
date=date(2021, 1, 3),
vector1=[3.0, 3.0],
vector2=np.array([4.0, 4.0]),
)
)
actual = table.to_arrow().to_pylist()[0]
expected = dict(
id=2,
str="baz",
float=3.3,
timestamp=datetime(2021, 1, 3),
date=date(2021, 1, 3),
vector1=[3.0, 3.0],
vector2=[4.0, 4.0],
)
assert actual == expected
def test_create_with_embedding_function(db): def test_create_with_embedding_function(db):
class MyTable(LanceModel): class MyTable(LanceModel):
text: str text: str

View File

@@ -11,12 +11,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os from lancedb.util import get_uri_scheme
import pathlib
import pytest
from lancedb.util import get_uri_scheme, join_uri
def test_normalize_uri(): def test_normalize_uri():
@@ -33,55 +28,3 @@ def test_normalize_uri():
for uri, expected_scheme in zip(uris, schemes): for uri, expected_scheme in zip(uris, schemes):
parsed_scheme = get_uri_scheme(uri) parsed_scheme = get_uri_scheme(uri)
assert parsed_scheme == expected_scheme assert parsed_scheme == expected_scheme
def test_join_uri_remote():
schemes = ["s3", "az", "gs"]
for scheme in schemes:
expected = f"{scheme}://bucket/path/to/table.lance"
base_uri = f"{scheme}://bucket/path/to/"
parts = ["table.lance"]
assert join_uri(base_uri, *parts) == expected
base_uri = f"{scheme}://bucket"
parts = ["path", "to", "table.lance"]
assert join_uri(base_uri, *parts) == expected
# skip this test if on windows
@pytest.mark.skipif(os.name == "nt", reason="Windows paths are not POSIX")
def test_join_uri_posix():
for base in [
# relative path
"relative/path",
"relative/path/",
# an absolute path
"/absolute/path",
"/absolute/path/",
# a file URI
"file:///absolute/path",
"file:///absolute/path/",
]:
joined = join_uri(base, "table.lance")
assert joined == str(pathlib.Path(base) / "table.lance")
joined = join_uri(pathlib.Path(base), "table.lance")
assert joined == pathlib.Path(base) / "table.lance"
# skip this test if not on windows
@pytest.mark.skipif(os.name != "nt", reason="Windows paths are not POSIX")
def test_local_join_uri_windows():
# https://learn.microsoft.com/en-us/dotnet/standard/io/file-path-formats
for base in [
# windows relative path
"relative\\path",
"relative\\path\\",
# windows absolute path from current drive
"c:\\absolute\\path",
# relative path from root of current drive
"\\relative\\path",
]:
joined = join_uri(base, "table.lance")
assert joined == str(pathlib.Path(base) / "table.lance")
joined = join_uri(pathlib.Path(base), "table.lance")
assert joined == pathlib.Path(base) / "table.lance"

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "vectordb-node" name = "vectordb-node"
version = "0.4.0" version = "0.3.9"
description = "Serverless, low-latency vector database for AI applications" description = "Serverless, low-latency vector database for AI applications"
license = "Apache-2.0" license = "Apache-2.0"
edition = "2018" edition = "2018"

View File

@@ -23,7 +23,7 @@ pub enum Error {
#[snafu(display("column '{name}' is missing"))] #[snafu(display("column '{name}' is missing"))]
MissingColumn { name: String }, MissingColumn { name: String },
#[snafu(display("{name}: {message}"))] #[snafu(display("{name}: {message}"))]
OutOfRange { name: String, message: String }, RangeError { name: String, message: String },
#[snafu(display("{index_type} is not a valid index type"))] #[snafu(display("{index_type} is not a valid index type"))]
InvalidIndexType { index_type: String }, InvalidIndexType { index_type: String },

View File

@@ -65,10 +65,12 @@ fn get_index_params_builder(
obj.get_opt::<JsString, _, _>(cx, "index_name")? obj.get_opt::<JsString, _, _>(cx, "index_name")?
.map(|s| index_builder.index_name(s.value(cx))); .map(|s| index_builder.index_name(s.value(cx)));
if let Some(metric_type) = obj.get_opt::<JsString, _, _>(cx, "metric_type")? { obj.get_opt::<JsString, _, _>(cx, "metric_type")?
let metric_type = MetricType::try_from(metric_type.value(cx).as_str()).unwrap(); .map(|s| MetricType::try_from(s.value(cx).as_str()))
index_builder.metric_type(metric_type); .map(|mt| {
} let metric_type = mt.unwrap();
index_builder.metric_type(metric_type);
});
let num_partitions = obj.get_opt_usize(cx, "num_partitions")?; let num_partitions = obj.get_opt_usize(cx, "num_partitions")?;
let max_iters = obj.get_opt_usize(cx, "max_iters")?; let max_iters = obj.get_opt_usize(cx, "max_iters")?;
@@ -83,29 +85,23 @@ fn get_index_params_builder(
index_builder.ivf_params(ivf_params) index_builder.ivf_params(ivf_params)
}); });
if let Some(use_opq) = obj.get_opt::<JsBoolean, _, _>(cx, "use_opq")? { obj.get_opt::<JsBoolean, _, _>(cx, "use_opq")?
pq_params.use_opq = use_opq.value(cx); .map(|s| pq_params.use_opq = s.value(cx));
}
if let Some(num_sub_vectors) = obj.get_opt_usize(cx, "num_sub_vectors")? { obj.get_opt_usize(cx, "num_sub_vectors")?
pq_params.num_sub_vectors = num_sub_vectors; .map(|s| pq_params.num_sub_vectors = s);
}
if let Some(num_bits) = obj.get_opt_usize(cx, "num_bits")? { obj.get_opt_usize(cx, "num_bits")?
pq_params.num_bits = num_bits; .map(|s| pq_params.num_bits = s);
}
if let Some(max_iters) = obj.get_opt_usize(cx, "max_iters")? { obj.get_opt_usize(cx, "max_iters")?
pq_params.max_iters = max_iters; .map(|s| pq_params.max_iters = s);
}
if let Some(max_opq_iters) = obj.get_opt_usize(cx, "max_opq_iters")? { obj.get_opt_usize(cx, "max_opq_iters")?
pq_params.max_opq_iters = max_opq_iters; .map(|s| pq_params.max_opq_iters = s);
}
if let Some(replace) = obj.get_opt::<JsBoolean, _, _>(cx, "replace")? { obj.get_opt::<JsBoolean, _, _>(cx, "replace")?
index_builder.replace(replace.value(cx)); .map(|s| index_builder.replace(s.value(cx)));
}
Ok(index_builder) Ok(index_builder)
} }

View File

@@ -237,7 +237,6 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> {
cx.export_function("tableAdd", JsTable::js_add)?; cx.export_function("tableAdd", JsTable::js_add)?;
cx.export_function("tableCountRows", JsTable::js_count_rows)?; cx.export_function("tableCountRows", JsTable::js_count_rows)?;
cx.export_function("tableDelete", JsTable::js_delete)?; cx.export_function("tableDelete", JsTable::js_delete)?;
cx.export_function("tableUpdate", JsTable::js_update)?;
cx.export_function("tableCleanupOldVersions", JsTable::js_cleanup)?; cx.export_function("tableCleanupOldVersions", JsTable::js_cleanup)?;
cx.export_function("tableCompactFiles", JsTable::js_compact)?; cx.export_function("tableCompactFiles", JsTable::js_compact)?;
cx.export_function("tableListIndices", JsTable::js_list_indices)?; cx.export_function("tableListIndices", JsTable::js_list_indices)?;

View File

@@ -47,15 +47,15 @@ fn f64_to_u32_safe(n: f64, key: &str) -> Result<u32> {
use conv::*; use conv::*;
n.approx_as::<u32>().map_err(|e| match e { n.approx_as::<u32>().map_err(|e| match e {
FloatError::NegOverflow(_) => Error::OutOfRange { FloatError::NegOverflow(_) => Error::RangeError {
name: key.into(), name: key.into(),
message: "must be > 0".to_string(), message: "must be > 0".to_string(),
}, },
FloatError::PosOverflow(_) => Error::OutOfRange { FloatError::PosOverflow(_) => Error::RangeError {
name: key.into(), name: key.into(),
message: format!("must be < {}", u32::MAX), message: format!("must be < {}", u32::MAX),
}, },
FloatError::NotANumber(_) => Error::OutOfRange { FloatError::NotANumber(_) => Error::RangeError {
name: key.into(), name: key.into(),
message: "not a valid number".to_string(), message: "not a valid number".to_string(),
}, },
@@ -66,15 +66,15 @@ fn f64_to_usize_safe(n: f64, key: &str) -> Result<usize> {
use conv::*; use conv::*;
n.approx_as::<usize>().map_err(|e| match e { n.approx_as::<usize>().map_err(|e| match e {
FloatError::NegOverflow(_) => Error::OutOfRange { FloatError::NegOverflow(_) => Error::RangeError {
name: key.into(), name: key.into(),
message: "must be > 0".to_string(), message: "must be > 0".to_string(),
}, },
FloatError::PosOverflow(_) => Error::OutOfRange { FloatError::PosOverflow(_) => Error::RangeError {
name: key.into(), name: key.into(),
message: format!("must be < {}", usize::MAX), message: format!("must be < {}", usize::MAX),
}, },
FloatError::NotANumber(_) => Error::OutOfRange { FloatError::NotANumber(_) => Error::RangeError {
name: key.into(), name: key.into(),
message: "not a valid number".to_string(), message: "not a valid number".to_string(),
}, },

View File

@@ -23,14 +23,8 @@ impl JsQuery {
let query_obj = cx.argument::<JsObject>(0)?; let query_obj = cx.argument::<JsObject>(0)?;
let limit = query_obj let limit = query_obj
.get_opt::<JsNumber, _, _>(&mut cx, "_limit")? .get::<JsNumber, _, _>(&mut cx, "_limit")?
.map(|value| { .value(&mut cx);
let limit = value.value(&mut cx);
if limit <= 0.0 {
panic!("Limit must be a positive integer");
}
limit as u64
});
let select = query_obj let select = query_obj
.get_opt::<JsArray, _, _>(&mut cx, "_select")? .get_opt::<JsArray, _, _>(&mut cx, "_select")?
.map(|arr| { .map(|arr| {
@@ -54,9 +48,7 @@ impl JsQuery {
.map(|s| s.value(&mut cx)) .map(|s| s.value(&mut cx))
.map(|s| MetricType::try_from(s.as_str()).unwrap()); .map(|s| MetricType::try_from(s.as_str()).unwrap());
let prefilter = query_obj let prefilter = query_obj.get::<JsBoolean, _, _>(&mut cx, "_prefilter")?.value(&mut cx);
.get::<JsBoolean, _, _>(&mut cx, "_prefilter")?
.value(&mut cx);
let is_electron = cx let is_electron = cx
.argument::<JsBoolean>(1) .argument::<JsBoolean>(1)
@@ -67,23 +59,20 @@ impl JsQuery {
let (deferred, promise) = cx.promise(); let (deferred, promise) = cx.promise();
let channel = cx.channel(); let channel = cx.channel();
let query_vector = query_obj.get_opt::<JsArray, _, _>(&mut cx, "_queryVector")?; let query_vector = query_obj.get::<JsArray, _, _>(&mut cx, "_queryVector")?;
let query = convert::js_array_to_vec(query_vector.deref(), &mut cx);
let table = js_table.table.clone(); let table = js_table.table.clone();
let query = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx));
rt.spawn(async move { rt.spawn(async move {
let mut builder = table let builder = table
.search(query.map(Float32Array::from)) .search(Float32Array::from(query))
.limit(limit as usize)
.refine_factor(refine_factor) .refine_factor(refine_factor)
.nprobes(nprobes) .nprobes(nprobes)
.filter(filter) .filter(filter)
.metric_type(metric_type) .metric_type(metric_type)
.select(select) .select(select)
.prefilter(prefilter); .prefilter(prefilter);
if let Some(limit) = limit {
builder = builder.limit(limit as usize);
};
let record_batch_stream = builder.execute(); let record_batch_stream = builder.execute();
let results = record_batch_stream let results = record_batch_stream
.and_then(|stream| { .and_then(|stream| {

View File

@@ -45,7 +45,7 @@ impl JsTable {
let table_name = cx.argument::<JsString>(0)?.value(&mut cx); let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
let buffer = cx.argument::<JsBuffer>(1)?; let buffer = cx.argument::<JsBuffer>(1)?;
let (batches, schema) = let (batches, schema) =
arrow_buffer_to_record_batch(buffer.as_slice(&cx)).or_throw(&mut cx)?; arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?;
// Write mode // Write mode
let mode = match cx.argument::<JsString>(2)?.value(&mut cx).as_str() { let mode = match cx.argument::<JsString>(2)?.value(&mut cx).as_str() {
@@ -93,7 +93,7 @@ impl JsTable {
let buffer = cx.argument::<JsBuffer>(0)?; let buffer = cx.argument::<JsBuffer>(0)?;
let write_mode = cx.argument::<JsString>(1)?.value(&mut cx); let write_mode = cx.argument::<JsString>(1)?.value(&mut cx);
let (batches, schema) = let (batches, schema) =
arrow_buffer_to_record_batch(buffer.as_slice(&cx)).or_throw(&mut cx)?; arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?;
let rt = runtime(&mut cx)?; let rt = runtime(&mut cx)?;
let channel = cx.channel(); let channel = cx.channel();
let mut table = js_table.table.clone(); let mut table = js_table.table.clone();
@@ -165,69 +165,6 @@ impl JsTable {
Ok(promise) Ok(promise)
} }
pub(crate) fn js_update(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let mut table = js_table.table.clone();
let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise();
let channel = cx.channel();
// create a vector of updates from the passed map
let updates_arg = cx.argument::<JsObject>(1)?;
let properties = updates_arg.get_own_property_names(&mut cx)?;
let mut updates: Vec<(String, String)> =
Vec::with_capacity(properties.len(&mut cx) as usize);
let len_properties = properties.len(&mut cx);
for i in 0..len_properties {
let property = properties
.get_value(&mut cx, i)?
.downcast_or_throw::<JsString, _>(&mut cx)?;
let value = updates_arg
.get_value(&mut cx, property)?
.downcast_or_throw::<JsString, _>(&mut cx)?;
let property = property.value(&mut cx);
let value = value.value(&mut cx);
updates.push((property, value));
}
// get the filter/predicate if the user passed one
let predicate = cx.argument_opt(0);
let predicate = predicate.unwrap().downcast::<JsString, _>(&mut cx);
let predicate = match predicate {
Ok(_) => {
let val = predicate.map(|s| s.value(&mut cx)).unwrap();
Some(val)
}
Err(_) => {
// if the predicate is not string, check it's null otherwise an invalid
// type was passed
cx.argument::<JsNull>(0)?;
None
}
};
rt.spawn(async move {
let updates_arg = updates
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect::<Vec<_>>();
let predicate = predicate.as_deref();
let update_result = table.update(predicate, updates_arg).await;
deferred.settle_with(&channel, move |mut cx| {
update_result.or_throw(&mut cx)?;
Ok(cx.boxed(JsTable::from(table)))
})
});
Ok(promise)
}
pub(crate) fn js_cleanup(mut cx: FunctionContext) -> JsResult<JsPromise> { pub(crate) fn js_cleanup(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?; let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let rt = runtime(&mut cx)?; let rt = runtime(&mut cx)?;

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "vectordb" name = "vectordb"
version = "0.4.0" version = "0.3.9"
edition = "2021" edition = "2021"
description = "LanceDB: A serverless, low-latency vector database for AI applications" description = "LanceDB: A serverless, low-latency vector database for AI applications"
license = "Apache-2.0" license = "Apache-2.0"

View File

@@ -26,7 +26,7 @@ use futures::{stream::BoxStream, FutureExt, StreamExt};
use lance::io::object_store::WrappingObjectStore; use lance::io::object_store::WrappingObjectStore;
use object_store::{ use object_store::{
path::Path, Error, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, path::Path, Error, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore,
PutOptions, PutResult, Result, Result,
}; };
use async_trait::async_trait; use async_trait::async_trait;
@@ -72,28 +72,13 @@ impl PrimaryOnly for Path {
/// Note: this object store does not mirror writes to *.manifest files /// Note: this object store does not mirror writes to *.manifest files
#[async_trait] #[async_trait]
impl ObjectStore for MirroringObjectStore { impl ObjectStore for MirroringObjectStore {
async fn put(&self, location: &Path, bytes: Bytes) -> Result<PutResult> { async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> {
if location.primary_only() { if location.primary_only() {
self.primary.put(location, bytes).await self.primary.put(location, bytes).await
} else { } else {
self.secondary.put(location, bytes.clone()).await?; self.secondary.put(location, bytes.clone()).await?;
self.primary.put(location, bytes).await self.primary.put(location, bytes).await?;
} Ok(())
}
async fn put_opts(
&self,
location: &Path,
bytes: Bytes,
options: PutOptions,
) -> Result<PutResult> {
if location.primary_only() {
self.primary.put_opts(location, bytes, options).await
} else {
self.secondary
.put_opts(location, bytes.clone(), options.clone())
.await?;
self.primary.put_opts(location, bytes, options).await
} }
} }
@@ -144,8 +129,8 @@ impl ObjectStore for MirroringObjectStore {
self.primary.delete(location).await self.primary.delete(location).await
} }
fn list(&self, prefix: Option<&Path>) -> BoxStream<'_, Result<ObjectMeta>> { async fn list(&self, prefix: Option<&Path>) -> Result<BoxStream<'_, Result<ObjectMeta>>> {
self.primary.list(prefix) self.primary.list(prefix).await
} }
async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result<ListResult> { async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result<ListResult> {
@@ -374,9 +359,7 @@ mod test {
assert_eq!(t.count_rows().await.unwrap(), 100); assert_eq!(t.count_rows().await.unwrap(), 100);
let q = t let q = t
.search(Some(PrimitiveArray::from_iter_values(vec![ .search(PrimitiveArray::from_iter_values(vec![0.1, 0.1, 0.1, 0.1]))
0.1, 0.1, 0.1, 0.1,
])))
.limit(10) .limit(10)
.execute() .execute()
.await .await

View File

@@ -24,9 +24,8 @@ use crate::error::Result;
/// A builder for nearest neighbor queries for LanceDB. /// A builder for nearest neighbor queries for LanceDB.
pub struct Query { pub struct Query {
pub dataset: Arc<Dataset>, pub dataset: Arc<Dataset>,
pub query_vector: Option<Float32Array>, pub query_vector: Float32Array,
pub column: String, pub limit: usize,
pub limit: Option<usize>,
pub filter: Option<String>, pub filter: Option<String>,
pub select: Option<Vec<String>>, pub select: Option<Vec<String>>,
pub nprobes: usize, pub nprobes: usize,
@@ -47,12 +46,11 @@ impl Query {
/// # Returns /// # Returns
/// ///
/// * A [Query] object. /// * A [Query] object.
pub(crate) fn new(dataset: Arc<Dataset>, vector: Option<Float32Array>) -> Self { pub(crate) fn new(dataset: Arc<Dataset>, vector: Float32Array) -> Self {
Query { Query {
dataset, dataset,
query_vector: vector, query_vector: vector,
column: crate::table::VECTOR_COLUMN_NAME.to_string(), limit: 10,
limit: None,
nprobes: 20, nprobes: 20,
refine_factor: None, refine_factor: None,
metric_type: None, metric_type: None,
@@ -71,13 +69,11 @@ impl Query {
pub async fn execute(&self) -> Result<DatasetRecordBatchStream> { pub async fn execute(&self) -> Result<DatasetRecordBatchStream> {
let mut scanner: Scanner = self.dataset.scan(); let mut scanner: Scanner = self.dataset.scan();
if let Some(query) = self.query_vector.as_ref() { scanner.nearest(
// If there is a vector query, default to limit=10 if unspecified crate::table::VECTOR_COLUMN_NAME,
scanner.nearest(&self.column, query, self.limit.unwrap_or(10))?; &self.query_vector,
} else { self.limit,
// If there is no vector query, it's ok to not have a limit )?;
scanner.limit(self.limit.map(|limit| limit as i64), None)?;
}
scanner.nprobs(self.nprobes); scanner.nprobs(self.nprobes);
scanner.use_index(self.use_index); scanner.use_index(self.use_index);
scanner.prefilter(self.prefilter); scanner.prefilter(self.prefilter);
@@ -89,23 +85,13 @@ impl Query {
Ok(scanner.try_into_stream().await?) Ok(scanner.try_into_stream().await?)
} }
/// Set the column to query
///
/// # Arguments
///
/// * `column` - The column name
pub fn column(mut self, column: &str) -> Query {
self.column = column.into();
self
}
/// Set the maximum number of results to return. /// Set the maximum number of results to return.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `limit` - The maximum number of results to return. /// * `limit` - The maximum number of results to return.
pub fn limit(mut self, limit: usize) -> Query { pub fn limit(mut self, limit: usize) -> Query {
self.limit = Some(limit); self.limit = limit;
self self
} }
@@ -115,7 +101,7 @@ impl Query {
/// ///
/// * `vector` - The vector that will be used for search. /// * `vector` - The vector that will be used for search.
pub fn query_vector(mut self, query_vector: Float32Array) -> Query { pub fn query_vector(mut self, query_vector: Float32Array) -> Query {
self.query_vector = Some(query_vector); self.query_vector = query_vector;
self self
} }
@@ -188,10 +174,7 @@ mod tests {
use std::sync::Arc; use std::sync::Arc;
use super::*; use super::*;
use arrow_array::{ use arrow_array::{Float32Array, RecordBatch, RecordBatchIterator, RecordBatchReader};
cast::AsArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
RecordBatchReader,
};
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema}; use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
use futures::StreamExt; use futures::StreamExt;
use lance::dataset::Dataset; use lance::dataset::Dataset;
@@ -204,7 +187,7 @@ mod tests {
let batches = make_test_batches(); let batches = make_test_batches();
let ds = Dataset::write(batches, "memory://foo", None).await.unwrap(); let ds = Dataset::write(batches, "memory://foo", None).await.unwrap();
let vector = Some(Float32Array::from_iter_values([0.1, 0.2])); let vector = Float32Array::from_iter_values([0.1, 0.2]);
let query = Query::new(Arc::new(ds), vector.clone()); let query = Query::new(Arc::new(ds), vector.clone());
assert_eq!(query.query_vector, vector); assert_eq!(query.query_vector, vector);
@@ -218,8 +201,8 @@ mod tests {
.metric_type(Some(MetricType::Cosine)) .metric_type(Some(MetricType::Cosine))
.refine_factor(Some(999)); .refine_factor(Some(999));
assert_eq!(query.query_vector.unwrap(), new_vector); assert_eq!(query.query_vector, new_vector);
assert_eq!(query.limit.unwrap(), 100); assert_eq!(query.limit, 100);
assert_eq!(query.nprobes, 1000); assert_eq!(query.nprobes, 1000);
assert_eq!(query.use_index, true); assert_eq!(query.use_index, true);
assert_eq!(query.metric_type, Some(MetricType::Cosine)); assert_eq!(query.metric_type, Some(MetricType::Cosine));
@@ -231,7 +214,7 @@ mod tests {
let batches = make_non_empty_batches(); let batches = make_non_empty_batches();
let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap()); let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap());
let vector = Some(Float32Array::from_iter_values([0.1; 4])); let vector = Float32Array::from_iter_values([0.1; 4]);
let query = Query::new(ds.clone(), vector.clone()); let query = Query::new(ds.clone(), vector.clone());
let result = query let result = query
@@ -261,27 +244,6 @@ mod tests {
} }
} }
#[tokio::test]
async fn test_execute_no_vector() {
// test that it's ok to not specify a query vector (just filter / limit)
let batches = make_non_empty_batches();
let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap());
let query = Query::new(ds.clone(), None);
let result = query
.filter(Some("id % 2 == 0".to_string()))
.execute()
.await;
let mut stream = result.expect("should have result");
// should only have one batch
while let Some(batch) = stream.next().await {
let b = batch.expect("should be Ok");
// cast arr into Int32Array
let arr: &Int32Array = b["id"].as_primitive();
assert!(arr.iter().all(|x| x.unwrap() % 2 == 0));
}
}
fn make_non_empty_batches() -> impl RecordBatchReader + Send + 'static { fn make_non_empty_batches() -> impl RecordBatchReader + Send + 'static {
let vec = Box::new(RandomVector::new().named("vector".to_string())); let vec = Box::new(RandomVector::new().named("vector".to_string()));
let id = Box::new(IncrementingInt32::new().named("id".to_string())); let id = Box::new(IncrementingInt32::new().named("id".to_string()));

View File

@@ -23,7 +23,7 @@ use lance::dataset::cleanup::RemovalStats;
use lance::dataset::optimize::{ use lance::dataset::optimize::{
compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions, compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions,
}; };
use lance::dataset::{Dataset, UpdateBuilder, WriteParams}; use lance::dataset::{Dataset, WriteParams};
use lance::index::DatasetIndexExt; use lance::index::DatasetIndexExt;
use lance::io::object_store::WrappingObjectStore; use lance::io::object_store::WrappingObjectStore;
use std::path::Path; use std::path::Path;
@@ -308,14 +308,10 @@ impl Table {
/// # Returns /// # Returns
/// ///
/// * A [Query] object. /// * A [Query] object.
pub fn search(&self, query_vector: Option<Float32Array>) -> Query { pub fn search(&self, query_vector: Float32Array) -> Query {
Query::new(self.dataset.clone(), query_vector) Query::new(self.dataset.clone(), query_vector)
} }
pub fn filter(&self, expr: String) -> Query {
Query::new(self.dataset.clone(), None).filter(Some(expr))
}
/// Returns the number of rows in this Table /// Returns the number of rows in this Table
pub async fn count_rows(&self) -> Result<usize> { pub async fn count_rows(&self) -> Result<usize> {
Ok(self.dataset.count_rows().await?) Ok(self.dataset.count_rows().await?)
@@ -342,27 +338,6 @@ impl Table {
Ok(()) Ok(())
} }
pub async fn update(
&mut self,
predicate: Option<&str>,
updates: Vec<(&str, &str)>,
) -> Result<()> {
let mut builder = UpdateBuilder::new(self.dataset.clone());
if let Some(predicate) = predicate {
builder = builder.update_where(predicate)?;
}
for (column, value) in updates {
builder = builder.set(column, value)?;
}
let operation = builder.build()?;
let new_ds = operation.execute().await?;
self.dataset = new_ds;
Ok(())
}
/// Remove old versions of the dataset from disk. /// Remove old versions of the dataset from disk.
/// ///
/// # Arguments /// # Arguments
@@ -438,14 +413,11 @@ mod tests {
use std::sync::Arc; use std::sync::Arc;
use arrow_array::{ use arrow_array::{
Array, BooleanArray, Date32Array, FixedSizeListArray, Float32Array, Float64Array, Array, FixedSizeListArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
Int32Array, Int64Array, LargeStringArray, RecordBatch, RecordBatchIterator, RecordBatchReader,
RecordBatchReader, StringArray, TimestampMillisecondArray, TimestampNanosecondArray,
UInt32Array,
}; };
use arrow_data::ArrayDataBuilder; use arrow_data::ArrayDataBuilder;
use arrow_schema::{DataType, Field, Schema, TimeUnit}; use arrow_schema::{DataType, Field, Schema};
use futures::TryStreamExt;
use lance::dataset::{Dataset, WriteMode}; use lance::dataset::{Dataset, WriteMode};
use lance::index::vector::pq::PQBuildParams; use lance::index::vector::pq::PQBuildParams;
use lance::io::object_store::{ObjectStoreParams, WrappingObjectStore}; use lance::io::object_store::{ObjectStoreParams, WrappingObjectStore};
@@ -568,272 +540,6 @@ mod tests {
assert_eq!(table.name, "test"); assert_eq!(table.name, "test");
} }
#[tokio::test]
async fn test_update_with_predicate() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path().join("test.lance");
let uri = dataset_path.to_str().unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
]));
let record_batch_iter = RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..10)),
Arc::new(StringArray::from_iter_values(vec![
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
])),
],
)
.unwrap()]
.into_iter()
.map(Ok),
schema.clone(),
);
Dataset::write(record_batch_iter, uri, None).await.unwrap();
let mut table = Table::open(uri).await.unwrap();
table
.update(Some("id > 5"), vec![("name", "'foo'")])
.await
.unwrap();
let ds_after = Dataset::open(uri).await.unwrap();
let mut batches = ds_after
.scan()
.project(&["id", "name"])
.unwrap()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
while let Some(batch) = batches.pop() {
let ids = batch
.column(0)
.as_any()
.downcast_ref::<Int32Array>()
.unwrap()
.iter()
.collect::<Vec<_>>();
let names = batch
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.iter()
.collect::<Vec<_>>();
for (i, name) in names.iter().enumerate() {
let id = ids[i].unwrap();
let name = name.unwrap();
if id > 5 {
assert_eq!(name, "foo");
} else {
assert_eq!(name, &format!("{}", (b'a' + id as u8) as char));
}
}
}
}
#[tokio::test]
async fn test_update_all_types() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path().join("test.lance");
let uri = dataset_path.to_str().unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("int32", DataType::Int32, false),
Field::new("int64", DataType::Int64, false),
Field::new("uint32", DataType::UInt32, false),
Field::new("string", DataType::Utf8, false),
Field::new("large_string", DataType::LargeUtf8, false),
Field::new("float32", DataType::Float32, false),
Field::new("float64", DataType::Float64, false),
Field::new("bool", DataType::Boolean, false),
Field::new("date32", DataType::Date32, false),
Field::new(
"timestamp_ns",
DataType::Timestamp(TimeUnit::Nanosecond, None),
false,
),
Field::new(
"timestamp_ms",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
Field::new(
"vec_f32",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 2),
false,
),
Field::new(
"vec_f64",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float64, true)), 2),
false,
),
]));
let record_batch_iter = RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(0..10)),
Arc::new(Int64Array::from_iter_values(0..10)),
Arc::new(UInt32Array::from_iter_values(0..10)),
Arc::new(StringArray::from_iter_values(vec![
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
])),
Arc::new(LargeStringArray::from_iter_values(vec![
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
])),
Arc::new(Float32Array::from_iter_values(
(0..10).into_iter().map(|i| i as f32),
)),
Arc::new(Float64Array::from_iter_values(
(0..10).into_iter().map(|i| i as f64),
)),
Arc::new(Into::<BooleanArray>::into(vec![
true, false, true, false, true, false, true, false, true, false,
])),
Arc::new(Date32Array::from_iter_values(0..10)),
Arc::new(TimestampNanosecondArray::from_iter_values(0..10)),
Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
Arc::new(
create_fixed_size_list(
Float32Array::from_iter_values((0..20).into_iter().map(|i| i as f32)),
2,
)
.unwrap(),
),
Arc::new(
create_fixed_size_list(
Float64Array::from_iter_values((0..20).into_iter().map(|i| i as f64)),
2,
)
.unwrap(),
),
],
)
.unwrap()]
.into_iter()
.map(Ok),
schema.clone(),
);
Dataset::write(record_batch_iter, uri, None).await.unwrap();
let mut table = Table::open(uri).await.unwrap();
// check it can do update for each type
let updates: Vec<(&str, &str)> = vec![
("string", "'foo'"),
("large_string", "'large_foo'"),
("int32", "1"),
("int64", "1"),
("uint32", "1"),
("float32", "1.0"),
("float64", "1.0"),
("bool", "true"),
("date32", "1"),
("timestamp_ns", "1"),
("timestamp_ms", "1"),
("vec_f32", "[1.0, 1.0]"),
("vec_f64", "[1.0, 1.0]"),
];
// for (column, value) in test_cases {
table.update(None, updates).await.unwrap();
let ds_after = Dataset::open(uri).await.unwrap();
let mut batches = ds_after
.scan()
.project(&[
"string",
"large_string",
"int32",
"int64",
"uint32",
"float32",
"float64",
"bool",
"date32",
"timestamp_ns",
"timestamp_ms",
"vec_f32",
"vec_f64",
])
.unwrap()
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
let batch = batches.pop().unwrap();
macro_rules! assert_column {
($column:expr, $array_type:ty, $expected:expr) => {
let array = $column
.as_any()
.downcast_ref::<$array_type>()
.unwrap()
.iter()
.collect::<Vec<_>>();
for v in array {
assert_eq!(v, Some($expected));
}
};
}
assert_column!(batch.column(0), StringArray, "foo");
assert_column!(batch.column(1), LargeStringArray, "large_foo");
assert_column!(batch.column(2), Int32Array, 1);
assert_column!(batch.column(3), Int64Array, 1);
assert_column!(batch.column(4), UInt32Array, 1);
assert_column!(batch.column(5), Float32Array, 1.0);
assert_column!(batch.column(6), Float64Array, 1.0);
assert_column!(batch.column(7), BooleanArray, true);
assert_column!(batch.column(8), Date32Array, 1);
assert_column!(batch.column(9), TimestampNanosecondArray, 1);
assert_column!(batch.column(10), TimestampMillisecondArray, 1);
let array = batch
.column(11)
.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap()
.iter()
.collect::<Vec<_>>();
for v in array {
let v = v.unwrap();
let f32array = v.as_any().downcast_ref::<Float32Array>().unwrap();
for v in f32array {
assert_eq!(v, Some(1.0));
}
}
let array = batch
.column(12)
.as_any()
.downcast_ref::<FixedSizeListArray>()
.unwrap()
.iter()
.collect::<Vec<_>>();
for v in array {
let v = v.unwrap();
let f64array = v.as_any().downcast_ref::<Float64Array>().unwrap();
for v in f64array {
assert_eq!(v, Some(1.0));
}
}
}
#[tokio::test] #[tokio::test]
async fn test_search() { async fn test_search() {
let tmp_dir = tempdir().unwrap(); let tmp_dir = tempdir().unwrap();
@@ -848,8 +554,8 @@ mod tests {
let table = Table::open(uri).await.unwrap(); let table = Table::open(uri).await.unwrap();
let vector = Float32Array::from_iter_values([0.1, 0.2]); let vector = Float32Array::from_iter_values([0.1, 0.2]);
let query = table.search(Some(vector.clone())); let query = table.search(vector.clone());
assert_eq!(vector, query.query_vector.unwrap()); assert_eq!(vector, query.query_vector);
} }
#[derive(Default, Debug)] #[derive(Default, Debug)]