Compare commits

...

15 Commits

Author SHA1 Message Date
Lance Release
82936c77ef [python] Bump version: 0.5.3 → 0.5.4 2024-02-09 22:56:45 +00:00
Weston Pace
dddcddcaf9 chore: bump lance version to 0.9.15 (#949) 2024-02-09 14:55:44 -08:00
Weston Pace
a9727eb318 feat: add support for filter during merge insert when matched (#948)
Closes #940
2024-02-09 10:26:14 -08:00
QianZhu
48d55bf952 added error msg to SaaS APIs (#852)
1. improved error msg for SaaS create_table and create_index

---------

Co-authored-by: Chang She <759245+changhiskhan@users.noreply.github.com>
2024-02-09 10:07:47 -08:00
Weston Pace
d2e71c8b08 feat: add a filterable count_rows to all the lancedb APIs (#913)
A `count_rows` method that takes a filter was recently added to
`LanceTable`. This PR adds it everywhere else except `RemoteTable` (that
will come soon).
2024-02-08 09:40:29 -08:00
Nitish Sharma
f53aace89c Minor updates to FAQ (#935)
Based on discussion over discord, adding minor updates to the FAQ
section about benchmarks, practical data size and concurrency in LanceDB
2024-02-07 20:49:25 -08:00
Ayush Chaurasia
d982ee934a feat(python): Reranker DX improvements (#904)
- Most users might not know how to use `QueryBuilder` object. Instead we
should just pass the string query.
- Add new rerankers: Colbert, openai
2024-02-06 13:59:31 +05:30
Will Jones
57605a2d86 feat(python): add read_consistency_interval argument (#828)
This PR refactors how we handle read consistency: does the `LanceTable`
class always pick up modifications to the table made by other instance
or processes. Users have three options they can set at the connection
level:

1. (Default) `read_consistency_interval=None` means it will not check at
all. Users can call `table.checkout_latest()` to manually check for
updates.
2. `read_consistency_interval=timedelta(0)` means **always** check for
updates, giving strong read consistency.
3. `read_consistency_interval=timedelta(seconds=20)` means check for
updates every 20 seconds. This is eventual consistency, a compromise
between the two options above.

## Table reference state

There is now an explicit difference between a `LanceTable` that tracks
the current version and one that is fixed at a historical version. We
now enforce that users cannot write if they have checked out an old
version. They are instructed to call `checkout_latest()` before calling
the write methods.

Since `conn.open_table()` doesn't have a parameter for version, users
will only get fixed references if they call `table.checkout()`.

The difference between these two can be seen in the repr: Table that are
fixed at a particular version will have a `version` displayed in the
repr. Otherwise, the version will not be shown.

```python
>>> table
LanceTable(connection=..., name="my_table")
>>> table.checkout(1)
>>> table
LanceTable(connection=..., name="my_table", version=1)
```

I decided to not create different classes for these states, because I
think we already have enough complexity with the Cloud vs OSS table
references.

Based on #812
2024-02-05 08:12:19 -08:00
Ayush Chaurasia
738511c5f2 feat(python): add support new openai embedding functions (#912)
@PrashantDixit0

---------

Co-authored-by: Chang She <759245+changhiskhan@users.noreply.github.com>
2024-02-04 18:19:42 -08:00
Lei Xu
0b0f42537e chore: add global cargo config to enable minimal cpu target (#925)
* Closes #895 
* Fix cargo clippy
2024-02-04 14:21:27 -08:00
QianZhu
e412194008 fix hybrid search example (#922) 2024-02-03 09:26:32 +05:30
Lance Release
a9088224c5 [python] Bump version: 0.5.2 → 0.5.3 2024-02-03 03:04:04 +00:00
Ayush Chaurasia
688c57a0d8 fix: revert safe_import_pandas usage (#921) 2024-02-02 18:57:13 -08:00
Lance Release
12a98deded Updating package-lock.json 2024-02-02 22:37:23 +00:00
Lance Release
e4bb042918 Updating package-lock.json 2024-02-02 21:57:07 +00:00
57 changed files with 1247 additions and 361 deletions

34
.cargo/config.toml Normal file
View File

@@ -0,0 +1,34 @@
[profile.release]
lto = "fat"
codegen-units = 1
[profile.release-with-debug]
inherits = "release"
debug = true
# Prioritize compile time over runtime performance
codegen-units = 16
lto = "thin"
[target.'cfg(all())']
rustflags = [
"-Wclippy::all",
"-Wclippy::style",
"-Wclippy::fallible_impl_from",
"-Wclippy::manual_let_else",
"-Wclippy::redundant_pub_crate",
"-Wclippy::string_add_assign",
"-Wclippy::string_add",
"-Wclippy::string_lit_as_bytes",
"-Wclippy::string_to_string",
"-Wclippy::use_self",
"-Dclippy::cargo",
"-Dclippy::dbg_macro",
# not too much we can do to avoid multiple crate versions
"-Aclippy::multiple-crate-versions",
]
[target.x86_64-unknown-linux-gnu]
rustflags = ["-C", "target-cpu=haswell", "-C", "target-feature=+avx2,+fma,+f16c"]
[target.aarch64-apple-darwin]
rustflags = ["-C", "target-cpu=apple-m1", "-C", "target-feature=+neon,+fp16,+fhm,+dotprod"]

View File

@@ -49,6 +49,9 @@ jobs:
test-node:
name: Test doc nodejs code
runs-on: "ubuntu-latest"
timeout-minutes: 45
strategy:
fail-fast: false
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -66,6 +69,12 @@ jobs:
uses: swatinem/rust-cache@v2
- name: Install node dependencies
run: |
sudo swapoff -a
sudo fallocate -l 8G /swapfile
sudo chmod 600 /swapfile
sudo mkswap /swapfile
sudo swapon /swapfile
sudo swapon --show
cd node
npm ci
npm run build-release

View File

@@ -6,15 +6,18 @@ resolver = "2"
[workspace.package]
edition = "2021"
authors = ["Lance Devs <dev@lancedb.com>"]
authors = ["LanceDB Devs <dev@lancedb.com>"]
license = "Apache-2.0"
repository = "https://github.com/lancedb/lancedb"
description = "Serverless, low-latency vector database for AI applications"
keywords = ["lancedb", "lance", "database", "vector", "search"]
categories = ["database-implementations"]
[workspace.dependencies]
lance = { "version" = "=0.9.12", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.9.12" }
lance-linalg = { "version" = "=0.9.12" }
lance-testing = { "version" = "=0.9.12" }
lance = { "version" = "=0.9.15", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.9.15" }
lance-linalg = { "version" = "=0.9.15" }
lance-testing = { "version" = "=0.9.15" }
# Note that this one does not include pyarrow
arrow = { version = "50.0", optional = false }
arrow-array = "50.0"

View File

@@ -69,3 +69,19 @@ MinIO supports an S3 compatible API. In order to connect to a MinIO instance, yo
- Set the envvar `AWS_ENDPOINT` to the URL of your MinIO API
- Set the envvars `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` with your MinIO credential
- Call `lancedb.connect("s3://minio_bucket_name")`
### Where can I find benchmarks for LanceDB?
Refer to this [post](https://blog.lancedb.com/benchmarking-lancedb-92b01032874a) for recent benchmarks.
### How much data can LanceDB practically manage without effecting performance?
We target good performance on ~10-50 billion rows and ~10-30 TB of data.
### Does LanceDB support concurrent operations?
LanceDB can handle concurrent reads very well, and can scale horizontally. The main constraint is how well the [storage layer](https://lancedb.github.io/lancedb/concepts/storage/) you've chosen scales. For writes, we support concurrent writing, though too many concurrent writers can lead to failing writes as there is a limited number of times a writer retries a commit
!!! info "Multiprocessing with LanceDB"
For multiprocessing you should probably not use ```fork``` as lance is multi-threaded internally and ```fork``` and multi-thread do not work well.[Refer to this discussion](https://discuss.python.org/t/concerns-regarding-deprecation-of-fork-with-alive-threads/33555)

View File

@@ -6,17 +6,24 @@ LanceDB supports both semantic and keyword-based search. In real world applicati
You can perform hybrid search in LanceDB by combining the results of semantic and full-text search via a reranking algorithm of your choice. LanceDB provides multiple rerankers out of the box. However, you can always write a custom reranker if your use case need more sophisticated logic .
```python
import os
import lancedb
import openai
from lancedb.embeddings import get_registry
from lancedb.pydanatic import LanceModel, Vector
from lancedb.pydantic import LanceModel, Vector
db = lancedb.connect("~/.lancedb")
# Ingest embedding function in LanceDB table
# Configuring the environment variable OPENAI_API_KEY
if "OPENAI_API_KEY" not in os.environ:
# OR set the key here as a variable
openai.api_key = "sk-..."
embeddings = get_registry().get("openai").create()
class Documents(LanceModel):
vector: Vector(embeddings.ndims) = embeddings.VectorField()
vector: Vector(embeddings.ndims()) = embeddings.VectorField()
text: str = embeddings.SourceField()
table = db.create_table("documents", schema=Documents)
@@ -31,17 +38,19 @@ data = [
# ingest docs with auto-vectorization
table.add(data)
# Create a fts index before the hybrid search
table.create_fts_index("text")
# hybrid search with default re-ranker
results = table.search("flower moon", query_type="hybrid").to_pandas()
```
By default, LanceDB uses `LinearCombinationReranker(weights=0.7)` to combine and rerank the results of semantic and full-text search. You can customize the hyperparameters as needed or write your own custom reranker. Here's how you can use any of the available rerankers:
By default, LanceDB uses `LinearCombinationReranker(weight=0.7)` to combine and rerank the results of semantic and full-text search. You can customize the hyperparameters as needed or write your own custom reranker. Here's how you can use any of the available rerankers:
### `rerank()` arguments
* `normalize`: `str`, default `"score"`:
The method to normalize the scores. Can be "rank" or "score". If "rank", the scores are converted to ranks and then normalized. If "score", the scores are normalized directly.
* `reranker`: `Reranker`, default `LinearCombinationReranker(weights=0.7)`.
* `reranker`: `Reranker`, default `LinearCombinationReranker(weight=0.7)`.
The reranker to use. If not specified, the default reranker is used.
@@ -55,7 +64,7 @@ This is the default re-ranker used by LanceDB. It combines the results of semant
```python
from lancedb.rerankers import LinearCombinationReranker
reranker = LinearCombinationReranker(weights=0.3) # Use 0.3 as the weight for vector search
reranker = LinearCombinationReranker(weight=0.3) # Use 0.3 as the weight for vector search
results = table.search("rebel", query_type="hybrid").rerank(reranker=reranker).to_pandas()
```
@@ -121,6 +130,60 @@ Arguments
Only returns `_relevance_score`. Does not support `return_score = "all"`.
### ColBERT Reranker
This reranker uses the ColBERT model to combine the results of semantic and full-text search. You can use it by passing `ColbertrReranker()` to the `rerank()` method.
ColBERT reranker model calculates relevance of given docs against the query and don't take existing fts and vector search scores into account, so it currently only supports `return_score="relevance"`. By default, it looks for `text` column to rerank the results. But you can specify the column name to use as input to the cross encoder model as described below.
```python
from lancedb.rerankers import ColbertReranker
reranker = ColbertReranker()
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
```
Arguments
----------------
* `model_name` : `str`, default `"colbert-ir/colbertv2.0"`
The name of the cross encoder model to use.
* `column` : `str`, default `"text"`
The name of the column to use as input to the cross encoder model.
* `return_score` : `str`, default `"relevance"`
options are `"relevance"` or `"all"`. Only `"relevance"` is supported for now.
!!! Note
Only returns `_relevance_score`. Does not support `return_score = "all"`.
### OpenAI Reranker
This reranker uses the OpenAI API to combine the results of semantic and full-text search. You can use it by passing `OpenaiReranker()` to the `rerank()` method.
!!! Note
This prompts chat model to rerank results which is not a dedicated reranker model. This should be treated as experimental.
!!! Tip
You might run out of token limit so set the search `limits` based on your token limit.
```python
from lancedb.rerankers import OpenaiReranker
reranker = OpenaiReranker()
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
```
Arguments
----------------
`model_name` : `str`, default `"gpt-3.5-turbo-1106"`
The name of the cross encoder model to use.
`column` : `str`, default `"text"`
The name of the column to use as input to the cross encoder model.
`return_score` : `str`, default `"relevance"`
options are "relevance" or "all". Only "relevance" is supported for now.
`api_key` : `str`, default `None`
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
## Building Custom Rerankers
You can build your own custom reranker by subclassing the `Reranker` class and implementing the `rerank_hybrid()` method. Here's an example of a custom reranker that combines the results of semantic and full-text search using a linear combination of the scores.
@@ -137,7 +200,7 @@ class MyReranker(Reranker):
self.param1 = param1
self.param2 = param2
def rerank_hybrid(self, vector_results: pa.Table, fts_results: pa.Table):
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table):
# Use the built-in merging function
combined_result = self.merge_results(vector_results, fts_results)
@@ -159,7 +222,7 @@ import pyarrow as pa
class MyReranker(Reranker):
...
def rerank_hybrid(self, vector_results: pa.Table, fts_results: pa.Table, filter: str):
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table, filter: str):
# Use the built-in merging function
combined_result = self.merge_results(vector_results, fts_results)

44
node/package-lock.json generated
View File

@@ -1,12 +1,12 @@
{
"name": "vectordb",
"version": "0.4.7",
"version": "0.4.8",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "vectordb",
"version": "0.4.7",
"version": "0.4.8",
"cpu": [
"x64",
"arm64"
@@ -53,11 +53,11 @@
"uuid": "^9.0.0"
},
"optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.4.7",
"@lancedb/vectordb-darwin-x64": "0.4.7",
"@lancedb/vectordb-linux-arm64-gnu": "0.4.7",
"@lancedb/vectordb-linux-x64-gnu": "0.4.7",
"@lancedb/vectordb-win32-x64-msvc": "0.4.7"
"@lancedb/vectordb-darwin-arm64": "0.4.8",
"@lancedb/vectordb-darwin-x64": "0.4.8",
"@lancedb/vectordb-linux-arm64-gnu": "0.4.8",
"@lancedb/vectordb-linux-x64-gnu": "0.4.8",
"@lancedb/vectordb-win32-x64-msvc": "0.4.8"
}
},
"node_modules/@75lb/deep-merge": {
@@ -329,9 +329,9 @@
}
},
"node_modules/@lancedb/vectordb-darwin-arm64": {
"version": "0.4.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.4.7.tgz",
"integrity": "sha512-kACOIytgjBfX8NRwjPKe311XRN3lbSN13B7avT5htMd3kYm3AnnMag9tZhlwoO7lIuvGaXhy7mApygJrjhfJ4g==",
"version": "0.4.8",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.4.8.tgz",
"integrity": "sha512-FpnJaw7KmNdD/FtOw9AcmPL5P+L04AcnfPj9ZyEjN8iCwB/qaOGYgdfBv+EbEtfHIsqA12q/1BRduu9KdB6BIA==",
"cpu": [
"arm64"
],
@@ -341,9 +341,9 @@
]
},
"node_modules/@lancedb/vectordb-darwin-x64": {
"version": "0.4.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.4.7.tgz",
"integrity": "sha512-vb74iK5uPWCwz5E60r3yWp/R/HSg54/Z9AZWYckYXqsPv4w/nfbkM5iZhfRqqR/9uE6JClWJKOtjbk7b8CFRFg==",
"version": "0.4.8",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.4.8.tgz",
"integrity": "sha512-RafOEYyZIgphp8wPGuVLFaTc8aAqo0NCO1LQMx0mB0xV96vrdo0Mooivs+dYN3RFfSHtTKPw9O1Jc957Vp1TLg==",
"cpu": [
"x64"
],
@@ -353,9 +353,9 @@
]
},
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.4.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.4.7.tgz",
"integrity": "sha512-jHp7THm6S9sB8RaCxGoZXLAwGAUHnawUUilB1K3mvQsRdfB2bBs0f7wDehW+PDhr+Iog4LshaWbcnoQEUJWR+Q==",
"version": "0.4.8",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.4.8.tgz",
"integrity": "sha512-WlbYNfj4+v1hBHUluF+hnlG/A0ZaQFdXBTGDfHQniL11o+n3emWm4ujP5nSAoQHXjSH9DaOTGr/N4Mc9Xe+luw==",
"cpu": [
"arm64"
],
@@ -365,9 +365,9 @@
]
},
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
"version": "0.4.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.4.7.tgz",
"integrity": "sha512-LKbVe6Wrp/AGqCCjKliNDmYoeTNgY/wfb2DTLjrx41Jko/04ywLrJ6xSEAn3XD5RDCO5u3fyUdXHHHv5a3VAAQ==",
"version": "0.4.8",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.4.8.tgz",
"integrity": "sha512-z+qFJrDqnNEv4JcwYDyt51PHmWjuM/XaOlSjpBnyyuUImeY+QcwctMuyXt8+Q4zhuqQR1AhLKrMwCU+YmMfk5g==",
"cpu": [
"x64"
],
@@ -377,9 +377,9 @@
]
},
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
"version": "0.4.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.4.7.tgz",
"integrity": "sha512-C5ln4+wafeY1Sm4PeV0Ios9lUaQVVip5Mjl9XU7ngioSEMEuXI/XMVfIdVfDPppVNXPeQxg33wLA272uw88D1Q==",
"version": "0.4.8",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.4.8.tgz",
"integrity": "sha512-VjUryVvEA04r0j4lU9pJy84cmjuQm1GhBzbPc8kwbn5voT4A6BPglrlNsU0Zc+j8Fbjyvauzw2lMEcMsF4F0rw==",
"cpu": [
"x64"
],

View File

@@ -372,7 +372,7 @@ export interface Table<T = number[]> {
/**
* Returns the number of rows in this table.
*/
countRows: () => Promise<number>
countRows: (filter?: string) => Promise<number>
/**
* Delete rows from this table.
@@ -525,8 +525,19 @@ export interface MergeInsertArgs {
* If there are multiple matches then the behavior is undefined.
* Currently this causes multiple copies of the row to be created
* but that behavior is subject to change.
*
* Optionally, a filter can be specified. This should be an SQL
* filter where fields with the prefix "target." refer to fields
* in the target table (old data) and fields with the prefix
* "source." refer to fields in the source table (new data). For
* example, the filter "target.lastUpdated < source.lastUpdated" will
* only update matched rows when the incoming `lastUpdated` value is
* newer.
*
* Rows that do not match the filter will not be updated. Rows that
* do not match the filter do become "not matched" rows.
*/
whenMatchedUpdateAll?: boolean
whenMatchedUpdateAll?: string | boolean
/**
* If true then rows that exist only in the source table (new data)
* will be inserted into the target table.
@@ -840,8 +851,8 @@ export class LocalTable<T = number[]> implements Table<T> {
/**
* Returns the number of rows in this table.
*/
async countRows (): Promise<number> {
return tableCountRows.call(this._tbl)
async countRows (filter?: string): Promise<number> {
return tableCountRows.call(this._tbl, filter)
}
/**
@@ -885,7 +896,14 @@ export class LocalTable<T = number[]> implements Table<T> {
}
async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
const whenMatchedUpdateAll = args.whenMatchedUpdateAll ?? false
let whenMatchedUpdateAll = false
let whenMatchedUpdateAllFilt = null
if (args.whenMatchedUpdateAll !== undefined && args.whenMatchedUpdateAll !== null) {
whenMatchedUpdateAll = true
if (args.whenMatchedUpdateAll !== true) {
whenMatchedUpdateAllFilt = args.whenMatchedUpdateAll
}
}
const whenNotMatchedInsertAll = args.whenNotMatchedInsertAll ?? false
let whenNotMatchedBySourceDelete = false
let whenNotMatchedBySourceDeleteFilt = null
@@ -909,6 +927,7 @@ export class LocalTable<T = number[]> implements Table<T> {
this._tbl,
on,
whenMatchedUpdateAll,
whenMatchedUpdateAllFilt,
whenNotMatchedInsertAll,
whenNotMatchedBySourceDelete,
whenNotMatchedBySourceDeleteFilt,

View File

@@ -286,8 +286,11 @@ export class RemoteTable<T = number[]> implements Table<T> {
const queryParams: any = {
on
}
if (args.whenMatchedUpdateAll ?? false) {
if (args.whenMatchedUpdateAll !== false && args.whenMatchedUpdateAll !== null && args.whenMatchedUpdateAll !== undefined) {
queryParams.when_matched_update_all = 'true'
if (typeof args.whenMatchedUpdateAll === 'string') {
queryParams.when_matched_update_all_filt = args.whenMatchedUpdateAll
}
} else {
queryParams.when_matched_update_all = 'false'
}

View File

@@ -294,6 +294,7 @@ describe('LanceDB client', function () {
})
assert.equal(table.name, 'vectors')
assert.equal(await table.countRows(), 10)
assert.equal(await table.countRows('vector IS NULL'), 0)
assert.deepEqual(await con.tableNames(), ['vectors'])
})
@@ -369,6 +370,7 @@ describe('LanceDB client', function () {
const table = await con.createTable('f16', data)
assert.equal(table.name, 'f16')
assert.equal(await table.countRows(), total)
assert.equal(await table.countRows('id < 5'), 5)
assert.deepEqual(await con.tableNames(), ['f16'])
assert.deepEqual(await table.schema, schema)
@@ -538,26 +540,36 @@ describe('LanceDB client', function () {
const data = [{ id: 1, age: 1 }, { id: 2, age: 1 }]
const table = await con.createTable('my_table', data)
// insert if not exists
let newData = [{ id: 2, age: 2 }, { id: 3, age: 2 }]
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true
})
assert.equal(await table.countRows(), 3)
assert.equal((await table.filter('age = 2').execute()).length, 1)
assert.equal(await table.countRows('age = 2'), 1)
newData = [{ id: 3, age: 3 }, { id: 4, age: 3 }]
// conditional update
newData = [{ id: 2, age: 3 }, { id: 3, age: 3 }]
await table.mergeInsert('id', newData, {
whenMatchedUpdateAll: 'target.age = 1'
})
assert.equal(await table.countRows(), 3)
assert.equal(await table.countRows('age = 1'), 1)
assert.equal(await table.countRows('age = 3'), 1)
newData = [{ id: 3, age: 4 }, { id: 4, age: 4 }]
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true,
whenMatchedUpdateAll: true
})
assert.equal(await table.countRows(), 4)
assert.equal((await table.filter('age = 3').execute()).length, 2)
assert.equal((await table.filter('age = 4').execute()).length, 2)
newData = [{ id: 5, age: 4 }]
newData = [{ id: 5, age: 5 }]
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true,
whenMatchedUpdateAll: true,
whenNotMatchedBySourceDelete: 'age < 3'
whenNotMatchedBySourceDelete: 'age < 4'
})
assert.equal(await table.countRows(), 3)

View File

@@ -1,9 +1,12 @@
[package]
name = "vectordb-nodejs"
edition = "2021"
edition.workspace = true
version = "0.0.0"
license.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
[lib]
crate-type = ["cdylib"]
@@ -14,15 +17,11 @@ futures.workspace = true
lance-linalg.workspace = true
lance.workspace = true
vectordb = { path = "../rust/vectordb" }
napi = { version = "2.14", default-features = false, features = [
napi = { version = "2.15", default-features = false, features = [
"napi7",
"async"
] }
napi-derive = "2.14"
napi-derive = "2"
[build-dependencies]
napi-build = "2.1"
[profile.release]
lto = true
strip = "symbols"

View File

@@ -2,4 +2,6 @@
module.exports = {
preset: 'ts-jest',
testEnvironment: 'node',
};
moduleDirectories: ["node_modules", "./dist"],
moduleFileExtensions: ["js", "ts"],
};

View File

@@ -57,8 +57,8 @@ impl Table {
}
#[napi]
pub async fn count_rows(&self) -> napi::Result<usize> {
self.table.count_rows().await.map_err(|e| {
pub async fn count_rows(&self, filter: Option<String>) -> napi::Result<usize> {
self.table.count_rows(filter).await.map_err(|e| {
napi::Error::from_reason(format!(
"Failed to count rows in table {}: {}",
self.table, e

View File

@@ -73,7 +73,7 @@ export class Table {
/** Return Schema as empty Arrow IPC file. */
schema(): Buffer
add(buf: Buffer): Promise<void>
countRows(): Promise<bigint>
countRows(filter?: string): Promise<bigint>
delete(predicate: string): Promise<void>
createIndex(): IndexBuilder
query(): Query

View File

@@ -50,8 +50,8 @@ export class Table {
}
/** Count the total number of rows in the dataset. */
async countRows(): Promise<bigint> {
return await this.inner.countRows();
async countRows(filter?: string): Promise<bigint> {
return await this.inner.countRows(filter);
}
/** Delete the rows that satisfy the predicate. */

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.5.2
current_version = 0.5.4
commit = True
message = [python] Bump version: {current_version} → {new_version}
tag = True

View File

@@ -42,6 +42,12 @@ To run the unit tests:
pytest
```
To run the doc tests:
```bash
pytest --doctest-modules lancedb
```
To run linter and automatically fix all errors:
```bash

View File

@@ -13,6 +13,7 @@
import importlib.metadata
import os
from datetime import timedelta
from typing import Optional
__version__ = importlib.metadata.version("lancedb")
@@ -30,6 +31,7 @@ def connect(
api_key: Optional[str] = None,
region: str = "us-east-1",
host_override: Optional[str] = None,
read_consistency_interval: Optional[timedelta] = None,
) -> DBConnection:
"""Connect to a LanceDB database.
@@ -45,6 +47,18 @@ def connect(
The region to use for LanceDB Cloud.
host_override: str, optional
The override url for LanceDB Cloud.
read_consistency_interval: timedelta, default None
(For LanceDB OSS only)
The interval at which to check for updates to the table from other
processes. If None, then consistency is not checked. For performance
reasons, this is the default. For strong consistency, set this to
zero seconds. Then every read will check for updates from other
processes. As a compromise, you can set this to a non-zero timedelta
for eventual consistency. If more than that interval has passed since
the last check, then the table will be checked for updates. Note: this
consistency only applies to read operations. Write operations are
always consistent.
Examples
--------
@@ -73,4 +87,4 @@ def connect(
if api_key is None:
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
return RemoteDBConnection(uri, api_key, region, host_override)
return LanceDBConnection(uri)
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)

View File

@@ -16,9 +16,9 @@ from typing import Iterable, List, Union
import numpy as np
import pyarrow as pa
from .util import safe_import
from .util import safe_import_pandas
pd = safe_import("pandas")
pd = safe_import_pandas()
DATA = Union[List[dict], dict, "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]]
VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray]

View File

@@ -16,9 +16,9 @@ import deprecation
from . import __version__
from .exceptions import MissingColumnError, MissingValueError
from .util import safe_import
from .util import safe_import_pandas
pd = safe_import("pandas")
pd = safe_import_pandas()
def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:

View File

@@ -26,6 +26,8 @@ from .table import LanceTable, Table
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
if TYPE_CHECKING:
from datetime import timedelta
from .common import DATA, URI
from .embeddings import EmbeddingFunctionConfig
from .pydantic import LanceModel
@@ -118,7 +120,7 @@ class DBConnection(EnforceOverrides):
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
>>> db.create_table("my_table", data)
LanceTable(my_table)
LanceTable(connection=..., name="my_table")
>>> db["my_table"].head()
pyarrow.Table
vector: fixed_size_list<item: float>[2]
@@ -139,7 +141,7 @@ class DBConnection(EnforceOverrides):
... "long": [-122.7, -74.1]
... })
>>> db.create_table("table2", data)
LanceTable(table2)
LanceTable(connection=..., name="table2")
>>> db["table2"].head()
pyarrow.Table
vector: fixed_size_list<item: float>[2]
@@ -161,7 +163,7 @@ class DBConnection(EnforceOverrides):
... pa.field("long", pa.float32())
... ])
>>> db.create_table("table3", data, schema = custom_schema)
LanceTable(table3)
LanceTable(connection=..., name="table3")
>>> db["table3"].head()
pyarrow.Table
vector: fixed_size_list<item: float>[2]
@@ -195,7 +197,7 @@ class DBConnection(EnforceOverrides):
... pa.field("price", pa.float32()),
... ])
>>> db.create_table("table4", make_batches(), schema=schema)
LanceTable(table4)
LanceTable(connection=..., name="table4")
"""
raise NotImplementedError
@@ -243,6 +245,16 @@ class LanceDBConnection(DBConnection):
----------
uri: str or Path
The root uri of the database.
read_consistency_interval: timedelta, default None
The interval at which to check for updates to the table from other
processes. If None, then consistency is not checked. For performance
reasons, this is the default. For strong consistency, set this to
zero seconds. Then every read will check for updates from other
processes. As a compromise, you can set this to a non-zero timedelta
for eventual consistency. If more than that interval has passed since
the last check, then the table will be checked for updates. Note: this
consistency only applies to read operations. Write operations are
always consistent.
Examples
--------
@@ -250,22 +262,24 @@ class LanceDBConnection(DBConnection):
>>> db = lancedb.connect("./.lancedb")
>>> db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2},
... {"vector": [0.5, 1.3], "b": 4}])
LanceTable(my_table)
LanceTable(connection=..., name="my_table")
>>> db.create_table("another_table", data=[{"vector": [0.4, 0.4], "b": 6}])
LanceTable(another_table)
LanceTable(connection=..., name="another_table")
>>> sorted(db.table_names())
['another_table', 'my_table']
>>> len(db)
2
>>> db["my_table"]
LanceTable(my_table)
LanceTable(connection=..., name="my_table")
>>> "my_table" in db
True
>>> db.drop_table("my_table")
>>> db.drop_table("another_table")
"""
def __init__(self, uri: URI):
def __init__(
self, uri: URI, *, read_consistency_interval: Optional[timedelta] = None
):
if not isinstance(uri, Path):
scheme = get_uri_scheme(uri)
is_local = isinstance(uri, Path) or scheme == "file"
@@ -277,6 +291,14 @@ class LanceDBConnection(DBConnection):
self._uri = str(uri)
self._entered = False
self.read_consistency_interval = read_consistency_interval
def __repr__(self) -> str:
val = f"{self.__class__.__name__}({self._uri}"
if self.read_consistency_interval is not None:
val += f", read_consistency_interval={repr(self.read_consistency_interval)}"
val += ")"
return val
@property
def uri(self) -> str:

View File

@@ -12,7 +12,7 @@
# limitations under the License.
import os
from functools import cached_property
from typing import List, Union
from typing import List, Optional, Union
import numpy as np
@@ -30,10 +30,21 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
"""
name: str = "text-embedding-ada-002"
dim: Optional[int] = None
def ndims(self):
# TODO don't hardcode this
return 1536
return self._ndims
@cached_property
def _ndims(self):
if self.name == "text-embedding-ada-002":
return 1536
elif self.name == "text-embedding-3-large":
return self.dim or 3072
elif self.name == "text-embedding-3-small":
return self.dim or 1536
else:
raise ValueError(f"Unknown model name {self.name}")
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
@@ -47,7 +58,12 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
The texts to embed
"""
# TODO retry, rate limit, token limit
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
if self.name == "text-embedding-ada-002":
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
else:
rs = self._openai_client.embeddings.create(
input=texts, model=self.name, dimensions=self.ndims()
)
return [v.embedding for v in rs.data]
@cached_property

View File

@@ -26,10 +26,10 @@ import pyarrow as pa
from lance.vector import vec_to_table
from retry import retry
from ..util import safe_import
from ..util import safe_import_pandas
from ..utils.general import LOGGER
pd = safe_import("pandas")
pd = safe_import_pandas()
DATA = Union[pa.Table, "pd.DataFrame"]
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]

View File

@@ -32,11 +32,14 @@ class LanceMergeInsertBuilder(object):
self._table = table
self._on = on
self._when_matched_update_all = False
self._when_matched_update_all_condition = None
self._when_not_matched_insert_all = False
self._when_not_matched_by_source_delete = False
self._when_not_matched_by_source_condition = None
def when_matched_update_all(self) -> LanceMergeInsertBuilder:
def when_matched_update_all(
self, *, where: Optional[str] = None
) -> LanceMergeInsertBuilder:
"""
Rows that exist in both the source table (new data) and
the target table (old data) will be updated, replacing
@@ -47,6 +50,7 @@ class LanceMergeInsertBuilder(object):
but that behavior is subject to change.
"""
self._when_matched_update_all = True
self._when_matched_update_all_condition = where
return self
def when_not_matched_insert_all(self) -> LanceMergeInsertBuilder:

View File

@@ -304,7 +304,7 @@ class LanceModel(pydantic.BaseModel):
... name: str
... vector: Vector(2)
...
>>> db = lancedb.connect("/tmp")
>>> db = lancedb.connect("./example")
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema())
>>> table.add([
... TestModel(name="test", vector=[1.0, 2.0])

View File

@@ -27,7 +27,7 @@ from . import __version__
from .common import VEC, VECTOR_COLUMN_NAME
from .rerankers.base import Reranker
from .rerankers.linear_combination import LinearCombinationReranker
from .util import safe_import
from .util import safe_import_pandas
if TYPE_CHECKING:
import PIL
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
from .pydantic import LanceModel
from .table import Table
pd = safe_import("pandas")
pd = safe_import_pandas()
class Query(pydantic.BaseModel):
@@ -626,7 +626,6 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
def __init__(self, table: "Table", query: str, vector_column: str):
super().__init__(table)
self._validate_fts_index()
self._query = query
vector_query, fts_query = self._validate_query(query)
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
vector_query = self._query_to_vector(table, vector_query, vector_column)
@@ -679,12 +678,18 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
# rerankers might need to preserve this score to support `return_score="all"`
fts_results = self._normalize_scores(fts_results, "score")
results = self._reranker.rerank_hybrid(self, vector_results, fts_results)
results = self._reranker.rerank_hybrid(
self._fts_query._query, vector_results, fts_results
)
if not isinstance(results, pa.Table): # Enforce type
raise TypeError(
f"rerank_hybrid must return a pyarrow.Table, got {type(results)}"
)
# apply limit after reranking
results = results.slice(length=self._limit)
if not self._with_row_id:
results = results.drop(["_rowid"])
return results
@@ -776,6 +781,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
"""
self._vector_query.limit(limit)
self._fts_query.limit(limit)
self._limit = limit
return self
def select(self, columns: list) -> LanceHybridQueryBuilder:

View File

@@ -118,6 +118,7 @@ class RemoteDBConnection(DBConnection):
schema: Optional[Union[pa.Schema, LanceModel]] = None,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
mode: Optional[str] = None,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
) -> Table:
"""Create a [Table][lancedb.table.Table] in the database.
@@ -215,11 +216,13 @@ class RemoteDBConnection(DBConnection):
if data is None and schema is None:
raise ValueError("Either data or schema must be provided.")
if embedding_functions is not None:
raise NotImplementedError(
"embedding_functions is not supported for remote databases."
logging.warning(
"embedding_functions is not yet supported on LanceDB Cloud."
"Please vote https://github.com/lancedb/lancedb/issues/626 "
"for this feature."
)
if mode is not None:
logging.warning("mode is not yet supported on LanceDB Cloud.")
if inspect.isclass(schema) and issubclass(schema, LanceModel):
# convert LanceModel to pyarrow schema

View File

@@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import uuid
from functools import cached_property
from typing import Dict, Optional, Union
@@ -37,6 +38,9 @@ class RemoteTable(Table):
def __repr__(self) -> str:
return f"RemoteTable({self._conn.db_name}.{self._name})"
def __len__(self) -> int:
self.count_rows(None)
@cached_property
def schema(self) -> pa.Schema:
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
@@ -54,17 +58,17 @@ class RemoteTable(Table):
return resp["version"]
def to_arrow(self) -> pa.Table:
"""to_arrow() is not supported on the LanceDB cloud"""
raise NotImplementedError("to_arrow() is not supported on the LanceDB cloud")
"""to_arrow() is not yet supported on LanceDB cloud."""
raise NotImplementedError("to_arrow() is not yet supported on LanceDB cloud.")
def to_pandas(self):
"""to_pandas() is not supported on the LanceDB cloud"""
return NotImplementedError("to_pandas() is not supported on the LanceDB cloud")
"""to_pandas() is not yet supported on LanceDB cloud."""
return NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.")
def create_scalar_index(self, *args, **kwargs):
"""Creates a scalar index"""
return NotImplementedError(
"create_scalar_index() is not supported on the LanceDB cloud"
"create_scalar_index() is not yet supported on LanceDB cloud."
)
def create_index(
@@ -72,6 +76,10 @@ class RemoteTable(Table):
metric="L2",
vector_column_name: str = VECTOR_COLUMN_NAME,
index_cache_size: Optional[int] = None,
num_partitions: Optional[int] = None,
num_sub_vectors: Optional[int] = None,
replace: Optional[bool] = None,
accelerator: Optional[str] = None,
):
"""Create an index on the table.
Currently, the only parameters that matter are
@@ -105,6 +113,28 @@ class RemoteTable(Table):
... )
>>> table.create_index("L2", "vector") # doctest: +SKIP
"""
if num_partitions is not None:
logging.warning(
"num_partitions is not supported on LanceDB cloud."
"This parameter will be tuned automatically."
)
if num_sub_vectors is not None:
logging.warning(
"num_sub_vectors is not supported on LanceDB cloud."
"This parameter will be tuned automatically."
)
if accelerator is not None:
logging.warning(
"GPU accelerator is not yet supported on LanceDB cloud."
"If you have 100M+ vectors to index,"
"please contact us at contact@lancedb.com"
)
if replace is not None:
logging.warning(
"replace is not supported on LanceDB cloud."
"Existing indexes will always be replaced."
)
index_type = "vector"
data = {
@@ -268,6 +298,10 @@ class RemoteTable(Table):
)
params["on"] = merge._on[0]
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
if merge._when_matched_update_all_condition is not None:
params[
"when_matched_update_all_filt"
] = merge._when_matched_update_all_condition
params["when_not_matched_insert_all"] = str(
merge._when_not_matched_insert_all
).lower()
@@ -409,6 +443,13 @@ class RemoteTable(Table):
"compact_files() is not supported on the LanceDB cloud"
)
def count_rows(self, filter: Optional[str] = None) -> int:
# payload = {"filter": filter}
# self._conn._client.post(f"/v1/table/{self._name}/count_rows/", data=payload)
return NotImplementedError(
"count_rows() is not yet supported on the LanceDB cloud"
)
def add_index(tbl: pa.Table, i: int) -> pa.Table:
return tbl.add_column(

View File

@@ -1,11 +1,15 @@
from .base import Reranker
from .cohere import CohereReranker
from .colbert import ColbertReranker
from .cross_encoder import CrossEncoderReranker
from .linear_combination import LinearCombinationReranker
from .openai import OpenaiReranker
__all__ = [
"Reranker",
"CrossEncoderReranker",
"CohereReranker",
"LinearCombinationReranker",
"OpenaiReranker",
"ColbertReranker",
]

View File

@@ -1,12 +1,8 @@
import typing
from abc import ABC, abstractmethod
import numpy as np
import pyarrow as pa
if typing.TYPE_CHECKING:
import lancedb
class Reranker(ABC):
def __init__(self, return_score: str = "relevance"):
@@ -30,7 +26,7 @@ class Reranker(ABC):
@abstractmethod
def rerank_hybrid(
query_builder: "lancedb.HybridQueryBuilder",
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
@@ -41,8 +37,8 @@ class Reranker(ABC):
Parameters
----------
query_builder : "lancedb.HybridQueryBuilder"
The query builder object that was used to generate the results
query : str
The input query
vector_results : pa.Table
The results from the vector search
fts_results : pa.Table
@@ -50,36 +46,6 @@ class Reranker(ABC):
"""
pass
def rerank_vector(
query_builder: "lancedb.VectorQueryBuilder", vector_results: pa.Table
):
"""
Rerank function receives the individual results from the vector search.
This isn't mandatory to implement
Parameters
----------
query_builder : "lancedb.VectorQueryBuilder"
The query builder object that was used to generate the results
vector_results : pa.Table
The results from the vector search
"""
raise NotImplementedError("Vector Reranking is not implemented")
def rerank_fts(query_builder: "lancedb.FTSQueryBuilder", fts_results: pa.Table):
"""
Rerank function receives the individual results from the FTS search.
This isn't mandatory to implement
Parameters
----------
query_builder : "lancedb.FTSQueryBuilder"
The query builder object that was used to generate the results
fts_results : pa.Table
The results from the FTS search
"""
raise NotImplementedError("FTS Reranking is not implemented")
def merge_results(self, vector_results: pa.Table, fts_results: pa.Table):
"""
Merge the results from the vector and FTS search. This is a vanilla merging

View File

@@ -1,5 +1,4 @@
import os
import typing
from functools import cached_property
from typing import Union
@@ -8,9 +7,6 @@ import pyarrow as pa
from ..util import safe_import
from .base import Reranker
if typing.TYPE_CHECKING:
import lancedb
class CohereReranker(Reranker):
"""
@@ -55,14 +51,14 @@ class CohereReranker(Reranker):
def rerank_hybrid(
self,
query_builder: "lancedb.HybridQueryBuilder",
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
docs = combined_results[self.column].to_pylist()
results = self._client.rerank(
query=query_builder._query,
query=query,
documents=docs,
top_n=self.top_n,
model=self.model_name,

View File

@@ -0,0 +1,107 @@
from functools import cached_property
import pyarrow as pa
from ..util import safe_import
from .base import Reranker
class ColbertReranker(Reranker):
"""
Reranks the results using the ColBERT model.
Parameters
----------
model_name : str, default "colbert-ir/colbertv2.0"
The name of the cross encoder model to use.
column : str, default "text"
The name of the column to use as input to the cross encoder model.
return_score : str, default "relevance"
options are "relevance" or "all". Only "relevance" is supported for now.
"""
def __init__(
self,
model_name: str = "colbert-ir/colbertv2.0",
column: str = "text",
return_score="relevance",
):
super().__init__(return_score)
self.model_name = model_name
self.column = column
self.torch = safe_import("torch") # import here for faster ops later
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
docs = combined_results[self.column].to_pylist()
tokenizer, model = self._model
# Encode the query
query_encoding = tokenizer(query, return_tensors="pt")
query_embedding = model(**query_encoding).last_hidden_state.mean(dim=1)
scores = []
# Get score for each document
for document in docs:
document_encoding = tokenizer(
document, return_tensors="pt", truncation=True, max_length=512
)
document_embedding = model(**document_encoding).last_hidden_state
# Calculate MaxSim score
score = self.maxsim(query_embedding.unsqueeze(0), document_embedding)
scores.append(score.item())
# replace the self.column column with the docs
combined_results = combined_results.drop(self.column)
combined_results = combined_results.append_column(
self.column, pa.array(docs, type=pa.string())
)
# add the scores
combined_results = combined_results.append_column(
"_relevance_score", pa.array(scores, type=pa.float32())
)
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
elif self.score == "all":
raise NotImplementedError(
"OpenAI Reranker does not support score='all' yet"
)
combined_results = combined_results.sort_by(
[("_relevance_score", "descending")]
)
return combined_results
@cached_property
def _model(self):
transformers = safe_import("transformers")
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
model = transformers.AutoModel.from_pretrained(self.model_name)
return tokenizer, model
def maxsim(self, query_embedding, document_embedding):
# Expand dimensions for broadcasting
# Query: [batch, length, size] -> [batch, query, 1, size]
# Document: [batch, length, size] -> [batch, 1, length, size]
expanded_query = query_embedding.unsqueeze(2)
expanded_doc = document_embedding.unsqueeze(1)
# Compute cosine similarity across the embedding dimension
sim_matrix = self.torch.nn.functional.cosine_similarity(
expanded_query, expanded_doc, dim=-1
)
# Take the maximum similarity for each query token (across all document tokens)
# sim_matrix shape: [batch_size, query_length, doc_length]
max_sim_scores, _ = self.torch.max(sim_matrix, dim=2)
# Average these maximum scores across all query tokens
avg_max_sim = self.torch.mean(max_sim_scores, dim=1)
return avg_max_sim

View File

@@ -1,4 +1,3 @@
import typing
from functools import cached_property
from typing import Union
@@ -7,9 +6,6 @@ import pyarrow as pa
from ..util import safe_import
from .base import Reranker
if typing.TYPE_CHECKING:
import lancedb
class CrossEncoderReranker(Reranker):
"""
@@ -52,13 +48,13 @@ class CrossEncoderReranker(Reranker):
def rerank_hybrid(
self,
query_builder: "lancedb.HybridQueryBuilder",
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
passages = combined_results[self.column].to_pylist()
cross_inp = [[query_builder._query, passage] for passage in passages]
cross_inp = [[query, passage] for passage in passages]
cross_scores = self.model.predict(cross_inp)
combined_results = combined_results.append_column(
"_relevance_score", pa.array(cross_scores, type=pa.float32())

View File

@@ -36,7 +36,7 @@ class LinearCombinationReranker(Reranker):
def rerank_hybrid(
self,
query_builder: "lancedb.HybridQueryBuilder", # noqa: F821
query: str, # noqa: F821
vector_results: pa.Table,
fts_results: pa.Table,
):

View File

@@ -0,0 +1,102 @@
import json
import os
from functools import cached_property
from typing import Optional
import pyarrow as pa
from ..util import safe_import
from .base import Reranker
class OpenaiReranker(Reranker):
"""
Reranks the results using the OpenAI API.
WARNING: This is a prompt based reranker that uses chat model that is
not a dedicated reranker API. This should be treated as experimental.
Parameters
----------
model_name : str, default "gpt-3.5-turbo-1106 "
The name of the cross encoder model to use.
column : str, default "text"
The name of the column to use as input to the cross encoder model.
return_score : str, default "relevance"
options are "relevance" or "all". Only "relevance" is supported for now.
api_key : str, default None
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
"""
def __init__(
self,
model_name: str = "gpt-3.5-turbo-1106",
column: str = "text",
return_score="relevance",
api_key: Optional[str] = None,
):
super().__init__(return_score)
self.model_name = model_name
self.column = column
self.api_key = api_key
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
docs = combined_results[self.column].to_pylist()
response = self._client.chat.completions.create(
model=self.model_name,
response_format={"type": "json_object"},
temperature=0,
messages=[
{
"role": "system",
"content": "You are an expert relevance ranker. Given a list of\
documents and a query, your job is to determine the relevance\
each document is for answering the query. Your output is JSON,\
which is a list of documents. Each document has two fields,\
content and relevance_score. relevance_score is from 0.0 to\
1.0 indicating the relevance of the text to the given query.\
Make sure to include all documents in the response.",
},
{"role": "user", "content": f"Query: {query} Docs: {docs}"},
],
)
results = json.loads(response.choices[0].message.content)["documents"]
docs, scores = list(
zip(*[(result["content"], result["relevance_score"]) for result in results])
) # tuples
# replace the self.column column with the docs
combined_results = combined_results.drop(self.column)
combined_results = combined_results.append_column(
self.column, pa.array(docs, type=pa.string())
)
# add the scores
combined_results = combined_results.append_column(
"_relevance_score", pa.array(scores, type=pa.float32())
)
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
elif self.score == "all":
raise NotImplementedError(
"OpenAI Reranker does not support score='all' yet"
)
combined_results = combined_results.sort_by(
[("_relevance_score", "descending")]
)
return combined_results
@cached_property
def _client(self):
openai = safe_import("openai") # TODO: force version or handle versions < 1.0
if os.environ.get("OPENAI_API_KEY") is None and self.api_key is None:
raise ValueError(
"OPENAI_API_KEY not set. Either set it in your environment or \
pass it as `api_key` argument to the CohereReranker."
)
return openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY") or self.api_key)

View File

@@ -14,7 +14,10 @@
from __future__ import annotations
import inspect
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
@@ -34,22 +37,21 @@ from .query import LanceQueryBuilder, Query
from .util import (
fs_from_uri,
join_uri,
safe_import,
safe_import_pandas,
safe_import_polars,
value_to_sql,
)
from .utils.events import register_event
if TYPE_CHECKING:
from datetime import timedelta
import PIL
from lance.dataset import CleanupStats, ReaderLike
from .db import LanceDBConnection
pd = safe_import("pandas")
pl = safe_import("polars")
pd = safe_import_pandas()
pl = safe_import_polars()
def _sanitize_data(
@@ -175,6 +177,18 @@ class Table(ABC):
"""
raise NotImplementedError
@abstractmethod
def count_rows(self, filter: Optional[str] = None) -> int:
"""
Count the number of rows in the table.
Parameters
----------
filter: str, optional
A SQL where clause to filter the rows to count.
"""
raise NotImplementedError
def to_pandas(self) -> "pd.DataFrame":
"""Return the table as a pandas DataFrame.
@@ -298,7 +312,7 @@ class Table(ABC):
import lance
dataset = lance.dataset("/tmp/images.lance")
dataset = lance.dataset("./images.lance")
dataset.create_scalar_index("category")
"""
raise NotImplementedError
@@ -445,7 +459,7 @@ class Table(ABC):
*default "vector"*
query_type: str
*default "auto"*.
Acceptable types are: "vector", "fts", or "auto"
Acceptable types are: "vector", "fts", "hybrid", or "auto"
- If "auto" then the query type is inferred from the query;
@@ -641,23 +655,145 @@ class Table(ABC):
"""
class _LanceDatasetRef(ABC):
@property
@abstractmethod
def dataset(self) -> LanceDataset:
pass
@property
@abstractmethod
def dataset_mut(self) -> LanceDataset:
pass
@dataclass
class _LanceLatestDatasetRef(_LanceDatasetRef):
"""Reference to the latest version of a LanceDataset."""
uri: str
read_consistency_interval: Optional[timedelta] = None
last_consistency_check: Optional[float] = None
_dataset: Optional[LanceDataset] = None
@property
def dataset(self) -> LanceDataset:
if not self._dataset:
self._dataset = lance.dataset(self.uri)
self.last_consistency_check = time.monotonic()
elif self.read_consistency_interval is not None:
now = time.monotonic()
diff = timedelta(seconds=now - self.last_consistency_check)
if (
self.last_consistency_check is None
or diff > self.read_consistency_interval
):
self._dataset = self._dataset.checkout_version(
self._dataset.latest_version
)
self.last_consistency_check = time.monotonic()
return self._dataset
@dataset.setter
def dataset(self, value: LanceDataset):
self._dataset = value
self.last_consistency_check = time.monotonic()
@property
def dataset_mut(self) -> LanceDataset:
return self.dataset
@dataclass
class _LanceTimeTravelRef(_LanceDatasetRef):
uri: str
version: int
_dataset: Optional[LanceDataset] = None
@property
def dataset(self) -> LanceDataset:
if not self._dataset:
self._dataset = lance.dataset(self.uri, version=self.version)
return self._dataset
@dataset.setter
def dataset(self, value: LanceDataset):
self._dataset = value
self.version = value.version
@property
def dataset_mut(self) -> LanceDataset:
raise ValueError(
"Cannot mutate table reference fixed at version "
f"{self.version}. Call checkout_latest() to get a mutable "
"table reference."
)
class LanceTable(Table):
"""
A table in a LanceDB database.
This can be opened in two modes: standard and time-travel.
Standard mode is the default. In this mode, the table is mutable and tracks
the latest version of the table. The level of read consistency is controlled
by the `read_consistency_interval` parameter on the connection.
Time-travel mode is activated by specifying a version number. In this mode,
the table is immutable and fixed to a specific version. This is useful for
querying historical versions of the table.
"""
def __init__(self, connection: "LanceDBConnection", name: str, version: int = None):
def __init__(
self,
connection: "LanceDBConnection",
name: str,
version: Optional[int] = None,
):
self._conn = connection
self.name = name
self._version = version
def _reset_dataset(self, version=None):
try:
if "_dataset" in self.__dict__:
del self.__dict__["_dataset"]
self._version = version
except AttributeError:
pass
if version is not None:
self._ref = _LanceTimeTravelRef(
uri=self._dataset_uri,
version=version,
)
else:
self._ref = _LanceLatestDatasetRef(
uri=self._dataset_uri,
read_consistency_interval=connection.read_consistency_interval,
)
@classmethod
def open(cls, db, name, **kwargs):
tbl = cls(db, name, **kwargs)
fs, path = fs_from_uri(tbl._dataset_uri)
file_info = fs.get_file_info(path)
if file_info.type != pa.fs.FileType.Directory:
raise FileNotFoundError(
f"Table {name} does not exist."
f"Please first call db.create_table({name}, data)"
)
register_event("open_table")
return tbl
@property
def _dataset_uri(self) -> str:
return join_uri(self._conn.uri, f"{self.name}.lance")
@property
def _dataset(self) -> LanceDataset:
return self._ref.dataset
@property
def _dataset_mut(self) -> LanceDataset:
return self._ref.dataset_mut
def to_lance(self) -> LanceDataset:
"""Return the LanceDataset backing this table."""
return self._dataset
@property
def schema(self) -> pa.Schema:
@@ -685,6 +821,9 @@ class LanceTable(Table):
keep writing to the dataset starting from an old version, then use
the `restore` function.
Calling this method will set the table into time-travel mode. If you
wish to return to standard mode, call `checkout_latest`.
Parameters
----------
version : int
@@ -709,15 +848,13 @@ class LanceTable(Table):
vector type
0 [1.1, 0.9] vector
"""
max_ver = max([v["version"] for v in self._dataset.versions()])
max_ver = self._dataset.latest_version
if version < 1 or version > max_ver:
raise ValueError(f"Invalid version {version}")
self._reset_dataset(version=version)
try:
# Accessing the property updates the cached value
_ = self._dataset
except Exception as e:
ds = self._dataset.checkout_version(version)
except IOError as e:
if "not found" in str(e):
raise ValueError(
f"Version {version} no longer exists. Was it cleaned up?"
@@ -725,6 +862,27 @@ class LanceTable(Table):
else:
raise e
self._ref = _LanceTimeTravelRef(
uri=self._dataset_uri,
version=version,
)
# We've already loaded the version so we can populate it directly.
self._ref.dataset = ds
def checkout_latest(self):
"""Checkout the latest version of the table. This is an in-place operation.
The table will be set back into standard mode, and will track the latest
version of the table.
"""
self.checkout(self._dataset.latest_version)
ds = self._ref.dataset
self._ref = _LanceLatestDatasetRef(
uri=self._dataset_uri,
read_consistency_interval=self._conn.read_consistency_interval,
)
self._ref.dataset = ds
def restore(self, version: int = None):
"""Restore a version of the table. This is an in-place operation.
@@ -759,7 +917,7 @@ class LanceTable(Table):
>>> len(table.list_versions())
4
"""
max_ver = max([v["version"] for v in self._dataset.versions()])
max_ver = self._dataset.latest_version
if version is None:
version = self.version
elif version < 1 or version > max_ver:
@@ -767,29 +925,30 @@ class LanceTable(Table):
else:
self.checkout(version)
if version == max_ver:
# no-op if restoring the latest version
return
ds = self._dataset
self._dataset.restore()
self._reset_dataset()
# no-op if restoring the latest version
if version != max_ver:
ds.restore()
self._ref = _LanceLatestDatasetRef(
uri=self._dataset_uri,
read_consistency_interval=self._conn.read_consistency_interval,
)
self._ref.dataset = ds
def count_rows(self, filter: Optional[str] = None) -> int:
"""
Count the number of rows in the table.
Parameters
----------
filter: str, optional
A SQL where clause to filter the rows to count.
"""
return self._dataset.count_rows(filter)
def __len__(self):
return self.count_rows()
def __repr__(self) -> str:
return f"LanceTable({self.name})"
val = f'{self.__class__.__name__}(connection={self._conn!r}, name="{self.name}"'
if isinstance(self._ref, _LanceTimeTravelRef):
val += f", version={self._ref.version}"
val += ")"
return val
def __str__(self) -> str:
return self.__repr__()
@@ -839,10 +998,6 @@ class LanceTable(Table):
self.to_lance(), allow_pyarrow_filter=False, batch_size=batch_size
)
@property
def _dataset_uri(self) -> str:
return join_uri(self._conn.uri, f"{self.name}.lance")
def create_index(
self,
metric="L2",
@@ -854,7 +1009,7 @@ class LanceTable(Table):
index_cache_size: Optional[int] = None,
):
"""Create an index on the table."""
self._dataset.create_index(
self._dataset_mut.create_index(
column=vector_column_name,
index_type="IVF_PQ",
metric=metric,
@@ -864,11 +1019,12 @@ class LanceTable(Table):
accelerator=accelerator,
index_cache_size=index_cache_size,
)
self._reset_dataset()
register_event("create_index")
def create_scalar_index(self, column: str, *, replace: bool = True):
self._dataset.create_scalar_index(column, index_type="BTREE", replace=replace)
self._dataset_mut.create_scalar_index(
column, index_type="BTREE", replace=replace
)
def create_fts_index(
self,
@@ -911,14 +1067,6 @@ class LanceTable(Table):
def _get_fts_index_path(self):
return join_uri(self._dataset_uri, "_indices", "tantivy")
@cached_property
def _dataset(self) -> LanceDataset:
return lance.dataset(self._dataset_uri, version=self._version)
def to_lance(self) -> LanceDataset:
"""Return the LanceDataset backing this table."""
return self._dataset
def add(
self,
data: DATA,
@@ -957,8 +1105,11 @@ class LanceTable(Table):
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
lance.write_dataset(data, self._dataset_uri, schema=self.schema, mode=mode)
self._reset_dataset()
# Access the dataset_mut property to ensure that the dataset is mutable.
self._ref.dataset_mut
self._ref.dataset = lance.write_dataset(
data, self._dataset_uri, schema=self.schema, mode=mode
)
register_event("add")
def merge(
@@ -1019,10 +1170,9 @@ class LanceTable(Table):
other_table = other_table.to_lance()
if isinstance(other_table, LanceDataset):
other_table = other_table.to_table()
self._dataset.merge(
self._ref.dataset = self._dataset_mut.merge(
other_table, left_on=left_on, right_on=right_on, schema=schema
)
self._reset_dataset()
register_event("merge")
@cached_property
@@ -1225,22 +1375,8 @@ class LanceTable(Table):
register_event("create_table")
return new_table
@classmethod
def open(cls, db, name):
tbl = cls(db, name)
fs, path = fs_from_uri(tbl._dataset_uri)
file_info = fs.get_file_info(path)
if file_info.type != pa.fs.FileType.Directory:
raise FileNotFoundError(
f"Table {name} does not exist."
f"Please first call db.create_table({name}, data)"
)
register_event("open_table")
return tbl
def delete(self, where: str):
self._dataset.delete(where)
self._dataset_mut.delete(where)
def update(
self,
@@ -1294,8 +1430,7 @@ class LanceTable(Table):
if values is not None:
values_sql = {k: value_to_sql(v) for k, v in values.items()}
self.to_lance().update(values_sql, where)
self._reset_dataset()
self._dataset_mut.update(values_sql, where)
register_event("update")
def _execute_query(self, query: Query) -> pa.Table:
@@ -1332,7 +1467,7 @@ class LanceTable(Table):
ds = self.to_lance()
builder = ds.merge_insert(merge._on)
if merge._when_matched_update_all:
builder.when_matched_update_all()
builder.when_matched_update_all(merge._when_matched_update_all_condition)
if merge._when_not_matched_insert_all:
builder.when_not_matched_insert_all()
if merge._when_not_matched_by_source_delete:

View File

@@ -134,6 +134,24 @@ def safe_import(module: str, mitigation=None):
raise ImportError(f"Please install {mitigation or module}")
def safe_import_pandas():
try:
import pandas as pd
return pd
except ImportError:
return None
def safe_import_polars():
try:
import polars as pl
return pl
except ImportError:
return None
@singledispatch
def value_to_sql(value):
raise NotImplementedError("SQL conversion is not implemented for this type")

View File

@@ -1,9 +1,9 @@
[project]
name = "lancedb"
version = "0.5.2"
version = "0.5.4"
dependencies = [
"deprecation",
"pylance==0.9.12",
"pylance==0.9.15",
"ratelimiter~=1.0",
"retry>=0.9.2",
"tqdm>=4.27.0",
@@ -48,7 +48,7 @@ classifiers = [
repository = "https://github.com/lancedb/lancedb"
[project.optional-dependencies]
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars"]
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars>=0.19"]
dev = ["ruff", "pre-commit"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"]

View File

@@ -88,6 +88,7 @@ def test_embedding_function(tmp_path):
assert np.allclose(actual, expected)
@pytest.mark.slow
def test_embedding_function_rate_limit(tmp_path):
def _get_schema_from_model(model):
class Schema(LanceModel):

View File

@@ -23,11 +23,6 @@ import lancedb
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
try:
if importlib.util.find_spec("mlx.core") is not None:
_mlx = True
except ImportError:
_mlx = None
# These are integration tests for embedding functions.
# They are slow because they require downloading models
# or connection to external api
@@ -210,6 +205,13 @@ def test_gemini_embedding(tmp_path):
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
try:
if importlib.util.find_spec("mlx.core") is not None:
_mlx = True
except ImportError:
_mlx = None
@pytest.mark.skipif(
_mlx is None,
reason="mlx tests only required for apple users.",
@@ -266,3 +268,49 @@ def test_bedrock_embedding(tmp_path):
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
)
def test_openai_embedding(tmp_path):
def _get_table(model):
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
return tbl
model = get_registry().get("openai").create(max_retries=0)
tbl = _get_table(model)
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
model = (
get_registry()
.get("openai")
.create(max_retries=0, name="text-embedding-3-large")
)
tbl = _get_table(model)
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
model = (
get_registry()
.get("openai")
.create(max_retries=0, name="text-embedding-3-large", dim=1024)
)
tbl = _get_table(model)
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"

View File

@@ -7,7 +7,12 @@ import lancedb
from lancedb.conftest import MockTextEmbeddingFunction # noqa
from lancedb.embeddings import EmbeddingFunctionRegistry
from lancedb.pydantic import LanceModel, Vector
from lancedb.rerankers import CohereReranker, CrossEncoderReranker
from lancedb.rerankers import (
CohereReranker,
ColbertReranker,
CrossEncoderReranker,
OpenaiReranker,
)
from lancedb.table import LanceTable
@@ -75,7 +80,6 @@ def get_test_table(tmp_path):
return table, MyTable
## These tests are pretty loose, we should also check for correctness
def test_linear_combination(tmp_path):
table, schema = get_test_table(tmp_path)
# The default reranker
@@ -95,14 +99,19 @@ def test_linear_combination(tmp_path):
assert result1 == result3 # 2 & 3 should be the same as they use score as score
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search("Our father who art in heaven", query_type="hybrid")
.limit(50)
table.search((query_vector, query))
.limit(30)
.rerank(normalize="score")
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _score column of the results returned by the reranker "
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
@@ -122,19 +131,24 @@ def test_cohere_reranker(tmp_path):
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="rank", reranker=CohereReranker())
.rerank(reranker=CohereReranker())
.to_pydantic(schema)
)
assert result1 == result2
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search("Our father who art in heaven", query_type="hybrid")
.limit(50)
table.search((query_vector, query))
.limit(30)
.rerank(reranker=CohereReranker())
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _score column of the results returned by the reranker "
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
@@ -150,19 +164,96 @@ def test_cross_encoder_reranker(tmp_path):
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="rank", reranker=CrossEncoderReranker())
.rerank(reranker=CrossEncoderReranker())
.to_pydantic(schema)
)
assert result1 == result2
# test explicit hybrid query
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search("Our father who art in heaven", query_type="hybrid")
.limit(50)
table.search((query_vector, query), query_type="hybrid")
.limit(30)
.rerank(reranker=CrossEncoderReranker())
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _score column of the results returned by the reranker "
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
def test_colbert_reranker(tmp_path):
pytest.importorskip("transformers")
table, schema = get_test_table(tmp_path)
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=ColbertReranker())
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=ColbertReranker())
.to_pydantic(schema)
)
assert result1 == result2
# test explicit hybrid query
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
.limit(30)
.rerank(reranker=ColbertReranker())
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
)
def test_openai_reranker(tmp_path):
pytest.importorskip("openai")
table, schema = get_test_table(tmp_path)
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=OpenaiReranker())
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=OpenaiReranker())
.to_pydantic(schema)
)
assert result1 == result2
# test explicit hybrid query
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
.limit(30)
.rerank(reranker=OpenaiReranker())
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)

View File

@@ -12,8 +12,10 @@
# limitations under the License.
import functools
from copy import copy
from datetime import date, datetime, timedelta
from pathlib import Path
from time import sleep
from typing import List
from unittest.mock import PropertyMock, patch
@@ -25,6 +27,7 @@ import pyarrow as pa
import pytest
from pydantic import BaseModel
import lancedb
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.db import LanceDBConnection
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
@@ -35,6 +38,7 @@ from lancedb.table import LanceTable
class MockDB:
def __init__(self, uri: Path):
self.uri = uri
self.read_consistency_interval = None
@functools.cached_property
def is_managed_remote(self) -> bool:
@@ -267,39 +271,38 @@ def test_versioning(db):
def test_create_index_method():
with patch.object(LanceTable, "_reset_dataset", return_value=None):
with patch.object(
LanceTable, "_dataset", new_callable=PropertyMock
) as mock_dataset:
# Setup mock responses
mock_dataset.return_value.create_index.return_value = None
with patch.object(
LanceTable, "_dataset_mut", new_callable=PropertyMock
) as mock_dataset:
# Setup mock responses
mock_dataset.return_value.create_index.return_value = None
# Create a LanceTable object
connection = LanceDBConnection(uri="mock.uri")
table = LanceTable(connection, "test_table")
# Create a LanceTable object
connection = LanceDBConnection(uri="mock.uri")
table = LanceTable(connection, "test_table")
# Call the create_index method
table.create_index(
metric="L2",
num_partitions=256,
num_sub_vectors=96,
vector_column_name="vector",
replace=True,
index_cache_size=256,
)
# Call the create_index method
table.create_index(
metric="L2",
num_partitions=256,
num_sub_vectors=96,
vector_column_name="vector",
replace=True,
index_cache_size=256,
)
# Check that the _dataset.create_index method was called
# with the right parameters
mock_dataset.return_value.create_index.assert_called_once_with(
column="vector",
index_type="IVF_PQ",
metric="L2",
num_partitions=256,
num_sub_vectors=96,
replace=True,
accelerator=None,
index_cache_size=256,
)
# Check that the _dataset.create_index method was called
# with the right parameters
mock_dataset.return_value.create_index.assert_called_once_with(
column="vector",
index_type="IVF_PQ",
metric="L2",
num_partitions=256,
num_sub_vectors=96,
replace=True,
accelerator=None,
index_cache_size=256,
)
def test_add_with_nans(db):
@@ -510,8 +513,15 @@ def test_merge_insert(db):
).when_matched_update_all().when_not_matched_insert_all().execute(new_data)
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "x", "y", "z"]})
# These `sort_by` calls can be removed once lance#1892
# is merged (it fixes the ordering)
assert table.to_arrow().sort_by("a") == expected
table.restore(version)
# conditional update
table.merge_insert("a").when_matched_update_all(where="target.b = 'b'").execute(
new_data
)
expected = pa.table({"a": [1, 2, 3], "b": ["a", "x", "c"]})
assert table.to_arrow().sort_by("a") == expected
table.restore(version)
@@ -792,3 +802,48 @@ def test_hybrid_search(db):
"Our father who art in heaven", query_type="hybrid"
).to_pydantic(MyTable)
assert result1 == result3
@pytest.mark.parametrize(
"consistency_interval", [None, timedelta(seconds=0), timedelta(seconds=0.1)]
)
def test_consistency(tmp_path, consistency_interval):
db = lancedb.connect(tmp_path)
table = LanceTable.create(db, "my_table", data=[{"id": 0}])
db2 = lancedb.connect(tmp_path, read_consistency_interval=consistency_interval)
table2 = db2.open_table("my_table")
assert table2.version == table.version
table.add([{"id": 1}])
if consistency_interval is None:
assert table2.version == table.version - 1
table2.checkout_latest()
assert table2.version == table.version
elif consistency_interval == timedelta(seconds=0):
assert table2.version == table.version
else:
# (consistency_interval == timedelta(seconds=0.1)
assert table2.version == table.version - 1
sleep(0.1)
assert table2.version == table.version
def test_restore_consistency(tmp_path):
db = lancedb.connect(tmp_path)
table = LanceTable.create(db, "my_table", data=[{"id": 0}])
db2 = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0))
table2 = db2.open_table("my_table")
assert table2.version == table.version
# If we call checkout, it should lose consistency
table_fixed = copy(table2)
table_fixed.checkout(table.version)
# But if we call checkout_latest, it should be consistent again
table_ref_latest = copy(table_fixed)
table_ref_latest.checkout_latest()
table.add([{"id": 2}])
assert table_fixed.version == table.version - 1
assert table_ref_latest.version == table.version

View File

@@ -2,8 +2,11 @@
name = "vectordb-node"
version = "0.4.8"
description = "Serverless, low-latency vector database for AI applications"
license = "Apache-2.0"
edition = "2018"
license.workspace = true
edition.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
exclude = ["index.node"]
[lib]

View File

@@ -22,7 +22,7 @@ use arrow_schema::SchemaRef;
use crate::error::Result;
pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBatch>, SchemaRef)> {
pub fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBatch>, SchemaRef)> {
let mut batches: Vec<RecordBatch> = Vec::new();
let file_reader = FileReader::try_new(Cursor::new(slice), None)?;
let schema = file_reader.schema();
@@ -33,7 +33,7 @@ pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBa
Ok((batches, schema))
}
pub(crate) fn record_batch_to_buffer(batches: Vec<RecordBatch>) -> Result<Vec<u8>> {
pub fn record_batch_to_buffer(batches: Vec<RecordBatch>) -> Result<Vec<u8>> {
if batches.is_empty() {
return Ok(Vec::new());
}

View File

@@ -17,7 +17,7 @@ use neon::types::buffer::TypedArray;
use crate::error::ResultExt;
pub(crate) fn vec_str_to_array<'a, C: Context<'a>>(
pub fn vec_str_to_array<'a, C: Context<'a>>(
vec: &Vec<String>,
cx: &mut C,
) -> JsResult<'a, JsArray> {
@@ -29,7 +29,7 @@ pub(crate) fn vec_str_to_array<'a, C: Context<'a>>(
Ok(a)
}
pub(crate) fn js_array_to_vec(array: &JsArray, cx: &mut FunctionContext) -> Vec<f32> {
pub fn js_array_to_vec(array: &JsArray, cx: &mut FunctionContext) -> Vec<f32> {
let mut query_vec: Vec<f32> = Vec::new();
for i in 0..array.len(cx) {
let entry: Handle<JsNumber> = array.get(cx, i).unwrap();
@@ -39,7 +39,7 @@ pub(crate) fn js_array_to_vec(array: &JsArray, cx: &mut FunctionContext) -> Vec<
}
// Creates a new JsBuffer from a rust buffer with a special logic for electron
pub(crate) fn new_js_buffer<'a>(
pub fn new_js_buffer<'a>(
buffer: Vec<u8>,
cx: &mut TaskContext<'a>,
is_electron: bool,

View File

@@ -18,7 +18,6 @@ use neon::prelude::NeonResult;
use snafu::Snafu;
#[derive(Debug, Snafu)]
#[snafu(visibility(pub(crate)))]
pub enum Error {
#[snafu(display("column '{name}' is missing"))]
MissingColumn { name: String },

View File

@@ -21,7 +21,7 @@ use neon::{
use crate::{error::ResultExt, runtime, table::JsTable};
use vectordb::Table;
pub(crate) fn table_create_scalar_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
pub fn table_create_scalar_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let column = cx.argument::<JsString>(0)?.value(&mut cx);
let replace = cx.argument::<JsBoolean>(1)?.value(&mut cx);

View File

@@ -24,7 +24,7 @@ use crate::neon_ext::js_object_ext::JsObjectExt;
use crate::runtime;
use crate::table::JsTable;
pub(crate) fn table_create_vector_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
pub fn table_create_vector_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let index_params = cx.argument::<JsObject>(0)?;

View File

@@ -13,7 +13,7 @@ use crate::neon_ext::js_object_ext::JsObjectExt;
use crate::table::JsTable;
use crate::{convert, runtime};
pub(crate) struct JsQuery {}
pub struct JsQuery {}
impl JsQuery {
pub(crate) fn js_search(mut cx: FunctionContext) -> JsResult<JsPromise> {

View File

@@ -28,7 +28,7 @@ use vectordb::TableRef;
use crate::error::ResultExt;
use crate::{convert, get_aws_credential_provider, get_aws_region, runtime, JsDatabase};
pub(crate) struct JsTable {
pub struct JsTable {
pub table: TableRef,
}
@@ -36,7 +36,7 @@ impl Finalize for JsTable {}
impl From<TableRef> for JsTable {
fn from(table: TableRef) -> Self {
JsTable { table }
Self { table }
}
}
@@ -85,14 +85,14 @@ impl JsTable {
deferred.settle_with(&channel, move |mut cx| {
let table = table_rst.or_throw(&mut cx)?;
Ok(cx.boxed(JsTable::from(table)))
Ok(cx.boxed(Self::from(table)))
});
});
Ok(promise)
}
pub(crate) fn js_add(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
let buffer = cx.argument::<JsBuffer>(0)?;
let write_mode = cx.argument::<JsString>(1)?.value(&mut cx);
let (batches, schema) =
@@ -125,21 +125,34 @@ impl JsTable {
deferred.settle_with(&channel, move |mut cx| {
add_result.or_throw(&mut cx)?;
Ok(cx.boxed(JsTable::from(table)))
Ok(cx.boxed(Self::from(table)))
});
});
Ok(promise)
}
pub(crate) fn js_count_rows(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
let filter = cx
.argument_opt(0)
.and_then(|filt| {
if filt.is_a::<JsUndefined, _>(&mut cx) || filt.is_a::<JsNull, _>(&mut cx) {
None
} else {
Some(
filt.downcast_or_throw::<JsString, _>(&mut cx)
.map(|js_filt| js_filt.deref().value(&mut cx)),
)
}
})
.transpose()?;
let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise();
let channel = cx.channel();
let table = js_table.table.clone();
rt.spawn(async move {
let num_rows_result = table.count_rows().await;
let num_rows_result = table.count_rows(filter).await;
deferred.settle_with(&channel, move |mut cx| {
let num_rows = num_rows_result.or_throw(&mut cx)?;
@@ -150,7 +163,7 @@ impl JsTable {
}
pub(crate) fn js_delete(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise();
let predicate = cx.argument::<JsString>(0)?.value(&mut cx);
@@ -162,14 +175,14 @@ impl JsTable {
deferred.settle_with(&channel, move |mut cx| {
delete_result.or_throw(&mut cx)?;
Ok(cx.boxed(JsTable::from(table)))
Ok(cx.boxed(Self::from(table)))
})
});
Ok(promise)
}
pub(crate) fn js_merge_insert(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise();
let channel = cx.channel();
@@ -178,28 +191,34 @@ impl JsTable {
let key = cx.argument::<JsString>(0)?.value(&mut cx);
let mut builder = table.merge_insert(&[&key]);
if cx.argument::<JsBoolean>(1)?.value(&mut cx) {
builder.when_matched_update_all();
}
if cx.argument::<JsBoolean>(2)?.value(&mut cx) {
builder.when_not_matched_insert_all();
let filter = cx.argument_opt(2).unwrap();
if filter.is_a::<JsNull, _>(&mut cx) {
builder.when_matched_update_all(None);
} else {
let filter = filter
.downcast_or_throw::<JsString, _>(&mut cx)?
.deref()
.value(&mut cx);
builder.when_matched_update_all(Some(filter));
}
}
if cx.argument::<JsBoolean>(3)?.value(&mut cx) {
if let Some(filter) = cx.argument_opt(4) {
if filter.is_a::<JsNull, _>(&mut cx) {
builder.when_not_matched_by_source_delete(None);
} else {
let filter = filter
.downcast_or_throw::<JsString, _>(&mut cx)?
.deref()
.value(&mut cx);
builder.when_not_matched_by_source_delete(Some(filter));
}
} else {
builder.when_not_matched_insert_all();
}
if cx.argument::<JsBoolean>(4)?.value(&mut cx) {
let filter = cx.argument_opt(5).unwrap();
if filter.is_a::<JsNull, _>(&mut cx) {
builder.when_not_matched_by_source_delete(None);
} else {
let filter = filter
.downcast_or_throw::<JsString, _>(&mut cx)?
.deref()
.value(&mut cx);
builder.when_not_matched_by_source_delete(Some(filter));
}
}
let buffer = cx.argument::<JsBuffer>(5)?;
let buffer = cx.argument::<JsBuffer>(6)?;
let (batches, schema) =
arrow_buffer_to_record_batch(buffer.as_slice(&cx)).or_throw(&mut cx)?;
@@ -209,14 +228,14 @@ impl JsTable {
deferred.settle_with(&channel, move |mut cx| {
merge_insert_result.or_throw(&mut cx)?;
Ok(cx.boxed(JsTable::from(table)))
Ok(cx.boxed(Self::from(table)))
})
});
Ok(promise)
}
pub(crate) fn js_update(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
let table = js_table.table.clone();
let rt = runtime(&mut cx)?;
@@ -275,7 +294,7 @@ impl JsTable {
.await;
deferred.settle_with(&channel, move |mut cx| {
update_result.or_throw(&mut cx)?;
Ok(cx.boxed(JsTable::from(table)))
Ok(cx.boxed(Self::from(table)))
})
});
@@ -283,7 +302,7 @@ impl JsTable {
}
pub(crate) fn js_cleanup(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise();
let table = js_table.table.clone();
@@ -321,7 +340,7 @@ impl JsTable {
let old_versions = cx.number(prune_stats.old_versions as f64);
output_metrics.set(&mut cx, "oldVersions", old_versions)?;
let output_table = cx.boxed(JsTable::from(table));
let output_table = cx.boxed(Self::from(table));
let output = JsObject::new(&mut cx);
output.set(&mut cx, "metrics", output_metrics)?;
@@ -334,7 +353,7 @@ impl JsTable {
}
pub(crate) fn js_compact(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise();
let table = js_table.table.clone();
@@ -393,7 +412,7 @@ impl JsTable {
let files_added = cx.number(stats.files_added as f64);
output_metrics.set(&mut cx, "filesAdded", files_added)?;
let output_table = cx.boxed(JsTable::from(table));
let output_table = cx.boxed(Self::from(table));
let output = JsObject::new(&mut cx);
output.set(&mut cx, "metrics", output_metrics)?;
@@ -406,7 +425,7 @@ impl JsTable {
}
pub(crate) fn js_list_indices(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise();
// let predicate = cx.argument::<JsString>(0)?.value(&mut cx);
@@ -445,7 +464,7 @@ impl JsTable {
}
pub(crate) fn js_index_stats(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise();
let index_uuid = cx.argument::<JsString>(0)?.value(&mut cx);
@@ -493,7 +512,7 @@ impl JsTable {
}
pub(crate) fn js_schema(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise();
let channel = cx.channel();

View File

@@ -1,12 +1,12 @@
[package]
name = "vectordb"
version = "0.4.8"
edition = "2021"
edition.workspace = true
description = "LanceDB: A serverless, low-latency vector database for AI applications"
license = "Apache-2.0"
repository = "https://github.com/lancedb/lancedb"
keywords = ["lancedb", "lance", "database", "search"]
categories = ["database-implementations"]
license.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]

View File

@@ -188,12 +188,12 @@ impl Database {
/// # Returns
///
/// * A [Database] object.
pub async fn connect(uri: &str) -> Result<Database> {
pub async fn connect(uri: &str) -> Result<Self> {
let options = ConnectOptions::new(uri);
Self::connect_with_options(&options).await
}
pub async fn connect_with_options(options: &ConnectOptions) -> Result<Database> {
pub async fn connect_with_options(options: &ConnectOptions) -> Result<Self> {
let uri = &options.uri;
let parse_res = url::Url::parse(uri);
@@ -276,7 +276,7 @@ impl Database {
None => None,
};
Ok(Database {
Ok(Self {
uri: table_base_uri,
query_string,
base_path,
@@ -288,7 +288,7 @@ impl Database {
}
}
async fn open_path(path: &str) -> Result<Database> {
async fn open_path(path: &str) -> Result<Self> {
let (object_store, base_path) = ObjectStore::from_uri(path).await?;
if object_store.is_local() {
Self::try_create_dir(path).context(CreateDirSnafu { path })?;

View File

@@ -69,7 +69,7 @@ pub struct IndexBuilder {
impl IndexBuilder {
pub(crate) fn new(table: Arc<dyn Table>, columns: &[&str]) -> Self {
IndexBuilder {
Self {
table,
columns: columns.iter().map(|c| c.to_string()).collect(),
name: None,
@@ -197,7 +197,7 @@ impl IndexBuilder {
let num_partitions = if let Some(n) = self.num_partitions {
n
} else {
suggested_num_partitions(self.table.count_rows().await?)
suggested_num_partitions(self.table.count_rows(None).await?)
};
let num_sub_vectors: u32 = if let Some(n) = self.num_sub_vectors {
n

View File

@@ -23,13 +23,13 @@ pub struct VectorIndex {
}
impl VectorIndex {
pub fn new_from_format(manifest: &Manifest, index: &Index) -> VectorIndex {
pub fn new_from_format(manifest: &Manifest, index: &Index) -> Self {
let fields = index
.fields
.iter()
.map(|i| manifest.schema.fields[*i as usize].name.clone())
.collect();
VectorIndex {
Self {
columns: fields,
index_name: index.name.clone(),
index_uuid: index.uuid.to_string(),

View File

@@ -372,7 +372,7 @@ mod test {
// leave this here for easy debugging
let t = res.unwrap();
assert_eq!(t.count_rows().await.unwrap(), 100);
assert_eq!(t.count_rows(None).await.unwrap(), 100);
let q = t
.search(&[0.1, 0.1, 0.1, 0.1])

View File

@@ -62,7 +62,7 @@ impl Query {
/// * `dataset` - Lance dataset.
///
pub(crate) fn new(dataset: Arc<Dataset>) -> Self {
Query {
Self {
dataset,
query_vector: None,
column: None,

View File

@@ -27,7 +27,7 @@ use lance::dataset::optimize::{
compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions,
};
pub use lance::dataset::ReadParams;
use lance::dataset::{Dataset, UpdateBuilder, WriteParams};
use lance::dataset::{Dataset, UpdateBuilder, WhenMatched, WriteParams};
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
use lance::io::WrappingObjectStore;
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
@@ -102,7 +102,11 @@ pub trait Table: std::fmt::Display + Send + Sync {
fn schema(&self) -> SchemaRef;
/// Count the number of rows in this dataset.
async fn count_rows(&self) -> Result<usize>;
///
/// # Arguments
///
/// * `filter` if present, only count rows matching the filter
async fn count_rows(&self, filter: Option<String>) -> Result<usize>;
/// Insert new records into this Table
///
@@ -234,7 +238,7 @@ pub trait Table: std::fmt::Display + Send + Sync {
/// schema.clone());
/// // Perform an upsert operation
/// let mut merge_insert = tbl.merge_insert(&["id"]);
/// merge_insert.when_matched_update_all()
/// merge_insert.when_matched_update_all(None)
/// .when_not_matched_insert_all();
/// merge_insert.execute(Box::new(new_data)).await.unwrap();
/// # });
@@ -385,7 +389,7 @@ impl NativeTable {
message: e.to_string(),
},
})?;
Ok(NativeTable {
Ok(Self {
name: name.to_string(),
uri: uri.to_string(),
dataset: Arc::new(Mutex::new(dataset)),
@@ -427,7 +431,7 @@ impl NativeTable {
message: e.to_string(),
},
})?;
Ok(NativeTable {
Ok(Self {
name: name.to_string(),
uri: uri.to_string(),
dataset: Arc::new(Mutex::new(dataset)),
@@ -501,7 +505,7 @@ impl NativeTable {
message: e.to_string(),
},
})?;
Ok(NativeTable {
Ok(Self {
name: name.to_string(),
uri: uri.to_string(),
dataset: Arc::new(Mutex::new(dataset)),
@@ -673,11 +677,14 @@ impl MergeInsert for NativeTable {
) -> Result<()> {
let dataset = Arc::new(self.clone_inner_dataset());
let mut builder = LanceMergeInsertBuilder::try_new(dataset.clone(), params.on)?;
if params.when_matched_update_all {
builder.when_matched(lance::dataset::WhenMatched::UpdateAll);
} else {
builder.when_matched(lance::dataset::WhenMatched::DoNothing);
}
match (
params.when_matched_update_all,
params.when_matched_update_all_filt,
) {
(false, _) => builder.when_matched(WhenMatched::DoNothing),
(true, None) => builder.when_matched(WhenMatched::UpdateAll),
(true, Some(filt)) => builder.when_matched(WhenMatched::update_if(&dataset, &filt)?),
};
if params.when_not_matched_insert_all {
builder.when_not_matched(lance::dataset::WhenNotMatched::InsertAll);
} else {
@@ -719,9 +726,15 @@ impl Table for NativeTable {
Arc::new(Schema::from(&lance_schema))
}
async fn count_rows(&self) -> Result<usize> {
async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
let dataset = { self.dataset.lock().expect("lock poison").clone() };
Ok(dataset.count_rows().await?)
if let Some(filter) = filter {
let mut scanner = dataset.scan();
scanner.filter(&filter)?;
Ok(scanner.count_rows().await? as usize)
} else {
Ok(dataset.count_rows().await?)
}
}
async fn add(
@@ -814,6 +827,7 @@ impl Table for NativeTable {
#[cfg(test)]
mod tests {
use std::iter;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
@@ -886,6 +900,23 @@ mod tests {
));
}
#[tokio::test]
async fn test_count_rows() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let batches = make_test_batches();
let table = NativeTable::create(&uri, "test", batches, None, None)
.await
.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 10);
assert_eq!(
table.count_rows(Some("i >= 5".to_string())).await.unwrap(),
5
);
}
#[tokio::test]
async fn test_add() {
let tmp_dir = tempdir().unwrap();
@@ -896,7 +927,7 @@ mod tests {
let table = NativeTable::create(&uri, "test", batches, None, None)
.await
.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
assert_eq!(table.count_rows(None).await.unwrap(), 10);
let new_batches = RecordBatchIterator::new(
vec![RecordBatch::try_new(
@@ -910,7 +941,7 @@ mod tests {
);
table.add(Box::new(new_batches), None).await.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 20);
assert_eq!(table.count_rows(None).await.unwrap(), 20);
assert_eq!(table.name, "test");
}
@@ -920,30 +951,44 @@ mod tests {
let uri = tmp_dir.path().to_str().unwrap();
// Create a dataset with i=0..10
let batches = make_test_batches_with_offset(0);
let batches = merge_insert_test_batches(0, 0);
let table = NativeTable::create(&uri, "test", batches, None, None)
.await
.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
assert_eq!(table.count_rows(None).await.unwrap(), 10);
// Create new data with i=5..15
let new_batches = Box::new(make_test_batches_with_offset(5));
let new_batches = Box::new(merge_insert_test_batches(5, 1));
// Perform a "insert if not exists"
let mut merge_insert_builder = table.merge_insert(&["i"]);
merge_insert_builder.when_not_matched_insert_all();
merge_insert_builder.execute(new_batches).await.unwrap();
// Only 5 rows should actually be inserted
assert_eq!(table.count_rows().await.unwrap(), 15);
assert_eq!(table.count_rows(None).await.unwrap(), 15);
// Create new data with i=15..25 (no id matches)
let new_batches = Box::new(make_test_batches_with_offset(15));
let new_batches = Box::new(merge_insert_test_batches(15, 2));
// Perform a "bulk update" (should not affect anything)
let mut merge_insert_builder = table.merge_insert(&["i"]);
merge_insert_builder.when_matched_update_all();
merge_insert_builder.when_matched_update_all(None);
merge_insert_builder.execute(new_batches).await.unwrap();
// No new rows should have been inserted
assert_eq!(table.count_rows().await.unwrap(), 15);
assert_eq!(table.count_rows(None).await.unwrap(), 15);
assert_eq!(
table.count_rows(Some("age = 2".to_string())).await.unwrap(),
0
);
// Conditional update that only replaces the age=0 data
let new_batches = Box::new(merge_insert_test_batches(5, 3));
let mut merge_insert_builder = table.merge_insert(&["i"]);
merge_insert_builder.when_matched_update_all(Some("target.age = 0".to_string()));
merge_insert_builder.execute(new_batches).await.unwrap();
assert_eq!(
table.count_rows(Some("age = 3".to_string())).await.unwrap(),
5
);
}
#[tokio::test]
@@ -956,7 +1001,7 @@ mod tests {
let table = NativeTable::create(uri, "test", batches, None, None)
.await
.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
assert_eq!(table.count_rows(None).await.unwrap(), 10);
let new_batches = RecordBatchIterator::new(
vec![RecordBatch::try_new(
@@ -975,7 +1020,7 @@ mod tests {
};
table.add(Box::new(new_batches), Some(param)).await.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
assert_eq!(table.count_rows(None).await.unwrap(), 10);
assert_eq!(table.name, "test");
}
@@ -1292,23 +1337,35 @@ mod tests {
assert!(wrapper.called());
}
fn make_test_batches_with_offset(
fn merge_insert_test_batches(
offset: i32,
age: i32,
) -> impl RecordBatchReader + Send + Sync + 'static {
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
let schema = Arc::new(Schema::new(vec![
Field::new("i", DataType::Int32, false),
Field::new("age", DataType::Int32, false),
]));
RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(
offset..(offset + 10),
))],
vec![
Arc::new(Int32Array::from_iter_values(offset..(offset + 10))),
Arc::new(Int32Array::from_iter_values(iter::repeat(age).take(10))),
],
)],
schema,
)
}
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
make_test_batches_with_offset(0)
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..10))],
)],
schema,
)
}
#[tokio::test]
@@ -1365,7 +1422,7 @@ mod tests {
.unwrap();
assert_eq!(table.load_indices().await.unwrap().len(), 1);
assert_eq!(table.count_rows().await.unwrap(), 512);
assert_eq!(table.count_rows(None).await.unwrap(), 512);
assert_eq!(table.name, "test");
let indices = table.load_indices().await.unwrap();

View File

@@ -35,6 +35,7 @@ pub struct MergeInsertBuilder {
table: Arc<dyn MergeInsert>,
pub(super) on: Vec<String>,
pub(super) when_matched_update_all: bool,
pub(super) when_matched_update_all_filt: Option<String>,
pub(super) when_not_matched_insert_all: bool,
pub(super) when_not_matched_by_source_delete: bool,
pub(super) when_not_matched_by_source_delete_filt: Option<String>,
@@ -46,6 +47,7 @@ impl MergeInsertBuilder {
table,
on,
when_matched_update_all: false,
when_matched_update_all_filt: None,
when_not_matched_insert_all: false,
when_not_matched_by_source_delete: false,
when_not_matched_by_source_delete_filt: None,
@@ -59,8 +61,22 @@ impl MergeInsertBuilder {
/// If there are multiple matches then the behavior is undefined.
/// Currently this causes multiple copies of the row to be created
/// but that behavior is subject to change.
pub fn when_matched_update_all(&mut self) -> &mut Self {
///
/// An optional condition may be specified. If it is, then only
/// matched rows that satisfy the condtion will be updated. Any
/// rows that do not satisfy the condition will be left as they
/// are. Failing to satisfy the condition does not cause a
/// "matched row" to become a "not matched" row.
///
/// The condition should be an SQL string. Use the prefix
/// target. to refer to rows in the target table (old data)
/// and the prefix source. to refer to rows in the source
/// table (new data).
///
/// For example, "target.last_update < source.last_update"
pub fn when_matched_update_all(&mut self, condition: Option<String>) -> &mut Self {
self.when_matched_update_all = true;
self.when_matched_update_all_filt = condition;
self
}