mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
Compare commits
16 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
82936c77ef | ||
|
|
dddcddcaf9 | ||
|
|
a9727eb318 | ||
|
|
48d55bf952 | ||
|
|
d2e71c8b08 | ||
|
|
f53aace89c | ||
|
|
d982ee934a | ||
|
|
57605a2d86 | ||
|
|
738511c5f2 | ||
|
|
0b0f42537e | ||
|
|
e412194008 | ||
|
|
a9088224c5 | ||
|
|
688c57a0d8 | ||
|
|
12a98deded | ||
|
|
e4bb042918 | ||
|
|
04e1662681 |
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.4.7
|
||||
current_version = 0.4.8
|
||||
commit = True
|
||||
message = Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
34
.cargo/config.toml
Normal file
34
.cargo/config.toml
Normal file
@@ -0,0 +1,34 @@
|
||||
[profile.release]
|
||||
lto = "fat"
|
||||
codegen-units = 1
|
||||
|
||||
[profile.release-with-debug]
|
||||
inherits = "release"
|
||||
debug = true
|
||||
# Prioritize compile time over runtime performance
|
||||
codegen-units = 16
|
||||
lto = "thin"
|
||||
|
||||
[target.'cfg(all())']
|
||||
rustflags = [
|
||||
"-Wclippy::all",
|
||||
"-Wclippy::style",
|
||||
"-Wclippy::fallible_impl_from",
|
||||
"-Wclippy::manual_let_else",
|
||||
"-Wclippy::redundant_pub_crate",
|
||||
"-Wclippy::string_add_assign",
|
||||
"-Wclippy::string_add",
|
||||
"-Wclippy::string_lit_as_bytes",
|
||||
"-Wclippy::string_to_string",
|
||||
"-Wclippy::use_self",
|
||||
"-Dclippy::cargo",
|
||||
"-Dclippy::dbg_macro",
|
||||
# not too much we can do to avoid multiple crate versions
|
||||
"-Aclippy::multiple-crate-versions",
|
||||
]
|
||||
|
||||
[target.x86_64-unknown-linux-gnu]
|
||||
rustflags = ["-C", "target-cpu=haswell", "-C", "target-feature=+avx2,+fma,+f16c"]
|
||||
|
||||
[target.aarch64-apple-darwin]
|
||||
rustflags = ["-C", "target-cpu=apple-m1", "-C", "target-feature=+neon,+fp16,+fhm,+dotprod"]
|
||||
9
.github/workflows/docs_test.yml
vendored
9
.github/workflows/docs_test.yml
vendored
@@ -49,6 +49,9 @@ jobs:
|
||||
test-node:
|
||||
name: Test doc nodejs code
|
||||
runs-on: "ubuntu-latest"
|
||||
timeout-minutes: 45
|
||||
strategy:
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -66,6 +69,12 @@ jobs:
|
||||
uses: swatinem/rust-cache@v2
|
||||
- name: Install node dependencies
|
||||
run: |
|
||||
sudo swapoff -a
|
||||
sudo fallocate -l 8G /swapfile
|
||||
sudo chmod 600 /swapfile
|
||||
sudo mkswap /swapfile
|
||||
sudo swapon /swapfile
|
||||
sudo swapon --show
|
||||
cd node
|
||||
npm ci
|
||||
npm run build-release
|
||||
|
||||
13
Cargo.toml
13
Cargo.toml
@@ -6,15 +6,18 @@ resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
edition = "2021"
|
||||
authors = ["Lance Devs <dev@lancedb.com>"]
|
||||
authors = ["LanceDB Devs <dev@lancedb.com>"]
|
||||
license = "Apache-2.0"
|
||||
repository = "https://github.com/lancedb/lancedb"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
keywords = ["lancedb", "lance", "database", "vector", "search"]
|
||||
categories = ["database-implementations"]
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.9.12", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.9.12" }
|
||||
lance-linalg = { "version" = "=0.9.12" }
|
||||
lance-testing = { "version" = "=0.9.12" }
|
||||
lance = { "version" = "=0.9.15", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.9.15" }
|
||||
lance-linalg = { "version" = "=0.9.15" }
|
||||
lance-testing = { "version" = "=0.9.15" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "50.0", optional = false }
|
||||
arrow-array = "50.0"
|
||||
|
||||
@@ -69,3 +69,19 @@ MinIO supports an S3 compatible API. In order to connect to a MinIO instance, yo
|
||||
- Set the envvar `AWS_ENDPOINT` to the URL of your MinIO API
|
||||
- Set the envvars `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` with your MinIO credential
|
||||
- Call `lancedb.connect("s3://minio_bucket_name")`
|
||||
|
||||
### Where can I find benchmarks for LanceDB?
|
||||
|
||||
Refer to this [post](https://blog.lancedb.com/benchmarking-lancedb-92b01032874a) for recent benchmarks.
|
||||
|
||||
### How much data can LanceDB practically manage without effecting performance?
|
||||
|
||||
We target good performance on ~10-50 billion rows and ~10-30 TB of data.
|
||||
|
||||
### Does LanceDB support concurrent operations?
|
||||
|
||||
LanceDB can handle concurrent reads very well, and can scale horizontally. The main constraint is how well the [storage layer](https://lancedb.github.io/lancedb/concepts/storage/) you've chosen scales. For writes, we support concurrent writing, though too many concurrent writers can lead to failing writes as there is a limited number of times a writer retries a commit
|
||||
|
||||
!!! info "Multiprocessing with LanceDB"
|
||||
|
||||
For multiprocessing you should probably not use ```fork``` as lance is multi-threaded internally and ```fork``` and multi-thread do not work well.[Refer to this discussion](https://discuss.python.org/t/concerns-regarding-deprecation-of-fork-with-alive-threads/33555)
|
||||
|
||||
@@ -6,17 +6,24 @@ LanceDB supports both semantic and keyword-based search. In real world applicati
|
||||
You can perform hybrid search in LanceDB by combining the results of semantic and full-text search via a reranking algorithm of your choice. LanceDB provides multiple rerankers out of the box. However, you can always write a custom reranker if your use case need more sophisticated logic .
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
import lancedb
|
||||
import openai
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydanatic import LanceModel, Vector
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
|
||||
# Ingest embedding function in LanceDB table
|
||||
# Configuring the environment variable OPENAI_API_KEY
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
# OR set the key here as a variable
|
||||
openai.api_key = "sk-..."
|
||||
embeddings = get_registry().get("openai").create()
|
||||
|
||||
class Documents(LanceModel):
|
||||
vector: Vector(embeddings.ndims) = embeddings.VectorField()
|
||||
vector: Vector(embeddings.ndims()) = embeddings.VectorField()
|
||||
text: str = embeddings.SourceField()
|
||||
|
||||
table = db.create_table("documents", schema=Documents)
|
||||
@@ -31,17 +38,19 @@ data = [
|
||||
# ingest docs with auto-vectorization
|
||||
table.add(data)
|
||||
|
||||
# Create a fts index before the hybrid search
|
||||
table.create_fts_index("text")
|
||||
# hybrid search with default re-ranker
|
||||
results = table.search("flower moon", query_type="hybrid").to_pandas()
|
||||
```
|
||||
|
||||
By default, LanceDB uses `LinearCombinationReranker(weights=0.7)` to combine and rerank the results of semantic and full-text search. You can customize the hyperparameters as needed or write your own custom reranker. Here's how you can use any of the available rerankers:
|
||||
By default, LanceDB uses `LinearCombinationReranker(weight=0.7)` to combine and rerank the results of semantic and full-text search. You can customize the hyperparameters as needed or write your own custom reranker. Here's how you can use any of the available rerankers:
|
||||
|
||||
|
||||
### `rerank()` arguments
|
||||
* `normalize`: `str`, default `"score"`:
|
||||
The method to normalize the scores. Can be "rank" or "score". If "rank", the scores are converted to ranks and then normalized. If "score", the scores are normalized directly.
|
||||
* `reranker`: `Reranker`, default `LinearCombinationReranker(weights=0.7)`.
|
||||
* `reranker`: `Reranker`, default `LinearCombinationReranker(weight=0.7)`.
|
||||
The reranker to use. If not specified, the default reranker is used.
|
||||
|
||||
|
||||
@@ -55,7 +64,7 @@ This is the default re-ranker used by LanceDB. It combines the results of semant
|
||||
```python
|
||||
from lancedb.rerankers import LinearCombinationReranker
|
||||
|
||||
reranker = LinearCombinationReranker(weights=0.3) # Use 0.3 as the weight for vector search
|
||||
reranker = LinearCombinationReranker(weight=0.3) # Use 0.3 as the weight for vector search
|
||||
|
||||
results = table.search("rebel", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||
```
|
||||
@@ -121,6 +130,60 @@ Arguments
|
||||
Only returns `_relevance_score`. Does not support `return_score = "all"`.
|
||||
|
||||
|
||||
### ColBERT Reranker
|
||||
This reranker uses the ColBERT model to combine the results of semantic and full-text search. You can use it by passing `ColbertrReranker()` to the `rerank()` method.
|
||||
|
||||
ColBERT reranker model calculates relevance of given docs against the query and don't take existing fts and vector search scores into account, so it currently only supports `return_score="relevance"`. By default, it looks for `text` column to rerank the results. But you can specify the column name to use as input to the cross encoder model as described below.
|
||||
|
||||
```python
|
||||
from lancedb.rerankers import ColbertReranker
|
||||
|
||||
reranker = ColbertReranker()
|
||||
|
||||
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||
```
|
||||
|
||||
Arguments
|
||||
----------------
|
||||
* `model_name` : `str`, default `"colbert-ir/colbertv2.0"`
|
||||
The name of the cross encoder model to use.
|
||||
* `column` : `str`, default `"text"`
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
* `return_score` : `str`, default `"relevance"`
|
||||
options are `"relevance"` or `"all"`. Only `"relevance"` is supported for now.
|
||||
|
||||
!!! Note
|
||||
Only returns `_relevance_score`. Does not support `return_score = "all"`.
|
||||
|
||||
### OpenAI Reranker
|
||||
This reranker uses the OpenAI API to combine the results of semantic and full-text search. You can use it by passing `OpenaiReranker()` to the `rerank()` method.
|
||||
|
||||
!!! Note
|
||||
This prompts chat model to rerank results which is not a dedicated reranker model. This should be treated as experimental.
|
||||
|
||||
!!! Tip
|
||||
You might run out of token limit so set the search `limits` based on your token limit.
|
||||
|
||||
```python
|
||||
from lancedb.rerankers import OpenaiReranker
|
||||
|
||||
reranker = OpenaiReranker()
|
||||
|
||||
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||
```
|
||||
|
||||
Arguments
|
||||
----------------
|
||||
`model_name` : `str`, default `"gpt-3.5-turbo-1106"`
|
||||
The name of the cross encoder model to use.
|
||||
`column` : `str`, default `"text"`
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
`return_score` : `str`, default `"relevance"`
|
||||
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||
`api_key` : `str`, default `None`
|
||||
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
|
||||
|
||||
|
||||
## Building Custom Rerankers
|
||||
You can build your own custom reranker by subclassing the `Reranker` class and implementing the `rerank_hybrid()` method. Here's an example of a custom reranker that combines the results of semantic and full-text search using a linear combination of the scores.
|
||||
|
||||
@@ -137,7 +200,7 @@ class MyReranker(Reranker):
|
||||
self.param1 = param1
|
||||
self.param2 = param2
|
||||
|
||||
def rerank_hybrid(self, vector_results: pa.Table, fts_results: pa.Table):
|
||||
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table):
|
||||
# Use the built-in merging function
|
||||
combined_result = self.merge_results(vector_results, fts_results)
|
||||
|
||||
@@ -159,7 +222,7 @@ import pyarrow as pa
|
||||
class MyReranker(Reranker):
|
||||
...
|
||||
|
||||
def rerank_hybrid(self, vector_results: pa.Table, fts_results: pa.Table, filter: str):
|
||||
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table, filter: str):
|
||||
# Use the built-in merging function
|
||||
combined_result = self.merge_results(vector_results, fts_results)
|
||||
|
||||
|
||||
44
node/package-lock.json
generated
44
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.4.7",
|
||||
"version": "0.4.8",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "vectordb",
|
||||
"version": "0.4.7",
|
||||
"version": "0.4.8",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
@@ -53,11 +53,11 @@
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.4.7",
|
||||
"@lancedb/vectordb-darwin-x64": "0.4.7",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.7",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.7",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.7"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.4.8",
|
||||
"@lancedb/vectordb-darwin-x64": "0.4.8",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.8",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.8",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.8"
|
||||
}
|
||||
},
|
||||
"node_modules/@75lb/deep-merge": {
|
||||
@@ -329,9 +329,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||
"version": "0.4.7",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.4.7.tgz",
|
||||
"integrity": "sha512-kACOIytgjBfX8NRwjPKe311XRN3lbSN13B7avT5htMd3kYm3AnnMag9tZhlwoO7lIuvGaXhy7mApygJrjhfJ4g==",
|
||||
"version": "0.4.8",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.4.8.tgz",
|
||||
"integrity": "sha512-FpnJaw7KmNdD/FtOw9AcmPL5P+L04AcnfPj9ZyEjN8iCwB/qaOGYgdfBv+EbEtfHIsqA12q/1BRduu9KdB6BIA==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -341,9 +341,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||
"version": "0.4.7",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.4.7.tgz",
|
||||
"integrity": "sha512-vb74iK5uPWCwz5E60r3yWp/R/HSg54/Z9AZWYckYXqsPv4w/nfbkM5iZhfRqqR/9uE6JClWJKOtjbk7b8CFRFg==",
|
||||
"version": "0.4.8",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.4.8.tgz",
|
||||
"integrity": "sha512-RafOEYyZIgphp8wPGuVLFaTc8aAqo0NCO1LQMx0mB0xV96vrdo0Mooivs+dYN3RFfSHtTKPw9O1Jc957Vp1TLg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -353,9 +353,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||
"version": "0.4.7",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.4.7.tgz",
|
||||
"integrity": "sha512-jHp7THm6S9sB8RaCxGoZXLAwGAUHnawUUilB1K3mvQsRdfB2bBs0f7wDehW+PDhr+Iog4LshaWbcnoQEUJWR+Q==",
|
||||
"version": "0.4.8",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.4.8.tgz",
|
||||
"integrity": "sha512-WlbYNfj4+v1hBHUluF+hnlG/A0ZaQFdXBTGDfHQniL11o+n3emWm4ujP5nSAoQHXjSH9DaOTGr/N4Mc9Xe+luw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -365,9 +365,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||
"version": "0.4.7",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.4.7.tgz",
|
||||
"integrity": "sha512-LKbVe6Wrp/AGqCCjKliNDmYoeTNgY/wfb2DTLjrx41Jko/04ywLrJ6xSEAn3XD5RDCO5u3fyUdXHHHv5a3VAAQ==",
|
||||
"version": "0.4.8",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.4.8.tgz",
|
||||
"integrity": "sha512-z+qFJrDqnNEv4JcwYDyt51PHmWjuM/XaOlSjpBnyyuUImeY+QcwctMuyXt8+Q4zhuqQR1AhLKrMwCU+YmMfk5g==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -377,9 +377,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||
"version": "0.4.7",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.4.7.tgz",
|
||||
"integrity": "sha512-C5ln4+wafeY1Sm4PeV0Ios9lUaQVVip5Mjl9XU7ngioSEMEuXI/XMVfIdVfDPppVNXPeQxg33wLA272uw88D1Q==",
|
||||
"version": "0.4.8",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.4.8.tgz",
|
||||
"integrity": "sha512-VjUryVvEA04r0j4lU9pJy84cmjuQm1GhBzbPc8kwbn5voT4A6BPglrlNsU0Zc+j8Fbjyvauzw2lMEcMsF4F0rw==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.4.7",
|
||||
"version": "0.4.8",
|
||||
"description": " Serverless, low-latency vector database for AI applications",
|
||||
"main": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
@@ -85,10 +85,10 @@
|
||||
}
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.4.7",
|
||||
"@lancedb/vectordb-darwin-x64": "0.4.7",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.7",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.7",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.7"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.4.8",
|
||||
"@lancedb/vectordb-darwin-x64": "0.4.8",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.8",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.8",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.8"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -372,7 +372,7 @@ export interface Table<T = number[]> {
|
||||
/**
|
||||
* Returns the number of rows in this table.
|
||||
*/
|
||||
countRows: () => Promise<number>
|
||||
countRows: (filter?: string) => Promise<number>
|
||||
|
||||
/**
|
||||
* Delete rows from this table.
|
||||
@@ -525,8 +525,19 @@ export interface MergeInsertArgs {
|
||||
* If there are multiple matches then the behavior is undefined.
|
||||
* Currently this causes multiple copies of the row to be created
|
||||
* but that behavior is subject to change.
|
||||
*
|
||||
* Optionally, a filter can be specified. This should be an SQL
|
||||
* filter where fields with the prefix "target." refer to fields
|
||||
* in the target table (old data) and fields with the prefix
|
||||
* "source." refer to fields in the source table (new data). For
|
||||
* example, the filter "target.lastUpdated < source.lastUpdated" will
|
||||
* only update matched rows when the incoming `lastUpdated` value is
|
||||
* newer.
|
||||
*
|
||||
* Rows that do not match the filter will not be updated. Rows that
|
||||
* do not match the filter do become "not matched" rows.
|
||||
*/
|
||||
whenMatchedUpdateAll?: boolean
|
||||
whenMatchedUpdateAll?: string | boolean
|
||||
/**
|
||||
* If true then rows that exist only in the source table (new data)
|
||||
* will be inserted into the target table.
|
||||
@@ -840,8 +851,8 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
/**
|
||||
* Returns the number of rows in this table.
|
||||
*/
|
||||
async countRows (): Promise<number> {
|
||||
return tableCountRows.call(this._tbl)
|
||||
async countRows (filter?: string): Promise<number> {
|
||||
return tableCountRows.call(this._tbl, filter)
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -885,7 +896,14 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
}
|
||||
|
||||
async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
|
||||
const whenMatchedUpdateAll = args.whenMatchedUpdateAll ?? false
|
||||
let whenMatchedUpdateAll = false
|
||||
let whenMatchedUpdateAllFilt = null
|
||||
if (args.whenMatchedUpdateAll !== undefined && args.whenMatchedUpdateAll !== null) {
|
||||
whenMatchedUpdateAll = true
|
||||
if (args.whenMatchedUpdateAll !== true) {
|
||||
whenMatchedUpdateAllFilt = args.whenMatchedUpdateAll
|
||||
}
|
||||
}
|
||||
const whenNotMatchedInsertAll = args.whenNotMatchedInsertAll ?? false
|
||||
let whenNotMatchedBySourceDelete = false
|
||||
let whenNotMatchedBySourceDeleteFilt = null
|
||||
@@ -909,6 +927,7 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
this._tbl,
|
||||
on,
|
||||
whenMatchedUpdateAll,
|
||||
whenMatchedUpdateAllFilt,
|
||||
whenNotMatchedInsertAll,
|
||||
whenNotMatchedBySourceDelete,
|
||||
whenNotMatchedBySourceDeleteFilt,
|
||||
|
||||
@@ -286,8 +286,11 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
const queryParams: any = {
|
||||
on
|
||||
}
|
||||
if (args.whenMatchedUpdateAll ?? false) {
|
||||
if (args.whenMatchedUpdateAll !== false && args.whenMatchedUpdateAll !== null && args.whenMatchedUpdateAll !== undefined) {
|
||||
queryParams.when_matched_update_all = 'true'
|
||||
if (typeof args.whenMatchedUpdateAll === 'string') {
|
||||
queryParams.when_matched_update_all_filt = args.whenMatchedUpdateAll
|
||||
}
|
||||
} else {
|
||||
queryParams.when_matched_update_all = 'false'
|
||||
}
|
||||
|
||||
@@ -294,6 +294,7 @@ describe('LanceDB client', function () {
|
||||
})
|
||||
assert.equal(table.name, 'vectors')
|
||||
assert.equal(await table.countRows(), 10)
|
||||
assert.equal(await table.countRows('vector IS NULL'), 0)
|
||||
assert.deepEqual(await con.tableNames(), ['vectors'])
|
||||
})
|
||||
|
||||
@@ -369,6 +370,7 @@ describe('LanceDB client', function () {
|
||||
const table = await con.createTable('f16', data)
|
||||
assert.equal(table.name, 'f16')
|
||||
assert.equal(await table.countRows(), total)
|
||||
assert.equal(await table.countRows('id < 5'), 5)
|
||||
assert.deepEqual(await con.tableNames(), ['f16'])
|
||||
assert.deepEqual(await table.schema, schema)
|
||||
|
||||
@@ -538,26 +540,36 @@ describe('LanceDB client', function () {
|
||||
const data = [{ id: 1, age: 1 }, { id: 2, age: 1 }]
|
||||
const table = await con.createTable('my_table', data)
|
||||
|
||||
// insert if not exists
|
||||
let newData = [{ id: 2, age: 2 }, { id: 3, age: 2 }]
|
||||
await table.mergeInsert('id', newData, {
|
||||
whenNotMatchedInsertAll: true
|
||||
})
|
||||
assert.equal(await table.countRows(), 3)
|
||||
assert.equal((await table.filter('age = 2').execute()).length, 1)
|
||||
assert.equal(await table.countRows('age = 2'), 1)
|
||||
|
||||
newData = [{ id: 3, age: 3 }, { id: 4, age: 3 }]
|
||||
// conditional update
|
||||
newData = [{ id: 2, age: 3 }, { id: 3, age: 3 }]
|
||||
await table.mergeInsert('id', newData, {
|
||||
whenMatchedUpdateAll: 'target.age = 1'
|
||||
})
|
||||
assert.equal(await table.countRows(), 3)
|
||||
assert.equal(await table.countRows('age = 1'), 1)
|
||||
assert.equal(await table.countRows('age = 3'), 1)
|
||||
|
||||
newData = [{ id: 3, age: 4 }, { id: 4, age: 4 }]
|
||||
await table.mergeInsert('id', newData, {
|
||||
whenNotMatchedInsertAll: true,
|
||||
whenMatchedUpdateAll: true
|
||||
})
|
||||
assert.equal(await table.countRows(), 4)
|
||||
assert.equal((await table.filter('age = 3').execute()).length, 2)
|
||||
assert.equal((await table.filter('age = 4').execute()).length, 2)
|
||||
|
||||
newData = [{ id: 5, age: 4 }]
|
||||
newData = [{ id: 5, age: 5 }]
|
||||
await table.mergeInsert('id', newData, {
|
||||
whenNotMatchedInsertAll: true,
|
||||
whenMatchedUpdateAll: true,
|
||||
whenNotMatchedBySourceDelete: 'age < 3'
|
||||
whenNotMatchedBySourceDelete: 'age < 4'
|
||||
})
|
||||
assert.equal(await table.countRows(), 3)
|
||||
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
[package]
|
||||
name = "vectordb-nodejs"
|
||||
edition = "2021"
|
||||
edition.workspace = true
|
||||
version = "0.0.0"
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib"]
|
||||
@@ -14,15 +17,11 @@ futures.workspace = true
|
||||
lance-linalg.workspace = true
|
||||
lance.workspace = true
|
||||
vectordb = { path = "../rust/vectordb" }
|
||||
napi = { version = "2.14", default-features = false, features = [
|
||||
napi = { version = "2.15", default-features = false, features = [
|
||||
"napi7",
|
||||
"async"
|
||||
] }
|
||||
napi-derive = "2.14"
|
||||
napi-derive = "2"
|
||||
|
||||
[build-dependencies]
|
||||
napi-build = "2.1"
|
||||
|
||||
[profile.release]
|
||||
lto = true
|
||||
strip = "symbols"
|
||||
|
||||
@@ -2,4 +2,6 @@
|
||||
module.exports = {
|
||||
preset: 'ts-jest',
|
||||
testEnvironment: 'node',
|
||||
};
|
||||
moduleDirectories: ["node_modules", "./dist"],
|
||||
moduleFileExtensions: ["js", "ts"],
|
||||
};
|
||||
|
||||
@@ -57,8 +57,8 @@ impl Table {
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async fn count_rows(&self) -> napi::Result<usize> {
|
||||
self.table.count_rows().await.map_err(|e| {
|
||||
pub async fn count_rows(&self, filter: Option<String>) -> napi::Result<usize> {
|
||||
self.table.count_rows(filter).await.map_err(|e| {
|
||||
napi::Error::from_reason(format!(
|
||||
"Failed to count rows in table {}: {}",
|
||||
self.table, e
|
||||
|
||||
2
nodejs/vectordb/native.d.ts
vendored
2
nodejs/vectordb/native.d.ts
vendored
@@ -73,7 +73,7 @@ export class Table {
|
||||
/** Return Schema as empty Arrow IPC file. */
|
||||
schema(): Buffer
|
||||
add(buf: Buffer): Promise<void>
|
||||
countRows(): Promise<bigint>
|
||||
countRows(filter?: string): Promise<bigint>
|
||||
delete(predicate: string): Promise<void>
|
||||
createIndex(): IndexBuilder
|
||||
query(): Query
|
||||
|
||||
@@ -50,8 +50,8 @@ export class Table {
|
||||
}
|
||||
|
||||
/** Count the total number of rows in the dataset. */
|
||||
async countRows(): Promise<bigint> {
|
||||
return await this.inner.countRows();
|
||||
async countRows(filter?: string): Promise<bigint> {
|
||||
return await this.inner.countRows(filter);
|
||||
}
|
||||
|
||||
/** Delete the rows that satisfy the predicate. */
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.5.2
|
||||
current_version = 0.5.4
|
||||
commit = True
|
||||
message = [python] Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
@@ -42,6 +42,12 @@ To run the unit tests:
|
||||
pytest
|
||||
```
|
||||
|
||||
To run the doc tests:
|
||||
|
||||
```bash
|
||||
pytest --doctest-modules lancedb
|
||||
```
|
||||
|
||||
To run linter and automatically fix all errors:
|
||||
|
||||
```bash
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
|
||||
import importlib.metadata
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
__version__ = importlib.metadata.version("lancedb")
|
||||
@@ -30,6 +31,7 @@ def connect(
|
||||
api_key: Optional[str] = None,
|
||||
region: str = "us-east-1",
|
||||
host_override: Optional[str] = None,
|
||||
read_consistency_interval: Optional[timedelta] = None,
|
||||
) -> DBConnection:
|
||||
"""Connect to a LanceDB database.
|
||||
|
||||
@@ -45,6 +47,18 @@ def connect(
|
||||
The region to use for LanceDB Cloud.
|
||||
host_override: str, optional
|
||||
The override url for LanceDB Cloud.
|
||||
read_consistency_interval: timedelta, default None
|
||||
(For LanceDB OSS only)
|
||||
The interval at which to check for updates to the table from other
|
||||
processes. If None, then consistency is not checked. For performance
|
||||
reasons, this is the default. For strong consistency, set this to
|
||||
zero seconds. Then every read will check for updates from other
|
||||
processes. As a compromise, you can set this to a non-zero timedelta
|
||||
for eventual consistency. If more than that interval has passed since
|
||||
the last check, then the table will be checked for updates. Note: this
|
||||
consistency only applies to read operations. Write operations are
|
||||
always consistent.
|
||||
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -73,4 +87,4 @@ def connect(
|
||||
if api_key is None:
|
||||
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
|
||||
return RemoteDBConnection(uri, api_key, region, host_override)
|
||||
return LanceDBConnection(uri)
|
||||
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)
|
||||
|
||||
@@ -16,9 +16,9 @@ from typing import Iterable, List, Union
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
from .util import safe_import
|
||||
from .util import safe_import_pandas
|
||||
|
||||
pd = safe_import("pandas")
|
||||
pd = safe_import_pandas()
|
||||
|
||||
DATA = Union[List[dict], dict, "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]]
|
||||
VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray]
|
||||
|
||||
@@ -16,9 +16,9 @@ import deprecation
|
||||
|
||||
from . import __version__
|
||||
from .exceptions import MissingColumnError, MissingValueError
|
||||
from .util import safe_import
|
||||
from .util import safe_import_pandas
|
||||
|
||||
pd = safe_import("pandas")
|
||||
pd = safe_import_pandas()
|
||||
|
||||
|
||||
def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
|
||||
|
||||
@@ -26,6 +26,8 @@ from .table import LanceTable, Table
|
||||
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
|
||||
from .common import DATA, URI
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
from .pydantic import LanceModel
|
||||
@@ -118,7 +120,7 @@ class DBConnection(EnforceOverrides):
|
||||
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
||||
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
||||
>>> db.create_table("my_table", data)
|
||||
LanceTable(my_table)
|
||||
LanceTable(connection=..., name="my_table")
|
||||
>>> db["my_table"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
@@ -139,7 +141,7 @@ class DBConnection(EnforceOverrides):
|
||||
... "long": [-122.7, -74.1]
|
||||
... })
|
||||
>>> db.create_table("table2", data)
|
||||
LanceTable(table2)
|
||||
LanceTable(connection=..., name="table2")
|
||||
>>> db["table2"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
@@ -161,7 +163,7 @@ class DBConnection(EnforceOverrides):
|
||||
... pa.field("long", pa.float32())
|
||||
... ])
|
||||
>>> db.create_table("table3", data, schema = custom_schema)
|
||||
LanceTable(table3)
|
||||
LanceTable(connection=..., name="table3")
|
||||
>>> db["table3"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
@@ -195,7 +197,7 @@ class DBConnection(EnforceOverrides):
|
||||
... pa.field("price", pa.float32()),
|
||||
... ])
|
||||
>>> db.create_table("table4", make_batches(), schema=schema)
|
||||
LanceTable(table4)
|
||||
LanceTable(connection=..., name="table4")
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -243,6 +245,16 @@ class LanceDBConnection(DBConnection):
|
||||
----------
|
||||
uri: str or Path
|
||||
The root uri of the database.
|
||||
read_consistency_interval: timedelta, default None
|
||||
The interval at which to check for updates to the table from other
|
||||
processes. If None, then consistency is not checked. For performance
|
||||
reasons, this is the default. For strong consistency, set this to
|
||||
zero seconds. Then every read will check for updates from other
|
||||
processes. As a compromise, you can set this to a non-zero timedelta
|
||||
for eventual consistency. If more than that interval has passed since
|
||||
the last check, then the table will be checked for updates. Note: this
|
||||
consistency only applies to read operations. Write operations are
|
||||
always consistent.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -250,22 +262,24 @@ class LanceDBConnection(DBConnection):
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2},
|
||||
... {"vector": [0.5, 1.3], "b": 4}])
|
||||
LanceTable(my_table)
|
||||
LanceTable(connection=..., name="my_table")
|
||||
>>> db.create_table("another_table", data=[{"vector": [0.4, 0.4], "b": 6}])
|
||||
LanceTable(another_table)
|
||||
LanceTable(connection=..., name="another_table")
|
||||
>>> sorted(db.table_names())
|
||||
['another_table', 'my_table']
|
||||
>>> len(db)
|
||||
2
|
||||
>>> db["my_table"]
|
||||
LanceTable(my_table)
|
||||
LanceTable(connection=..., name="my_table")
|
||||
>>> "my_table" in db
|
||||
True
|
||||
>>> db.drop_table("my_table")
|
||||
>>> db.drop_table("another_table")
|
||||
"""
|
||||
|
||||
def __init__(self, uri: URI):
|
||||
def __init__(
|
||||
self, uri: URI, *, read_consistency_interval: Optional[timedelta] = None
|
||||
):
|
||||
if not isinstance(uri, Path):
|
||||
scheme = get_uri_scheme(uri)
|
||||
is_local = isinstance(uri, Path) or scheme == "file"
|
||||
@@ -277,6 +291,14 @@ class LanceDBConnection(DBConnection):
|
||||
self._uri = str(uri)
|
||||
|
||||
self._entered = False
|
||||
self.read_consistency_interval = read_consistency_interval
|
||||
|
||||
def __repr__(self) -> str:
|
||||
val = f"{self.__class__.__name__}({self._uri}"
|
||||
if self.read_consistency_interval is not None:
|
||||
val += f", read_consistency_interval={repr(self.read_consistency_interval)}"
|
||||
val += ")"
|
||||
return val
|
||||
|
||||
@property
|
||||
def uri(self) -> str:
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# limitations under the License.
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import List, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -30,10 +30,21 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
|
||||
name: str = "text-embedding-ada-002"
|
||||
dim: Optional[int] = None
|
||||
|
||||
def ndims(self):
|
||||
# TODO don't hardcode this
|
||||
return 1536
|
||||
return self._ndims
|
||||
|
||||
@cached_property
|
||||
def _ndims(self):
|
||||
if self.name == "text-embedding-ada-002":
|
||||
return 1536
|
||||
elif self.name == "text-embedding-3-large":
|
||||
return self.dim or 3072
|
||||
elif self.name == "text-embedding-3-small":
|
||||
return self.dim or 1536
|
||||
else:
|
||||
raise ValueError(f"Unknown model name {self.name}")
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
@@ -47,7 +58,12 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
The texts to embed
|
||||
"""
|
||||
# TODO retry, rate limit, token limit
|
||||
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
|
||||
if self.name == "text-embedding-ada-002":
|
||||
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
|
||||
else:
|
||||
rs = self._openai_client.embeddings.create(
|
||||
input=texts, model=self.name, dimensions=self.ndims()
|
||||
)
|
||||
return [v.embedding for v in rs.data]
|
||||
|
||||
@cached_property
|
||||
|
||||
@@ -26,10 +26,10 @@ import pyarrow as pa
|
||||
from lance.vector import vec_to_table
|
||||
from retry import retry
|
||||
|
||||
from ..util import safe_import
|
||||
from ..util import safe_import_pandas
|
||||
from ..utils.general import LOGGER
|
||||
|
||||
pd = safe_import("pandas")
|
||||
pd = safe_import_pandas()
|
||||
|
||||
DATA = Union[pa.Table, "pd.DataFrame"]
|
||||
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
|
||||
@@ -32,11 +32,14 @@ class LanceMergeInsertBuilder(object):
|
||||
self._table = table
|
||||
self._on = on
|
||||
self._when_matched_update_all = False
|
||||
self._when_matched_update_all_condition = None
|
||||
self._when_not_matched_insert_all = False
|
||||
self._when_not_matched_by_source_delete = False
|
||||
self._when_not_matched_by_source_condition = None
|
||||
|
||||
def when_matched_update_all(self) -> LanceMergeInsertBuilder:
|
||||
def when_matched_update_all(
|
||||
self, *, where: Optional[str] = None
|
||||
) -> LanceMergeInsertBuilder:
|
||||
"""
|
||||
Rows that exist in both the source table (new data) and
|
||||
the target table (old data) will be updated, replacing
|
||||
@@ -47,6 +50,7 @@ class LanceMergeInsertBuilder(object):
|
||||
but that behavior is subject to change.
|
||||
"""
|
||||
self._when_matched_update_all = True
|
||||
self._when_matched_update_all_condition = where
|
||||
return self
|
||||
|
||||
def when_not_matched_insert_all(self) -> LanceMergeInsertBuilder:
|
||||
|
||||
@@ -304,7 +304,7 @@ class LanceModel(pydantic.BaseModel):
|
||||
... name: str
|
||||
... vector: Vector(2)
|
||||
...
|
||||
>>> db = lancedb.connect("/tmp")
|
||||
>>> db = lancedb.connect("./example")
|
||||
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema())
|
||||
>>> table.add([
|
||||
... TestModel(name="test", vector=[1.0, 2.0])
|
||||
|
||||
@@ -27,7 +27,7 @@ from . import __version__
|
||||
from .common import VEC, VECTOR_COLUMN_NAME
|
||||
from .rerankers.base import Reranker
|
||||
from .rerankers.linear_combination import LinearCombinationReranker
|
||||
from .util import safe_import
|
||||
from .util import safe_import_pandas
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL
|
||||
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
|
||||
from .pydantic import LanceModel
|
||||
from .table import Table
|
||||
|
||||
pd = safe_import("pandas")
|
||||
pd = safe_import_pandas()
|
||||
|
||||
|
||||
class Query(pydantic.BaseModel):
|
||||
@@ -626,7 +626,6 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
def __init__(self, table: "Table", query: str, vector_column: str):
|
||||
super().__init__(table)
|
||||
self._validate_fts_index()
|
||||
self._query = query
|
||||
vector_query, fts_query = self._validate_query(query)
|
||||
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
|
||||
vector_query = self._query_to_vector(table, vector_query, vector_column)
|
||||
@@ -679,12 +678,18 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
# rerankers might need to preserve this score to support `return_score="all"`
|
||||
fts_results = self._normalize_scores(fts_results, "score")
|
||||
|
||||
results = self._reranker.rerank_hybrid(self, vector_results, fts_results)
|
||||
results = self._reranker.rerank_hybrid(
|
||||
self._fts_query._query, vector_results, fts_results
|
||||
)
|
||||
|
||||
if not isinstance(results, pa.Table): # Enforce type
|
||||
raise TypeError(
|
||||
f"rerank_hybrid must return a pyarrow.Table, got {type(results)}"
|
||||
)
|
||||
|
||||
# apply limit after reranking
|
||||
results = results.slice(length=self._limit)
|
||||
|
||||
if not self._with_row_id:
|
||||
results = results.drop(["_rowid"])
|
||||
return results
|
||||
@@ -776,6 +781,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
"""
|
||||
self._vector_query.limit(limit)
|
||||
self._fts_query.limit(limit)
|
||||
self._limit = limit
|
||||
|
||||
return self
|
||||
|
||||
def select(self, columns: list) -> LanceHybridQueryBuilder:
|
||||
|
||||
@@ -118,6 +118,7 @@ class RemoteDBConnection(DBConnection):
|
||||
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
mode: Optional[str] = None,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
) -> Table:
|
||||
"""Create a [Table][lancedb.table.Table] in the database.
|
||||
@@ -215,11 +216,13 @@ class RemoteDBConnection(DBConnection):
|
||||
if data is None and schema is None:
|
||||
raise ValueError("Either data or schema must be provided.")
|
||||
if embedding_functions is not None:
|
||||
raise NotImplementedError(
|
||||
"embedding_functions is not supported for remote databases."
|
||||
logging.warning(
|
||||
"embedding_functions is not yet supported on LanceDB Cloud."
|
||||
"Please vote https://github.com/lancedb/lancedb/issues/626 "
|
||||
"for this feature."
|
||||
)
|
||||
if mode is not None:
|
||||
logging.warning("mode is not yet supported on LanceDB Cloud.")
|
||||
|
||||
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
||||
# convert LanceModel to pyarrow schema
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from functools import cached_property
|
||||
from typing import Dict, Optional, Union
|
||||
@@ -37,6 +38,9 @@ class RemoteTable(Table):
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoteTable({self._conn.db_name}.{self._name})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
self.count_rows(None)
|
||||
|
||||
@cached_property
|
||||
def schema(self) -> pa.Schema:
|
||||
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
|
||||
@@ -54,17 +58,17 @@ class RemoteTable(Table):
|
||||
return resp["version"]
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
"""to_arrow() is not supported on the LanceDB cloud"""
|
||||
raise NotImplementedError("to_arrow() is not supported on the LanceDB cloud")
|
||||
"""to_arrow() is not yet supported on LanceDB cloud."""
|
||||
raise NotImplementedError("to_arrow() is not yet supported on LanceDB cloud.")
|
||||
|
||||
def to_pandas(self):
|
||||
"""to_pandas() is not supported on the LanceDB cloud"""
|
||||
return NotImplementedError("to_pandas() is not supported on the LanceDB cloud")
|
||||
"""to_pandas() is not yet supported on LanceDB cloud."""
|
||||
return NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.")
|
||||
|
||||
def create_scalar_index(self, *args, **kwargs):
|
||||
"""Creates a scalar index"""
|
||||
return NotImplementedError(
|
||||
"create_scalar_index() is not supported on the LanceDB cloud"
|
||||
"create_scalar_index() is not yet supported on LanceDB cloud."
|
||||
)
|
||||
|
||||
def create_index(
|
||||
@@ -72,6 +76,10 @@ class RemoteTable(Table):
|
||||
metric="L2",
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
index_cache_size: Optional[int] = None,
|
||||
num_partitions: Optional[int] = None,
|
||||
num_sub_vectors: Optional[int] = None,
|
||||
replace: Optional[bool] = None,
|
||||
accelerator: Optional[str] = None,
|
||||
):
|
||||
"""Create an index on the table.
|
||||
Currently, the only parameters that matter are
|
||||
@@ -105,6 +113,28 @@ class RemoteTable(Table):
|
||||
... )
|
||||
>>> table.create_index("L2", "vector") # doctest: +SKIP
|
||||
"""
|
||||
|
||||
if num_partitions is not None:
|
||||
logging.warning(
|
||||
"num_partitions is not supported on LanceDB cloud."
|
||||
"This parameter will be tuned automatically."
|
||||
)
|
||||
if num_sub_vectors is not None:
|
||||
logging.warning(
|
||||
"num_sub_vectors is not supported on LanceDB cloud."
|
||||
"This parameter will be tuned automatically."
|
||||
)
|
||||
if accelerator is not None:
|
||||
logging.warning(
|
||||
"GPU accelerator is not yet supported on LanceDB cloud."
|
||||
"If you have 100M+ vectors to index,"
|
||||
"please contact us at contact@lancedb.com"
|
||||
)
|
||||
if replace is not None:
|
||||
logging.warning(
|
||||
"replace is not supported on LanceDB cloud."
|
||||
"Existing indexes will always be replaced."
|
||||
)
|
||||
index_type = "vector"
|
||||
|
||||
data = {
|
||||
@@ -268,6 +298,10 @@ class RemoteTable(Table):
|
||||
)
|
||||
params["on"] = merge._on[0]
|
||||
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
|
||||
if merge._when_matched_update_all_condition is not None:
|
||||
params[
|
||||
"when_matched_update_all_filt"
|
||||
] = merge._when_matched_update_all_condition
|
||||
params["when_not_matched_insert_all"] = str(
|
||||
merge._when_not_matched_insert_all
|
||||
).lower()
|
||||
@@ -409,6 +443,13 @@ class RemoteTable(Table):
|
||||
"compact_files() is not supported on the LanceDB cloud"
|
||||
)
|
||||
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
# payload = {"filter": filter}
|
||||
# self._conn._client.post(f"/v1/table/{self._name}/count_rows/", data=payload)
|
||||
return NotImplementedError(
|
||||
"count_rows() is not yet supported on the LanceDB cloud"
|
||||
)
|
||||
|
||||
|
||||
def add_index(tbl: pa.Table, i: int) -> pa.Table:
|
||||
return tbl.add_column(
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
from .base import Reranker
|
||||
from .cohere import CohereReranker
|
||||
from .colbert import ColbertReranker
|
||||
from .cross_encoder import CrossEncoderReranker
|
||||
from .linear_combination import LinearCombinationReranker
|
||||
from .openai import OpenaiReranker
|
||||
|
||||
__all__ = [
|
||||
"Reranker",
|
||||
"CrossEncoderReranker",
|
||||
"CohereReranker",
|
||||
"LinearCombinationReranker",
|
||||
"OpenaiReranker",
|
||||
"ColbertReranker",
|
||||
]
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import lancedb
|
||||
|
||||
|
||||
class Reranker(ABC):
|
||||
def __init__(self, return_score: str = "relevance"):
|
||||
@@ -30,7 +26,7 @@ class Reranker(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def rerank_hybrid(
|
||||
query_builder: "lancedb.HybridQueryBuilder",
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
@@ -41,8 +37,8 @@ class Reranker(ABC):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query_builder : "lancedb.HybridQueryBuilder"
|
||||
The query builder object that was used to generate the results
|
||||
query : str
|
||||
The input query
|
||||
vector_results : pa.Table
|
||||
The results from the vector search
|
||||
fts_results : pa.Table
|
||||
@@ -50,36 +46,6 @@ class Reranker(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def rerank_vector(
|
||||
query_builder: "lancedb.VectorQueryBuilder", vector_results: pa.Table
|
||||
):
|
||||
"""
|
||||
Rerank function receives the individual results from the vector search.
|
||||
This isn't mandatory to implement
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query_builder : "lancedb.VectorQueryBuilder"
|
||||
The query builder object that was used to generate the results
|
||||
vector_results : pa.Table
|
||||
The results from the vector search
|
||||
"""
|
||||
raise NotImplementedError("Vector Reranking is not implemented")
|
||||
|
||||
def rerank_fts(query_builder: "lancedb.FTSQueryBuilder", fts_results: pa.Table):
|
||||
"""
|
||||
Rerank function receives the individual results from the FTS search.
|
||||
This isn't mandatory to implement
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query_builder : "lancedb.FTSQueryBuilder"
|
||||
The query builder object that was used to generate the results
|
||||
fts_results : pa.Table
|
||||
The results from the FTS search
|
||||
"""
|
||||
raise NotImplementedError("FTS Reranking is not implemented")
|
||||
|
||||
def merge_results(self, vector_results: pa.Table, fts_results: pa.Table):
|
||||
"""
|
||||
Merge the results from the vector and FTS search. This is a vanilla merging
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
import typing
|
||||
from functools import cached_property
|
||||
from typing import Union
|
||||
|
||||
@@ -8,9 +7,6 @@ import pyarrow as pa
|
||||
from ..util import safe_import
|
||||
from .base import Reranker
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import lancedb
|
||||
|
||||
|
||||
class CohereReranker(Reranker):
|
||||
"""
|
||||
@@ -55,14 +51,14 @@ class CohereReranker(Reranker):
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query_builder: "lancedb.HybridQueryBuilder",
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
docs = combined_results[self.column].to_pylist()
|
||||
results = self._client.rerank(
|
||||
query=query_builder._query,
|
||||
query=query,
|
||||
documents=docs,
|
||||
top_n=self.top_n,
|
||||
model=self.model_name,
|
||||
|
||||
107
python/lancedb/rerankers/colbert.py
Normal file
107
python/lancedb/rerankers/colbert.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from functools import cached_property
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import safe_import
|
||||
from .base import Reranker
|
||||
|
||||
|
||||
class ColbertReranker(Reranker):
|
||||
"""
|
||||
Reranks the results using the ColBERT model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str, default "colbert-ir/colbertv2.0"
|
||||
The name of the cross encoder model to use.
|
||||
column : str, default "text"
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
return_score : str, default "relevance"
|
||||
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "colbert-ir/colbertv2.0",
|
||||
column: str = "text",
|
||||
return_score="relevance",
|
||||
):
|
||||
super().__init__(return_score)
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
self.torch = safe_import("torch") # import here for faster ops later
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
docs = combined_results[self.column].to_pylist()
|
||||
|
||||
tokenizer, model = self._model
|
||||
|
||||
# Encode the query
|
||||
query_encoding = tokenizer(query, return_tensors="pt")
|
||||
query_embedding = model(**query_encoding).last_hidden_state.mean(dim=1)
|
||||
scores = []
|
||||
# Get score for each document
|
||||
for document in docs:
|
||||
document_encoding = tokenizer(
|
||||
document, return_tensors="pt", truncation=True, max_length=512
|
||||
)
|
||||
document_embedding = model(**document_encoding).last_hidden_state
|
||||
# Calculate MaxSim score
|
||||
score = self.maxsim(query_embedding.unsqueeze(0), document_embedding)
|
||||
scores.append(score.item())
|
||||
|
||||
# replace the self.column column with the docs
|
||||
combined_results = combined_results.drop(self.column)
|
||||
combined_results = combined_results.append_column(
|
||||
self.column, pa.array(docs, type=pa.string())
|
||||
)
|
||||
# add the scores
|
||||
combined_results = combined_results.append_column(
|
||||
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||
)
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"OpenAI Reranker does not support score='all' yet"
|
||||
)
|
||||
|
||||
combined_results = combined_results.sort_by(
|
||||
[("_relevance_score", "descending")]
|
||||
)
|
||||
|
||||
return combined_results
|
||||
|
||||
@cached_property
|
||||
def _model(self):
|
||||
transformers = safe_import("transformers")
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
|
||||
model = transformers.AutoModel.from_pretrained(self.model_name)
|
||||
|
||||
return tokenizer, model
|
||||
|
||||
def maxsim(self, query_embedding, document_embedding):
|
||||
# Expand dimensions for broadcasting
|
||||
# Query: [batch, length, size] -> [batch, query, 1, size]
|
||||
# Document: [batch, length, size] -> [batch, 1, length, size]
|
||||
expanded_query = query_embedding.unsqueeze(2)
|
||||
expanded_doc = document_embedding.unsqueeze(1)
|
||||
|
||||
# Compute cosine similarity across the embedding dimension
|
||||
sim_matrix = self.torch.nn.functional.cosine_similarity(
|
||||
expanded_query, expanded_doc, dim=-1
|
||||
)
|
||||
|
||||
# Take the maximum similarity for each query token (across all document tokens)
|
||||
# sim_matrix shape: [batch_size, query_length, doc_length]
|
||||
max_sim_scores, _ = self.torch.max(sim_matrix, dim=2)
|
||||
|
||||
# Average these maximum scores across all query tokens
|
||||
avg_max_sim = self.torch.mean(max_sim_scores, dim=1)
|
||||
return avg_max_sim
|
||||
@@ -1,4 +1,3 @@
|
||||
import typing
|
||||
from functools import cached_property
|
||||
from typing import Union
|
||||
|
||||
@@ -7,9 +6,6 @@ import pyarrow as pa
|
||||
from ..util import safe_import
|
||||
from .base import Reranker
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import lancedb
|
||||
|
||||
|
||||
class CrossEncoderReranker(Reranker):
|
||||
"""
|
||||
@@ -52,13 +48,13 @@ class CrossEncoderReranker(Reranker):
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query_builder: "lancedb.HybridQueryBuilder",
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
passages = combined_results[self.column].to_pylist()
|
||||
cross_inp = [[query_builder._query, passage] for passage in passages]
|
||||
cross_inp = [[query, passage] for passage in passages]
|
||||
cross_scores = self.model.predict(cross_inp)
|
||||
combined_results = combined_results.append_column(
|
||||
"_relevance_score", pa.array(cross_scores, type=pa.float32())
|
||||
|
||||
@@ -36,7 +36,7 @@ class LinearCombinationReranker(Reranker):
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query_builder: "lancedb.HybridQueryBuilder", # noqa: F821
|
||||
query: str, # noqa: F821
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
|
||||
102
python/lancedb/rerankers/openai.py
Normal file
102
python/lancedb/rerankers/openai.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import json
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import Optional
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import safe_import
|
||||
from .base import Reranker
|
||||
|
||||
|
||||
class OpenaiReranker(Reranker):
|
||||
"""
|
||||
Reranks the results using the OpenAI API.
|
||||
WARNING: This is a prompt based reranker that uses chat model that is
|
||||
not a dedicated reranker API. This should be treated as experimental.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str, default "gpt-3.5-turbo-1106 "
|
||||
The name of the cross encoder model to use.
|
||||
column : str, default "text"
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
return_score : str, default "relevance"
|
||||
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||
api_key : str, default None
|
||||
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "gpt-3.5-turbo-1106",
|
||||
column: str = "text",
|
||||
return_score="relevance",
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
super().__init__(return_score)
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
self.api_key = api_key
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
docs = combined_results[self.column].to_pylist()
|
||||
response = self._client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are an expert relevance ranker. Given a list of\
|
||||
documents and a query, your job is to determine the relevance\
|
||||
each document is for answering the query. Your output is JSON,\
|
||||
which is a list of documents. Each document has two fields,\
|
||||
content and relevance_score. relevance_score is from 0.0 to\
|
||||
1.0 indicating the relevance of the text to the given query.\
|
||||
Make sure to include all documents in the response.",
|
||||
},
|
||||
{"role": "user", "content": f"Query: {query} Docs: {docs}"},
|
||||
],
|
||||
)
|
||||
results = json.loads(response.choices[0].message.content)["documents"]
|
||||
docs, scores = list(
|
||||
zip(*[(result["content"], result["relevance_score"]) for result in results])
|
||||
) # tuples
|
||||
# replace the self.column column with the docs
|
||||
combined_results = combined_results.drop(self.column)
|
||||
combined_results = combined_results.append_column(
|
||||
self.column, pa.array(docs, type=pa.string())
|
||||
)
|
||||
# add the scores
|
||||
combined_results = combined_results.append_column(
|
||||
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||
)
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"OpenAI Reranker does not support score='all' yet"
|
||||
)
|
||||
|
||||
combined_results = combined_results.sort_by(
|
||||
[("_relevance_score", "descending")]
|
||||
)
|
||||
|
||||
return combined_results
|
||||
|
||||
@cached_property
|
||||
def _client(self):
|
||||
openai = safe_import("openai") # TODO: force version or handle versions < 1.0
|
||||
if os.environ.get("OPENAI_API_KEY") is None and self.api_key is None:
|
||||
raise ValueError(
|
||||
"OPENAI_API_KEY not set. Either set it in your environment or \
|
||||
pass it as `api_key` argument to the CohereReranker."
|
||||
)
|
||||
return openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY") or self.api_key)
|
||||
@@ -14,7 +14,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
@@ -34,22 +37,21 @@ from .query import LanceQueryBuilder, Query
|
||||
from .util import (
|
||||
fs_from_uri,
|
||||
join_uri,
|
||||
safe_import,
|
||||
safe_import_pandas,
|
||||
safe_import_polars,
|
||||
value_to_sql,
|
||||
)
|
||||
from .utils.events import register_event
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
|
||||
import PIL
|
||||
from lance.dataset import CleanupStats, ReaderLike
|
||||
|
||||
from .db import LanceDBConnection
|
||||
|
||||
|
||||
pd = safe_import("pandas")
|
||||
pl = safe_import("polars")
|
||||
pd = safe_import_pandas()
|
||||
pl = safe_import_polars()
|
||||
|
||||
|
||||
def _sanitize_data(
|
||||
@@ -175,6 +177,18 @@ class Table(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
"""
|
||||
Count the number of rows in the table.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filter: str, optional
|
||||
A SQL where clause to filter the rows to count.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_pandas(self) -> "pd.DataFrame":
|
||||
"""Return the table as a pandas DataFrame.
|
||||
|
||||
@@ -298,7 +312,7 @@ class Table(ABC):
|
||||
|
||||
import lance
|
||||
|
||||
dataset = lance.dataset("/tmp/images.lance")
|
||||
dataset = lance.dataset("./images.lance")
|
||||
dataset.create_scalar_index("category")
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -445,7 +459,7 @@ class Table(ABC):
|
||||
*default "vector"*
|
||||
query_type: str
|
||||
*default "auto"*.
|
||||
Acceptable types are: "vector", "fts", or "auto"
|
||||
Acceptable types are: "vector", "fts", "hybrid", or "auto"
|
||||
|
||||
- If "auto" then the query type is inferred from the query;
|
||||
|
||||
@@ -641,23 +655,145 @@ class Table(ABC):
|
||||
"""
|
||||
|
||||
|
||||
class _LanceDatasetRef(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def dataset(self) -> LanceDataset:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dataset_mut(self) -> LanceDataset:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class _LanceLatestDatasetRef(_LanceDatasetRef):
|
||||
"""Reference to the latest version of a LanceDataset."""
|
||||
|
||||
uri: str
|
||||
read_consistency_interval: Optional[timedelta] = None
|
||||
last_consistency_check: Optional[float] = None
|
||||
_dataset: Optional[LanceDataset] = None
|
||||
|
||||
@property
|
||||
def dataset(self) -> LanceDataset:
|
||||
if not self._dataset:
|
||||
self._dataset = lance.dataset(self.uri)
|
||||
self.last_consistency_check = time.monotonic()
|
||||
elif self.read_consistency_interval is not None:
|
||||
now = time.monotonic()
|
||||
diff = timedelta(seconds=now - self.last_consistency_check)
|
||||
if (
|
||||
self.last_consistency_check is None
|
||||
or diff > self.read_consistency_interval
|
||||
):
|
||||
self._dataset = self._dataset.checkout_version(
|
||||
self._dataset.latest_version
|
||||
)
|
||||
self.last_consistency_check = time.monotonic()
|
||||
return self._dataset
|
||||
|
||||
@dataset.setter
|
||||
def dataset(self, value: LanceDataset):
|
||||
self._dataset = value
|
||||
self.last_consistency_check = time.monotonic()
|
||||
|
||||
@property
|
||||
def dataset_mut(self) -> LanceDataset:
|
||||
return self.dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class _LanceTimeTravelRef(_LanceDatasetRef):
|
||||
uri: str
|
||||
version: int
|
||||
_dataset: Optional[LanceDataset] = None
|
||||
|
||||
@property
|
||||
def dataset(self) -> LanceDataset:
|
||||
if not self._dataset:
|
||||
self._dataset = lance.dataset(self.uri, version=self.version)
|
||||
return self._dataset
|
||||
|
||||
@dataset.setter
|
||||
def dataset(self, value: LanceDataset):
|
||||
self._dataset = value
|
||||
self.version = value.version
|
||||
|
||||
@property
|
||||
def dataset_mut(self) -> LanceDataset:
|
||||
raise ValueError(
|
||||
"Cannot mutate table reference fixed at version "
|
||||
f"{self.version}. Call checkout_latest() to get a mutable "
|
||||
"table reference."
|
||||
)
|
||||
|
||||
|
||||
class LanceTable(Table):
|
||||
"""
|
||||
A table in a LanceDB database.
|
||||
|
||||
This can be opened in two modes: standard and time-travel.
|
||||
|
||||
Standard mode is the default. In this mode, the table is mutable and tracks
|
||||
the latest version of the table. The level of read consistency is controlled
|
||||
by the `read_consistency_interval` parameter on the connection.
|
||||
|
||||
Time-travel mode is activated by specifying a version number. In this mode,
|
||||
the table is immutable and fixed to a specific version. This is useful for
|
||||
querying historical versions of the table.
|
||||
"""
|
||||
|
||||
def __init__(self, connection: "LanceDBConnection", name: str, version: int = None):
|
||||
def __init__(
|
||||
self,
|
||||
connection: "LanceDBConnection",
|
||||
name: str,
|
||||
version: Optional[int] = None,
|
||||
):
|
||||
self._conn = connection
|
||||
self.name = name
|
||||
self._version = version
|
||||
|
||||
def _reset_dataset(self, version=None):
|
||||
try:
|
||||
if "_dataset" in self.__dict__:
|
||||
del self.__dict__["_dataset"]
|
||||
self._version = version
|
||||
except AttributeError:
|
||||
pass
|
||||
if version is not None:
|
||||
self._ref = _LanceTimeTravelRef(
|
||||
uri=self._dataset_uri,
|
||||
version=version,
|
||||
)
|
||||
else:
|
||||
self._ref = _LanceLatestDatasetRef(
|
||||
uri=self._dataset_uri,
|
||||
read_consistency_interval=connection.read_consistency_interval,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def open(cls, db, name, **kwargs):
|
||||
tbl = cls(db, name, **kwargs)
|
||||
fs, path = fs_from_uri(tbl._dataset_uri)
|
||||
file_info = fs.get_file_info(path)
|
||||
if file_info.type != pa.fs.FileType.Directory:
|
||||
raise FileNotFoundError(
|
||||
f"Table {name} does not exist."
|
||||
f"Please first call db.create_table({name}, data)"
|
||||
)
|
||||
register_event("open_table")
|
||||
|
||||
return tbl
|
||||
|
||||
@property
|
||||
def _dataset_uri(self) -> str:
|
||||
return join_uri(self._conn.uri, f"{self.name}.lance")
|
||||
|
||||
@property
|
||||
def _dataset(self) -> LanceDataset:
|
||||
return self._ref.dataset
|
||||
|
||||
@property
|
||||
def _dataset_mut(self) -> LanceDataset:
|
||||
return self._ref.dataset_mut
|
||||
|
||||
def to_lance(self) -> LanceDataset:
|
||||
"""Return the LanceDataset backing this table."""
|
||||
return self._dataset
|
||||
|
||||
@property
|
||||
def schema(self) -> pa.Schema:
|
||||
@@ -685,6 +821,9 @@ class LanceTable(Table):
|
||||
keep writing to the dataset starting from an old version, then use
|
||||
the `restore` function.
|
||||
|
||||
Calling this method will set the table into time-travel mode. If you
|
||||
wish to return to standard mode, call `checkout_latest`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
version : int
|
||||
@@ -709,15 +848,13 @@ class LanceTable(Table):
|
||||
vector type
|
||||
0 [1.1, 0.9] vector
|
||||
"""
|
||||
max_ver = max([v["version"] for v in self._dataset.versions()])
|
||||
max_ver = self._dataset.latest_version
|
||||
if version < 1 or version > max_ver:
|
||||
raise ValueError(f"Invalid version {version}")
|
||||
self._reset_dataset(version=version)
|
||||
|
||||
try:
|
||||
# Accessing the property updates the cached value
|
||||
_ = self._dataset
|
||||
except Exception as e:
|
||||
ds = self._dataset.checkout_version(version)
|
||||
except IOError as e:
|
||||
if "not found" in str(e):
|
||||
raise ValueError(
|
||||
f"Version {version} no longer exists. Was it cleaned up?"
|
||||
@@ -725,6 +862,27 @@ class LanceTable(Table):
|
||||
else:
|
||||
raise e
|
||||
|
||||
self._ref = _LanceTimeTravelRef(
|
||||
uri=self._dataset_uri,
|
||||
version=version,
|
||||
)
|
||||
# We've already loaded the version so we can populate it directly.
|
||||
self._ref.dataset = ds
|
||||
|
||||
def checkout_latest(self):
|
||||
"""Checkout the latest version of the table. This is an in-place operation.
|
||||
|
||||
The table will be set back into standard mode, and will track the latest
|
||||
version of the table.
|
||||
"""
|
||||
self.checkout(self._dataset.latest_version)
|
||||
ds = self._ref.dataset
|
||||
self._ref = _LanceLatestDatasetRef(
|
||||
uri=self._dataset_uri,
|
||||
read_consistency_interval=self._conn.read_consistency_interval,
|
||||
)
|
||||
self._ref.dataset = ds
|
||||
|
||||
def restore(self, version: int = None):
|
||||
"""Restore a version of the table. This is an in-place operation.
|
||||
|
||||
@@ -759,7 +917,7 @@ class LanceTable(Table):
|
||||
>>> len(table.list_versions())
|
||||
4
|
||||
"""
|
||||
max_ver = max([v["version"] for v in self._dataset.versions()])
|
||||
max_ver = self._dataset.latest_version
|
||||
if version is None:
|
||||
version = self.version
|
||||
elif version < 1 or version > max_ver:
|
||||
@@ -767,29 +925,30 @@ class LanceTable(Table):
|
||||
else:
|
||||
self.checkout(version)
|
||||
|
||||
if version == max_ver:
|
||||
# no-op if restoring the latest version
|
||||
return
|
||||
ds = self._dataset
|
||||
|
||||
self._dataset.restore()
|
||||
self._reset_dataset()
|
||||
# no-op if restoring the latest version
|
||||
if version != max_ver:
|
||||
ds.restore()
|
||||
|
||||
self._ref = _LanceLatestDatasetRef(
|
||||
uri=self._dataset_uri,
|
||||
read_consistency_interval=self._conn.read_consistency_interval,
|
||||
)
|
||||
self._ref.dataset = ds
|
||||
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
"""
|
||||
Count the number of rows in the table.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filter: str, optional
|
||||
A SQL where clause to filter the rows to count.
|
||||
"""
|
||||
return self._dataset.count_rows(filter)
|
||||
|
||||
def __len__(self):
|
||||
return self.count_rows()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"LanceTable({self.name})"
|
||||
val = f'{self.__class__.__name__}(connection={self._conn!r}, name="{self.name}"'
|
||||
if isinstance(self._ref, _LanceTimeTravelRef):
|
||||
val += f", version={self._ref.version}"
|
||||
val += ")"
|
||||
return val
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__repr__()
|
||||
@@ -839,10 +998,6 @@ class LanceTable(Table):
|
||||
self.to_lance(), allow_pyarrow_filter=False, batch_size=batch_size
|
||||
)
|
||||
|
||||
@property
|
||||
def _dataset_uri(self) -> str:
|
||||
return join_uri(self._conn.uri, f"{self.name}.lance")
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
metric="L2",
|
||||
@@ -854,7 +1009,7 @@ class LanceTable(Table):
|
||||
index_cache_size: Optional[int] = None,
|
||||
):
|
||||
"""Create an index on the table."""
|
||||
self._dataset.create_index(
|
||||
self._dataset_mut.create_index(
|
||||
column=vector_column_name,
|
||||
index_type="IVF_PQ",
|
||||
metric=metric,
|
||||
@@ -864,11 +1019,12 @@ class LanceTable(Table):
|
||||
accelerator=accelerator,
|
||||
index_cache_size=index_cache_size,
|
||||
)
|
||||
self._reset_dataset()
|
||||
register_event("create_index")
|
||||
|
||||
def create_scalar_index(self, column: str, *, replace: bool = True):
|
||||
self._dataset.create_scalar_index(column, index_type="BTREE", replace=replace)
|
||||
self._dataset_mut.create_scalar_index(
|
||||
column, index_type="BTREE", replace=replace
|
||||
)
|
||||
|
||||
def create_fts_index(
|
||||
self,
|
||||
@@ -911,14 +1067,6 @@ class LanceTable(Table):
|
||||
def _get_fts_index_path(self):
|
||||
return join_uri(self._dataset_uri, "_indices", "tantivy")
|
||||
|
||||
@cached_property
|
||||
def _dataset(self) -> LanceDataset:
|
||||
return lance.dataset(self._dataset_uri, version=self._version)
|
||||
|
||||
def to_lance(self) -> LanceDataset:
|
||||
"""Return the LanceDataset backing this table."""
|
||||
return self._dataset
|
||||
|
||||
def add(
|
||||
self,
|
||||
data: DATA,
|
||||
@@ -957,8 +1105,11 @@ class LanceTable(Table):
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
lance.write_dataset(data, self._dataset_uri, schema=self.schema, mode=mode)
|
||||
self._reset_dataset()
|
||||
# Access the dataset_mut property to ensure that the dataset is mutable.
|
||||
self._ref.dataset_mut
|
||||
self._ref.dataset = lance.write_dataset(
|
||||
data, self._dataset_uri, schema=self.schema, mode=mode
|
||||
)
|
||||
register_event("add")
|
||||
|
||||
def merge(
|
||||
@@ -1019,10 +1170,9 @@ class LanceTable(Table):
|
||||
other_table = other_table.to_lance()
|
||||
if isinstance(other_table, LanceDataset):
|
||||
other_table = other_table.to_table()
|
||||
self._dataset.merge(
|
||||
self._ref.dataset = self._dataset_mut.merge(
|
||||
other_table, left_on=left_on, right_on=right_on, schema=schema
|
||||
)
|
||||
self._reset_dataset()
|
||||
register_event("merge")
|
||||
|
||||
@cached_property
|
||||
@@ -1225,22 +1375,8 @@ class LanceTable(Table):
|
||||
register_event("create_table")
|
||||
return new_table
|
||||
|
||||
@classmethod
|
||||
def open(cls, db, name):
|
||||
tbl = cls(db, name)
|
||||
fs, path = fs_from_uri(tbl._dataset_uri)
|
||||
file_info = fs.get_file_info(path)
|
||||
if file_info.type != pa.fs.FileType.Directory:
|
||||
raise FileNotFoundError(
|
||||
f"Table {name} does not exist."
|
||||
f"Please first call db.create_table({name}, data)"
|
||||
)
|
||||
register_event("open_table")
|
||||
|
||||
return tbl
|
||||
|
||||
def delete(self, where: str):
|
||||
self._dataset.delete(where)
|
||||
self._dataset_mut.delete(where)
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -1294,8 +1430,7 @@ class LanceTable(Table):
|
||||
if values is not None:
|
||||
values_sql = {k: value_to_sql(v) for k, v in values.items()}
|
||||
|
||||
self.to_lance().update(values_sql, where)
|
||||
self._reset_dataset()
|
||||
self._dataset_mut.update(values_sql, where)
|
||||
register_event("update")
|
||||
|
||||
def _execute_query(self, query: Query) -> pa.Table:
|
||||
@@ -1332,7 +1467,7 @@ class LanceTable(Table):
|
||||
ds = self.to_lance()
|
||||
builder = ds.merge_insert(merge._on)
|
||||
if merge._when_matched_update_all:
|
||||
builder.when_matched_update_all()
|
||||
builder.when_matched_update_all(merge._when_matched_update_all_condition)
|
||||
if merge._when_not_matched_insert_all:
|
||||
builder.when_not_matched_insert_all()
|
||||
if merge._when_not_matched_by_source_delete:
|
||||
|
||||
@@ -134,6 +134,24 @@ def safe_import(module: str, mitigation=None):
|
||||
raise ImportError(f"Please install {mitigation or module}")
|
||||
|
||||
|
||||
def safe_import_pandas():
|
||||
try:
|
||||
import pandas as pd
|
||||
|
||||
return pd
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def safe_import_polars():
|
||||
try:
|
||||
import polars as pl
|
||||
|
||||
return pl
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
@singledispatch
|
||||
def value_to_sql(value):
|
||||
raise NotImplementedError("SQL conversion is not implemented for this type")
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
[project]
|
||||
name = "lancedb"
|
||||
version = "0.5.2"
|
||||
version = "0.5.4"
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.9.12",
|
||||
"pylance==0.9.15",
|
||||
"ratelimiter~=1.0",
|
||||
"retry>=0.9.2",
|
||||
"tqdm>=4.27.0",
|
||||
@@ -48,7 +48,7 @@ classifiers = [
|
||||
repository = "https://github.com/lancedb/lancedb"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars"]
|
||||
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars>=0.19"]
|
||||
dev = ["ruff", "pre-commit"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
|
||||
@@ -88,6 +88,7 @@ def test_embedding_function(tmp_path):
|
||||
assert np.allclose(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_embedding_function_rate_limit(tmp_path):
|
||||
def _get_schema_from_model(model):
|
||||
class Schema(LanceModel):
|
||||
|
||||
@@ -23,11 +23,6 @@ import lancedb
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
try:
|
||||
if importlib.util.find_spec("mlx.core") is not None:
|
||||
_mlx = True
|
||||
except ImportError:
|
||||
_mlx = None
|
||||
# These are integration tests for embedding functions.
|
||||
# They are slow because they require downloading models
|
||||
# or connection to external api
|
||||
@@ -210,6 +205,13 @@ def test_gemini_embedding(tmp_path):
|
||||
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||
|
||||
|
||||
try:
|
||||
if importlib.util.find_spec("mlx.core") is not None:
|
||||
_mlx = True
|
||||
except ImportError:
|
||||
_mlx = None
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
_mlx is None,
|
||||
reason="mlx tests only required for apple users.",
|
||||
@@ -266,3 +268,49 @@ def test_bedrock_embedding(tmp_path):
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
||||
)
|
||||
def test_openai_embedding(tmp_path):
|
||||
def _get_table(model):
|
||||
class TextModel(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
return tbl
|
||||
|
||||
model = get_registry().get("openai").create(max_retries=0)
|
||||
tbl = _get_table(model)
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||
|
||||
model = (
|
||||
get_registry()
|
||||
.get("openai")
|
||||
.create(max_retries=0, name="text-embedding-3-large")
|
||||
)
|
||||
tbl = _get_table(model)
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||
|
||||
model = (
|
||||
get_registry()
|
||||
.get("openai")
|
||||
.create(max_retries=0, name="text-embedding-3-large", dim=1024)
|
||||
)
|
||||
tbl = _get_table(model)
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||
|
||||
@@ -7,7 +7,12 @@ import lancedb
|
||||
from lancedb.conftest import MockTextEmbeddingFunction # noqa
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.rerankers import CohereReranker, CrossEncoderReranker
|
||||
from lancedb.rerankers import (
|
||||
CohereReranker,
|
||||
ColbertReranker,
|
||||
CrossEncoderReranker,
|
||||
OpenaiReranker,
|
||||
)
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
|
||||
@@ -75,7 +80,6 @@ def get_test_table(tmp_path):
|
||||
return table, MyTable
|
||||
|
||||
|
||||
## These tests are pretty loose, we should also check for correctness
|
||||
def test_linear_combination(tmp_path):
|
||||
table, schema = get_test_table(tmp_path)
|
||||
# The default reranker
|
||||
@@ -95,14 +99,19 @@ def test_linear_combination(tmp_path):
|
||||
|
||||
assert result1 == result3 # 2 & 3 should be the same as they use score as score
|
||||
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.limit(50)
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(normalize="score")
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _score column of the results returned by the reranker "
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
@@ -122,19 +131,24 @@ def test_cohere_reranker(tmp_path):
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="rank", reranker=CohereReranker())
|
||||
.rerank(reranker=CohereReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.limit(50)
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(reranker=CohereReranker())
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _score column of the results returned by the reranker "
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
@@ -150,19 +164,96 @@ def test_cross_encoder_reranker(tmp_path):
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="rank", reranker=CrossEncoderReranker())
|
||||
.rerank(reranker=CrossEncoderReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
# test explicit hybrid query
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.limit(50)
|
||||
table.search((query_vector, query), query_type="hybrid")
|
||||
.limit(30)
|
||||
.rerank(reranker=CrossEncoderReranker())
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _score column of the results returned by the reranker "
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
|
||||
|
||||
def test_colbert_reranker(tmp_path):
|
||||
pytest.importorskip("transformers")
|
||||
table, schema = get_test_table(tmp_path)
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=ColbertReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(reranker=ColbertReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
# test explicit hybrid query
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(reranker=ColbertReranker())
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
||||
)
|
||||
def test_openai_reranker(tmp_path):
|
||||
pytest.importorskip("openai")
|
||||
table, schema = get_test_table(tmp_path)
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=OpenaiReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(reranker=OpenaiReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
# test explicit hybrid query
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(reranker=OpenaiReranker())
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from copy import copy
|
||||
from datetime import date, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
from typing import List
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
@@ -25,6 +27,7 @@ import pyarrow as pa
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
import lancedb
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
@@ -35,6 +38,7 @@ from lancedb.table import LanceTable
|
||||
class MockDB:
|
||||
def __init__(self, uri: Path):
|
||||
self.uri = uri
|
||||
self.read_consistency_interval = None
|
||||
|
||||
@functools.cached_property
|
||||
def is_managed_remote(self) -> bool:
|
||||
@@ -267,39 +271,38 @@ def test_versioning(db):
|
||||
|
||||
|
||||
def test_create_index_method():
|
||||
with patch.object(LanceTable, "_reset_dataset", return_value=None):
|
||||
with patch.object(
|
||||
LanceTable, "_dataset", new_callable=PropertyMock
|
||||
) as mock_dataset:
|
||||
# Setup mock responses
|
||||
mock_dataset.return_value.create_index.return_value = None
|
||||
with patch.object(
|
||||
LanceTable, "_dataset_mut", new_callable=PropertyMock
|
||||
) as mock_dataset:
|
||||
# Setup mock responses
|
||||
mock_dataset.return_value.create_index.return_value = None
|
||||
|
||||
# Create a LanceTable object
|
||||
connection = LanceDBConnection(uri="mock.uri")
|
||||
table = LanceTable(connection, "test_table")
|
||||
# Create a LanceTable object
|
||||
connection = LanceDBConnection(uri="mock.uri")
|
||||
table = LanceTable(connection, "test_table")
|
||||
|
||||
# Call the create_index method
|
||||
table.create_index(
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
vector_column_name="vector",
|
||||
replace=True,
|
||||
index_cache_size=256,
|
||||
)
|
||||
# Call the create_index method
|
||||
table.create_index(
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
vector_column_name="vector",
|
||||
replace=True,
|
||||
index_cache_size=256,
|
||||
)
|
||||
|
||||
# Check that the _dataset.create_index method was called
|
||||
# with the right parameters
|
||||
mock_dataset.return_value.create_index.assert_called_once_with(
|
||||
column="vector",
|
||||
index_type="IVF_PQ",
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
replace=True,
|
||||
accelerator=None,
|
||||
index_cache_size=256,
|
||||
)
|
||||
# Check that the _dataset.create_index method was called
|
||||
# with the right parameters
|
||||
mock_dataset.return_value.create_index.assert_called_once_with(
|
||||
column="vector",
|
||||
index_type="IVF_PQ",
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
replace=True,
|
||||
accelerator=None,
|
||||
index_cache_size=256,
|
||||
)
|
||||
|
||||
|
||||
def test_add_with_nans(db):
|
||||
@@ -510,8 +513,15 @@ def test_merge_insert(db):
|
||||
).when_matched_update_all().when_not_matched_insert_all().execute(new_data)
|
||||
|
||||
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "x", "y", "z"]})
|
||||
# These `sort_by` calls can be removed once lance#1892
|
||||
# is merged (it fixes the ordering)
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
|
||||
table.restore(version)
|
||||
|
||||
# conditional update
|
||||
table.merge_insert("a").when_matched_update_all(where="target.b = 'b'").execute(
|
||||
new_data
|
||||
)
|
||||
expected = pa.table({"a": [1, 2, 3], "b": ["a", "x", "c"]})
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
|
||||
table.restore(version)
|
||||
@@ -792,3 +802,48 @@ def test_hybrid_search(db):
|
||||
"Our father who art in heaven", query_type="hybrid"
|
||||
).to_pydantic(MyTable)
|
||||
assert result1 == result3
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"consistency_interval", [None, timedelta(seconds=0), timedelta(seconds=0.1)]
|
||||
)
|
||||
def test_consistency(tmp_path, consistency_interval):
|
||||
db = lancedb.connect(tmp_path)
|
||||
table = LanceTable.create(db, "my_table", data=[{"id": 0}])
|
||||
|
||||
db2 = lancedb.connect(tmp_path, read_consistency_interval=consistency_interval)
|
||||
table2 = db2.open_table("my_table")
|
||||
assert table2.version == table.version
|
||||
|
||||
table.add([{"id": 1}])
|
||||
|
||||
if consistency_interval is None:
|
||||
assert table2.version == table.version - 1
|
||||
table2.checkout_latest()
|
||||
assert table2.version == table.version
|
||||
elif consistency_interval == timedelta(seconds=0):
|
||||
assert table2.version == table.version
|
||||
else:
|
||||
# (consistency_interval == timedelta(seconds=0.1)
|
||||
assert table2.version == table.version - 1
|
||||
sleep(0.1)
|
||||
assert table2.version == table.version
|
||||
|
||||
|
||||
def test_restore_consistency(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
table = LanceTable.create(db, "my_table", data=[{"id": 0}])
|
||||
|
||||
db2 = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0))
|
||||
table2 = db2.open_table("my_table")
|
||||
assert table2.version == table.version
|
||||
|
||||
# If we call checkout, it should lose consistency
|
||||
table_fixed = copy(table2)
|
||||
table_fixed.checkout(table.version)
|
||||
# But if we call checkout_latest, it should be consistent again
|
||||
table_ref_latest = copy(table_fixed)
|
||||
table_ref_latest.checkout_latest()
|
||||
table.add([{"id": 2}])
|
||||
assert table_fixed.version == table.version - 1
|
||||
assert table_ref_latest.version == table.version
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
[package]
|
||||
name = "vectordb-node"
|
||||
version = "0.4.7"
|
||||
version = "0.4.8"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
license = "Apache-2.0"
|
||||
edition = "2018"
|
||||
license.workspace = true
|
||||
edition.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
exclude = ["index.node"]
|
||||
|
||||
[lib]
|
||||
|
||||
@@ -22,7 +22,7 @@ use arrow_schema::SchemaRef;
|
||||
|
||||
use crate::error::Result;
|
||||
|
||||
pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBatch>, SchemaRef)> {
|
||||
pub fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBatch>, SchemaRef)> {
|
||||
let mut batches: Vec<RecordBatch> = Vec::new();
|
||||
let file_reader = FileReader::try_new(Cursor::new(slice), None)?;
|
||||
let schema = file_reader.schema();
|
||||
@@ -33,7 +33,7 @@ pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBa
|
||||
Ok((batches, schema))
|
||||
}
|
||||
|
||||
pub(crate) fn record_batch_to_buffer(batches: Vec<RecordBatch>) -> Result<Vec<u8>> {
|
||||
pub fn record_batch_to_buffer(batches: Vec<RecordBatch>) -> Result<Vec<u8>> {
|
||||
if batches.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ use neon::types::buffer::TypedArray;
|
||||
|
||||
use crate::error::ResultExt;
|
||||
|
||||
pub(crate) fn vec_str_to_array<'a, C: Context<'a>>(
|
||||
pub fn vec_str_to_array<'a, C: Context<'a>>(
|
||||
vec: &Vec<String>,
|
||||
cx: &mut C,
|
||||
) -> JsResult<'a, JsArray> {
|
||||
@@ -29,7 +29,7 @@ pub(crate) fn vec_str_to_array<'a, C: Context<'a>>(
|
||||
Ok(a)
|
||||
}
|
||||
|
||||
pub(crate) fn js_array_to_vec(array: &JsArray, cx: &mut FunctionContext) -> Vec<f32> {
|
||||
pub fn js_array_to_vec(array: &JsArray, cx: &mut FunctionContext) -> Vec<f32> {
|
||||
let mut query_vec: Vec<f32> = Vec::new();
|
||||
for i in 0..array.len(cx) {
|
||||
let entry: Handle<JsNumber> = array.get(cx, i).unwrap();
|
||||
@@ -39,7 +39,7 @@ pub(crate) fn js_array_to_vec(array: &JsArray, cx: &mut FunctionContext) -> Vec<
|
||||
}
|
||||
|
||||
// Creates a new JsBuffer from a rust buffer with a special logic for electron
|
||||
pub(crate) fn new_js_buffer<'a>(
|
||||
pub fn new_js_buffer<'a>(
|
||||
buffer: Vec<u8>,
|
||||
cx: &mut TaskContext<'a>,
|
||||
is_electron: bool,
|
||||
|
||||
@@ -18,7 +18,6 @@ use neon::prelude::NeonResult;
|
||||
use snafu::Snafu;
|
||||
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(visibility(pub(crate)))]
|
||||
pub enum Error {
|
||||
#[snafu(display("column '{name}' is missing"))]
|
||||
MissingColumn { name: String },
|
||||
|
||||
@@ -21,7 +21,7 @@ use neon::{
|
||||
use crate::{error::ResultExt, runtime, table::JsTable};
|
||||
use vectordb::Table;
|
||||
|
||||
pub(crate) fn table_create_scalar_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
pub fn table_create_scalar_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let column = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
let replace = cx.argument::<JsBoolean>(1)?.value(&mut cx);
|
||||
|
||||
@@ -24,7 +24,7 @@ use crate::neon_ext::js_object_ext::JsObjectExt;
|
||||
use crate::runtime;
|
||||
use crate::table::JsTable;
|
||||
|
||||
pub(crate) fn table_create_vector_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
pub fn table_create_vector_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let index_params = cx.argument::<JsObject>(0)?;
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ use crate::neon_ext::js_object_ext::JsObjectExt;
|
||||
use crate::table::JsTable;
|
||||
use crate::{convert, runtime};
|
||||
|
||||
pub(crate) struct JsQuery {}
|
||||
pub struct JsQuery {}
|
||||
|
||||
impl JsQuery {
|
||||
pub(crate) fn js_search(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
|
||||
@@ -28,7 +28,7 @@ use vectordb::TableRef;
|
||||
use crate::error::ResultExt;
|
||||
use crate::{convert, get_aws_credential_provider, get_aws_region, runtime, JsDatabase};
|
||||
|
||||
pub(crate) struct JsTable {
|
||||
pub struct JsTable {
|
||||
pub table: TableRef,
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ impl Finalize for JsTable {}
|
||||
|
||||
impl From<TableRef> for JsTable {
|
||||
fn from(table: TableRef) -> Self {
|
||||
JsTable { table }
|
||||
Self { table }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,14 +85,14 @@ impl JsTable {
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let table = table_rst.or_throw(&mut cx)?;
|
||||
Ok(cx.boxed(JsTable::from(table)))
|
||||
Ok(cx.boxed(Self::from(table)))
|
||||
});
|
||||
});
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
pub(crate) fn js_add(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||
let buffer = cx.argument::<JsBuffer>(0)?;
|
||||
let write_mode = cx.argument::<JsString>(1)?.value(&mut cx);
|
||||
let (batches, schema) =
|
||||
@@ -125,21 +125,34 @@ impl JsTable {
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
add_result.or_throw(&mut cx)?;
|
||||
Ok(cx.boxed(JsTable::from(table)))
|
||||
Ok(cx.boxed(Self::from(table)))
|
||||
});
|
||||
});
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
pub(crate) fn js_count_rows(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||
let filter = cx
|
||||
.argument_opt(0)
|
||||
.and_then(|filt| {
|
||||
if filt.is_a::<JsUndefined, _>(&mut cx) || filt.is_a::<JsNull, _>(&mut cx) {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
filt.downcast_or_throw::<JsString, _>(&mut cx)
|
||||
.map(|js_filt| js_filt.deref().value(&mut cx)),
|
||||
)
|
||||
}
|
||||
})
|
||||
.transpose()?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
let channel = cx.channel();
|
||||
let table = js_table.table.clone();
|
||||
|
||||
rt.spawn(async move {
|
||||
let num_rows_result = table.count_rows().await;
|
||||
let num_rows_result = table.count_rows(filter).await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let num_rows = num_rows_result.or_throw(&mut cx)?;
|
||||
@@ -150,7 +163,7 @@ impl JsTable {
|
||||
}
|
||||
|
||||
pub(crate) fn js_delete(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
let predicate = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
@@ -162,14 +175,14 @@ impl JsTable {
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
delete_result.or_throw(&mut cx)?;
|
||||
Ok(cx.boxed(JsTable::from(table)))
|
||||
Ok(cx.boxed(Self::from(table)))
|
||||
})
|
||||
});
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
pub(crate) fn js_merge_insert(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
let channel = cx.channel();
|
||||
@@ -178,28 +191,34 @@ impl JsTable {
|
||||
let key = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
let mut builder = table.merge_insert(&[&key]);
|
||||
if cx.argument::<JsBoolean>(1)?.value(&mut cx) {
|
||||
builder.when_matched_update_all();
|
||||
}
|
||||
if cx.argument::<JsBoolean>(2)?.value(&mut cx) {
|
||||
builder.when_not_matched_insert_all();
|
||||
let filter = cx.argument_opt(2).unwrap();
|
||||
if filter.is_a::<JsNull, _>(&mut cx) {
|
||||
builder.when_matched_update_all(None);
|
||||
} else {
|
||||
let filter = filter
|
||||
.downcast_or_throw::<JsString, _>(&mut cx)?
|
||||
.deref()
|
||||
.value(&mut cx);
|
||||
builder.when_matched_update_all(Some(filter));
|
||||
}
|
||||
}
|
||||
if cx.argument::<JsBoolean>(3)?.value(&mut cx) {
|
||||
if let Some(filter) = cx.argument_opt(4) {
|
||||
if filter.is_a::<JsNull, _>(&mut cx) {
|
||||
builder.when_not_matched_by_source_delete(None);
|
||||
} else {
|
||||
let filter = filter
|
||||
.downcast_or_throw::<JsString, _>(&mut cx)?
|
||||
.deref()
|
||||
.value(&mut cx);
|
||||
builder.when_not_matched_by_source_delete(Some(filter));
|
||||
}
|
||||
} else {
|
||||
builder.when_not_matched_insert_all();
|
||||
}
|
||||
if cx.argument::<JsBoolean>(4)?.value(&mut cx) {
|
||||
let filter = cx.argument_opt(5).unwrap();
|
||||
if filter.is_a::<JsNull, _>(&mut cx) {
|
||||
builder.when_not_matched_by_source_delete(None);
|
||||
} else {
|
||||
let filter = filter
|
||||
.downcast_or_throw::<JsString, _>(&mut cx)?
|
||||
.deref()
|
||||
.value(&mut cx);
|
||||
builder.when_not_matched_by_source_delete(Some(filter));
|
||||
}
|
||||
}
|
||||
|
||||
let buffer = cx.argument::<JsBuffer>(5)?;
|
||||
let buffer = cx.argument::<JsBuffer>(6)?;
|
||||
let (batches, schema) =
|
||||
arrow_buffer_to_record_batch(buffer.as_slice(&cx)).or_throw(&mut cx)?;
|
||||
|
||||
@@ -209,14 +228,14 @@ impl JsTable {
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
merge_insert_result.or_throw(&mut cx)?;
|
||||
Ok(cx.boxed(JsTable::from(table)))
|
||||
Ok(cx.boxed(Self::from(table)))
|
||||
})
|
||||
});
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
pub(crate) fn js_update(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||
let table = js_table.table.clone();
|
||||
|
||||
let rt = runtime(&mut cx)?;
|
||||
@@ -275,7 +294,7 @@ impl JsTable {
|
||||
.await;
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
update_result.or_throw(&mut cx)?;
|
||||
Ok(cx.boxed(JsTable::from(table)))
|
||||
Ok(cx.boxed(Self::from(table)))
|
||||
})
|
||||
});
|
||||
|
||||
@@ -283,7 +302,7 @@ impl JsTable {
|
||||
}
|
||||
|
||||
pub(crate) fn js_cleanup(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
let table = js_table.table.clone();
|
||||
@@ -321,7 +340,7 @@ impl JsTable {
|
||||
let old_versions = cx.number(prune_stats.old_versions as f64);
|
||||
output_metrics.set(&mut cx, "oldVersions", old_versions)?;
|
||||
|
||||
let output_table = cx.boxed(JsTable::from(table));
|
||||
let output_table = cx.boxed(Self::from(table));
|
||||
|
||||
let output = JsObject::new(&mut cx);
|
||||
output.set(&mut cx, "metrics", output_metrics)?;
|
||||
@@ -334,7 +353,7 @@ impl JsTable {
|
||||
}
|
||||
|
||||
pub(crate) fn js_compact(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
let table = js_table.table.clone();
|
||||
@@ -393,7 +412,7 @@ impl JsTable {
|
||||
let files_added = cx.number(stats.files_added as f64);
|
||||
output_metrics.set(&mut cx, "filesAdded", files_added)?;
|
||||
|
||||
let output_table = cx.boxed(JsTable::from(table));
|
||||
let output_table = cx.boxed(Self::from(table));
|
||||
|
||||
let output = JsObject::new(&mut cx);
|
||||
output.set(&mut cx, "metrics", output_metrics)?;
|
||||
@@ -406,7 +425,7 @@ impl JsTable {
|
||||
}
|
||||
|
||||
pub(crate) fn js_list_indices(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
// let predicate = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
@@ -445,7 +464,7 @@ impl JsTable {
|
||||
}
|
||||
|
||||
pub(crate) fn js_index_stats(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
let index_uuid = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
@@ -493,7 +512,7 @@ impl JsTable {
|
||||
}
|
||||
|
||||
pub(crate) fn js_schema(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
let channel = cx.channel();
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
[package]
|
||||
name = "vectordb"
|
||||
version = "0.4.7"
|
||||
edition = "2021"
|
||||
version = "0.4.8"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license = "Apache-2.0"
|
||||
repository = "https://github.com/lancedb/lancedb"
|
||||
keywords = ["lancedb", "lance", "database", "search"]
|
||||
categories = ["database-implementations"]
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
[dependencies]
|
||||
|
||||
@@ -188,12 +188,12 @@ impl Database {
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [Database] object.
|
||||
pub async fn connect(uri: &str) -> Result<Database> {
|
||||
pub async fn connect(uri: &str) -> Result<Self> {
|
||||
let options = ConnectOptions::new(uri);
|
||||
Self::connect_with_options(&options).await
|
||||
}
|
||||
|
||||
pub async fn connect_with_options(options: &ConnectOptions) -> Result<Database> {
|
||||
pub async fn connect_with_options(options: &ConnectOptions) -> Result<Self> {
|
||||
let uri = &options.uri;
|
||||
let parse_res = url::Url::parse(uri);
|
||||
|
||||
@@ -276,7 +276,7 @@ impl Database {
|
||||
None => None,
|
||||
};
|
||||
|
||||
Ok(Database {
|
||||
Ok(Self {
|
||||
uri: table_base_uri,
|
||||
query_string,
|
||||
base_path,
|
||||
@@ -288,7 +288,7 @@ impl Database {
|
||||
}
|
||||
}
|
||||
|
||||
async fn open_path(path: &str) -> Result<Database> {
|
||||
async fn open_path(path: &str) -> Result<Self> {
|
||||
let (object_store, base_path) = ObjectStore::from_uri(path).await?;
|
||||
if object_store.is_local() {
|
||||
Self::try_create_dir(path).context(CreateDirSnafu { path })?;
|
||||
|
||||
@@ -69,7 +69,7 @@ pub struct IndexBuilder {
|
||||
|
||||
impl IndexBuilder {
|
||||
pub(crate) fn new(table: Arc<dyn Table>, columns: &[&str]) -> Self {
|
||||
IndexBuilder {
|
||||
Self {
|
||||
table,
|
||||
columns: columns.iter().map(|c| c.to_string()).collect(),
|
||||
name: None,
|
||||
@@ -197,7 +197,7 @@ impl IndexBuilder {
|
||||
let num_partitions = if let Some(n) = self.num_partitions {
|
||||
n
|
||||
} else {
|
||||
suggested_num_partitions(self.table.count_rows().await?)
|
||||
suggested_num_partitions(self.table.count_rows(None).await?)
|
||||
};
|
||||
let num_sub_vectors: u32 = if let Some(n) = self.num_sub_vectors {
|
||||
n
|
||||
|
||||
@@ -23,13 +23,13 @@ pub struct VectorIndex {
|
||||
}
|
||||
|
||||
impl VectorIndex {
|
||||
pub fn new_from_format(manifest: &Manifest, index: &Index) -> VectorIndex {
|
||||
pub fn new_from_format(manifest: &Manifest, index: &Index) -> Self {
|
||||
let fields = index
|
||||
.fields
|
||||
.iter()
|
||||
.map(|i| manifest.schema.fields[*i as usize].name.clone())
|
||||
.collect();
|
||||
VectorIndex {
|
||||
Self {
|
||||
columns: fields,
|
||||
index_name: index.name.clone(),
|
||||
index_uuid: index.uuid.to_string(),
|
||||
|
||||
@@ -372,7 +372,7 @@ mod test {
|
||||
// leave this here for easy debugging
|
||||
let t = res.unwrap();
|
||||
|
||||
assert_eq!(t.count_rows().await.unwrap(), 100);
|
||||
assert_eq!(t.count_rows(None).await.unwrap(), 100);
|
||||
|
||||
let q = t
|
||||
.search(&[0.1, 0.1, 0.1, 0.1])
|
||||
|
||||
@@ -62,7 +62,7 @@ impl Query {
|
||||
/// * `dataset` - Lance dataset.
|
||||
///
|
||||
pub(crate) fn new(dataset: Arc<Dataset>) -> Self {
|
||||
Query {
|
||||
Self {
|
||||
dataset,
|
||||
query_vector: None,
|
||||
column: None,
|
||||
|
||||
@@ -27,7 +27,7 @@ use lance::dataset::optimize::{
|
||||
compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions,
|
||||
};
|
||||
pub use lance::dataset::ReadParams;
|
||||
use lance::dataset::{Dataset, UpdateBuilder, WriteParams};
|
||||
use lance::dataset::{Dataset, UpdateBuilder, WhenMatched, WriteParams};
|
||||
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
|
||||
use lance::io::WrappingObjectStore;
|
||||
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
|
||||
@@ -102,7 +102,11 @@ pub trait Table: std::fmt::Display + Send + Sync {
|
||||
fn schema(&self) -> SchemaRef;
|
||||
|
||||
/// Count the number of rows in this dataset.
|
||||
async fn count_rows(&self) -> Result<usize>;
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `filter` if present, only count rows matching the filter
|
||||
async fn count_rows(&self, filter: Option<String>) -> Result<usize>;
|
||||
|
||||
/// Insert new records into this Table
|
||||
///
|
||||
@@ -234,7 +238,7 @@ pub trait Table: std::fmt::Display + Send + Sync {
|
||||
/// schema.clone());
|
||||
/// // Perform an upsert operation
|
||||
/// let mut merge_insert = tbl.merge_insert(&["id"]);
|
||||
/// merge_insert.when_matched_update_all()
|
||||
/// merge_insert.when_matched_update_all(None)
|
||||
/// .when_not_matched_insert_all();
|
||||
/// merge_insert.execute(Box::new(new_data)).await.unwrap();
|
||||
/// # });
|
||||
@@ -385,7 +389,7 @@ impl NativeTable {
|
||||
message: e.to_string(),
|
||||
},
|
||||
})?;
|
||||
Ok(NativeTable {
|
||||
Ok(Self {
|
||||
name: name.to_string(),
|
||||
uri: uri.to_string(),
|
||||
dataset: Arc::new(Mutex::new(dataset)),
|
||||
@@ -427,7 +431,7 @@ impl NativeTable {
|
||||
message: e.to_string(),
|
||||
},
|
||||
})?;
|
||||
Ok(NativeTable {
|
||||
Ok(Self {
|
||||
name: name.to_string(),
|
||||
uri: uri.to_string(),
|
||||
dataset: Arc::new(Mutex::new(dataset)),
|
||||
@@ -501,7 +505,7 @@ impl NativeTable {
|
||||
message: e.to_string(),
|
||||
},
|
||||
})?;
|
||||
Ok(NativeTable {
|
||||
Ok(Self {
|
||||
name: name.to_string(),
|
||||
uri: uri.to_string(),
|
||||
dataset: Arc::new(Mutex::new(dataset)),
|
||||
@@ -673,11 +677,14 @@ impl MergeInsert for NativeTable {
|
||||
) -> Result<()> {
|
||||
let dataset = Arc::new(self.clone_inner_dataset());
|
||||
let mut builder = LanceMergeInsertBuilder::try_new(dataset.clone(), params.on)?;
|
||||
if params.when_matched_update_all {
|
||||
builder.when_matched(lance::dataset::WhenMatched::UpdateAll);
|
||||
} else {
|
||||
builder.when_matched(lance::dataset::WhenMatched::DoNothing);
|
||||
}
|
||||
match (
|
||||
params.when_matched_update_all,
|
||||
params.when_matched_update_all_filt,
|
||||
) {
|
||||
(false, _) => builder.when_matched(WhenMatched::DoNothing),
|
||||
(true, None) => builder.when_matched(WhenMatched::UpdateAll),
|
||||
(true, Some(filt)) => builder.when_matched(WhenMatched::update_if(&dataset, &filt)?),
|
||||
};
|
||||
if params.when_not_matched_insert_all {
|
||||
builder.when_not_matched(lance::dataset::WhenNotMatched::InsertAll);
|
||||
} else {
|
||||
@@ -719,9 +726,15 @@ impl Table for NativeTable {
|
||||
Arc::new(Schema::from(&lance_schema))
|
||||
}
|
||||
|
||||
async fn count_rows(&self) -> Result<usize> {
|
||||
async fn count_rows(&self, filter: Option<String>) -> Result<usize> {
|
||||
let dataset = { self.dataset.lock().expect("lock poison").clone() };
|
||||
Ok(dataset.count_rows().await?)
|
||||
if let Some(filter) = filter {
|
||||
let mut scanner = dataset.scan();
|
||||
scanner.filter(&filter)?;
|
||||
Ok(scanner.count_rows().await? as usize)
|
||||
} else {
|
||||
Ok(dataset.count_rows().await?)
|
||||
}
|
||||
}
|
||||
|
||||
async fn add(
|
||||
@@ -814,6 +827,7 @@ impl Table for NativeTable {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::iter;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -886,6 +900,23 @@ mod tests {
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_count_rows() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
|
||||
let batches = make_test_batches();
|
||||
let table = NativeTable::create(&uri, "test", batches, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||
assert_eq!(
|
||||
table.count_rows(Some("i >= 5".to_string())).await.unwrap(),
|
||||
5
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_add() {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
@@ -896,7 +927,7 @@ mod tests {
|
||||
let table = NativeTable::create(&uri, "test", batches, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(table.count_rows().await.unwrap(), 10);
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||
|
||||
let new_batches = RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
@@ -910,7 +941,7 @@ mod tests {
|
||||
);
|
||||
|
||||
table.add(Box::new(new_batches), None).await.unwrap();
|
||||
assert_eq!(table.count_rows().await.unwrap(), 20);
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 20);
|
||||
assert_eq!(table.name, "test");
|
||||
}
|
||||
|
||||
@@ -920,30 +951,44 @@ mod tests {
|
||||
let uri = tmp_dir.path().to_str().unwrap();
|
||||
|
||||
// Create a dataset with i=0..10
|
||||
let batches = make_test_batches_with_offset(0);
|
||||
let batches = merge_insert_test_batches(0, 0);
|
||||
let table = NativeTable::create(&uri, "test", batches, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(table.count_rows().await.unwrap(), 10);
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||
|
||||
// Create new data with i=5..15
|
||||
let new_batches = Box::new(make_test_batches_with_offset(5));
|
||||
let new_batches = Box::new(merge_insert_test_batches(5, 1));
|
||||
|
||||
// Perform a "insert if not exists"
|
||||
let mut merge_insert_builder = table.merge_insert(&["i"]);
|
||||
merge_insert_builder.when_not_matched_insert_all();
|
||||
merge_insert_builder.execute(new_batches).await.unwrap();
|
||||
// Only 5 rows should actually be inserted
|
||||
assert_eq!(table.count_rows().await.unwrap(), 15);
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 15);
|
||||
|
||||
// Create new data with i=15..25 (no id matches)
|
||||
let new_batches = Box::new(make_test_batches_with_offset(15));
|
||||
let new_batches = Box::new(merge_insert_test_batches(15, 2));
|
||||
// Perform a "bulk update" (should not affect anything)
|
||||
let mut merge_insert_builder = table.merge_insert(&["i"]);
|
||||
merge_insert_builder.when_matched_update_all();
|
||||
merge_insert_builder.when_matched_update_all(None);
|
||||
merge_insert_builder.execute(new_batches).await.unwrap();
|
||||
// No new rows should have been inserted
|
||||
assert_eq!(table.count_rows().await.unwrap(), 15);
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 15);
|
||||
assert_eq!(
|
||||
table.count_rows(Some("age = 2".to_string())).await.unwrap(),
|
||||
0
|
||||
);
|
||||
|
||||
// Conditional update that only replaces the age=0 data
|
||||
let new_batches = Box::new(merge_insert_test_batches(5, 3));
|
||||
let mut merge_insert_builder = table.merge_insert(&["i"]);
|
||||
merge_insert_builder.when_matched_update_all(Some("target.age = 0".to_string()));
|
||||
merge_insert_builder.execute(new_batches).await.unwrap();
|
||||
assert_eq!(
|
||||
table.count_rows(Some("age = 3".to_string())).await.unwrap(),
|
||||
5
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -956,7 +1001,7 @@ mod tests {
|
||||
let table = NativeTable::create(uri, "test", batches, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(table.count_rows().await.unwrap(), 10);
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||
|
||||
let new_batches = RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
@@ -975,7 +1020,7 @@ mod tests {
|
||||
};
|
||||
|
||||
table.add(Box::new(new_batches), Some(param)).await.unwrap();
|
||||
assert_eq!(table.count_rows().await.unwrap(), 10);
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 10);
|
||||
assert_eq!(table.name, "test");
|
||||
}
|
||||
|
||||
@@ -1292,23 +1337,35 @@ mod tests {
|
||||
assert!(wrapper.called());
|
||||
}
|
||||
|
||||
fn make_test_batches_with_offset(
|
||||
fn merge_insert_test_batches(
|
||||
offset: i32,
|
||||
age: i32,
|
||||
) -> impl RecordBatchReader + Send + Sync + 'static {
|
||||
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("i", DataType::Int32, false),
|
||||
Field::new("age", DataType::Int32, false),
|
||||
]));
|
||||
RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(Int32Array::from_iter_values(
|
||||
offset..(offset + 10),
|
||||
))],
|
||||
vec![
|
||||
Arc::new(Int32Array::from_iter_values(offset..(offset + 10))),
|
||||
Arc::new(Int32Array::from_iter_values(iter::repeat(age).take(10))),
|
||||
],
|
||||
)],
|
||||
schema,
|
||||
)
|
||||
}
|
||||
|
||||
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
|
||||
make_test_batches_with_offset(0)
|
||||
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
|
||||
RecordBatchIterator::new(
|
||||
vec![RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(Int32Array::from_iter_values(0..10))],
|
||||
)],
|
||||
schema,
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -1365,7 +1422,7 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(table.load_indices().await.unwrap().len(), 1);
|
||||
assert_eq!(table.count_rows().await.unwrap(), 512);
|
||||
assert_eq!(table.count_rows(None).await.unwrap(), 512);
|
||||
assert_eq!(table.name, "test");
|
||||
|
||||
let indices = table.load_indices().await.unwrap();
|
||||
|
||||
@@ -35,6 +35,7 @@ pub struct MergeInsertBuilder {
|
||||
table: Arc<dyn MergeInsert>,
|
||||
pub(super) on: Vec<String>,
|
||||
pub(super) when_matched_update_all: bool,
|
||||
pub(super) when_matched_update_all_filt: Option<String>,
|
||||
pub(super) when_not_matched_insert_all: bool,
|
||||
pub(super) when_not_matched_by_source_delete: bool,
|
||||
pub(super) when_not_matched_by_source_delete_filt: Option<String>,
|
||||
@@ -46,6 +47,7 @@ impl MergeInsertBuilder {
|
||||
table,
|
||||
on,
|
||||
when_matched_update_all: false,
|
||||
when_matched_update_all_filt: None,
|
||||
when_not_matched_insert_all: false,
|
||||
when_not_matched_by_source_delete: false,
|
||||
when_not_matched_by_source_delete_filt: None,
|
||||
@@ -59,8 +61,22 @@ impl MergeInsertBuilder {
|
||||
/// If there are multiple matches then the behavior is undefined.
|
||||
/// Currently this causes multiple copies of the row to be created
|
||||
/// but that behavior is subject to change.
|
||||
pub fn when_matched_update_all(&mut self) -> &mut Self {
|
||||
///
|
||||
/// An optional condition may be specified. If it is, then only
|
||||
/// matched rows that satisfy the condtion will be updated. Any
|
||||
/// rows that do not satisfy the condition will be left as they
|
||||
/// are. Failing to satisfy the condition does not cause a
|
||||
/// "matched row" to become a "not matched" row.
|
||||
///
|
||||
/// The condition should be an SQL string. Use the prefix
|
||||
/// target. to refer to rows in the target table (old data)
|
||||
/// and the prefix source. to refer to rows in the source
|
||||
/// table (new data).
|
||||
///
|
||||
/// For example, "target.last_update < source.last_update"
|
||||
pub fn when_matched_update_all(&mut self, condition: Option<String>) -> &mut Self {
|
||||
self.when_matched_update_all = true;
|
||||
self.when_matched_update_all_filt = condition;
|
||||
self
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user