mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
Compare commits
11 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
82936c77ef | ||
|
|
dddcddcaf9 | ||
|
|
a9727eb318 | ||
|
|
48d55bf952 | ||
|
|
d2e71c8b08 | ||
|
|
f53aace89c | ||
|
|
d982ee934a | ||
|
|
57605a2d86 | ||
|
|
738511c5f2 | ||
|
|
0b0f42537e | ||
|
|
e412194008 |
34
.cargo/config.toml
Normal file
34
.cargo/config.toml
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
[profile.release]
|
||||||
|
lto = "fat"
|
||||||
|
codegen-units = 1
|
||||||
|
|
||||||
|
[profile.release-with-debug]
|
||||||
|
inherits = "release"
|
||||||
|
debug = true
|
||||||
|
# Prioritize compile time over runtime performance
|
||||||
|
codegen-units = 16
|
||||||
|
lto = "thin"
|
||||||
|
|
||||||
|
[target.'cfg(all())']
|
||||||
|
rustflags = [
|
||||||
|
"-Wclippy::all",
|
||||||
|
"-Wclippy::style",
|
||||||
|
"-Wclippy::fallible_impl_from",
|
||||||
|
"-Wclippy::manual_let_else",
|
||||||
|
"-Wclippy::redundant_pub_crate",
|
||||||
|
"-Wclippy::string_add_assign",
|
||||||
|
"-Wclippy::string_add",
|
||||||
|
"-Wclippy::string_lit_as_bytes",
|
||||||
|
"-Wclippy::string_to_string",
|
||||||
|
"-Wclippy::use_self",
|
||||||
|
"-Dclippy::cargo",
|
||||||
|
"-Dclippy::dbg_macro",
|
||||||
|
# not too much we can do to avoid multiple crate versions
|
||||||
|
"-Aclippy::multiple-crate-versions",
|
||||||
|
]
|
||||||
|
|
||||||
|
[target.x86_64-unknown-linux-gnu]
|
||||||
|
rustflags = ["-C", "target-cpu=haswell", "-C", "target-feature=+avx2,+fma,+f16c"]
|
||||||
|
|
||||||
|
[target.aarch64-apple-darwin]
|
||||||
|
rustflags = ["-C", "target-cpu=apple-m1", "-C", "target-feature=+neon,+fp16,+fhm,+dotprod"]
|
||||||
9
.github/workflows/docs_test.yml
vendored
9
.github/workflows/docs_test.yml
vendored
@@ -49,6 +49,9 @@ jobs:
|
|||||||
test-node:
|
test-node:
|
||||||
name: Test doc nodejs code
|
name: Test doc nodejs code
|
||||||
runs-on: "ubuntu-latest"
|
runs-on: "ubuntu-latest"
|
||||||
|
timeout-minutes: 45
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -66,6 +69,12 @@ jobs:
|
|||||||
uses: swatinem/rust-cache@v2
|
uses: swatinem/rust-cache@v2
|
||||||
- name: Install node dependencies
|
- name: Install node dependencies
|
||||||
run: |
|
run: |
|
||||||
|
sudo swapoff -a
|
||||||
|
sudo fallocate -l 8G /swapfile
|
||||||
|
sudo chmod 600 /swapfile
|
||||||
|
sudo mkswap /swapfile
|
||||||
|
sudo swapon /swapfile
|
||||||
|
sudo swapon --show
|
||||||
cd node
|
cd node
|
||||||
npm ci
|
npm ci
|
||||||
npm run build-release
|
npm run build-release
|
||||||
|
|||||||
13
Cargo.toml
13
Cargo.toml
@@ -6,15 +6,18 @@ resolver = "2"
|
|||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Lance Devs <dev@lancedb.com>"]
|
authors = ["LanceDB Devs <dev@lancedb.com>"]
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
repository = "https://github.com/lancedb/lancedb"
|
repository = "https://github.com/lancedb/lancedb"
|
||||||
|
description = "Serverless, low-latency vector database for AI applications"
|
||||||
|
keywords = ["lancedb", "lance", "database", "vector", "search"]
|
||||||
|
categories = ["database-implementations"]
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
lance = { "version" = "=0.9.12", "features" = ["dynamodb"] }
|
lance = { "version" = "=0.9.15", "features" = ["dynamodb"] }
|
||||||
lance-index = { "version" = "=0.9.12" }
|
lance-index = { "version" = "=0.9.15" }
|
||||||
lance-linalg = { "version" = "=0.9.12" }
|
lance-linalg = { "version" = "=0.9.15" }
|
||||||
lance-testing = { "version" = "=0.9.12" }
|
lance-testing = { "version" = "=0.9.15" }
|
||||||
# Note that this one does not include pyarrow
|
# Note that this one does not include pyarrow
|
||||||
arrow = { version = "50.0", optional = false }
|
arrow = { version = "50.0", optional = false }
|
||||||
arrow-array = "50.0"
|
arrow-array = "50.0"
|
||||||
|
|||||||
@@ -69,3 +69,19 @@ MinIO supports an S3 compatible API. In order to connect to a MinIO instance, yo
|
|||||||
- Set the envvar `AWS_ENDPOINT` to the URL of your MinIO API
|
- Set the envvar `AWS_ENDPOINT` to the URL of your MinIO API
|
||||||
- Set the envvars `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` with your MinIO credential
|
- Set the envvars `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` with your MinIO credential
|
||||||
- Call `lancedb.connect("s3://minio_bucket_name")`
|
- Call `lancedb.connect("s3://minio_bucket_name")`
|
||||||
|
|
||||||
|
### Where can I find benchmarks for LanceDB?
|
||||||
|
|
||||||
|
Refer to this [post](https://blog.lancedb.com/benchmarking-lancedb-92b01032874a) for recent benchmarks.
|
||||||
|
|
||||||
|
### How much data can LanceDB practically manage without effecting performance?
|
||||||
|
|
||||||
|
We target good performance on ~10-50 billion rows and ~10-30 TB of data.
|
||||||
|
|
||||||
|
### Does LanceDB support concurrent operations?
|
||||||
|
|
||||||
|
LanceDB can handle concurrent reads very well, and can scale horizontally. The main constraint is how well the [storage layer](https://lancedb.github.io/lancedb/concepts/storage/) you've chosen scales. For writes, we support concurrent writing, though too many concurrent writers can lead to failing writes as there is a limited number of times a writer retries a commit
|
||||||
|
|
||||||
|
!!! info "Multiprocessing with LanceDB"
|
||||||
|
|
||||||
|
For multiprocessing you should probably not use ```fork``` as lance is multi-threaded internally and ```fork``` and multi-thread do not work well.[Refer to this discussion](https://discuss.python.org/t/concerns-regarding-deprecation-of-fork-with-alive-threads/33555)
|
||||||
|
|||||||
@@ -6,17 +6,24 @@ LanceDB supports both semantic and keyword-based search. In real world applicati
|
|||||||
You can perform hybrid search in LanceDB by combining the results of semantic and full-text search via a reranking algorithm of your choice. LanceDB provides multiple rerankers out of the box. However, you can always write a custom reranker if your use case need more sophisticated logic .
|
You can perform hybrid search in LanceDB by combining the results of semantic and full-text search via a reranking algorithm of your choice. LanceDB provides multiple rerankers out of the box. However, you can always write a custom reranker if your use case need more sophisticated logic .
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
import os
|
||||||
|
|
||||||
import lancedb
|
import lancedb
|
||||||
|
import openai
|
||||||
from lancedb.embeddings import get_registry
|
from lancedb.embeddings import get_registry
|
||||||
from lancedb.pydanatic import LanceModel, Vector
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
|
||||||
db = lancedb.connect("~/.lancedb")
|
db = lancedb.connect("~/.lancedb")
|
||||||
|
|
||||||
# Ingest embedding function in LanceDB table
|
# Ingest embedding function in LanceDB table
|
||||||
|
# Configuring the environment variable OPENAI_API_KEY
|
||||||
|
if "OPENAI_API_KEY" not in os.environ:
|
||||||
|
# OR set the key here as a variable
|
||||||
|
openai.api_key = "sk-..."
|
||||||
embeddings = get_registry().get("openai").create()
|
embeddings = get_registry().get("openai").create()
|
||||||
|
|
||||||
class Documents(LanceModel):
|
class Documents(LanceModel):
|
||||||
vector: Vector(embeddings.ndims) = embeddings.VectorField()
|
vector: Vector(embeddings.ndims()) = embeddings.VectorField()
|
||||||
text: str = embeddings.SourceField()
|
text: str = embeddings.SourceField()
|
||||||
|
|
||||||
table = db.create_table("documents", schema=Documents)
|
table = db.create_table("documents", schema=Documents)
|
||||||
@@ -31,17 +38,19 @@ data = [
|
|||||||
# ingest docs with auto-vectorization
|
# ingest docs with auto-vectorization
|
||||||
table.add(data)
|
table.add(data)
|
||||||
|
|
||||||
|
# Create a fts index before the hybrid search
|
||||||
|
table.create_fts_index("text")
|
||||||
# hybrid search with default re-ranker
|
# hybrid search with default re-ranker
|
||||||
results = table.search("flower moon", query_type="hybrid").to_pandas()
|
results = table.search("flower moon", query_type="hybrid").to_pandas()
|
||||||
```
|
```
|
||||||
|
|
||||||
By default, LanceDB uses `LinearCombinationReranker(weights=0.7)` to combine and rerank the results of semantic and full-text search. You can customize the hyperparameters as needed or write your own custom reranker. Here's how you can use any of the available rerankers:
|
By default, LanceDB uses `LinearCombinationReranker(weight=0.7)` to combine and rerank the results of semantic and full-text search. You can customize the hyperparameters as needed or write your own custom reranker. Here's how you can use any of the available rerankers:
|
||||||
|
|
||||||
|
|
||||||
### `rerank()` arguments
|
### `rerank()` arguments
|
||||||
* `normalize`: `str`, default `"score"`:
|
* `normalize`: `str`, default `"score"`:
|
||||||
The method to normalize the scores. Can be "rank" or "score". If "rank", the scores are converted to ranks and then normalized. If "score", the scores are normalized directly.
|
The method to normalize the scores. Can be "rank" or "score". If "rank", the scores are converted to ranks and then normalized. If "score", the scores are normalized directly.
|
||||||
* `reranker`: `Reranker`, default `LinearCombinationReranker(weights=0.7)`.
|
* `reranker`: `Reranker`, default `LinearCombinationReranker(weight=0.7)`.
|
||||||
The reranker to use. If not specified, the default reranker is used.
|
The reranker to use. If not specified, the default reranker is used.
|
||||||
|
|
||||||
|
|
||||||
@@ -55,7 +64,7 @@ This is the default re-ranker used by LanceDB. It combines the results of semant
|
|||||||
```python
|
```python
|
||||||
from lancedb.rerankers import LinearCombinationReranker
|
from lancedb.rerankers import LinearCombinationReranker
|
||||||
|
|
||||||
reranker = LinearCombinationReranker(weights=0.3) # Use 0.3 as the weight for vector search
|
reranker = LinearCombinationReranker(weight=0.3) # Use 0.3 as the weight for vector search
|
||||||
|
|
||||||
results = table.search("rebel", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
results = table.search("rebel", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||||
```
|
```
|
||||||
@@ -121,6 +130,60 @@ Arguments
|
|||||||
Only returns `_relevance_score`. Does not support `return_score = "all"`.
|
Only returns `_relevance_score`. Does not support `return_score = "all"`.
|
||||||
|
|
||||||
|
|
||||||
|
### ColBERT Reranker
|
||||||
|
This reranker uses the ColBERT model to combine the results of semantic and full-text search. You can use it by passing `ColbertrReranker()` to the `rerank()` method.
|
||||||
|
|
||||||
|
ColBERT reranker model calculates relevance of given docs against the query and don't take existing fts and vector search scores into account, so it currently only supports `return_score="relevance"`. By default, it looks for `text` column to rerank the results. But you can specify the column name to use as input to the cross encoder model as described below.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lancedb.rerankers import ColbertReranker
|
||||||
|
|
||||||
|
reranker = ColbertReranker()
|
||||||
|
|
||||||
|
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||||
|
```
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
----------------
|
||||||
|
* `model_name` : `str`, default `"colbert-ir/colbertv2.0"`
|
||||||
|
The name of the cross encoder model to use.
|
||||||
|
* `column` : `str`, default `"text"`
|
||||||
|
The name of the column to use as input to the cross encoder model.
|
||||||
|
* `return_score` : `str`, default `"relevance"`
|
||||||
|
options are `"relevance"` or `"all"`. Only `"relevance"` is supported for now.
|
||||||
|
|
||||||
|
!!! Note
|
||||||
|
Only returns `_relevance_score`. Does not support `return_score = "all"`.
|
||||||
|
|
||||||
|
### OpenAI Reranker
|
||||||
|
This reranker uses the OpenAI API to combine the results of semantic and full-text search. You can use it by passing `OpenaiReranker()` to the `rerank()` method.
|
||||||
|
|
||||||
|
!!! Note
|
||||||
|
This prompts chat model to rerank results which is not a dedicated reranker model. This should be treated as experimental.
|
||||||
|
|
||||||
|
!!! Tip
|
||||||
|
You might run out of token limit so set the search `limits` based on your token limit.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lancedb.rerankers import OpenaiReranker
|
||||||
|
|
||||||
|
reranker = OpenaiReranker()
|
||||||
|
|
||||||
|
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||||
|
```
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
----------------
|
||||||
|
`model_name` : `str`, default `"gpt-3.5-turbo-1106"`
|
||||||
|
The name of the cross encoder model to use.
|
||||||
|
`column` : `str`, default `"text"`
|
||||||
|
The name of the column to use as input to the cross encoder model.
|
||||||
|
`return_score` : `str`, default `"relevance"`
|
||||||
|
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||||
|
`api_key` : `str`, default `None`
|
||||||
|
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
|
||||||
|
|
||||||
|
|
||||||
## Building Custom Rerankers
|
## Building Custom Rerankers
|
||||||
You can build your own custom reranker by subclassing the `Reranker` class and implementing the `rerank_hybrid()` method. Here's an example of a custom reranker that combines the results of semantic and full-text search using a linear combination of the scores.
|
You can build your own custom reranker by subclassing the `Reranker` class and implementing the `rerank_hybrid()` method. Here's an example of a custom reranker that combines the results of semantic and full-text search using a linear combination of the scores.
|
||||||
|
|
||||||
@@ -137,7 +200,7 @@ class MyReranker(Reranker):
|
|||||||
self.param1 = param1
|
self.param1 = param1
|
||||||
self.param2 = param2
|
self.param2 = param2
|
||||||
|
|
||||||
def rerank_hybrid(self, vector_results: pa.Table, fts_results: pa.Table):
|
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table):
|
||||||
# Use the built-in merging function
|
# Use the built-in merging function
|
||||||
combined_result = self.merge_results(vector_results, fts_results)
|
combined_result = self.merge_results(vector_results, fts_results)
|
||||||
|
|
||||||
@@ -159,7 +222,7 @@ import pyarrow as pa
|
|||||||
class MyReranker(Reranker):
|
class MyReranker(Reranker):
|
||||||
...
|
...
|
||||||
|
|
||||||
def rerank_hybrid(self, vector_results: pa.Table, fts_results: pa.Table, filter: str):
|
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table, filter: str):
|
||||||
# Use the built-in merging function
|
# Use the built-in merging function
|
||||||
combined_result = self.merge_results(vector_results, fts_results)
|
combined_result = self.merge_results(vector_results, fts_results)
|
||||||
|
|
||||||
|
|||||||
@@ -372,7 +372,7 @@ export interface Table<T = number[]> {
|
|||||||
/**
|
/**
|
||||||
* Returns the number of rows in this table.
|
* Returns the number of rows in this table.
|
||||||
*/
|
*/
|
||||||
countRows: () => Promise<number>
|
countRows: (filter?: string) => Promise<number>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Delete rows from this table.
|
* Delete rows from this table.
|
||||||
@@ -525,8 +525,19 @@ export interface MergeInsertArgs {
|
|||||||
* If there are multiple matches then the behavior is undefined.
|
* If there are multiple matches then the behavior is undefined.
|
||||||
* Currently this causes multiple copies of the row to be created
|
* Currently this causes multiple copies of the row to be created
|
||||||
* but that behavior is subject to change.
|
* but that behavior is subject to change.
|
||||||
|
*
|
||||||
|
* Optionally, a filter can be specified. This should be an SQL
|
||||||
|
* filter where fields with the prefix "target." refer to fields
|
||||||
|
* in the target table (old data) and fields with the prefix
|
||||||
|
* "source." refer to fields in the source table (new data). For
|
||||||
|
* example, the filter "target.lastUpdated < source.lastUpdated" will
|
||||||
|
* only update matched rows when the incoming `lastUpdated` value is
|
||||||
|
* newer.
|
||||||
|
*
|
||||||
|
* Rows that do not match the filter will not be updated. Rows that
|
||||||
|
* do not match the filter do become "not matched" rows.
|
||||||
*/
|
*/
|
||||||
whenMatchedUpdateAll?: boolean
|
whenMatchedUpdateAll?: string | boolean
|
||||||
/**
|
/**
|
||||||
* If true then rows that exist only in the source table (new data)
|
* If true then rows that exist only in the source table (new data)
|
||||||
* will be inserted into the target table.
|
* will be inserted into the target table.
|
||||||
@@ -840,8 +851,8 @@ export class LocalTable<T = number[]> implements Table<T> {
|
|||||||
/**
|
/**
|
||||||
* Returns the number of rows in this table.
|
* Returns the number of rows in this table.
|
||||||
*/
|
*/
|
||||||
async countRows (): Promise<number> {
|
async countRows (filter?: string): Promise<number> {
|
||||||
return tableCountRows.call(this._tbl)
|
return tableCountRows.call(this._tbl, filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -885,7 +896,14 @@ export class LocalTable<T = number[]> implements Table<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
|
async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
|
||||||
const whenMatchedUpdateAll = args.whenMatchedUpdateAll ?? false
|
let whenMatchedUpdateAll = false
|
||||||
|
let whenMatchedUpdateAllFilt = null
|
||||||
|
if (args.whenMatchedUpdateAll !== undefined && args.whenMatchedUpdateAll !== null) {
|
||||||
|
whenMatchedUpdateAll = true
|
||||||
|
if (args.whenMatchedUpdateAll !== true) {
|
||||||
|
whenMatchedUpdateAllFilt = args.whenMatchedUpdateAll
|
||||||
|
}
|
||||||
|
}
|
||||||
const whenNotMatchedInsertAll = args.whenNotMatchedInsertAll ?? false
|
const whenNotMatchedInsertAll = args.whenNotMatchedInsertAll ?? false
|
||||||
let whenNotMatchedBySourceDelete = false
|
let whenNotMatchedBySourceDelete = false
|
||||||
let whenNotMatchedBySourceDeleteFilt = null
|
let whenNotMatchedBySourceDeleteFilt = null
|
||||||
@@ -909,6 +927,7 @@ export class LocalTable<T = number[]> implements Table<T> {
|
|||||||
this._tbl,
|
this._tbl,
|
||||||
on,
|
on,
|
||||||
whenMatchedUpdateAll,
|
whenMatchedUpdateAll,
|
||||||
|
whenMatchedUpdateAllFilt,
|
||||||
whenNotMatchedInsertAll,
|
whenNotMatchedInsertAll,
|
||||||
whenNotMatchedBySourceDelete,
|
whenNotMatchedBySourceDelete,
|
||||||
whenNotMatchedBySourceDeleteFilt,
|
whenNotMatchedBySourceDeleteFilt,
|
||||||
|
|||||||
@@ -286,8 +286,11 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
|||||||
const queryParams: any = {
|
const queryParams: any = {
|
||||||
on
|
on
|
||||||
}
|
}
|
||||||
if (args.whenMatchedUpdateAll ?? false) {
|
if (args.whenMatchedUpdateAll !== false && args.whenMatchedUpdateAll !== null && args.whenMatchedUpdateAll !== undefined) {
|
||||||
queryParams.when_matched_update_all = 'true'
|
queryParams.when_matched_update_all = 'true'
|
||||||
|
if (typeof args.whenMatchedUpdateAll === 'string') {
|
||||||
|
queryParams.when_matched_update_all_filt = args.whenMatchedUpdateAll
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
queryParams.when_matched_update_all = 'false'
|
queryParams.when_matched_update_all = 'false'
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -294,6 +294,7 @@ describe('LanceDB client', function () {
|
|||||||
})
|
})
|
||||||
assert.equal(table.name, 'vectors')
|
assert.equal(table.name, 'vectors')
|
||||||
assert.equal(await table.countRows(), 10)
|
assert.equal(await table.countRows(), 10)
|
||||||
|
assert.equal(await table.countRows('vector IS NULL'), 0)
|
||||||
assert.deepEqual(await con.tableNames(), ['vectors'])
|
assert.deepEqual(await con.tableNames(), ['vectors'])
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -369,6 +370,7 @@ describe('LanceDB client', function () {
|
|||||||
const table = await con.createTable('f16', data)
|
const table = await con.createTable('f16', data)
|
||||||
assert.equal(table.name, 'f16')
|
assert.equal(table.name, 'f16')
|
||||||
assert.equal(await table.countRows(), total)
|
assert.equal(await table.countRows(), total)
|
||||||
|
assert.equal(await table.countRows('id < 5'), 5)
|
||||||
assert.deepEqual(await con.tableNames(), ['f16'])
|
assert.deepEqual(await con.tableNames(), ['f16'])
|
||||||
assert.deepEqual(await table.schema, schema)
|
assert.deepEqual(await table.schema, schema)
|
||||||
|
|
||||||
@@ -538,26 +540,36 @@ describe('LanceDB client', function () {
|
|||||||
const data = [{ id: 1, age: 1 }, { id: 2, age: 1 }]
|
const data = [{ id: 1, age: 1 }, { id: 2, age: 1 }]
|
||||||
const table = await con.createTable('my_table', data)
|
const table = await con.createTable('my_table', data)
|
||||||
|
|
||||||
|
// insert if not exists
|
||||||
let newData = [{ id: 2, age: 2 }, { id: 3, age: 2 }]
|
let newData = [{ id: 2, age: 2 }, { id: 3, age: 2 }]
|
||||||
await table.mergeInsert('id', newData, {
|
await table.mergeInsert('id', newData, {
|
||||||
whenNotMatchedInsertAll: true
|
whenNotMatchedInsertAll: true
|
||||||
})
|
})
|
||||||
assert.equal(await table.countRows(), 3)
|
assert.equal(await table.countRows(), 3)
|
||||||
assert.equal((await table.filter('age = 2').execute()).length, 1)
|
assert.equal(await table.countRows('age = 2'), 1)
|
||||||
|
|
||||||
newData = [{ id: 3, age: 3 }, { id: 4, age: 3 }]
|
// conditional update
|
||||||
|
newData = [{ id: 2, age: 3 }, { id: 3, age: 3 }]
|
||||||
|
await table.mergeInsert('id', newData, {
|
||||||
|
whenMatchedUpdateAll: 'target.age = 1'
|
||||||
|
})
|
||||||
|
assert.equal(await table.countRows(), 3)
|
||||||
|
assert.equal(await table.countRows('age = 1'), 1)
|
||||||
|
assert.equal(await table.countRows('age = 3'), 1)
|
||||||
|
|
||||||
|
newData = [{ id: 3, age: 4 }, { id: 4, age: 4 }]
|
||||||
await table.mergeInsert('id', newData, {
|
await table.mergeInsert('id', newData, {
|
||||||
whenNotMatchedInsertAll: true,
|
whenNotMatchedInsertAll: true,
|
||||||
whenMatchedUpdateAll: true
|
whenMatchedUpdateAll: true
|
||||||
})
|
})
|
||||||
assert.equal(await table.countRows(), 4)
|
assert.equal(await table.countRows(), 4)
|
||||||
assert.equal((await table.filter('age = 3').execute()).length, 2)
|
assert.equal((await table.filter('age = 4').execute()).length, 2)
|
||||||
|
|
||||||
newData = [{ id: 5, age: 4 }]
|
newData = [{ id: 5, age: 5 }]
|
||||||
await table.mergeInsert('id', newData, {
|
await table.mergeInsert('id', newData, {
|
||||||
whenNotMatchedInsertAll: true,
|
whenNotMatchedInsertAll: true,
|
||||||
whenMatchedUpdateAll: true,
|
whenMatchedUpdateAll: true,
|
||||||
whenNotMatchedBySourceDelete: 'age < 3'
|
whenNotMatchedBySourceDelete: 'age < 4'
|
||||||
})
|
})
|
||||||
assert.equal(await table.countRows(), 3)
|
assert.equal(await table.countRows(), 3)
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "vectordb-nodejs"
|
name = "vectordb-nodejs"
|
||||||
edition = "2021"
|
edition.workspace = true
|
||||||
version = "0.0.0"
|
version = "0.0.0"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
|
description.workspace = true
|
||||||
repository.workspace = true
|
repository.workspace = true
|
||||||
|
keywords.workspace = true
|
||||||
|
categories.workspace = true
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
crate-type = ["cdylib"]
|
crate-type = ["cdylib"]
|
||||||
@@ -14,15 +17,11 @@ futures.workspace = true
|
|||||||
lance-linalg.workspace = true
|
lance-linalg.workspace = true
|
||||||
lance.workspace = true
|
lance.workspace = true
|
||||||
vectordb = { path = "../rust/vectordb" }
|
vectordb = { path = "../rust/vectordb" }
|
||||||
napi = { version = "2.14", default-features = false, features = [
|
napi = { version = "2.15", default-features = false, features = [
|
||||||
"napi7",
|
"napi7",
|
||||||
"async"
|
"async"
|
||||||
] }
|
] }
|
||||||
napi-derive = "2.14"
|
napi-derive = "2"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
napi-build = "2.1"
|
napi-build = "2.1"
|
||||||
|
|
||||||
[profile.release]
|
|
||||||
lto = true
|
|
||||||
strip = "symbols"
|
|
||||||
|
|||||||
@@ -2,4 +2,6 @@
|
|||||||
module.exports = {
|
module.exports = {
|
||||||
preset: 'ts-jest',
|
preset: 'ts-jest',
|
||||||
testEnvironment: 'node',
|
testEnvironment: 'node',
|
||||||
};
|
moduleDirectories: ["node_modules", "./dist"],
|
||||||
|
moduleFileExtensions: ["js", "ts"],
|
||||||
|
};
|
||||||
|
|||||||
@@ -57,8 +57,8 @@ impl Table {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi]
|
||||||
pub async fn count_rows(&self) -> napi::Result<usize> {
|
pub async fn count_rows(&self, filter: Option<String>) -> napi::Result<usize> {
|
||||||
self.table.count_rows().await.map_err(|e| {
|
self.table.count_rows(filter).await.map_err(|e| {
|
||||||
napi::Error::from_reason(format!(
|
napi::Error::from_reason(format!(
|
||||||
"Failed to count rows in table {}: {}",
|
"Failed to count rows in table {}: {}",
|
||||||
self.table, e
|
self.table, e
|
||||||
|
|||||||
2
nodejs/vectordb/native.d.ts
vendored
2
nodejs/vectordb/native.d.ts
vendored
@@ -73,7 +73,7 @@ export class Table {
|
|||||||
/** Return Schema as empty Arrow IPC file. */
|
/** Return Schema as empty Arrow IPC file. */
|
||||||
schema(): Buffer
|
schema(): Buffer
|
||||||
add(buf: Buffer): Promise<void>
|
add(buf: Buffer): Promise<void>
|
||||||
countRows(): Promise<bigint>
|
countRows(filter?: string): Promise<bigint>
|
||||||
delete(predicate: string): Promise<void>
|
delete(predicate: string): Promise<void>
|
||||||
createIndex(): IndexBuilder
|
createIndex(): IndexBuilder
|
||||||
query(): Query
|
query(): Query
|
||||||
|
|||||||
@@ -50,8 +50,8 @@ export class Table {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** Count the total number of rows in the dataset. */
|
/** Count the total number of rows in the dataset. */
|
||||||
async countRows(): Promise<bigint> {
|
async countRows(filter?: string): Promise<bigint> {
|
||||||
return await this.inner.countRows();
|
return await this.inner.countRows(filter);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Delete the rows that satisfy the predicate. */
|
/** Delete the rows that satisfy the predicate. */
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.5.3
|
current_version = 0.5.4
|
||||||
commit = True
|
commit = True
|
||||||
message = [python] Bump version: {current_version} → {new_version}
|
message = [python] Bump version: {current_version} → {new_version}
|
||||||
tag = True
|
tag = True
|
||||||
|
|||||||
@@ -42,6 +42,12 @@ To run the unit tests:
|
|||||||
pytest
|
pytest
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To run the doc tests:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pytest --doctest-modules lancedb
|
||||||
|
```
|
||||||
|
|
||||||
To run linter and automatically fix all errors:
|
To run linter and automatically fix all errors:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
|
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
import os
|
import os
|
||||||
|
from datetime import timedelta
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
__version__ = importlib.metadata.version("lancedb")
|
__version__ = importlib.metadata.version("lancedb")
|
||||||
@@ -30,6 +31,7 @@ def connect(
|
|||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
region: str = "us-east-1",
|
region: str = "us-east-1",
|
||||||
host_override: Optional[str] = None,
|
host_override: Optional[str] = None,
|
||||||
|
read_consistency_interval: Optional[timedelta] = None,
|
||||||
) -> DBConnection:
|
) -> DBConnection:
|
||||||
"""Connect to a LanceDB database.
|
"""Connect to a LanceDB database.
|
||||||
|
|
||||||
@@ -45,6 +47,18 @@ def connect(
|
|||||||
The region to use for LanceDB Cloud.
|
The region to use for LanceDB Cloud.
|
||||||
host_override: str, optional
|
host_override: str, optional
|
||||||
The override url for LanceDB Cloud.
|
The override url for LanceDB Cloud.
|
||||||
|
read_consistency_interval: timedelta, default None
|
||||||
|
(For LanceDB OSS only)
|
||||||
|
The interval at which to check for updates to the table from other
|
||||||
|
processes. If None, then consistency is not checked. For performance
|
||||||
|
reasons, this is the default. For strong consistency, set this to
|
||||||
|
zero seconds. Then every read will check for updates from other
|
||||||
|
processes. As a compromise, you can set this to a non-zero timedelta
|
||||||
|
for eventual consistency. If more than that interval has passed since
|
||||||
|
the last check, then the table will be checked for updates. Note: this
|
||||||
|
consistency only applies to read operations. Write operations are
|
||||||
|
always consistent.
|
||||||
|
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
@@ -73,4 +87,4 @@ def connect(
|
|||||||
if api_key is None:
|
if api_key is None:
|
||||||
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
|
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
|
||||||
return RemoteDBConnection(uri, api_key, region, host_override)
|
return RemoteDBConnection(uri, api_key, region, host_override)
|
||||||
return LanceDBConnection(uri)
|
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ from .table import LanceTable, Table
|
|||||||
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
|
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
from .common import DATA, URI
|
from .common import DATA, URI
|
||||||
from .embeddings import EmbeddingFunctionConfig
|
from .embeddings import EmbeddingFunctionConfig
|
||||||
from .pydantic import LanceModel
|
from .pydantic import LanceModel
|
||||||
@@ -118,7 +120,7 @@ class DBConnection(EnforceOverrides):
|
|||||||
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
||||||
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
||||||
>>> db.create_table("my_table", data)
|
>>> db.create_table("my_table", data)
|
||||||
LanceTable(my_table)
|
LanceTable(connection=..., name="my_table")
|
||||||
>>> db["my_table"].head()
|
>>> db["my_table"].head()
|
||||||
pyarrow.Table
|
pyarrow.Table
|
||||||
vector: fixed_size_list<item: float>[2]
|
vector: fixed_size_list<item: float>[2]
|
||||||
@@ -139,7 +141,7 @@ class DBConnection(EnforceOverrides):
|
|||||||
... "long": [-122.7, -74.1]
|
... "long": [-122.7, -74.1]
|
||||||
... })
|
... })
|
||||||
>>> db.create_table("table2", data)
|
>>> db.create_table("table2", data)
|
||||||
LanceTable(table2)
|
LanceTable(connection=..., name="table2")
|
||||||
>>> db["table2"].head()
|
>>> db["table2"].head()
|
||||||
pyarrow.Table
|
pyarrow.Table
|
||||||
vector: fixed_size_list<item: float>[2]
|
vector: fixed_size_list<item: float>[2]
|
||||||
@@ -161,7 +163,7 @@ class DBConnection(EnforceOverrides):
|
|||||||
... pa.field("long", pa.float32())
|
... pa.field("long", pa.float32())
|
||||||
... ])
|
... ])
|
||||||
>>> db.create_table("table3", data, schema = custom_schema)
|
>>> db.create_table("table3", data, schema = custom_schema)
|
||||||
LanceTable(table3)
|
LanceTable(connection=..., name="table3")
|
||||||
>>> db["table3"].head()
|
>>> db["table3"].head()
|
||||||
pyarrow.Table
|
pyarrow.Table
|
||||||
vector: fixed_size_list<item: float>[2]
|
vector: fixed_size_list<item: float>[2]
|
||||||
@@ -195,7 +197,7 @@ class DBConnection(EnforceOverrides):
|
|||||||
... pa.field("price", pa.float32()),
|
... pa.field("price", pa.float32()),
|
||||||
... ])
|
... ])
|
||||||
>>> db.create_table("table4", make_batches(), schema=schema)
|
>>> db.create_table("table4", make_batches(), schema=schema)
|
||||||
LanceTable(table4)
|
LanceTable(connection=..., name="table4")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -243,6 +245,16 @@ class LanceDBConnection(DBConnection):
|
|||||||
----------
|
----------
|
||||||
uri: str or Path
|
uri: str or Path
|
||||||
The root uri of the database.
|
The root uri of the database.
|
||||||
|
read_consistency_interval: timedelta, default None
|
||||||
|
The interval at which to check for updates to the table from other
|
||||||
|
processes. If None, then consistency is not checked. For performance
|
||||||
|
reasons, this is the default. For strong consistency, set this to
|
||||||
|
zero seconds. Then every read will check for updates from other
|
||||||
|
processes. As a compromise, you can set this to a non-zero timedelta
|
||||||
|
for eventual consistency. If more than that interval has passed since
|
||||||
|
the last check, then the table will be checked for updates. Note: this
|
||||||
|
consistency only applies to read operations. Write operations are
|
||||||
|
always consistent.
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
@@ -250,22 +262,24 @@ class LanceDBConnection(DBConnection):
|
|||||||
>>> db = lancedb.connect("./.lancedb")
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
>>> db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2},
|
>>> db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2},
|
||||||
... {"vector": [0.5, 1.3], "b": 4}])
|
... {"vector": [0.5, 1.3], "b": 4}])
|
||||||
LanceTable(my_table)
|
LanceTable(connection=..., name="my_table")
|
||||||
>>> db.create_table("another_table", data=[{"vector": [0.4, 0.4], "b": 6}])
|
>>> db.create_table("another_table", data=[{"vector": [0.4, 0.4], "b": 6}])
|
||||||
LanceTable(another_table)
|
LanceTable(connection=..., name="another_table")
|
||||||
>>> sorted(db.table_names())
|
>>> sorted(db.table_names())
|
||||||
['another_table', 'my_table']
|
['another_table', 'my_table']
|
||||||
>>> len(db)
|
>>> len(db)
|
||||||
2
|
2
|
||||||
>>> db["my_table"]
|
>>> db["my_table"]
|
||||||
LanceTable(my_table)
|
LanceTable(connection=..., name="my_table")
|
||||||
>>> "my_table" in db
|
>>> "my_table" in db
|
||||||
True
|
True
|
||||||
>>> db.drop_table("my_table")
|
>>> db.drop_table("my_table")
|
||||||
>>> db.drop_table("another_table")
|
>>> db.drop_table("another_table")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, uri: URI):
|
def __init__(
|
||||||
|
self, uri: URI, *, read_consistency_interval: Optional[timedelta] = None
|
||||||
|
):
|
||||||
if not isinstance(uri, Path):
|
if not isinstance(uri, Path):
|
||||||
scheme = get_uri_scheme(uri)
|
scheme = get_uri_scheme(uri)
|
||||||
is_local = isinstance(uri, Path) or scheme == "file"
|
is_local = isinstance(uri, Path) or scheme == "file"
|
||||||
@@ -277,6 +291,14 @@ class LanceDBConnection(DBConnection):
|
|||||||
self._uri = str(uri)
|
self._uri = str(uri)
|
||||||
|
|
||||||
self._entered = False
|
self._entered = False
|
||||||
|
self.read_consistency_interval = read_consistency_interval
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
val = f"{self.__class__.__name__}({self._uri}"
|
||||||
|
if self.read_consistency_interval is not None:
|
||||||
|
val += f", read_consistency_interval={repr(self.read_consistency_interval)}"
|
||||||
|
val += ")"
|
||||||
|
return val
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def uri(self) -> str:
|
def uri(self) -> str:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import List, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -30,10 +30,21 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
name: str = "text-embedding-ada-002"
|
name: str = "text-embedding-ada-002"
|
||||||
|
dim: Optional[int] = None
|
||||||
|
|
||||||
def ndims(self):
|
def ndims(self):
|
||||||
# TODO don't hardcode this
|
return self._ndims
|
||||||
return 1536
|
|
||||||
|
@cached_property
|
||||||
|
def _ndims(self):
|
||||||
|
if self.name == "text-embedding-ada-002":
|
||||||
|
return 1536
|
||||||
|
elif self.name == "text-embedding-3-large":
|
||||||
|
return self.dim or 3072
|
||||||
|
elif self.name == "text-embedding-3-small":
|
||||||
|
return self.dim or 1536
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model name {self.name}")
|
||||||
|
|
||||||
def generate_embeddings(
|
def generate_embeddings(
|
||||||
self, texts: Union[List[str], np.ndarray]
|
self, texts: Union[List[str], np.ndarray]
|
||||||
@@ -47,7 +58,12 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
|||||||
The texts to embed
|
The texts to embed
|
||||||
"""
|
"""
|
||||||
# TODO retry, rate limit, token limit
|
# TODO retry, rate limit, token limit
|
||||||
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
|
if self.name == "text-embedding-ada-002":
|
||||||
|
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
|
||||||
|
else:
|
||||||
|
rs = self._openai_client.embeddings.create(
|
||||||
|
input=texts, model=self.name, dimensions=self.ndims()
|
||||||
|
)
|
||||||
return [v.embedding for v in rs.data]
|
return [v.embedding for v in rs.data]
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
|
|||||||
@@ -32,11 +32,14 @@ class LanceMergeInsertBuilder(object):
|
|||||||
self._table = table
|
self._table = table
|
||||||
self._on = on
|
self._on = on
|
||||||
self._when_matched_update_all = False
|
self._when_matched_update_all = False
|
||||||
|
self._when_matched_update_all_condition = None
|
||||||
self._when_not_matched_insert_all = False
|
self._when_not_matched_insert_all = False
|
||||||
self._when_not_matched_by_source_delete = False
|
self._when_not_matched_by_source_delete = False
|
||||||
self._when_not_matched_by_source_condition = None
|
self._when_not_matched_by_source_condition = None
|
||||||
|
|
||||||
def when_matched_update_all(self) -> LanceMergeInsertBuilder:
|
def when_matched_update_all(
|
||||||
|
self, *, where: Optional[str] = None
|
||||||
|
) -> LanceMergeInsertBuilder:
|
||||||
"""
|
"""
|
||||||
Rows that exist in both the source table (new data) and
|
Rows that exist in both the source table (new data) and
|
||||||
the target table (old data) will be updated, replacing
|
the target table (old data) will be updated, replacing
|
||||||
@@ -47,6 +50,7 @@ class LanceMergeInsertBuilder(object):
|
|||||||
but that behavior is subject to change.
|
but that behavior is subject to change.
|
||||||
"""
|
"""
|
||||||
self._when_matched_update_all = True
|
self._when_matched_update_all = True
|
||||||
|
self._when_matched_update_all_condition = where
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def when_not_matched_insert_all(self) -> LanceMergeInsertBuilder:
|
def when_not_matched_insert_all(self) -> LanceMergeInsertBuilder:
|
||||||
|
|||||||
@@ -304,7 +304,7 @@ class LanceModel(pydantic.BaseModel):
|
|||||||
... name: str
|
... name: str
|
||||||
... vector: Vector(2)
|
... vector: Vector(2)
|
||||||
...
|
...
|
||||||
>>> db = lancedb.connect("/tmp")
|
>>> db = lancedb.connect("./example")
|
||||||
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema())
|
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema())
|
||||||
>>> table.add([
|
>>> table.add([
|
||||||
... TestModel(name="test", vector=[1.0, 2.0])
|
... TestModel(name="test", vector=[1.0, 2.0])
|
||||||
|
|||||||
@@ -626,7 +626,6 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
|||||||
def __init__(self, table: "Table", query: str, vector_column: str):
|
def __init__(self, table: "Table", query: str, vector_column: str):
|
||||||
super().__init__(table)
|
super().__init__(table)
|
||||||
self._validate_fts_index()
|
self._validate_fts_index()
|
||||||
self._query = query
|
|
||||||
vector_query, fts_query = self._validate_query(query)
|
vector_query, fts_query = self._validate_query(query)
|
||||||
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
|
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
|
||||||
vector_query = self._query_to_vector(table, vector_query, vector_column)
|
vector_query = self._query_to_vector(table, vector_query, vector_column)
|
||||||
@@ -679,12 +678,18 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
|||||||
# rerankers might need to preserve this score to support `return_score="all"`
|
# rerankers might need to preserve this score to support `return_score="all"`
|
||||||
fts_results = self._normalize_scores(fts_results, "score")
|
fts_results = self._normalize_scores(fts_results, "score")
|
||||||
|
|
||||||
results = self._reranker.rerank_hybrid(self, vector_results, fts_results)
|
results = self._reranker.rerank_hybrid(
|
||||||
|
self._fts_query._query, vector_results, fts_results
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(results, pa.Table): # Enforce type
|
if not isinstance(results, pa.Table): # Enforce type
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"rerank_hybrid must return a pyarrow.Table, got {type(results)}"
|
f"rerank_hybrid must return a pyarrow.Table, got {type(results)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# apply limit after reranking
|
||||||
|
results = results.slice(length=self._limit)
|
||||||
|
|
||||||
if not self._with_row_id:
|
if not self._with_row_id:
|
||||||
results = results.drop(["_rowid"])
|
results = results.drop(["_rowid"])
|
||||||
return results
|
return results
|
||||||
@@ -776,6 +781,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
|||||||
"""
|
"""
|
||||||
self._vector_query.limit(limit)
|
self._vector_query.limit(limit)
|
||||||
self._fts_query.limit(limit)
|
self._fts_query.limit(limit)
|
||||||
|
self._limit = limit
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def select(self, columns: list) -> LanceHybridQueryBuilder:
|
def select(self, columns: list) -> LanceHybridQueryBuilder:
|
||||||
|
|||||||
@@ -118,6 +118,7 @@ class RemoteDBConnection(DBConnection):
|
|||||||
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||||
on_bad_vectors: str = "error",
|
on_bad_vectors: str = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
|
mode: Optional[str] = None,
|
||||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||||
) -> Table:
|
) -> Table:
|
||||||
"""Create a [Table][lancedb.table.Table] in the database.
|
"""Create a [Table][lancedb.table.Table] in the database.
|
||||||
@@ -215,11 +216,13 @@ class RemoteDBConnection(DBConnection):
|
|||||||
if data is None and schema is None:
|
if data is None and schema is None:
|
||||||
raise ValueError("Either data or schema must be provided.")
|
raise ValueError("Either data or schema must be provided.")
|
||||||
if embedding_functions is not None:
|
if embedding_functions is not None:
|
||||||
raise NotImplementedError(
|
logging.warning(
|
||||||
"embedding_functions is not supported for remote databases."
|
"embedding_functions is not yet supported on LanceDB Cloud."
|
||||||
"Please vote https://github.com/lancedb/lancedb/issues/626 "
|
"Please vote https://github.com/lancedb/lancedb/issues/626 "
|
||||||
"for this feature."
|
"for this feature."
|
||||||
)
|
)
|
||||||
|
if mode is not None:
|
||||||
|
logging.warning("mode is not yet supported on LanceDB Cloud.")
|
||||||
|
|
||||||
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
||||||
# convert LanceModel to pyarrow schema
|
# convert LanceModel to pyarrow schema
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
@@ -37,6 +38,9 @@ class RemoteTable(Table):
|
|||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"RemoteTable({self._conn.db_name}.{self._name})"
|
return f"RemoteTable({self._conn.db_name}.{self._name})"
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
self.count_rows(None)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def schema(self) -> pa.Schema:
|
def schema(self) -> pa.Schema:
|
||||||
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
|
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
|
||||||
@@ -54,17 +58,17 @@ class RemoteTable(Table):
|
|||||||
return resp["version"]
|
return resp["version"]
|
||||||
|
|
||||||
def to_arrow(self) -> pa.Table:
|
def to_arrow(self) -> pa.Table:
|
||||||
"""to_arrow() is not supported on the LanceDB cloud"""
|
"""to_arrow() is not yet supported on LanceDB cloud."""
|
||||||
raise NotImplementedError("to_arrow() is not supported on the LanceDB cloud")
|
raise NotImplementedError("to_arrow() is not yet supported on LanceDB cloud.")
|
||||||
|
|
||||||
def to_pandas(self):
|
def to_pandas(self):
|
||||||
"""to_pandas() is not supported on the LanceDB cloud"""
|
"""to_pandas() is not yet supported on LanceDB cloud."""
|
||||||
return NotImplementedError("to_pandas() is not supported on the LanceDB cloud")
|
return NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.")
|
||||||
|
|
||||||
def create_scalar_index(self, *args, **kwargs):
|
def create_scalar_index(self, *args, **kwargs):
|
||||||
"""Creates a scalar index"""
|
"""Creates a scalar index"""
|
||||||
return NotImplementedError(
|
return NotImplementedError(
|
||||||
"create_scalar_index() is not supported on the LanceDB cloud"
|
"create_scalar_index() is not yet supported on LanceDB cloud."
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_index(
|
def create_index(
|
||||||
@@ -72,6 +76,10 @@ class RemoteTable(Table):
|
|||||||
metric="L2",
|
metric="L2",
|
||||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||||
index_cache_size: Optional[int] = None,
|
index_cache_size: Optional[int] = None,
|
||||||
|
num_partitions: Optional[int] = None,
|
||||||
|
num_sub_vectors: Optional[int] = None,
|
||||||
|
replace: Optional[bool] = None,
|
||||||
|
accelerator: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Create an index on the table.
|
"""Create an index on the table.
|
||||||
Currently, the only parameters that matter are
|
Currently, the only parameters that matter are
|
||||||
@@ -105,6 +113,28 @@ class RemoteTable(Table):
|
|||||||
... )
|
... )
|
||||||
>>> table.create_index("L2", "vector") # doctest: +SKIP
|
>>> table.create_index("L2", "vector") # doctest: +SKIP
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if num_partitions is not None:
|
||||||
|
logging.warning(
|
||||||
|
"num_partitions is not supported on LanceDB cloud."
|
||||||
|
"This parameter will be tuned automatically."
|
||||||
|
)
|
||||||
|
if num_sub_vectors is not None:
|
||||||
|
logging.warning(
|
||||||
|
"num_sub_vectors is not supported on LanceDB cloud."
|
||||||
|
"This parameter will be tuned automatically."
|
||||||
|
)
|
||||||
|
if accelerator is not None:
|
||||||
|
logging.warning(
|
||||||
|
"GPU accelerator is not yet supported on LanceDB cloud."
|
||||||
|
"If you have 100M+ vectors to index,"
|
||||||
|
"please contact us at contact@lancedb.com"
|
||||||
|
)
|
||||||
|
if replace is not None:
|
||||||
|
logging.warning(
|
||||||
|
"replace is not supported on LanceDB cloud."
|
||||||
|
"Existing indexes will always be replaced."
|
||||||
|
)
|
||||||
index_type = "vector"
|
index_type = "vector"
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
@@ -268,6 +298,10 @@ class RemoteTable(Table):
|
|||||||
)
|
)
|
||||||
params["on"] = merge._on[0]
|
params["on"] = merge._on[0]
|
||||||
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
|
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
|
||||||
|
if merge._when_matched_update_all_condition is not None:
|
||||||
|
params[
|
||||||
|
"when_matched_update_all_filt"
|
||||||
|
] = merge._when_matched_update_all_condition
|
||||||
params["when_not_matched_insert_all"] = str(
|
params["when_not_matched_insert_all"] = str(
|
||||||
merge._when_not_matched_insert_all
|
merge._when_not_matched_insert_all
|
||||||
).lower()
|
).lower()
|
||||||
@@ -409,6 +443,13 @@ class RemoteTable(Table):
|
|||||||
"compact_files() is not supported on the LanceDB cloud"
|
"compact_files() is not supported on the LanceDB cloud"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||||
|
# payload = {"filter": filter}
|
||||||
|
# self._conn._client.post(f"/v1/table/{self._name}/count_rows/", data=payload)
|
||||||
|
return NotImplementedError(
|
||||||
|
"count_rows() is not yet supported on the LanceDB cloud"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_index(tbl: pa.Table, i: int) -> pa.Table:
|
def add_index(tbl: pa.Table, i: int) -> pa.Table:
|
||||||
return tbl.add_column(
|
return tbl.add_column(
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
from .base import Reranker
|
from .base import Reranker
|
||||||
from .cohere import CohereReranker
|
from .cohere import CohereReranker
|
||||||
|
from .colbert import ColbertReranker
|
||||||
from .cross_encoder import CrossEncoderReranker
|
from .cross_encoder import CrossEncoderReranker
|
||||||
from .linear_combination import LinearCombinationReranker
|
from .linear_combination import LinearCombinationReranker
|
||||||
|
from .openai import OpenaiReranker
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Reranker",
|
"Reranker",
|
||||||
"CrossEncoderReranker",
|
"CrossEncoderReranker",
|
||||||
"CohereReranker",
|
"CohereReranker",
|
||||||
"LinearCombinationReranker",
|
"LinearCombinationReranker",
|
||||||
|
"OpenaiReranker",
|
||||||
|
"ColbertReranker",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,12 +1,8 @@
|
|||||||
import typing
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
import lancedb
|
|
||||||
|
|
||||||
|
|
||||||
class Reranker(ABC):
|
class Reranker(ABC):
|
||||||
def __init__(self, return_score: str = "relevance"):
|
def __init__(self, return_score: str = "relevance"):
|
||||||
@@ -30,7 +26,7 @@ class Reranker(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def rerank_hybrid(
|
def rerank_hybrid(
|
||||||
query_builder: "lancedb.HybridQueryBuilder",
|
query: str,
|
||||||
vector_results: pa.Table,
|
vector_results: pa.Table,
|
||||||
fts_results: pa.Table,
|
fts_results: pa.Table,
|
||||||
):
|
):
|
||||||
@@ -41,8 +37,8 @@ class Reranker(ABC):
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
query_builder : "lancedb.HybridQueryBuilder"
|
query : str
|
||||||
The query builder object that was used to generate the results
|
The input query
|
||||||
vector_results : pa.Table
|
vector_results : pa.Table
|
||||||
The results from the vector search
|
The results from the vector search
|
||||||
fts_results : pa.Table
|
fts_results : pa.Table
|
||||||
@@ -50,36 +46,6 @@ class Reranker(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def rerank_vector(
|
|
||||||
query_builder: "lancedb.VectorQueryBuilder", vector_results: pa.Table
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Rerank function receives the individual results from the vector search.
|
|
||||||
This isn't mandatory to implement
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
query_builder : "lancedb.VectorQueryBuilder"
|
|
||||||
The query builder object that was used to generate the results
|
|
||||||
vector_results : pa.Table
|
|
||||||
The results from the vector search
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("Vector Reranking is not implemented")
|
|
||||||
|
|
||||||
def rerank_fts(query_builder: "lancedb.FTSQueryBuilder", fts_results: pa.Table):
|
|
||||||
"""
|
|
||||||
Rerank function receives the individual results from the FTS search.
|
|
||||||
This isn't mandatory to implement
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
query_builder : "lancedb.FTSQueryBuilder"
|
|
||||||
The query builder object that was used to generate the results
|
|
||||||
fts_results : pa.Table
|
|
||||||
The results from the FTS search
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("FTS Reranking is not implemented")
|
|
||||||
|
|
||||||
def merge_results(self, vector_results: pa.Table, fts_results: pa.Table):
|
def merge_results(self, vector_results: pa.Table, fts_results: pa.Table):
|
||||||
"""
|
"""
|
||||||
Merge the results from the vector and FTS search. This is a vanilla merging
|
Merge the results from the vector and FTS search. This is a vanilla merging
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import typing
|
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -8,9 +7,6 @@ import pyarrow as pa
|
|||||||
from ..util import safe_import
|
from ..util import safe_import
|
||||||
from .base import Reranker
|
from .base import Reranker
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
import lancedb
|
|
||||||
|
|
||||||
|
|
||||||
class CohereReranker(Reranker):
|
class CohereReranker(Reranker):
|
||||||
"""
|
"""
|
||||||
@@ -55,14 +51,14 @@ class CohereReranker(Reranker):
|
|||||||
|
|
||||||
def rerank_hybrid(
|
def rerank_hybrid(
|
||||||
self,
|
self,
|
||||||
query_builder: "lancedb.HybridQueryBuilder",
|
query: str,
|
||||||
vector_results: pa.Table,
|
vector_results: pa.Table,
|
||||||
fts_results: pa.Table,
|
fts_results: pa.Table,
|
||||||
):
|
):
|
||||||
combined_results = self.merge_results(vector_results, fts_results)
|
combined_results = self.merge_results(vector_results, fts_results)
|
||||||
docs = combined_results[self.column].to_pylist()
|
docs = combined_results[self.column].to_pylist()
|
||||||
results = self._client.rerank(
|
results = self._client.rerank(
|
||||||
query=query_builder._query,
|
query=query,
|
||||||
documents=docs,
|
documents=docs,
|
||||||
top_n=self.top_n,
|
top_n=self.top_n,
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
|
|||||||
107
python/lancedb/rerankers/colbert.py
Normal file
107
python/lancedb/rerankers/colbert.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
from functools import cached_property
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
from ..util import safe_import
|
||||||
|
from .base import Reranker
|
||||||
|
|
||||||
|
|
||||||
|
class ColbertReranker(Reranker):
|
||||||
|
"""
|
||||||
|
Reranks the results using the ColBERT model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_name : str, default "colbert-ir/colbertv2.0"
|
||||||
|
The name of the cross encoder model to use.
|
||||||
|
column : str, default "text"
|
||||||
|
The name of the column to use as input to the cross encoder model.
|
||||||
|
return_score : str, default "relevance"
|
||||||
|
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "colbert-ir/colbertv2.0",
|
||||||
|
column: str = "text",
|
||||||
|
return_score="relevance",
|
||||||
|
):
|
||||||
|
super().__init__(return_score)
|
||||||
|
self.model_name = model_name
|
||||||
|
self.column = column
|
||||||
|
self.torch = safe_import("torch") # import here for faster ops later
|
||||||
|
|
||||||
|
def rerank_hybrid(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
vector_results: pa.Table,
|
||||||
|
fts_results: pa.Table,
|
||||||
|
):
|
||||||
|
combined_results = self.merge_results(vector_results, fts_results)
|
||||||
|
docs = combined_results[self.column].to_pylist()
|
||||||
|
|
||||||
|
tokenizer, model = self._model
|
||||||
|
|
||||||
|
# Encode the query
|
||||||
|
query_encoding = tokenizer(query, return_tensors="pt")
|
||||||
|
query_embedding = model(**query_encoding).last_hidden_state.mean(dim=1)
|
||||||
|
scores = []
|
||||||
|
# Get score for each document
|
||||||
|
for document in docs:
|
||||||
|
document_encoding = tokenizer(
|
||||||
|
document, return_tensors="pt", truncation=True, max_length=512
|
||||||
|
)
|
||||||
|
document_embedding = model(**document_encoding).last_hidden_state
|
||||||
|
# Calculate MaxSim score
|
||||||
|
score = self.maxsim(query_embedding.unsqueeze(0), document_embedding)
|
||||||
|
scores.append(score.item())
|
||||||
|
|
||||||
|
# replace the self.column column with the docs
|
||||||
|
combined_results = combined_results.drop(self.column)
|
||||||
|
combined_results = combined_results.append_column(
|
||||||
|
self.column, pa.array(docs, type=pa.string())
|
||||||
|
)
|
||||||
|
# add the scores
|
||||||
|
combined_results = combined_results.append_column(
|
||||||
|
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||||
|
)
|
||||||
|
if self.score == "relevance":
|
||||||
|
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||||
|
elif self.score == "all":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"OpenAI Reranker does not support score='all' yet"
|
||||||
|
)
|
||||||
|
|
||||||
|
combined_results = combined_results.sort_by(
|
||||||
|
[("_relevance_score", "descending")]
|
||||||
|
)
|
||||||
|
|
||||||
|
return combined_results
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _model(self):
|
||||||
|
transformers = safe_import("transformers")
|
||||||
|
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
|
||||||
|
model = transformers.AutoModel.from_pretrained(self.model_name)
|
||||||
|
|
||||||
|
return tokenizer, model
|
||||||
|
|
||||||
|
def maxsim(self, query_embedding, document_embedding):
|
||||||
|
# Expand dimensions for broadcasting
|
||||||
|
# Query: [batch, length, size] -> [batch, query, 1, size]
|
||||||
|
# Document: [batch, length, size] -> [batch, 1, length, size]
|
||||||
|
expanded_query = query_embedding.unsqueeze(2)
|
||||||
|
expanded_doc = document_embedding.unsqueeze(1)
|
||||||
|
|
||||||
|
# Compute cosine similarity across the embedding dimension
|
||||||
|
sim_matrix = self.torch.nn.functional.cosine_similarity(
|
||||||
|
expanded_query, expanded_doc, dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Take the maximum similarity for each query token (across all document tokens)
|
||||||
|
# sim_matrix shape: [batch_size, query_length, doc_length]
|
||||||
|
max_sim_scores, _ = self.torch.max(sim_matrix, dim=2)
|
||||||
|
|
||||||
|
# Average these maximum scores across all query tokens
|
||||||
|
avg_max_sim = self.torch.mean(max_sim_scores, dim=1)
|
||||||
|
return avg_max_sim
|
||||||
@@ -1,4 +1,3 @@
|
|||||||
import typing
|
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -7,9 +6,6 @@ import pyarrow as pa
|
|||||||
from ..util import safe_import
|
from ..util import safe_import
|
||||||
from .base import Reranker
|
from .base import Reranker
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
import lancedb
|
|
||||||
|
|
||||||
|
|
||||||
class CrossEncoderReranker(Reranker):
|
class CrossEncoderReranker(Reranker):
|
||||||
"""
|
"""
|
||||||
@@ -52,13 +48,13 @@ class CrossEncoderReranker(Reranker):
|
|||||||
|
|
||||||
def rerank_hybrid(
|
def rerank_hybrid(
|
||||||
self,
|
self,
|
||||||
query_builder: "lancedb.HybridQueryBuilder",
|
query: str,
|
||||||
vector_results: pa.Table,
|
vector_results: pa.Table,
|
||||||
fts_results: pa.Table,
|
fts_results: pa.Table,
|
||||||
):
|
):
|
||||||
combined_results = self.merge_results(vector_results, fts_results)
|
combined_results = self.merge_results(vector_results, fts_results)
|
||||||
passages = combined_results[self.column].to_pylist()
|
passages = combined_results[self.column].to_pylist()
|
||||||
cross_inp = [[query_builder._query, passage] for passage in passages]
|
cross_inp = [[query, passage] for passage in passages]
|
||||||
cross_scores = self.model.predict(cross_inp)
|
cross_scores = self.model.predict(cross_inp)
|
||||||
combined_results = combined_results.append_column(
|
combined_results = combined_results.append_column(
|
||||||
"_relevance_score", pa.array(cross_scores, type=pa.float32())
|
"_relevance_score", pa.array(cross_scores, type=pa.float32())
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ class LinearCombinationReranker(Reranker):
|
|||||||
|
|
||||||
def rerank_hybrid(
|
def rerank_hybrid(
|
||||||
self,
|
self,
|
||||||
query_builder: "lancedb.HybridQueryBuilder", # noqa: F821
|
query: str, # noqa: F821
|
||||||
vector_results: pa.Table,
|
vector_results: pa.Table,
|
||||||
fts_results: pa.Table,
|
fts_results: pa.Table,
|
||||||
):
|
):
|
||||||
|
|||||||
102
python/lancedb/rerankers/openai.py
Normal file
102
python/lancedb/rerankers/openai.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
from ..util import safe_import
|
||||||
|
from .base import Reranker
|
||||||
|
|
||||||
|
|
||||||
|
class OpenaiReranker(Reranker):
|
||||||
|
"""
|
||||||
|
Reranks the results using the OpenAI API.
|
||||||
|
WARNING: This is a prompt based reranker that uses chat model that is
|
||||||
|
not a dedicated reranker API. This should be treated as experimental.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_name : str, default "gpt-3.5-turbo-1106 "
|
||||||
|
The name of the cross encoder model to use.
|
||||||
|
column : str, default "text"
|
||||||
|
The name of the column to use as input to the cross encoder model.
|
||||||
|
return_score : str, default "relevance"
|
||||||
|
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||||
|
api_key : str, default None
|
||||||
|
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "gpt-3.5-turbo-1106",
|
||||||
|
column: str = "text",
|
||||||
|
return_score="relevance",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(return_score)
|
||||||
|
self.model_name = model_name
|
||||||
|
self.column = column
|
||||||
|
self.api_key = api_key
|
||||||
|
|
||||||
|
def rerank_hybrid(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
vector_results: pa.Table,
|
||||||
|
fts_results: pa.Table,
|
||||||
|
):
|
||||||
|
combined_results = self.merge_results(vector_results, fts_results)
|
||||||
|
docs = combined_results[self.column].to_pylist()
|
||||||
|
response = self._client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
temperature=0,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are an expert relevance ranker. Given a list of\
|
||||||
|
documents and a query, your job is to determine the relevance\
|
||||||
|
each document is for answering the query. Your output is JSON,\
|
||||||
|
which is a list of documents. Each document has two fields,\
|
||||||
|
content and relevance_score. relevance_score is from 0.0 to\
|
||||||
|
1.0 indicating the relevance of the text to the given query.\
|
||||||
|
Make sure to include all documents in the response.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": f"Query: {query} Docs: {docs}"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
results = json.loads(response.choices[0].message.content)["documents"]
|
||||||
|
docs, scores = list(
|
||||||
|
zip(*[(result["content"], result["relevance_score"]) for result in results])
|
||||||
|
) # tuples
|
||||||
|
# replace the self.column column with the docs
|
||||||
|
combined_results = combined_results.drop(self.column)
|
||||||
|
combined_results = combined_results.append_column(
|
||||||
|
self.column, pa.array(docs, type=pa.string())
|
||||||
|
)
|
||||||
|
# add the scores
|
||||||
|
combined_results = combined_results.append_column(
|
||||||
|
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||||
|
)
|
||||||
|
if self.score == "relevance":
|
||||||
|
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||||
|
elif self.score == "all":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"OpenAI Reranker does not support score='all' yet"
|
||||||
|
)
|
||||||
|
|
||||||
|
combined_results = combined_results.sort_by(
|
||||||
|
[("_relevance_score", "descending")]
|
||||||
|
)
|
||||||
|
|
||||||
|
return combined_results
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _client(self):
|
||||||
|
openai = safe_import("openai") # TODO: force version or handle versions < 1.0
|
||||||
|
if os.environ.get("OPENAI_API_KEY") is None and self.api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"OPENAI_API_KEY not set. Either set it in your environment or \
|
||||||
|
pass it as `api_key` argument to the CohereReranker."
|
||||||
|
)
|
||||||
|
return openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY") or self.api_key)
|
||||||
@@ -14,7 +14,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import timedelta
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -41,8 +44,6 @@ from .util import (
|
|||||||
from .utils.events import register_event
|
from .utils.events import register_event
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datetime import timedelta
|
|
||||||
|
|
||||||
import PIL
|
import PIL
|
||||||
from lance.dataset import CleanupStats, ReaderLike
|
from lance.dataset import CleanupStats, ReaderLike
|
||||||
|
|
||||||
@@ -176,6 +177,18 @@ class Table(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||||
|
"""
|
||||||
|
Count the number of rows in the table.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filter: str, optional
|
||||||
|
A SQL where clause to filter the rows to count.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def to_pandas(self) -> "pd.DataFrame":
|
def to_pandas(self) -> "pd.DataFrame":
|
||||||
"""Return the table as a pandas DataFrame.
|
"""Return the table as a pandas DataFrame.
|
||||||
|
|
||||||
@@ -299,7 +312,7 @@ class Table(ABC):
|
|||||||
|
|
||||||
import lance
|
import lance
|
||||||
|
|
||||||
dataset = lance.dataset("/tmp/images.lance")
|
dataset = lance.dataset("./images.lance")
|
||||||
dataset.create_scalar_index("category")
|
dataset.create_scalar_index("category")
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -446,7 +459,7 @@ class Table(ABC):
|
|||||||
*default "vector"*
|
*default "vector"*
|
||||||
query_type: str
|
query_type: str
|
||||||
*default "auto"*.
|
*default "auto"*.
|
||||||
Acceptable types are: "vector", "fts", or "auto"
|
Acceptable types are: "vector", "fts", "hybrid", or "auto"
|
||||||
|
|
||||||
- If "auto" then the query type is inferred from the query;
|
- If "auto" then the query type is inferred from the query;
|
||||||
|
|
||||||
@@ -642,23 +655,145 @@ class Table(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class _LanceDatasetRef(ABC):
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def dataset(self) -> LanceDataset:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def dataset_mut(self) -> LanceDataset:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _LanceLatestDatasetRef(_LanceDatasetRef):
|
||||||
|
"""Reference to the latest version of a LanceDataset."""
|
||||||
|
|
||||||
|
uri: str
|
||||||
|
read_consistency_interval: Optional[timedelta] = None
|
||||||
|
last_consistency_check: Optional[float] = None
|
||||||
|
_dataset: Optional[LanceDataset] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dataset(self) -> LanceDataset:
|
||||||
|
if not self._dataset:
|
||||||
|
self._dataset = lance.dataset(self.uri)
|
||||||
|
self.last_consistency_check = time.monotonic()
|
||||||
|
elif self.read_consistency_interval is not None:
|
||||||
|
now = time.monotonic()
|
||||||
|
diff = timedelta(seconds=now - self.last_consistency_check)
|
||||||
|
if (
|
||||||
|
self.last_consistency_check is None
|
||||||
|
or diff > self.read_consistency_interval
|
||||||
|
):
|
||||||
|
self._dataset = self._dataset.checkout_version(
|
||||||
|
self._dataset.latest_version
|
||||||
|
)
|
||||||
|
self.last_consistency_check = time.monotonic()
|
||||||
|
return self._dataset
|
||||||
|
|
||||||
|
@dataset.setter
|
||||||
|
def dataset(self, value: LanceDataset):
|
||||||
|
self._dataset = value
|
||||||
|
self.last_consistency_check = time.monotonic()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dataset_mut(self) -> LanceDataset:
|
||||||
|
return self.dataset
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _LanceTimeTravelRef(_LanceDatasetRef):
|
||||||
|
uri: str
|
||||||
|
version: int
|
||||||
|
_dataset: Optional[LanceDataset] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dataset(self) -> LanceDataset:
|
||||||
|
if not self._dataset:
|
||||||
|
self._dataset = lance.dataset(self.uri, version=self.version)
|
||||||
|
return self._dataset
|
||||||
|
|
||||||
|
@dataset.setter
|
||||||
|
def dataset(self, value: LanceDataset):
|
||||||
|
self._dataset = value
|
||||||
|
self.version = value.version
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dataset_mut(self) -> LanceDataset:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot mutate table reference fixed at version "
|
||||||
|
f"{self.version}. Call checkout_latest() to get a mutable "
|
||||||
|
"table reference."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LanceTable(Table):
|
class LanceTable(Table):
|
||||||
"""
|
"""
|
||||||
A table in a LanceDB database.
|
A table in a LanceDB database.
|
||||||
|
|
||||||
|
This can be opened in two modes: standard and time-travel.
|
||||||
|
|
||||||
|
Standard mode is the default. In this mode, the table is mutable and tracks
|
||||||
|
the latest version of the table. The level of read consistency is controlled
|
||||||
|
by the `read_consistency_interval` parameter on the connection.
|
||||||
|
|
||||||
|
Time-travel mode is activated by specifying a version number. In this mode,
|
||||||
|
the table is immutable and fixed to a specific version. This is useful for
|
||||||
|
querying historical versions of the table.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, connection: "LanceDBConnection", name: str, version: int = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
connection: "LanceDBConnection",
|
||||||
|
name: str,
|
||||||
|
version: Optional[int] = None,
|
||||||
|
):
|
||||||
self._conn = connection
|
self._conn = connection
|
||||||
self.name = name
|
self.name = name
|
||||||
self._version = version
|
|
||||||
|
|
||||||
def _reset_dataset(self, version=None):
|
if version is not None:
|
||||||
try:
|
self._ref = _LanceTimeTravelRef(
|
||||||
if "_dataset" in self.__dict__:
|
uri=self._dataset_uri,
|
||||||
del self.__dict__["_dataset"]
|
version=version,
|
||||||
self._version = version
|
)
|
||||||
except AttributeError:
|
else:
|
||||||
pass
|
self._ref = _LanceLatestDatasetRef(
|
||||||
|
uri=self._dataset_uri,
|
||||||
|
read_consistency_interval=connection.read_consistency_interval,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def open(cls, db, name, **kwargs):
|
||||||
|
tbl = cls(db, name, **kwargs)
|
||||||
|
fs, path = fs_from_uri(tbl._dataset_uri)
|
||||||
|
file_info = fs.get_file_info(path)
|
||||||
|
if file_info.type != pa.fs.FileType.Directory:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Table {name} does not exist."
|
||||||
|
f"Please first call db.create_table({name}, data)"
|
||||||
|
)
|
||||||
|
register_event("open_table")
|
||||||
|
|
||||||
|
return tbl
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _dataset_uri(self) -> str:
|
||||||
|
return join_uri(self._conn.uri, f"{self.name}.lance")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _dataset(self) -> LanceDataset:
|
||||||
|
return self._ref.dataset
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _dataset_mut(self) -> LanceDataset:
|
||||||
|
return self._ref.dataset_mut
|
||||||
|
|
||||||
|
def to_lance(self) -> LanceDataset:
|
||||||
|
"""Return the LanceDataset backing this table."""
|
||||||
|
return self._dataset
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def schema(self) -> pa.Schema:
|
def schema(self) -> pa.Schema:
|
||||||
@@ -686,6 +821,9 @@ class LanceTable(Table):
|
|||||||
keep writing to the dataset starting from an old version, then use
|
keep writing to the dataset starting from an old version, then use
|
||||||
the `restore` function.
|
the `restore` function.
|
||||||
|
|
||||||
|
Calling this method will set the table into time-travel mode. If you
|
||||||
|
wish to return to standard mode, call `checkout_latest`.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
version : int
|
version : int
|
||||||
@@ -710,15 +848,13 @@ class LanceTable(Table):
|
|||||||
vector type
|
vector type
|
||||||
0 [1.1, 0.9] vector
|
0 [1.1, 0.9] vector
|
||||||
"""
|
"""
|
||||||
max_ver = max([v["version"] for v in self._dataset.versions()])
|
max_ver = self._dataset.latest_version
|
||||||
if version < 1 or version > max_ver:
|
if version < 1 or version > max_ver:
|
||||||
raise ValueError(f"Invalid version {version}")
|
raise ValueError(f"Invalid version {version}")
|
||||||
self._reset_dataset(version=version)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Accessing the property updates the cached value
|
ds = self._dataset.checkout_version(version)
|
||||||
_ = self._dataset
|
except IOError as e:
|
||||||
except Exception as e:
|
|
||||||
if "not found" in str(e):
|
if "not found" in str(e):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Version {version} no longer exists. Was it cleaned up?"
|
f"Version {version} no longer exists. Was it cleaned up?"
|
||||||
@@ -726,6 +862,27 @@ class LanceTable(Table):
|
|||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
self._ref = _LanceTimeTravelRef(
|
||||||
|
uri=self._dataset_uri,
|
||||||
|
version=version,
|
||||||
|
)
|
||||||
|
# We've already loaded the version so we can populate it directly.
|
||||||
|
self._ref.dataset = ds
|
||||||
|
|
||||||
|
def checkout_latest(self):
|
||||||
|
"""Checkout the latest version of the table. This is an in-place operation.
|
||||||
|
|
||||||
|
The table will be set back into standard mode, and will track the latest
|
||||||
|
version of the table.
|
||||||
|
"""
|
||||||
|
self.checkout(self._dataset.latest_version)
|
||||||
|
ds = self._ref.dataset
|
||||||
|
self._ref = _LanceLatestDatasetRef(
|
||||||
|
uri=self._dataset_uri,
|
||||||
|
read_consistency_interval=self._conn.read_consistency_interval,
|
||||||
|
)
|
||||||
|
self._ref.dataset = ds
|
||||||
|
|
||||||
def restore(self, version: int = None):
|
def restore(self, version: int = None):
|
||||||
"""Restore a version of the table. This is an in-place operation.
|
"""Restore a version of the table. This is an in-place operation.
|
||||||
|
|
||||||
@@ -760,7 +917,7 @@ class LanceTable(Table):
|
|||||||
>>> len(table.list_versions())
|
>>> len(table.list_versions())
|
||||||
4
|
4
|
||||||
"""
|
"""
|
||||||
max_ver = max([v["version"] for v in self._dataset.versions()])
|
max_ver = self._dataset.latest_version
|
||||||
if version is None:
|
if version is None:
|
||||||
version = self.version
|
version = self.version
|
||||||
elif version < 1 or version > max_ver:
|
elif version < 1 or version > max_ver:
|
||||||
@@ -768,29 +925,30 @@ class LanceTable(Table):
|
|||||||
else:
|
else:
|
||||||
self.checkout(version)
|
self.checkout(version)
|
||||||
|
|
||||||
if version == max_ver:
|
ds = self._dataset
|
||||||
# no-op if restoring the latest version
|
|
||||||
return
|
|
||||||
|
|
||||||
self._dataset.restore()
|
# no-op if restoring the latest version
|
||||||
self._reset_dataset()
|
if version != max_ver:
|
||||||
|
ds.restore()
|
||||||
|
|
||||||
|
self._ref = _LanceLatestDatasetRef(
|
||||||
|
uri=self._dataset_uri,
|
||||||
|
read_consistency_interval=self._conn.read_consistency_interval,
|
||||||
|
)
|
||||||
|
self._ref.dataset = ds
|
||||||
|
|
||||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||||
"""
|
|
||||||
Count the number of rows in the table.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
filter: str, optional
|
|
||||||
A SQL where clause to filter the rows to count.
|
|
||||||
"""
|
|
||||||
return self._dataset.count_rows(filter)
|
return self._dataset.count_rows(filter)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.count_rows()
|
return self.count_rows()
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"LanceTable({self.name})"
|
val = f'{self.__class__.__name__}(connection={self._conn!r}, name="{self.name}"'
|
||||||
|
if isinstance(self._ref, _LanceTimeTravelRef):
|
||||||
|
val += f", version={self._ref.version}"
|
||||||
|
val += ")"
|
||||||
|
return val
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.__repr__()
|
return self.__repr__()
|
||||||
@@ -840,10 +998,6 @@ class LanceTable(Table):
|
|||||||
self.to_lance(), allow_pyarrow_filter=False, batch_size=batch_size
|
self.to_lance(), allow_pyarrow_filter=False, batch_size=batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
|
||||||
def _dataset_uri(self) -> str:
|
|
||||||
return join_uri(self._conn.uri, f"{self.name}.lance")
|
|
||||||
|
|
||||||
def create_index(
|
def create_index(
|
||||||
self,
|
self,
|
||||||
metric="L2",
|
metric="L2",
|
||||||
@@ -855,7 +1009,7 @@ class LanceTable(Table):
|
|||||||
index_cache_size: Optional[int] = None,
|
index_cache_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
"""Create an index on the table."""
|
"""Create an index on the table."""
|
||||||
self._dataset.create_index(
|
self._dataset_mut.create_index(
|
||||||
column=vector_column_name,
|
column=vector_column_name,
|
||||||
index_type="IVF_PQ",
|
index_type="IVF_PQ",
|
||||||
metric=metric,
|
metric=metric,
|
||||||
@@ -865,11 +1019,12 @@ class LanceTable(Table):
|
|||||||
accelerator=accelerator,
|
accelerator=accelerator,
|
||||||
index_cache_size=index_cache_size,
|
index_cache_size=index_cache_size,
|
||||||
)
|
)
|
||||||
self._reset_dataset()
|
|
||||||
register_event("create_index")
|
register_event("create_index")
|
||||||
|
|
||||||
def create_scalar_index(self, column: str, *, replace: bool = True):
|
def create_scalar_index(self, column: str, *, replace: bool = True):
|
||||||
self._dataset.create_scalar_index(column, index_type="BTREE", replace=replace)
|
self._dataset_mut.create_scalar_index(
|
||||||
|
column, index_type="BTREE", replace=replace
|
||||||
|
)
|
||||||
|
|
||||||
def create_fts_index(
|
def create_fts_index(
|
||||||
self,
|
self,
|
||||||
@@ -912,14 +1067,6 @@ class LanceTable(Table):
|
|||||||
def _get_fts_index_path(self):
|
def _get_fts_index_path(self):
|
||||||
return join_uri(self._dataset_uri, "_indices", "tantivy")
|
return join_uri(self._dataset_uri, "_indices", "tantivy")
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def _dataset(self) -> LanceDataset:
|
|
||||||
return lance.dataset(self._dataset_uri, version=self._version)
|
|
||||||
|
|
||||||
def to_lance(self) -> LanceDataset:
|
|
||||||
"""Return the LanceDataset backing this table."""
|
|
||||||
return self._dataset
|
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
data: DATA,
|
data: DATA,
|
||||||
@@ -958,8 +1105,11 @@ class LanceTable(Table):
|
|||||||
on_bad_vectors=on_bad_vectors,
|
on_bad_vectors=on_bad_vectors,
|
||||||
fill_value=fill_value,
|
fill_value=fill_value,
|
||||||
)
|
)
|
||||||
lance.write_dataset(data, self._dataset_uri, schema=self.schema, mode=mode)
|
# Access the dataset_mut property to ensure that the dataset is mutable.
|
||||||
self._reset_dataset()
|
self._ref.dataset_mut
|
||||||
|
self._ref.dataset = lance.write_dataset(
|
||||||
|
data, self._dataset_uri, schema=self.schema, mode=mode
|
||||||
|
)
|
||||||
register_event("add")
|
register_event("add")
|
||||||
|
|
||||||
def merge(
|
def merge(
|
||||||
@@ -1020,10 +1170,9 @@ class LanceTable(Table):
|
|||||||
other_table = other_table.to_lance()
|
other_table = other_table.to_lance()
|
||||||
if isinstance(other_table, LanceDataset):
|
if isinstance(other_table, LanceDataset):
|
||||||
other_table = other_table.to_table()
|
other_table = other_table.to_table()
|
||||||
self._dataset.merge(
|
self._ref.dataset = self._dataset_mut.merge(
|
||||||
other_table, left_on=left_on, right_on=right_on, schema=schema
|
other_table, left_on=left_on, right_on=right_on, schema=schema
|
||||||
)
|
)
|
||||||
self._reset_dataset()
|
|
||||||
register_event("merge")
|
register_event("merge")
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
@@ -1226,22 +1375,8 @@ class LanceTable(Table):
|
|||||||
register_event("create_table")
|
register_event("create_table")
|
||||||
return new_table
|
return new_table
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def open(cls, db, name):
|
|
||||||
tbl = cls(db, name)
|
|
||||||
fs, path = fs_from_uri(tbl._dataset_uri)
|
|
||||||
file_info = fs.get_file_info(path)
|
|
||||||
if file_info.type != pa.fs.FileType.Directory:
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"Table {name} does not exist."
|
|
||||||
f"Please first call db.create_table({name}, data)"
|
|
||||||
)
|
|
||||||
register_event("open_table")
|
|
||||||
|
|
||||||
return tbl
|
|
||||||
|
|
||||||
def delete(self, where: str):
|
def delete(self, where: str):
|
||||||
self._dataset.delete(where)
|
self._dataset_mut.delete(where)
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
@@ -1295,8 +1430,7 @@ class LanceTable(Table):
|
|||||||
if values is not None:
|
if values is not None:
|
||||||
values_sql = {k: value_to_sql(v) for k, v in values.items()}
|
values_sql = {k: value_to_sql(v) for k, v in values.items()}
|
||||||
|
|
||||||
self.to_lance().update(values_sql, where)
|
self._dataset_mut.update(values_sql, where)
|
||||||
self._reset_dataset()
|
|
||||||
register_event("update")
|
register_event("update")
|
||||||
|
|
||||||
def _execute_query(self, query: Query) -> pa.Table:
|
def _execute_query(self, query: Query) -> pa.Table:
|
||||||
@@ -1333,7 +1467,7 @@ class LanceTable(Table):
|
|||||||
ds = self.to_lance()
|
ds = self.to_lance()
|
||||||
builder = ds.merge_insert(merge._on)
|
builder = ds.merge_insert(merge._on)
|
||||||
if merge._when_matched_update_all:
|
if merge._when_matched_update_all:
|
||||||
builder.when_matched_update_all()
|
builder.when_matched_update_all(merge._when_matched_update_all_condition)
|
||||||
if merge._when_not_matched_insert_all:
|
if merge._when_not_matched_insert_all:
|
||||||
builder.when_not_matched_insert_all()
|
builder.when_not_matched_insert_all()
|
||||||
if merge._when_not_matched_by_source_delete:
|
if merge._when_not_matched_by_source_delete:
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.5.3"
|
version = "0.5.4"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"deprecation",
|
"deprecation",
|
||||||
"pylance==0.9.12",
|
"pylance==0.9.15",
|
||||||
"ratelimiter~=1.0",
|
"ratelimiter~=1.0",
|
||||||
"retry>=0.9.2",
|
"retry>=0.9.2",
|
||||||
"tqdm>=4.27.0",
|
"tqdm>=4.27.0",
|
||||||
@@ -48,7 +48,7 @@ classifiers = [
|
|||||||
repository = "https://github.com/lancedb/lancedb"
|
repository = "https://github.com/lancedb/lancedb"
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars"]
|
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars>=0.19"]
|
||||||
dev = ["ruff", "pre-commit"]
|
dev = ["ruff", "pre-commit"]
|
||||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||||
clip = ["torch", "pillow", "open-clip"]
|
clip = ["torch", "pillow", "open-clip"]
|
||||||
|
|||||||
@@ -88,6 +88,7 @@ def test_embedding_function(tmp_path):
|
|||||||
assert np.allclose(actual, expected)
|
assert np.allclose(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
def test_embedding_function_rate_limit(tmp_path):
|
def test_embedding_function_rate_limit(tmp_path):
|
||||||
def _get_schema_from_model(model):
|
def _get_schema_from_model(model):
|
||||||
class Schema(LanceModel):
|
class Schema(LanceModel):
|
||||||
|
|||||||
@@ -23,11 +23,6 @@ import lancedb
|
|||||||
from lancedb.embeddings import get_registry
|
from lancedb.embeddings import get_registry
|
||||||
from lancedb.pydantic import LanceModel, Vector
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
|
||||||
try:
|
|
||||||
if importlib.util.find_spec("mlx.core") is not None:
|
|
||||||
_mlx = True
|
|
||||||
except ImportError:
|
|
||||||
_mlx = None
|
|
||||||
# These are integration tests for embedding functions.
|
# These are integration tests for embedding functions.
|
||||||
# They are slow because they require downloading models
|
# They are slow because they require downloading models
|
||||||
# or connection to external api
|
# or connection to external api
|
||||||
@@ -210,6 +205,13 @@ def test_gemini_embedding(tmp_path):
|
|||||||
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
if importlib.util.find_spec("mlx.core") is not None:
|
||||||
|
_mlx = True
|
||||||
|
except ImportError:
|
||||||
|
_mlx = None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
_mlx is None,
|
_mlx is None,
|
||||||
reason="mlx tests only required for apple users.",
|
reason="mlx tests only required for apple users.",
|
||||||
@@ -266,3 +268,49 @@ def test_bedrock_embedding(tmp_path):
|
|||||||
|
|
||||||
tbl.add(df)
|
tbl.add(df)
|
||||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
||||||
|
)
|
||||||
|
def test_openai_embedding(tmp_path):
|
||||||
|
def _get_table(model):
|
||||||
|
class TextModel(LanceModel):
|
||||||
|
text: str = model.SourceField()
|
||||||
|
vector: Vector(model.ndims()) = model.VectorField()
|
||||||
|
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||||
|
|
||||||
|
return tbl
|
||||||
|
|
||||||
|
model = get_registry().get("openai").create(max_retries=0)
|
||||||
|
tbl = _get_table(model)
|
||||||
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||||
|
|
||||||
|
tbl.add(df)
|
||||||
|
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||||
|
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||||
|
|
||||||
|
model = (
|
||||||
|
get_registry()
|
||||||
|
.get("openai")
|
||||||
|
.create(max_retries=0, name="text-embedding-3-large")
|
||||||
|
)
|
||||||
|
tbl = _get_table(model)
|
||||||
|
|
||||||
|
tbl.add(df)
|
||||||
|
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||||
|
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||||
|
|
||||||
|
model = (
|
||||||
|
get_registry()
|
||||||
|
.get("openai")
|
||||||
|
.create(max_retries=0, name="text-embedding-3-large", dim=1024)
|
||||||
|
)
|
||||||
|
tbl = _get_table(model)
|
||||||
|
|
||||||
|
tbl.add(df)
|
||||||
|
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||||
|
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||||
|
|||||||
@@ -7,7 +7,12 @@ import lancedb
|
|||||||
from lancedb.conftest import MockTextEmbeddingFunction # noqa
|
from lancedb.conftest import MockTextEmbeddingFunction # noqa
|
||||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||||
from lancedb.pydantic import LanceModel, Vector
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
from lancedb.rerankers import CohereReranker, CrossEncoderReranker
|
from lancedb.rerankers import (
|
||||||
|
CohereReranker,
|
||||||
|
ColbertReranker,
|
||||||
|
CrossEncoderReranker,
|
||||||
|
OpenaiReranker,
|
||||||
|
)
|
||||||
from lancedb.table import LanceTable
|
from lancedb.table import LanceTable
|
||||||
|
|
||||||
|
|
||||||
@@ -75,7 +80,6 @@ def get_test_table(tmp_path):
|
|||||||
return table, MyTable
|
return table, MyTable
|
||||||
|
|
||||||
|
|
||||||
## These tests are pretty loose, we should also check for correctness
|
|
||||||
def test_linear_combination(tmp_path):
|
def test_linear_combination(tmp_path):
|
||||||
table, schema = get_test_table(tmp_path)
|
table, schema = get_test_table(tmp_path)
|
||||||
# The default reranker
|
# The default reranker
|
||||||
@@ -95,14 +99,19 @@ def test_linear_combination(tmp_path):
|
|||||||
|
|
||||||
assert result1 == result3 # 2 & 3 should be the same as they use score as score
|
assert result1 == result3 # 2 & 3 should be the same as they use score as score
|
||||||
|
|
||||||
|
query = "Our father who art in heaven"
|
||||||
|
query_vector = table.to_pandas()["vector"][0]
|
||||||
result = (
|
result = (
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
table.search((query_vector, query))
|
||||||
.limit(50)
|
.limit(30)
|
||||||
.rerank(normalize="score")
|
.rerank(normalize="score")
|
||||||
.to_arrow()
|
.to_arrow()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert len(result) == 30
|
||||||
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||||
"The _score column of the results returned by the reranker "
|
"The _relevance_score column of the results returned by the reranker "
|
||||||
"represents the relevance of the result to the query & should "
|
"represents the relevance of the result to the query & should "
|
||||||
"be descending."
|
"be descending."
|
||||||
)
|
)
|
||||||
@@ -122,19 +131,24 @@ def test_cohere_reranker(tmp_path):
|
|||||||
)
|
)
|
||||||
result2 = (
|
result2 = (
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
.rerank(normalize="rank", reranker=CohereReranker())
|
.rerank(reranker=CohereReranker())
|
||||||
.to_pydantic(schema)
|
.to_pydantic(schema)
|
||||||
)
|
)
|
||||||
assert result1 == result2
|
assert result1 == result2
|
||||||
|
|
||||||
|
query = "Our father who art in heaven"
|
||||||
|
query_vector = table.to_pandas()["vector"][0]
|
||||||
result = (
|
result = (
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
table.search((query_vector, query))
|
||||||
.limit(50)
|
.limit(30)
|
||||||
.rerank(reranker=CohereReranker())
|
.rerank(reranker=CohereReranker())
|
||||||
.to_arrow()
|
.to_arrow()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert len(result) == 30
|
||||||
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||||
"The _score column of the results returned by the reranker "
|
"The _relevance_score column of the results returned by the reranker "
|
||||||
"represents the relevance of the result to the query & should "
|
"represents the relevance of the result to the query & should "
|
||||||
"be descending."
|
"be descending."
|
||||||
)
|
)
|
||||||
@@ -150,19 +164,96 @@ def test_cross_encoder_reranker(tmp_path):
|
|||||||
)
|
)
|
||||||
result2 = (
|
result2 = (
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
.rerank(normalize="rank", reranker=CrossEncoderReranker())
|
.rerank(reranker=CrossEncoderReranker())
|
||||||
.to_pydantic(schema)
|
.to_pydantic(schema)
|
||||||
)
|
)
|
||||||
assert result1 == result2
|
assert result1 == result2
|
||||||
|
|
||||||
|
# test explicit hybrid query
|
||||||
|
query = "Our father who art in heaven"
|
||||||
|
query_vector = table.to_pandas()["vector"][0]
|
||||||
result = (
|
result = (
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
table.search((query_vector, query), query_type="hybrid")
|
||||||
.limit(50)
|
.limit(30)
|
||||||
.rerank(reranker=CrossEncoderReranker())
|
.rerank(reranker=CrossEncoderReranker())
|
||||||
.to_arrow()
|
.to_arrow()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert len(result) == 30
|
||||||
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||||
"The _score column of the results returned by the reranker "
|
"The _relevance_score column of the results returned by the reranker "
|
||||||
|
"represents the relevance of the result to the query & should "
|
||||||
|
"be descending."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_colbert_reranker(tmp_path):
|
||||||
|
pytest.importorskip("transformers")
|
||||||
|
table, schema = get_test_table(tmp_path)
|
||||||
|
result1 = (
|
||||||
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
|
.rerank(normalize="score", reranker=ColbertReranker())
|
||||||
|
.to_pydantic(schema)
|
||||||
|
)
|
||||||
|
result2 = (
|
||||||
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
|
.rerank(reranker=ColbertReranker())
|
||||||
|
.to_pydantic(schema)
|
||||||
|
)
|
||||||
|
assert result1 == result2
|
||||||
|
|
||||||
|
# test explicit hybrid query
|
||||||
|
query = "Our father who art in heaven"
|
||||||
|
query_vector = table.to_pandas()["vector"][0]
|
||||||
|
result = (
|
||||||
|
table.search((query_vector, query))
|
||||||
|
.limit(30)
|
||||||
|
.rerank(reranker=ColbertReranker())
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 30
|
||||||
|
|
||||||
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||||
|
"The _relevance_score column of the results returned by the reranker "
|
||||||
|
"represents the relevance of the result to the query & should "
|
||||||
|
"be descending."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
||||||
|
)
|
||||||
|
def test_openai_reranker(tmp_path):
|
||||||
|
pytest.importorskip("openai")
|
||||||
|
table, schema = get_test_table(tmp_path)
|
||||||
|
result1 = (
|
||||||
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
|
.rerank(normalize="score", reranker=OpenaiReranker())
|
||||||
|
.to_pydantic(schema)
|
||||||
|
)
|
||||||
|
result2 = (
|
||||||
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
|
.rerank(reranker=OpenaiReranker())
|
||||||
|
.to_pydantic(schema)
|
||||||
|
)
|
||||||
|
assert result1 == result2
|
||||||
|
|
||||||
|
# test explicit hybrid query
|
||||||
|
query = "Our father who art in heaven"
|
||||||
|
query_vector = table.to_pandas()["vector"][0]
|
||||||
|
result = (
|
||||||
|
table.search((query_vector, query))
|
||||||
|
.limit(30)
|
||||||
|
.rerank(reranker=OpenaiReranker())
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 30
|
||||||
|
|
||||||
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||||
|
"The _relevance_score column of the results returned by the reranker "
|
||||||
"represents the relevance of the result to the query & should "
|
"represents the relevance of the result to the query & should "
|
||||||
"be descending."
|
"be descending."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,8 +12,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
from copy import copy
|
||||||
from datetime import date, datetime, timedelta
|
from datetime import date, datetime, timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from time import sleep
|
||||||
from typing import List
|
from typing import List
|
||||||
from unittest.mock import PropertyMock, patch
|
from unittest.mock import PropertyMock, patch
|
||||||
|
|
||||||
@@ -25,6 +27,7 @@ import pyarrow as pa
|
|||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
import lancedb
|
||||||
from lancedb.conftest import MockTextEmbeddingFunction
|
from lancedb.conftest import MockTextEmbeddingFunction
|
||||||
from lancedb.db import LanceDBConnection
|
from lancedb.db import LanceDBConnection
|
||||||
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||||
@@ -35,6 +38,7 @@ from lancedb.table import LanceTable
|
|||||||
class MockDB:
|
class MockDB:
|
||||||
def __init__(self, uri: Path):
|
def __init__(self, uri: Path):
|
||||||
self.uri = uri
|
self.uri = uri
|
||||||
|
self.read_consistency_interval = None
|
||||||
|
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def is_managed_remote(self) -> bool:
|
def is_managed_remote(self) -> bool:
|
||||||
@@ -267,39 +271,38 @@ def test_versioning(db):
|
|||||||
|
|
||||||
|
|
||||||
def test_create_index_method():
|
def test_create_index_method():
|
||||||
with patch.object(LanceTable, "_reset_dataset", return_value=None):
|
with patch.object(
|
||||||
with patch.object(
|
LanceTable, "_dataset_mut", new_callable=PropertyMock
|
||||||
LanceTable, "_dataset", new_callable=PropertyMock
|
) as mock_dataset:
|
||||||
) as mock_dataset:
|
# Setup mock responses
|
||||||
# Setup mock responses
|
mock_dataset.return_value.create_index.return_value = None
|
||||||
mock_dataset.return_value.create_index.return_value = None
|
|
||||||
|
|
||||||
# Create a LanceTable object
|
# Create a LanceTable object
|
||||||
connection = LanceDBConnection(uri="mock.uri")
|
connection = LanceDBConnection(uri="mock.uri")
|
||||||
table = LanceTable(connection, "test_table")
|
table = LanceTable(connection, "test_table")
|
||||||
|
|
||||||
# Call the create_index method
|
# Call the create_index method
|
||||||
table.create_index(
|
table.create_index(
|
||||||
metric="L2",
|
metric="L2",
|
||||||
num_partitions=256,
|
num_partitions=256,
|
||||||
num_sub_vectors=96,
|
num_sub_vectors=96,
|
||||||
vector_column_name="vector",
|
vector_column_name="vector",
|
||||||
replace=True,
|
replace=True,
|
||||||
index_cache_size=256,
|
index_cache_size=256,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check that the _dataset.create_index method was called
|
# Check that the _dataset.create_index method was called
|
||||||
# with the right parameters
|
# with the right parameters
|
||||||
mock_dataset.return_value.create_index.assert_called_once_with(
|
mock_dataset.return_value.create_index.assert_called_once_with(
|
||||||
column="vector",
|
column="vector",
|
||||||
index_type="IVF_PQ",
|
index_type="IVF_PQ",
|
||||||
metric="L2",
|
metric="L2",
|
||||||
num_partitions=256,
|
num_partitions=256,
|
||||||
num_sub_vectors=96,
|
num_sub_vectors=96,
|
||||||
replace=True,
|
replace=True,
|
||||||
accelerator=None,
|
accelerator=None,
|
||||||
index_cache_size=256,
|
index_cache_size=256,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_add_with_nans(db):
|
def test_add_with_nans(db):
|
||||||
@@ -510,8 +513,15 @@ def test_merge_insert(db):
|
|||||||
).when_matched_update_all().when_not_matched_insert_all().execute(new_data)
|
).when_matched_update_all().when_not_matched_insert_all().execute(new_data)
|
||||||
|
|
||||||
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "x", "y", "z"]})
|
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "x", "y", "z"]})
|
||||||
# These `sort_by` calls can be removed once lance#1892
|
assert table.to_arrow().sort_by("a") == expected
|
||||||
# is merged (it fixes the ordering)
|
|
||||||
|
table.restore(version)
|
||||||
|
|
||||||
|
# conditional update
|
||||||
|
table.merge_insert("a").when_matched_update_all(where="target.b = 'b'").execute(
|
||||||
|
new_data
|
||||||
|
)
|
||||||
|
expected = pa.table({"a": [1, 2, 3], "b": ["a", "x", "c"]})
|
||||||
assert table.to_arrow().sort_by("a") == expected
|
assert table.to_arrow().sort_by("a") == expected
|
||||||
|
|
||||||
table.restore(version)
|
table.restore(version)
|
||||||
@@ -792,3 +802,48 @@ def test_hybrid_search(db):
|
|||||||
"Our father who art in heaven", query_type="hybrid"
|
"Our father who art in heaven", query_type="hybrid"
|
||||||
).to_pydantic(MyTable)
|
).to_pydantic(MyTable)
|
||||||
assert result1 == result3
|
assert result1 == result3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"consistency_interval", [None, timedelta(seconds=0), timedelta(seconds=0.1)]
|
||||||
|
)
|
||||||
|
def test_consistency(tmp_path, consistency_interval):
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
table = LanceTable.create(db, "my_table", data=[{"id": 0}])
|
||||||
|
|
||||||
|
db2 = lancedb.connect(tmp_path, read_consistency_interval=consistency_interval)
|
||||||
|
table2 = db2.open_table("my_table")
|
||||||
|
assert table2.version == table.version
|
||||||
|
|
||||||
|
table.add([{"id": 1}])
|
||||||
|
|
||||||
|
if consistency_interval is None:
|
||||||
|
assert table2.version == table.version - 1
|
||||||
|
table2.checkout_latest()
|
||||||
|
assert table2.version == table.version
|
||||||
|
elif consistency_interval == timedelta(seconds=0):
|
||||||
|
assert table2.version == table.version
|
||||||
|
else:
|
||||||
|
# (consistency_interval == timedelta(seconds=0.1)
|
||||||
|
assert table2.version == table.version - 1
|
||||||
|
sleep(0.1)
|
||||||
|
assert table2.version == table.version
|
||||||
|
|
||||||
|
|
||||||
|
def test_restore_consistency(tmp_path):
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
table = LanceTable.create(db, "my_table", data=[{"id": 0}])
|
||||||
|
|
||||||
|
db2 = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0))
|
||||||
|
table2 = db2.open_table("my_table")
|
||||||
|
assert table2.version == table.version
|
||||||
|
|
||||||
|
# If we call checkout, it should lose consistency
|
||||||
|
table_fixed = copy(table2)
|
||||||
|
table_fixed.checkout(table.version)
|
||||||
|
# But if we call checkout_latest, it should be consistent again
|
||||||
|
table_ref_latest = copy(table_fixed)
|
||||||
|
table_ref_latest.checkout_latest()
|
||||||
|
table.add([{"id": 2}])
|
||||||
|
assert table_fixed.version == table.version - 1
|
||||||
|
assert table_ref_latest.version == table.version
|
||||||
|
|||||||
@@ -2,8 +2,11 @@
|
|||||||
name = "vectordb-node"
|
name = "vectordb-node"
|
||||||
version = "0.4.8"
|
version = "0.4.8"
|
||||||
description = "Serverless, low-latency vector database for AI applications"
|
description = "Serverless, low-latency vector database for AI applications"
|
||||||
license = "Apache-2.0"
|
license.workspace = true
|
||||||
edition = "2018"
|
edition.workspace = true
|
||||||
|
repository.workspace = true
|
||||||
|
keywords.workspace = true
|
||||||
|
categories.workspace = true
|
||||||
exclude = ["index.node"]
|
exclude = ["index.node"]
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ use arrow_schema::SchemaRef;
|
|||||||
|
|
||||||
use crate::error::Result;
|
use crate::error::Result;
|
||||||
|
|
||||||
pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBatch>, SchemaRef)> {
|
pub fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBatch>, SchemaRef)> {
|
||||||
let mut batches: Vec<RecordBatch> = Vec::new();
|
let mut batches: Vec<RecordBatch> = Vec::new();
|
||||||
let file_reader = FileReader::try_new(Cursor::new(slice), None)?;
|
let file_reader = FileReader::try_new(Cursor::new(slice), None)?;
|
||||||
let schema = file_reader.schema();
|
let schema = file_reader.schema();
|
||||||
@@ -33,7 +33,7 @@ pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBa
|
|||||||
Ok((batches, schema))
|
Ok((batches, schema))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn record_batch_to_buffer(batches: Vec<RecordBatch>) -> Result<Vec<u8>> {
|
pub fn record_batch_to_buffer(batches: Vec<RecordBatch>) -> Result<Vec<u8>> {
|
||||||
if batches.is_empty() {
|
if batches.is_empty() {
|
||||||
return Ok(Vec::new());
|
return Ok(Vec::new());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ use neon::types::buffer::TypedArray;
|
|||||||
|
|
||||||
use crate::error::ResultExt;
|
use crate::error::ResultExt;
|
||||||
|
|
||||||
pub(crate) fn vec_str_to_array<'a, C: Context<'a>>(
|
pub fn vec_str_to_array<'a, C: Context<'a>>(
|
||||||
vec: &Vec<String>,
|
vec: &Vec<String>,
|
||||||
cx: &mut C,
|
cx: &mut C,
|
||||||
) -> JsResult<'a, JsArray> {
|
) -> JsResult<'a, JsArray> {
|
||||||
@@ -29,7 +29,7 @@ pub(crate) fn vec_str_to_array<'a, C: Context<'a>>(
|
|||||||
Ok(a)
|
Ok(a)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn js_array_to_vec(array: &JsArray, cx: &mut FunctionContext) -> Vec<f32> {
|
pub fn js_array_to_vec(array: &JsArray, cx: &mut FunctionContext) -> Vec<f32> {
|
||||||
let mut query_vec: Vec<f32> = Vec::new();
|
let mut query_vec: Vec<f32> = Vec::new();
|
||||||
for i in 0..array.len(cx) {
|
for i in 0..array.len(cx) {
|
||||||
let entry: Handle<JsNumber> = array.get(cx, i).unwrap();
|
let entry: Handle<JsNumber> = array.get(cx, i).unwrap();
|
||||||
@@ -39,7 +39,7 @@ pub(crate) fn js_array_to_vec(array: &JsArray, cx: &mut FunctionContext) -> Vec<
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Creates a new JsBuffer from a rust buffer with a special logic for electron
|
// Creates a new JsBuffer from a rust buffer with a special logic for electron
|
||||||
pub(crate) fn new_js_buffer<'a>(
|
pub fn new_js_buffer<'a>(
|
||||||
buffer: Vec<u8>,
|
buffer: Vec<u8>,
|
||||||
cx: &mut TaskContext<'a>,
|
cx: &mut TaskContext<'a>,
|
||||||
is_electron: bool,
|
is_electron: bool,
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ use neon::prelude::NeonResult;
|
|||||||
use snafu::Snafu;
|
use snafu::Snafu;
|
||||||
|
|
||||||
#[derive(Debug, Snafu)]
|
#[derive(Debug, Snafu)]
|
||||||
#[snafu(visibility(pub(crate)))]
|
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
#[snafu(display("column '{name}' is missing"))]
|
#[snafu(display("column '{name}' is missing"))]
|
||||||
MissingColumn { name: String },
|
MissingColumn { name: String },
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ use neon::{
|
|||||||
use crate::{error::ResultExt, runtime, table::JsTable};
|
use crate::{error::ResultExt, runtime, table::JsTable};
|
||||||
use vectordb::Table;
|
use vectordb::Table;
|
||||||
|
|
||||||
pub(crate) fn table_create_scalar_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub fn table_create_scalar_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||||
let column = cx.argument::<JsString>(0)?.value(&mut cx);
|
let column = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||||
let replace = cx.argument::<JsBoolean>(1)?.value(&mut cx);
|
let replace = cx.argument::<JsBoolean>(1)?.value(&mut cx);
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ use crate::neon_ext::js_object_ext::JsObjectExt;
|
|||||||
use crate::runtime;
|
use crate::runtime;
|
||||||
use crate::table::JsTable;
|
use crate::table::JsTable;
|
||||||
|
|
||||||
pub(crate) fn table_create_vector_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub fn table_create_vector_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||||
let index_params = cx.argument::<JsObject>(0)?;
|
let index_params = cx.argument::<JsObject>(0)?;
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ use crate::neon_ext::js_object_ext::JsObjectExt;
|
|||||||
use crate::table::JsTable;
|
use crate::table::JsTable;
|
||||||
use crate::{convert, runtime};
|
use crate::{convert, runtime};
|
||||||
|
|
||||||
pub(crate) struct JsQuery {}
|
pub struct JsQuery {}
|
||||||
|
|
||||||
impl JsQuery {
|
impl JsQuery {
|
||||||
pub(crate) fn js_search(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub(crate) fn js_search(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ use vectordb::TableRef;
|
|||||||
use crate::error::ResultExt;
|
use crate::error::ResultExt;
|
||||||
use crate::{convert, get_aws_credential_provider, get_aws_region, runtime, JsDatabase};
|
use crate::{convert, get_aws_credential_provider, get_aws_region, runtime, JsDatabase};
|
||||||
|
|
||||||
pub(crate) struct JsTable {
|
pub struct JsTable {
|
||||||
pub table: TableRef,
|
pub table: TableRef,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ impl Finalize for JsTable {}
|
|||||||
|
|
||||||
impl From<TableRef> for JsTable {
|
impl From<TableRef> for JsTable {
|
||||||
fn from(table: TableRef) -> Self {
|
fn from(table: TableRef) -> Self {
|
||||||
JsTable { table }
|
Self { table }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -85,14 +85,14 @@ impl JsTable {
|
|||||||
|
|
||||||
deferred.settle_with(&channel, move |mut cx| {
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
let table = table_rst.or_throw(&mut cx)?;
|
let table = table_rst.or_throw(&mut cx)?;
|
||||||
Ok(cx.boxed(JsTable::from(table)))
|
Ok(cx.boxed(Self::from(table)))
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
Ok(promise)
|
Ok(promise)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn js_add(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub(crate) fn js_add(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||||
let buffer = cx.argument::<JsBuffer>(0)?;
|
let buffer = cx.argument::<JsBuffer>(0)?;
|
||||||
let write_mode = cx.argument::<JsString>(1)?.value(&mut cx);
|
let write_mode = cx.argument::<JsString>(1)?.value(&mut cx);
|
||||||
let (batches, schema) =
|
let (batches, schema) =
|
||||||
@@ -125,21 +125,34 @@ impl JsTable {
|
|||||||
|
|
||||||
deferred.settle_with(&channel, move |mut cx| {
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
add_result.or_throw(&mut cx)?;
|
add_result.or_throw(&mut cx)?;
|
||||||
Ok(cx.boxed(JsTable::from(table)))
|
Ok(cx.boxed(Self::from(table)))
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
Ok(promise)
|
Ok(promise)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn js_count_rows(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub(crate) fn js_count_rows(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||||
|
let filter = cx
|
||||||
|
.argument_opt(0)
|
||||||
|
.and_then(|filt| {
|
||||||
|
if filt.is_a::<JsUndefined, _>(&mut cx) || filt.is_a::<JsNull, _>(&mut cx) {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(
|
||||||
|
filt.downcast_or_throw::<JsString, _>(&mut cx)
|
||||||
|
.map(|js_filt| js_filt.deref().value(&mut cx)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.transpose()?;
|
||||||
let rt = runtime(&mut cx)?;
|
let rt = runtime(&mut cx)?;
|
||||||
let (deferred, promise) = cx.promise();
|
let (deferred, promise) = cx.promise();
|
||||||
let channel = cx.channel();
|
let channel = cx.channel();
|
||||||
let table = js_table.table.clone();
|
let table = js_table.table.clone();
|
||||||
|
|
||||||
rt.spawn(async move {
|
rt.spawn(async move {
|
||||||
let num_rows_result = table.count_rows().await;
|
let num_rows_result = table.count_rows(filter).await;
|
||||||
|
|
||||||
deferred.settle_with(&channel, move |mut cx| {
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
let num_rows = num_rows_result.or_throw(&mut cx)?;
|
let num_rows = num_rows_result.or_throw(&mut cx)?;
|
||||||
@@ -150,7 +163,7 @@ impl JsTable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn js_delete(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub(crate) fn js_delete(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||||
let rt = runtime(&mut cx)?;
|
let rt = runtime(&mut cx)?;
|
||||||
let (deferred, promise) = cx.promise();
|
let (deferred, promise) = cx.promise();
|
||||||
let predicate = cx.argument::<JsString>(0)?.value(&mut cx);
|
let predicate = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||||
@@ -162,14 +175,14 @@ impl JsTable {
|
|||||||
|
|
||||||
deferred.settle_with(&channel, move |mut cx| {
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
delete_result.or_throw(&mut cx)?;
|
delete_result.or_throw(&mut cx)?;
|
||||||
Ok(cx.boxed(JsTable::from(table)))
|
Ok(cx.boxed(Self::from(table)))
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
Ok(promise)
|
Ok(promise)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn js_merge_insert(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub(crate) fn js_merge_insert(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||||
let rt = runtime(&mut cx)?;
|
let rt = runtime(&mut cx)?;
|
||||||
let (deferred, promise) = cx.promise();
|
let (deferred, promise) = cx.promise();
|
||||||
let channel = cx.channel();
|
let channel = cx.channel();
|
||||||
@@ -178,28 +191,34 @@ impl JsTable {
|
|||||||
let key = cx.argument::<JsString>(0)?.value(&mut cx);
|
let key = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||||
let mut builder = table.merge_insert(&[&key]);
|
let mut builder = table.merge_insert(&[&key]);
|
||||||
if cx.argument::<JsBoolean>(1)?.value(&mut cx) {
|
if cx.argument::<JsBoolean>(1)?.value(&mut cx) {
|
||||||
builder.when_matched_update_all();
|
let filter = cx.argument_opt(2).unwrap();
|
||||||
}
|
if filter.is_a::<JsNull, _>(&mut cx) {
|
||||||
if cx.argument::<JsBoolean>(2)?.value(&mut cx) {
|
builder.when_matched_update_all(None);
|
||||||
builder.when_not_matched_insert_all();
|
} else {
|
||||||
|
let filter = filter
|
||||||
|
.downcast_or_throw::<JsString, _>(&mut cx)?
|
||||||
|
.deref()
|
||||||
|
.value(&mut cx);
|
||||||
|
builder.when_matched_update_all(Some(filter));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if cx.argument::<JsBoolean>(3)?.value(&mut cx) {
|
if cx.argument::<JsBoolean>(3)?.value(&mut cx) {
|
||||||
if let Some(filter) = cx.argument_opt(4) {
|
builder.when_not_matched_insert_all();
|
||||||
if filter.is_a::<JsNull, _>(&mut cx) {
|
}
|
||||||
builder.when_not_matched_by_source_delete(None);
|
if cx.argument::<JsBoolean>(4)?.value(&mut cx) {
|
||||||
} else {
|
let filter = cx.argument_opt(5).unwrap();
|
||||||
let filter = filter
|
if filter.is_a::<JsNull, _>(&mut cx) {
|
||||||
.downcast_or_throw::<JsString, _>(&mut cx)?
|
|
||||||
.deref()
|
|
||||||
.value(&mut cx);
|
|
||||||
builder.when_not_matched_by_source_delete(Some(filter));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
builder.when_not_matched_by_source_delete(None);
|
builder.when_not_matched_by_source_delete(None);
|
||||||
|
} else {
|
||||||
|
let filter = filter
|
||||||
|
.downcast_or_throw::<JsString, _>(&mut cx)?
|
||||||
|
.deref()
|
||||||
|
.value(&mut cx);
|
||||||
|
builder.when_not_matched_by_source_delete(Some(filter));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let buffer = cx.argument::<JsBuffer>(5)?;
|
let buffer = cx.argument::<JsBuffer>(6)?;
|
||||||
let (batches, schema) =
|
let (batches, schema) =
|
||||||
arrow_buffer_to_record_batch(buffer.as_slice(&cx)).or_throw(&mut cx)?;
|
arrow_buffer_to_record_batch(buffer.as_slice(&cx)).or_throw(&mut cx)?;
|
||||||
|
|
||||||
@@ -209,14 +228,14 @@ impl JsTable {
|
|||||||
|
|
||||||
deferred.settle_with(&channel, move |mut cx| {
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
merge_insert_result.or_throw(&mut cx)?;
|
merge_insert_result.or_throw(&mut cx)?;
|
||||||
Ok(cx.boxed(JsTable::from(table)))
|
Ok(cx.boxed(Self::from(table)))
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
Ok(promise)
|
Ok(promise)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn js_update(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub(crate) fn js_update(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||||
let table = js_table.table.clone();
|
let table = js_table.table.clone();
|
||||||
|
|
||||||
let rt = runtime(&mut cx)?;
|
let rt = runtime(&mut cx)?;
|
||||||
@@ -275,7 +294,7 @@ impl JsTable {
|
|||||||
.await;
|
.await;
|
||||||
deferred.settle_with(&channel, move |mut cx| {
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
update_result.or_throw(&mut cx)?;
|
update_result.or_throw(&mut cx)?;
|
||||||
Ok(cx.boxed(JsTable::from(table)))
|
Ok(cx.boxed(Self::from(table)))
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -283,7 +302,7 @@ impl JsTable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn js_cleanup(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub(crate) fn js_cleanup(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||||
let rt = runtime(&mut cx)?;
|
let rt = runtime(&mut cx)?;
|
||||||
let (deferred, promise) = cx.promise();
|
let (deferred, promise) = cx.promise();
|
||||||
let table = js_table.table.clone();
|
let table = js_table.table.clone();
|
||||||
@@ -321,7 +340,7 @@ impl JsTable {
|
|||||||
let old_versions = cx.number(prune_stats.old_versions as f64);
|
let old_versions = cx.number(prune_stats.old_versions as f64);
|
||||||
output_metrics.set(&mut cx, "oldVersions", old_versions)?;
|
output_metrics.set(&mut cx, "oldVersions", old_versions)?;
|
||||||
|
|
||||||
let output_table = cx.boxed(JsTable::from(table));
|
let output_table = cx.boxed(Self::from(table));
|
||||||
|
|
||||||
let output = JsObject::new(&mut cx);
|
let output = JsObject::new(&mut cx);
|
||||||
output.set(&mut cx, "metrics", output_metrics)?;
|
output.set(&mut cx, "metrics", output_metrics)?;
|
||||||
@@ -334,7 +353,7 @@ impl JsTable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn js_compact(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub(crate) fn js_compact(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||||
let rt = runtime(&mut cx)?;
|
let rt = runtime(&mut cx)?;
|
||||||
let (deferred, promise) = cx.promise();
|
let (deferred, promise) = cx.promise();
|
||||||
let table = js_table.table.clone();
|
let table = js_table.table.clone();
|
||||||
@@ -393,7 +412,7 @@ impl JsTable {
|
|||||||
let files_added = cx.number(stats.files_added as f64);
|
let files_added = cx.number(stats.files_added as f64);
|
||||||
output_metrics.set(&mut cx, "filesAdded", files_added)?;
|
output_metrics.set(&mut cx, "filesAdded", files_added)?;
|
||||||
|
|
||||||
let output_table = cx.boxed(JsTable::from(table));
|
let output_table = cx.boxed(Self::from(table));
|
||||||
|
|
||||||
let output = JsObject::new(&mut cx);
|
let output = JsObject::new(&mut cx);
|
||||||
output.set(&mut cx, "metrics", output_metrics)?;
|
output.set(&mut cx, "metrics", output_metrics)?;
|
||||||
@@ -406,7 +425,7 @@ impl JsTable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn js_list_indices(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub(crate) fn js_list_indices(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||||
let rt = runtime(&mut cx)?;
|
let rt = runtime(&mut cx)?;
|
||||||
let (deferred, promise) = cx.promise();
|
let (deferred, promise) = cx.promise();
|
||||||
// let predicate = cx.argument::<JsString>(0)?.value(&mut cx);
|
// let predicate = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||||
@@ -445,7 +464,7 @@ impl JsTable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn js_index_stats(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub(crate) fn js_index_stats(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||||
let rt = runtime(&mut cx)?;
|
let rt = runtime(&mut cx)?;
|
||||||
let (deferred, promise) = cx.promise();
|
let (deferred, promise) = cx.promise();
|
||||||
let index_uuid = cx.argument::<JsString>(0)?.value(&mut cx);
|
let index_uuid = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||||
@@ -493,7 +512,7 @@ impl JsTable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn js_schema(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
pub(crate) fn js_schema(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||||
let rt = runtime(&mut cx)?;
|
let rt = runtime(&mut cx)?;
|
||||||
let (deferred, promise) = cx.promise();
|
let (deferred, promise) = cx.promise();
|
||||||
let channel = cx.channel();
|
let channel = cx.channel();
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "vectordb"
|
name = "vectordb"
|
||||||
version = "0.4.8"
|
version = "0.4.8"
|
||||||
edition = "2021"
|
edition.workspace = true
|
||||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||||
license = "Apache-2.0"
|
license.workspace = true
|
||||||
repository = "https://github.com/lancedb/lancedb"
|
repository.workspace = true
|
||||||
keywords = ["lancedb", "lance", "database", "search"]
|
keywords.workspace = true
|
||||||
categories = ["database-implementations"]
|
categories.workspace = true
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
|||||||
@@ -188,12 +188,12 @@ impl Database {
|
|||||||
/// # Returns
|
/// # Returns
|
||||||
///
|
///
|
||||||
/// * A [Database] object.
|
/// * A [Database] object.
|
||||||
pub async fn connect(uri: &str) -> Result<Database> {
|
pub async fn connect(uri: &str) -> Result<Self> {
|
||||||
let options = ConnectOptions::new(uri);
|
let options = ConnectOptions::new(uri);
|
||||||
Self::connect_with_options(&options).await
|
Self::connect_with_options(&options).await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn connect_with_options(options: &ConnectOptions) -> Result<Database> {
|
pub async fn connect_with_options(options: &ConnectOptions) -> Result<Self> {
|
||||||
let uri = &options.uri;
|
let uri = &options.uri;
|
||||||
let parse_res = url::Url::parse(uri);
|
let parse_res = url::Url::parse(uri);
|
||||||
|
|
||||||
@@ -276,7 +276,7 @@ impl Database {
|
|||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Database {
|
Ok(Self {
|
||||||
uri: table_base_uri,
|
uri: table_base_uri,
|
||||||
query_string,
|
query_string,
|
||||||
base_path,
|
base_path,
|
||||||
@@ -288,7 +288,7 @@ impl Database {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn open_path(path: &str) -> Result<Database> {
|
async fn open_path(path: &str) -> Result<Self> {
|
||||||
let (object_store, base_path) = ObjectStore::from_uri(path).await?;
|
let (object_store, base_path) = ObjectStore::from_uri(path).await?;
|
||||||
if object_store.is_local() {
|
if object_store.is_local() {
|
||||||
Self::try_create_dir(path).context(CreateDirSnafu { path })?;
|
Self::try_create_dir(path).context(CreateDirSnafu { path })?;
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ pub struct IndexBuilder {
|
|||||||
|
|
||||||
impl IndexBuilder {
|
impl IndexBuilder {
|
||||||
pub(crate) fn new(table: Arc<dyn Table>, columns: &[&str]) -> Self {
|
pub(crate) fn new(table: Arc<dyn Table>, columns: &[&str]) -> Self {
|
||||||
IndexBuilder {
|
Self {
|
||||||
table,
|
table,
|
||||||
columns: columns.iter().map(|c| c.to_string()).collect(),
|
columns: columns.iter().map(|c| c.to_string()).collect(),
|
||||||
name: None,
|
name: None,
|
||||||
@@ -197,7 +197,7 @@ impl IndexBuilder {
|
|||||||
let num_partitions = if let Some(n) = self.num_partitions {
|
let num_partitions = if let Some(n) = self.num_partitions {
|
||||||
n
|
n
|
||||||
} else {
|
} else {
|
||||||
suggested_num_partitions(self.table.count_rows().await?)
|
suggested_num_partitions(self.table.count_rows(None).await?)
|
||||||
};
|
};
|
||||||
let num_sub_vectors: u32 = if let Some(n) = self.num_sub_vectors {
|
let num_sub_vectors: u32 = if let Some(n) = self.num_sub_vectors {
|
||||||
n
|
n
|
||||||
|
|||||||
@@ -23,13 +23,13 @@ pub struct VectorIndex {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl VectorIndex {
|
impl VectorIndex {
|
||||||
pub fn new_from_format(manifest: &Manifest, index: &Index) -> VectorIndex {
|
pub fn new_from_format(manifest: &Manifest, index: &Index) -> Self {
|
||||||
let fields = index
|
let fields = index
|
||||||
.fields
|
.fields
|
||||||
.iter()
|
.iter()
|
||||||
.map(|i| manifest.schema.fields[*i as usize].name.clone())
|
.map(|i| manifest.schema.fields[*i as usize].name.clone())
|
||||||
.collect();
|
.collect();
|
||||||
VectorIndex {
|
Self {
|
||||||
columns: fields,
|
columns: fields,
|
||||||
index_name: index.name.clone(),
|
index_name: index.name.clone(),
|
||||||
index_uuid: index.uuid.to_string(),
|
index_uuid: index.uuid.to_string(),
|
||||||
|
|||||||
@@ -372,7 +372,7 @@ mod test {
|
|||||||
// leave this here for easy debugging
|
// leave this here for easy debugging
|
||||||
let t = res.unwrap();
|
let t = res.unwrap();
|
||||||
|
|
||||||
assert_eq!(t.count_rows().await.unwrap(), 100);
|
assert_eq!(t.count_rows(None).await.unwrap(), 100);
|
||||||
|
|
||||||
let q = t
|
let q = t
|
||||||
.search(&[0.1, 0.1, 0.1, 0.1])
|
.search(&[0.1, 0.1, 0.1, 0.1])
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ impl Query {
|
|||||||
/// * `dataset` - Lance dataset.
|
/// * `dataset` - Lance dataset.
|
||||||
///
|
///
|
||||||
pub(crate) fn new(dataset: Arc<Dataset>) -> Self {
|
pub(crate) fn new(dataset: Arc<Dataset>) -> Self {
|
||||||
Query {
|
Self {
|
||||||
dataset,
|
dataset,
|
||||||
query_vector: None,
|
query_vector: None,
|
||||||
column: None,
|
column: None,
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ use lance::dataset::optimize::{
|
|||||||
compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions,
|
compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions,
|
||||||
};
|
};
|
||||||
pub use lance::dataset::ReadParams;
|
pub use lance::dataset::ReadParams;
|
||||||
use lance::dataset::{Dataset, UpdateBuilder, WriteParams};
|
use lance::dataset::{Dataset, UpdateBuilder, WhenMatched, WriteParams};
|
||||||
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
|
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
|
||||||
use lance::io::WrappingObjectStore;
|
use lance::io::WrappingObjectStore;
|
||||||
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
|
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
|
||||||
@@ -102,7 +102,11 @@ pub trait Table: std::fmt::Display + Send + Sync {
|
|||||||
fn schema(&self) -> SchemaRef;
|
fn schema(&self) -> SchemaRef;
|
||||||
|
|
||||||
/// Count the number of rows in this dataset.
|
/// Count the number of rows in this dataset.
|
||||||
async fn count_rows(&self) -> Result<usize>;
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `filter` if present, only count rows matching the filter
|
||||||
|
async fn count_rows(&self, filter: Option<String>) -> Result<usize>;
|
||||||
|
|
||||||
/// Insert new records into this Table
|
/// Insert new records into this Table
|
||||||
///
|
///
|
||||||
@@ -234,7 +238,7 @@ pub trait Table: std::fmt::Display + Send + Sync {
|
|||||||
/// schema.clone());
|
/// schema.clone());
|
||||||
/// // Perform an upsert operation
|
/// // Perform an upsert operation
|
||||||
/// let mut merge_insert = tbl.merge_insert(&["id"]);
|
/// let mut merge_insert = tbl.merge_insert(&["id"]);
|
||||||
/// merge_insert.when_matched_update_all()
|
/// merge_insert.when_matched_update_all(None)
|
||||||
/// .when_not_matched_insert_all();
|
/// .when_not_matched_insert_all();
|
||||||
/// merge_insert.execute(Box::new(new_data)).await.unwrap();
|
/// merge_insert.execute(Box::new(new_data)).await.unwrap();
|
||||||
/// # });
|
/// # });
|
||||||
@@ -385,7 +389,7 @@ impl NativeTable {
|
|||||||
message: e.to_string(),
|
message: e.to_string(),
|
||||||
},
|
},
|
||||||
})?;
|
})?;
|
||||||
Ok(NativeTable {
|
Ok(Self {
|
||||||
name: name.to_string(),
|
name: name.to_string(),
|
||||||
uri: uri.to_string(),
|
uri: uri.to_string(),
|
||||||
dataset: Arc::new(Mutex::new(dataset)),
|
dataset: Arc::new(Mutex::new(dataset)),
|
||||||
@@ -427,7 +431,7 @@ impl NativeTable {
|
|||||||
message: e.to_string(),
|
message: e.to_string(),
|
||||||
},
|
},
|
||||||
})?;
|
})?;
|
||||||
Ok(NativeTable {
|
Ok(Self {
|
||||||
name: name.to_string(),
|
name: name.to_string(),
|
||||||
uri: uri.to_string(),
|
uri: uri.to_string(),
|
||||||
dataset: Arc::new(Mutex::new(dataset)),
|
dataset: Arc::new(Mutex::new(dataset)),
|
||||||
@@ -501,7 +505,7 @@ impl NativeTable {
|
|||||||
message: e.to_string(),
|
message: e.to_string(),
|
||||||
},
|
},
|
||||||
})?;
|
})?;
|
||||||
Ok(NativeTable {
|
Ok(Self {
|
||||||
name: name.to_string(),
|
name: name.to_string(),
|
||||||
uri: uri.to_string(),
|
uri: uri.to_string(),
|
||||||
dataset: Arc::new(Mutex::new(dataset)),
|
dataset: Arc::new(Mutex::new(dataset)),
|
||||||
@@ -673,11 +677,14 @@ impl MergeInsert for NativeTable {
|
|||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let dataset = Arc::new(self.clone_inner_dataset());
|
let dataset = Arc::new(self.clone_inner_dataset());
|
||||||
let mut builder = LanceMergeInsertBuilder::try_new(dataset.clone(), params.on)?;
|
let mut builder = LanceMergeInsertBuilder::try_new(dataset.clone(), params.on)?;
|
||||||
if params.when_matched_update_all {
|
match (
|
||||||
builder.when_matched(lance::dataset::WhenMatched::UpdateAll);
|
params.when_matched_update_all,
|
||||||
} else {
|
params.when_matched_update_all_filt,
|
||||||
builder.when_matched(lance::dataset::WhenMatched::DoNothing);
|
) {
|
||||||
}
|
(false, _) => builder.when_matched(WhenMatched::DoNothing),
|
||||||
|
(true, None) => builder.when_matched(WhenMatched::UpdateAll),
|
||||||
|
(true, Some(filt)) => builder.when_matched(WhenMatched::update_if(&dataset, &filt)?),
|
||||||
|
};
|
||||||
if params.when_not_matched_insert_all {
|
if params.when_not_matched_insert_all {
|
||||||
builder.when_not_matched(lance::dataset::WhenNotMatched::InsertAll);
|
builder.when_not_matched(lance::dataset::WhenNotMatched::InsertAll);
|
||||||
} else {
|
} else {
|
||||||
@@ -719,9 +726,15 @@ impl Table for NativeTable {
|
|||||||
Arc::new(Schema::from(&lance_schema))
|
Arc::new(Schema::from(&lance_schema))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn count_rows(&self) -> Result<usize> {
|
async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
|
||||||
let dataset = { self.dataset.lock().expect("lock poison").clone() };
|
let dataset = { self.dataset.lock().expect("lock poison").clone() };
|
||||||
Ok(dataset.count_rows().await?)
|
if let Some(filter) = filter {
|
||||||
|
let mut scanner = dataset.scan();
|
||||||
|
scanner.filter(&filter)?;
|
||||||
|
Ok(scanner.count_rows().await? as usize)
|
||||||
|
} else {
|
||||||
|
Ok(dataset.count_rows().await?)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn add(
|
async fn add(
|
||||||
@@ -814,6 +827,7 @@ impl Table for NativeTable {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use std::iter;
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
@@ -886,6 +900,23 @@ mod tests {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_count_rows() {
|
||||||
|
let tmp_dir = tempdir().unwrap();
|
||||||
|
let uri = tmp_dir.path().to_str().unwrap();
|
||||||
|
|
||||||
|
let batches = make_test_batches();
|
||||||
|
let table = NativeTable::create(&uri, "test", batches, None, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||||
|
assert_eq!(
|
||||||
|
table.count_rows(Some("i >= 5".to_string())).await.unwrap(),
|
||||||
|
5
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_add() {
|
async fn test_add() {
|
||||||
let tmp_dir = tempdir().unwrap();
|
let tmp_dir = tempdir().unwrap();
|
||||||
@@ -896,7 +927,7 @@ mod tests {
|
|||||||
let table = NativeTable::create(&uri, "test", batches, None, None)
|
let table = NativeTable::create(&uri, "test", batches, None, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(table.count_rows().await.unwrap(), 10);
|
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||||
|
|
||||||
let new_batches = RecordBatchIterator::new(
|
let new_batches = RecordBatchIterator::new(
|
||||||
vec![RecordBatch::try_new(
|
vec![RecordBatch::try_new(
|
||||||
@@ -910,7 +941,7 @@ mod tests {
|
|||||||
);
|
);
|
||||||
|
|
||||||
table.add(Box::new(new_batches), None).await.unwrap();
|
table.add(Box::new(new_batches), None).await.unwrap();
|
||||||
assert_eq!(table.count_rows().await.unwrap(), 20);
|
assert_eq!(table.count_rows(None).await.unwrap(), 20);
|
||||||
assert_eq!(table.name, "test");
|
assert_eq!(table.name, "test");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -920,30 +951,44 @@ mod tests {
|
|||||||
let uri = tmp_dir.path().to_str().unwrap();
|
let uri = tmp_dir.path().to_str().unwrap();
|
||||||
|
|
||||||
// Create a dataset with i=0..10
|
// Create a dataset with i=0..10
|
||||||
let batches = make_test_batches_with_offset(0);
|
let batches = merge_insert_test_batches(0, 0);
|
||||||
let table = NativeTable::create(&uri, "test", batches, None, None)
|
let table = NativeTable::create(&uri, "test", batches, None, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(table.count_rows().await.unwrap(), 10);
|
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||||
|
|
||||||
// Create new data with i=5..15
|
// Create new data with i=5..15
|
||||||
let new_batches = Box::new(make_test_batches_with_offset(5));
|
let new_batches = Box::new(merge_insert_test_batches(5, 1));
|
||||||
|
|
||||||
// Perform a "insert if not exists"
|
// Perform a "insert if not exists"
|
||||||
let mut merge_insert_builder = table.merge_insert(&["i"]);
|
let mut merge_insert_builder = table.merge_insert(&["i"]);
|
||||||
merge_insert_builder.when_not_matched_insert_all();
|
merge_insert_builder.when_not_matched_insert_all();
|
||||||
merge_insert_builder.execute(new_batches).await.unwrap();
|
merge_insert_builder.execute(new_batches).await.unwrap();
|
||||||
// Only 5 rows should actually be inserted
|
// Only 5 rows should actually be inserted
|
||||||
assert_eq!(table.count_rows().await.unwrap(), 15);
|
assert_eq!(table.count_rows(None).await.unwrap(), 15);
|
||||||
|
|
||||||
// Create new data with i=15..25 (no id matches)
|
// Create new data with i=15..25 (no id matches)
|
||||||
let new_batches = Box::new(make_test_batches_with_offset(15));
|
let new_batches = Box::new(merge_insert_test_batches(15, 2));
|
||||||
// Perform a "bulk update" (should not affect anything)
|
// Perform a "bulk update" (should not affect anything)
|
||||||
let mut merge_insert_builder = table.merge_insert(&["i"]);
|
let mut merge_insert_builder = table.merge_insert(&["i"]);
|
||||||
merge_insert_builder.when_matched_update_all();
|
merge_insert_builder.when_matched_update_all(None);
|
||||||
merge_insert_builder.execute(new_batches).await.unwrap();
|
merge_insert_builder.execute(new_batches).await.unwrap();
|
||||||
// No new rows should have been inserted
|
// No new rows should have been inserted
|
||||||
assert_eq!(table.count_rows().await.unwrap(), 15);
|
assert_eq!(table.count_rows(None).await.unwrap(), 15);
|
||||||
|
assert_eq!(
|
||||||
|
table.count_rows(Some("age = 2".to_string())).await.unwrap(),
|
||||||
|
0
|
||||||
|
);
|
||||||
|
|
||||||
|
// Conditional update that only replaces the age=0 data
|
||||||
|
let new_batches = Box::new(merge_insert_test_batches(5, 3));
|
||||||
|
let mut merge_insert_builder = table.merge_insert(&["i"]);
|
||||||
|
merge_insert_builder.when_matched_update_all(Some("target.age = 0".to_string()));
|
||||||
|
merge_insert_builder.execute(new_batches).await.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
table.count_rows(Some("age = 3".to_string())).await.unwrap(),
|
||||||
|
5
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -956,7 +1001,7 @@ mod tests {
|
|||||||
let table = NativeTable::create(uri, "test", batches, None, None)
|
let table = NativeTable::create(uri, "test", batches, None, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(table.count_rows().await.unwrap(), 10);
|
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||||
|
|
||||||
let new_batches = RecordBatchIterator::new(
|
let new_batches = RecordBatchIterator::new(
|
||||||
vec![RecordBatch::try_new(
|
vec![RecordBatch::try_new(
|
||||||
@@ -975,7 +1020,7 @@ mod tests {
|
|||||||
};
|
};
|
||||||
|
|
||||||
table.add(Box::new(new_batches), Some(param)).await.unwrap();
|
table.add(Box::new(new_batches), Some(param)).await.unwrap();
|
||||||
assert_eq!(table.count_rows().await.unwrap(), 10);
|
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||||
assert_eq!(table.name, "test");
|
assert_eq!(table.name, "test");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1292,23 +1337,35 @@ mod tests {
|
|||||||
assert!(wrapper.called());
|
assert!(wrapper.called());
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_test_batches_with_offset(
|
fn merge_insert_test_batches(
|
||||||
offset: i32,
|
offset: i32,
|
||||||
|
age: i32,
|
||||||
) -> impl RecordBatchReader + Send + Sync + 'static {
|
) -> impl RecordBatchReader + Send + Sync + 'static {
|
||||||
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
|
let schema = Arc::new(Schema::new(vec![
|
||||||
|
Field::new("i", DataType::Int32, false),
|
||||||
|
Field::new("age", DataType::Int32, false),
|
||||||
|
]));
|
||||||
RecordBatchIterator::new(
|
RecordBatchIterator::new(
|
||||||
vec![RecordBatch::try_new(
|
vec![RecordBatch::try_new(
|
||||||
schema.clone(),
|
schema.clone(),
|
||||||
vec![Arc::new(Int32Array::from_iter_values(
|
vec![
|
||||||
offset..(offset + 10),
|
Arc::new(Int32Array::from_iter_values(offset..(offset + 10))),
|
||||||
))],
|
Arc::new(Int32Array::from_iter_values(iter::repeat(age).take(10))),
|
||||||
|
],
|
||||||
)],
|
)],
|
||||||
schema,
|
schema,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
|
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
|
||||||
make_test_batches_with_offset(0)
|
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
|
||||||
|
RecordBatchIterator::new(
|
||||||
|
vec![RecordBatch::try_new(
|
||||||
|
schema.clone(),
|
||||||
|
vec![Arc::new(Int32Array::from_iter_values(0..10))],
|
||||||
|
)],
|
||||||
|
schema,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -1365,7 +1422,7 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(table.load_indices().await.unwrap().len(), 1);
|
assert_eq!(table.load_indices().await.unwrap().len(), 1);
|
||||||
assert_eq!(table.count_rows().await.unwrap(), 512);
|
assert_eq!(table.count_rows(None).await.unwrap(), 512);
|
||||||
assert_eq!(table.name, "test");
|
assert_eq!(table.name, "test");
|
||||||
|
|
||||||
let indices = table.load_indices().await.unwrap();
|
let indices = table.load_indices().await.unwrap();
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ pub struct MergeInsertBuilder {
|
|||||||
table: Arc<dyn MergeInsert>,
|
table: Arc<dyn MergeInsert>,
|
||||||
pub(super) on: Vec<String>,
|
pub(super) on: Vec<String>,
|
||||||
pub(super) when_matched_update_all: bool,
|
pub(super) when_matched_update_all: bool,
|
||||||
|
pub(super) when_matched_update_all_filt: Option<String>,
|
||||||
pub(super) when_not_matched_insert_all: bool,
|
pub(super) when_not_matched_insert_all: bool,
|
||||||
pub(super) when_not_matched_by_source_delete: bool,
|
pub(super) when_not_matched_by_source_delete: bool,
|
||||||
pub(super) when_not_matched_by_source_delete_filt: Option<String>,
|
pub(super) when_not_matched_by_source_delete_filt: Option<String>,
|
||||||
@@ -46,6 +47,7 @@ impl MergeInsertBuilder {
|
|||||||
table,
|
table,
|
||||||
on,
|
on,
|
||||||
when_matched_update_all: false,
|
when_matched_update_all: false,
|
||||||
|
when_matched_update_all_filt: None,
|
||||||
when_not_matched_insert_all: false,
|
when_not_matched_insert_all: false,
|
||||||
when_not_matched_by_source_delete: false,
|
when_not_matched_by_source_delete: false,
|
||||||
when_not_matched_by_source_delete_filt: None,
|
when_not_matched_by_source_delete_filt: None,
|
||||||
@@ -59,8 +61,22 @@ impl MergeInsertBuilder {
|
|||||||
/// If there are multiple matches then the behavior is undefined.
|
/// If there are multiple matches then the behavior is undefined.
|
||||||
/// Currently this causes multiple copies of the row to be created
|
/// Currently this causes multiple copies of the row to be created
|
||||||
/// but that behavior is subject to change.
|
/// but that behavior is subject to change.
|
||||||
pub fn when_matched_update_all(&mut self) -> &mut Self {
|
///
|
||||||
|
/// An optional condition may be specified. If it is, then only
|
||||||
|
/// matched rows that satisfy the condtion will be updated. Any
|
||||||
|
/// rows that do not satisfy the condition will be left as they
|
||||||
|
/// are. Failing to satisfy the condition does not cause a
|
||||||
|
/// "matched row" to become a "not matched" row.
|
||||||
|
///
|
||||||
|
/// The condition should be an SQL string. Use the prefix
|
||||||
|
/// target. to refer to rows in the target table (old data)
|
||||||
|
/// and the prefix source. to refer to rows in the source
|
||||||
|
/// table (new data).
|
||||||
|
///
|
||||||
|
/// For example, "target.last_update < source.last_update"
|
||||||
|
pub fn when_matched_update_all(&mut self, condition: Option<String>) -> &mut Self {
|
||||||
self.when_matched_update_all = true;
|
self.when_matched_update_all = true;
|
||||||
|
self.when_matched_update_all_filt = condition;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user