Compare commits

..

1 Commits

Author SHA1 Message Date
qzhu
c3be2e3962 small fix for the guides/table page 2024-02-01 14:41:00 -08:00
83 changed files with 679 additions and 3881 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 0.4.10 current_version = 0.4.7
commit = True commit = True
message = Bump version: {current_version} → {new_version} message = Bump version: {current_version} → {new_version}
tag = True tag = True

View File

@@ -1,35 +0,0 @@
[profile.release]
lto = "fat"
codegen-units = 1
[profile.release-with-debug]
inherits = "release"
debug = true
# Prioritize compile time over runtime performance
codegen-units = 16
lto = "thin"
[target.'cfg(all())']
rustflags = [
"-Wclippy::all",
"-Wclippy::style",
"-Wclippy::fallible_impl_from",
"-Wclippy::manual_let_else",
"-Wclippy::redundant_pub_crate",
"-Wclippy::string_add_assign",
"-Wclippy::string_add",
"-Wclippy::string_lit_as_bytes",
"-Wclippy::string_to_string",
"-Wclippy::use_self",
"-Dclippy::cargo",
"-Dclippy::dbg_macro",
# not too much we can do to avoid multiple crate versions
"-Aclippy::multiple-crate-versions",
"-Aclippy::wildcard_dependencies",
]
[target.x86_64-unknown-linux-gnu]
rustflags = ["-C", "target-cpu=haswell", "-C", "target-feature=+avx2,+fma,+f16c"]
[target.aarch64-apple-darwin]
rustflags = ["-C", "target-cpu=apple-m1", "-C", "target-feature=+neon,+fp16,+fhm,+dotprod"]

View File

@@ -49,9 +49,6 @@ jobs:
test-node: test-node:
name: Test doc nodejs code name: Test doc nodejs code
runs-on: "ubuntu-latest" runs-on: "ubuntu-latest"
timeout-minutes: 45
strategy:
fail-fast: false
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -69,12 +66,6 @@ jobs:
uses: swatinem/rust-cache@v2 uses: swatinem/rust-cache@v2
- name: Install node dependencies - name: Install node dependencies
run: | run: |
sudo swapoff -a
sudo fallocate -l 8G /swapfile
sudo chmod 600 /swapfile
sudo mkswap /swapfile
sudo swapon /swapfile
sudo swapon --show
cd node cd node
npm ci npm ci
npm run build-release npm run build-release

View File

@@ -80,7 +80,7 @@ jobs:
- arch: x86_64 - arch: x86_64
runner: ubuntu-latest runner: ubuntu-latest
- arch: aarch64 - arch: aarch64
runner: buildjet-8vcpu-ubuntu-2204-arm runner: buildjet-4vcpu-ubuntu-2204-arm
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4

View File

@@ -6,18 +6,15 @@ resolver = "2"
[workspace.package] [workspace.package]
edition = "2021" edition = "2021"
authors = ["LanceDB Devs <dev@lancedb.com>"] authors = ["Lance Devs <dev@lancedb.com>"]
license = "Apache-2.0" license = "Apache-2.0"
repository = "https://github.com/lancedb/lancedb" repository = "https://github.com/lancedb/lancedb"
description = "Serverless, low-latency vector database for AI applications"
keywords = ["lancedb", "lance", "database", "vector", "search"]
categories = ["database-implementations"]
[workspace.dependencies] [workspace.dependencies]
lance = { "version" = "=0.9.15", "features" = ["dynamodb"] } lance = { "version" = "=0.9.10", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.9.15" } lance-index = { "version" = "=0.9.10" }
lance-linalg = { "version" = "=0.9.15" } lance-linalg = { "version" = "=0.9.10" }
lance-testing = { "version" = "=0.9.15" } lance-testing = { "version" = "=0.9.10" }
# Note that this one does not include pyarrow # Note that this one does not include pyarrow
arrow = { version = "50.0", optional = false } arrow = { version = "50.0", optional = false }
arrow-array = "50.0" arrow-array = "50.0"

View File

@@ -90,9 +90,7 @@ nav:
- Building an ANN index: ann_indexes.md - Building an ANN index: ann_indexes.md
- Vector Search: search.md - Vector Search: search.md
- Full-text search: fts.md - Full-text search: fts.md
- Hybrid search: - Hybrid search: hybrid_search.md
- Overview: hybrid_search/hybrid_search.md
- Airbnb financial data example: notebooks/hybrid_search.ipynb
- Filtering: sql.md - Filtering: sql.md
- Versioning & Reproducibility: notebooks/reproducibility.ipynb - Versioning & Reproducibility: notebooks/reproducibility.ipynb
- Configuring Storage: guides/storage.md - Configuring Storage: guides/storage.md
@@ -154,9 +152,7 @@ nav:
- Building an ANN index: ann_indexes.md - Building an ANN index: ann_indexes.md
- Vector Search: search.md - Vector Search: search.md
- Full-text search: fts.md - Full-text search: fts.md
- Hybrid search: - Hybrid search: hybrid_search.md
- Overview: hybrid_search/hybrid_search.md
- Airbnb financial data example: notebooks/hybrid_search.ipynb
- Filtering: sql.md - Filtering: sql.md
- Versioning & Reproducibility: notebooks/reproducibility.ipynb - Versioning & Reproducibility: notebooks/reproducibility.ipynb
- Configuring Storage: guides/storage.md - Configuring Storage: guides/storage.md

View File

@@ -17,7 +17,6 @@ Let's implement `SentenceTransformerEmbeddings` class. All you need to do is imp
```python ```python
from lancedb.embeddings import register from lancedb.embeddings import register
from lancedb.util import attempt_import_or_raise
@register("sentence-transformers") @register("sentence-transformers")
class SentenceTransformerEmbeddings(TextEmbeddingFunction): class SentenceTransformerEmbeddings(TextEmbeddingFunction):
@@ -82,7 +81,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
open_clip = attempt_import_or_raise("open_clip", "open-clip") # EmbeddingFunction util to import external libs and raise if not found open_clip = self.safe_import("open_clip", "open-clip") # EmbeddingFunction util to import external libs and raise if not found
model, _, preprocess = open_clip.create_model_and_transforms( model, _, preprocess = open_clip.create_model_and_transforms(
self.name, pretrained=self.pretrained self.name, pretrained=self.pretrained
) )
@@ -110,14 +109,14 @@ class OpenClipEmbeddings(EmbeddingFunction):
if isinstance(query, str): if isinstance(query, str):
return [self.generate_text_embeddings(query)] return [self.generate_text_embeddings(query)]
else: else:
PIL = attempt_import_or_raise("PIL", "pillow") PIL = self.safe_import("PIL", "pillow")
if isinstance(query, PIL.Image.Image): if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)] return [self.generate_image_embedding(query)]
else: else:
raise TypeError("OpenClip supports str or PIL Image as query") raise TypeError("OpenClip supports str or PIL Image as query")
def generate_text_embeddings(self, text: str) -> np.ndarray: def generate_text_embeddings(self, text: str) -> np.ndarray:
torch = attempt_import_or_raise("torch") torch = self.safe_import("torch")
text = self.sanitize_input(text) text = self.sanitize_input(text)
text = self._tokenizer(text) text = self._tokenizer(text)
text.to(self.device) text.to(self.device)
@@ -176,7 +175,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
The image to embed. If the image is a str, it is treated as a uri. The image to embed. If the image is a str, it is treated as a uri.
If the image is bytes, it is treated as the raw image bytes. If the image is bytes, it is treated as the raw image bytes.
""" """
torch = attempt_import_or_raise("torch") torch = self.safe_import("torch")
# TODO handle retry and errors for https # TODO handle retry and errors for https
image = self._to_pil(image) image = self._to_pil(image)
image = self._preprocess(image).unsqueeze(0) image = self._preprocess(image).unsqueeze(0)
@@ -184,7 +183,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
return self._encode_and_normalize_image(image) return self._encode_and_normalize_image(image)
def _to_pil(self, image: Union[str, bytes]): def _to_pil(self, image: Union[str, bytes]):
PIL = attempt_import_or_raise("PIL", "pillow") PIL = self.safe_import("PIL", "pillow")
if isinstance(image, bytes): if isinstance(image, bytes):
return PIL.Image.open(io.BytesIO(image)) return PIL.Image.open(io.BytesIO(image))
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL.Image.Image):

View File

@@ -9,9 +9,6 @@ Contains the text embedding functions registered by default.
### Sentence transformers ### Sentence transformers
Allows you to set parameters when registering a `sentence-transformers` object. Allows you to set parameters when registering a `sentence-transformers` object.
!!! info
Sentence transformer embeddings are normalized by default. It is recommended to use normalized embeddings for similarity search.
| Parameter | Type | Default Value | Description | | Parameter | Type | Default Value | Description |
|---|---|---|---| |---|---|---|---|
| `name` | `str` | `all-MiniLM-L6-v2` | The name of the model | | `name` | `str` | `all-MiniLM-L6-v2` | The name of the model |

View File

@@ -69,19 +69,3 @@ MinIO supports an S3 compatible API. In order to connect to a MinIO instance, yo
- Set the envvar `AWS_ENDPOINT` to the URL of your MinIO API - Set the envvar `AWS_ENDPOINT` to the URL of your MinIO API
- Set the envvars `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` with your MinIO credential - Set the envvars `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` with your MinIO credential
- Call `lancedb.connect("s3://minio_bucket_name")` - Call `lancedb.connect("s3://minio_bucket_name")`
### Where can I find benchmarks for LanceDB?
Refer to this [post](https://blog.lancedb.com/benchmarking-lancedb-92b01032874a) for recent benchmarks.
### How much data can LanceDB practically manage without effecting performance?
We target good performance on ~10-50 billion rows and ~10-30 TB of data.
### Does LanceDB support concurrent operations?
LanceDB can handle concurrent reads very well, and can scale horizontally. The main constraint is how well the [storage layer](https://lancedb.github.io/lancedb/concepts/storage/) you've chosen scales. For writes, we support concurrent writing, though too many concurrent writers can lead to failing writes as there is a limited number of times a writer retries a commit
!!! info "Multiprocessing with LanceDB"
For multiprocessing you should probably not use ```fork``` as lance is multi-threaded internally and ```fork``` and multi-thread do not work well.[Refer to this discussion](https://discuss.python.org/t/concerns-regarding-deprecation-of-fork-with-alive-threads/33555)

View File

@@ -100,9 +100,7 @@ This guide will show how to create tables, insert data into them, and update the
db["my_table"].head() db["my_table"].head()
``` ```
!!! info "Note" !!! info "Note"
Data is converted to Arrow before being written to disk. For maximum control over how data is saved, either provide the PyArrow schema to convert to or else provide a PyArrow Table directly. Data is converted to Arrow before being written to disk. For maximum control over how data is saved, either provide the PyArrow schema to convert to or else provide a PyArrow Table directly.
The **`vector`** column needs to be a [Vector](../python/pydantic.md#vector-field) (defined as [pyarrow.FixedSizeList](https://arrow.apache.org/docs/python/generated/pyarrow.list_.html)) type.
```python ```python
custom_schema = pa.schema([ custom_schema = pa.schema([

View File

@@ -1,29 +1,22 @@
# Hybrid Search # Hybrid Search
LanceDB supports both semantic and keyword-based search (also termed full-text search, or FTS). In real world applications, it is often useful to combine these two approaches to get the best best results. For example, you may want to search for a document that is semantically similar to a query document, but also contains a specific keyword. This is an example of *hybrid search*, a search algorithm that combines multiple search techniques. LanceDB supports both semantic and keyword-based search. In real world applications, it is often useful to combine these two approaches to get the best best results. For example, you may want to search for a document that is semantically similar to a query document, but also contains a specific keyword. This is an example of *hybrid search*, a search algorithm that combines multiple search techniques.
## Hybrid search in LanceDB ## Hybrid search in LanceDB
You can perform hybrid search in LanceDB by combining the results of semantic and full-text search via a reranking algorithm of your choice. LanceDB provides multiple rerankers out of the box. However, you can always write a custom reranker if your use case need more sophisticated logic . You can perform hybrid search in LanceDB by combining the results of semantic and full-text search via a reranking algorithm of your choice. LanceDB provides multiple rerankers out of the box. However, you can always write a custom reranker if your use case need more sophisticated logic .
```python ```python
import os
import lancedb import lancedb
import openai
from lancedb.embeddings import get_registry from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector from lancedb.pydanatic import LanceModel, Vector
db = lancedb.connect("~/.lancedb") db = lancedb.connect("~/.lancedb")
# Ingest embedding function in LanceDB table # Ingest embedding function in LanceDB table
# Configuring the environment variable OPENAI_API_KEY
if "OPENAI_API_KEY" not in os.environ:
# OR set the key here as a variable
openai.api_key = "sk-..."
embeddings = get_registry().get("openai").create() embeddings = get_registry().get("openai").create()
class Documents(LanceModel): class Documents(LanceModel):
vector: Vector(embeddings.ndims()) = embeddings.VectorField() vector: Vector(embeddings.ndims) = embeddings.VectorField()
text: str = embeddings.SourceField() text: str = embeddings.SourceField()
table = db.create_table("documents", schema=Documents) table = db.create_table("documents", schema=Documents)
@@ -38,19 +31,17 @@ data = [
# ingest docs with auto-vectorization # ingest docs with auto-vectorization
table.add(data) table.add(data)
# Create a fts index before the hybrid search
table.create_fts_index("text")
# hybrid search with default re-ranker # hybrid search with default re-ranker
results = table.search("flower moon", query_type="hybrid").to_pandas() results = table.search("flower moon", query_type="hybrid").to_pandas()
``` ```
By default, LanceDB uses `LinearCombinationReranker(weight=0.7)` to combine and rerank the results of semantic and full-text search. You can customize the hyperparameters as needed or write your own custom reranker. Here's how you can use any of the available rerankers: By default, LanceDB uses `LinearCombinationReranker(weights=0.7)` to combine and rerank the results of semantic and full-text search. You can customize the hyperparameters as needed or write your own custom reranker. Here's how you can use any of the available rerankers:
### `rerank()` arguments ### `rerank()` arguments
* `normalize`: `str`, default `"score"`: * `normalize`: `str`, default `"score"`:
The method to normalize the scores. Can be "rank" or "score". If "rank", the scores are converted to ranks and then normalized. If "score", the scores are normalized directly. The method to normalize the scores. Can be "rank" or "score". If "rank", the scores are converted to ranks and then normalized. If "score", the scores are normalized directly.
* `reranker`: `Reranker`, default `LinearCombinationReranker(weight=0.7)`. * `reranker`: `Reranker`, default `LinearCombinationReranker(weights=0.7)`.
The reranker to use. If not specified, the default reranker is used. The reranker to use. If not specified, the default reranker is used.
@@ -64,12 +55,12 @@ This is the default re-ranker used by LanceDB. It combines the results of semant
```python ```python
from lancedb.rerankers import LinearCombinationReranker from lancedb.rerankers import LinearCombinationReranker
reranker = LinearCombinationReranker(weight=0.3) # Use 0.3 as the weight for vector search reranker = LinearCombinationReranker(weights=0.3) # Use 0.3 as the weight for vector search
results = table.search("rebel", query_type="hybrid").rerank(reranker=reranker).to_pandas() results = table.search("rebel", query_type="hybrid").rerank(reranker=reranker).to_pandas()
``` ```
### Arguments Arguments
---------------- ----------------
* `weight`: `float`, default `0.7`: * `weight`: `float`, default `0.7`:
The weight to use for the semantic search score. The weight for the full-text search score is `1 - weights`. The weight to use for the semantic search score. The weight for the full-text search score is `1 - weights`.
@@ -91,9 +82,9 @@ reranker = CohereReranker()
results = table.search("vampire weekend", query_type="hybrid").rerank(reranker=reranker).to_pandas() results = table.search("vampire weekend", query_type="hybrid").rerank(reranker=reranker).to_pandas()
``` ```
### Arguments Arguments
---------------- ----------------
* `model_name` : str, default `"rerank-english-v2.0"` * `model_name`` : str, default `"rerank-english-v2.0"``
The name of the cross encoder model to use. Available cohere models are: The name of the cross encoder model to use. Available cohere models are:
- rerank-english-v2.0 - rerank-english-v2.0
- rerank-multilingual-v2.0 - rerank-multilingual-v2.0
@@ -117,7 +108,7 @@ results = table.search("harmony hall", query_type="hybrid").rerank(reranker=rera
``` ```
### Arguments Arguments
---------------- ----------------
* `model` : str, default `"cross-encoder/ms-marco-TinyBERT-L-6"` * `model` : str, default `"cross-encoder/ms-marco-TinyBERT-L-6"`
The name of the cross encoder model to use. Available cross encoder models can be found [here](https://www.sbert.net/docs/pretrained_cross-encoders.html) The name of the cross encoder model to use. Available cross encoder models can be found [here](https://www.sbert.net/docs/pretrained_cross-encoders.html)
@@ -130,61 +121,6 @@ results = table.search("harmony hall", query_type="hybrid").rerank(reranker=rera
Only returns `_relevance_score`. Does not support `return_score = "all"`. Only returns `_relevance_score`. Does not support `return_score = "all"`.
### ColBERT Reranker
This reranker uses the ColBERT model to combine the results of semantic and full-text search. You can use it by passing `ColbertrReranker()` to the `rerank()` method.
ColBERT reranker model calculates relevance of given docs against the query and don't take existing fts and vector search scores into account, so it currently only supports `return_score="relevance"`. By default, it looks for `text` column to rerank the results. But you can specify the column name to use as input to the cross encoder model as described below.
```python
from lancedb.rerankers import ColbertReranker
reranker = ColbertReranker()
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
```
### Arguments
----------------
* `model_name` : `str`, default `"colbert-ir/colbertv2.0"`
The name of the cross encoder model to use.
* `column` : `str`, default `"text"`
The name of the column to use as input to the cross encoder model.
* `return_score` : `str`, default `"relevance"`
options are `"relevance"` or `"all"`. Only `"relevance"` is supported for now.
!!! Note
Only returns `_relevance_score`. Does not support `return_score = "all"`.
### OpenAI Reranker
This reranker uses the OpenAI API to combine the results of semantic and full-text search. You can use it by passing `OpenaiReranker()` to the `rerank()` method.
!!! Note
This prompts chat model to rerank results which is not a dedicated reranker model. This should be treated as experimental.
!!! Tip
- You might run out of token limit so set the search `limits` based on your token limit.
- It is recommended to use gpt-4-turbo-preview, the default model, older models might lead to undesired behaviour
```python
from lancedb.rerankers import OpenaiReranker
reranker = OpenaiReranker()
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
```
### Arguments
----------------
* `model_name` : `str`, default `"gpt-4-turbo-preview"`
The name of the cross encoder model to use.
* `column` : `str`, default `"text"`
The name of the column to use as input to the cross encoder model.
* `return_score` : `str`, default `"relevance"`
options are "relevance" or "all". Only "relevance" is supported for now.
* `api_key` : `str`, default `None`
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
## Building Custom Rerankers ## Building Custom Rerankers
You can build your own custom reranker by subclassing the `Reranker` class and implementing the `rerank_hybrid()` method. Here's an example of a custom reranker that combines the results of semantic and full-text search using a linear combination of the scores. You can build your own custom reranker by subclassing the `Reranker` class and implementing the `rerank_hybrid()` method. Here's an example of a custom reranker that combines the results of semantic and full-text search using a linear combination of the scores.
@@ -201,7 +137,7 @@ class MyReranker(Reranker):
self.param1 = param1 self.param1 = param1
self.param2 = param2 self.param2 = param2
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table): def rerank_hybrid(self, vector_results: pa.Table, fts_results: pa.Table):
# Use the built-in merging function # Use the built-in merging function
combined_result = self.merge_results(vector_results, fts_results) combined_result = self.merge_results(vector_results, fts_results)
@@ -213,30 +149,24 @@ class MyReranker(Reranker):
``` ```
### Example of a Custom Reranker You can also accept additional arguments like a filter along with fts and vector search results
For the sake of simplicity let's build custom reranker that just enchances the Cohere Reranker by accepting a filter query, and accept other CohereReranker params as kwags.
```python ```python
from typing import List, Union from lancedb.rerankers import Reranker
import pandas as pd import pyarrow as pa
from lancedb.rerankers import CohereReranker
class MofidifiedCohereReranker(CohereReranker): class MyReranker(Reranker):
def __init__(self, filters: Union[str, List[str]], **kwargs): ...
super().__init__(**kwargs)
filters = filters if isinstance(filters, list) else [filters] def rerank_hybrid(self, vector_results: pa.Table, fts_results: pa.Table, filter: str):
self.filters = filters # Use the built-in merging function
combined_result = self.merge_results(vector_results, fts_results)
# Do something with the combined results & filter
# ...
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table)-> pa.Table: # Return the combined results
combined_result = super().rerank_hybrid(query, vector_results, fts_results) return combined_result
df = combined_result.to_pandas()
for filter in self.filters:
df = df.query("not text.str.contains(@filter)")
return pa.Table.from_pandas(df)
``` ```
!!! tip
The `vector_results` and `fts_results` are pyarrow tables. You can convert them to pandas dataframes using `to_pandas()` method and perform any operations you want. After you are done, you can convert the dataframe back to pyarrow table using `pa.Table.from_pandas()` method and return it.

File diff suppressed because it is too large Load Diff

View File

@@ -14,7 +14,7 @@ excluded_globs = [
"../src/concepts/*.md", "../src/concepts/*.md",
"../src/ann_indexes.md", "../src/ann_indexes.md",
"../src/basic.md", "../src/basic.md",
"../src/hybrid_search/hybrid_search.md", "../src/hybrid_search.md",
] ]
python_prefix = "py" python_prefix = "py"

74
node/package-lock.json generated
View File

@@ -1,12 +1,12 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.4.10", "version": "0.4.7",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "vectordb", "name": "vectordb",
"version": "0.4.10", "version": "0.4.7",
"cpu": [ "cpu": [
"x64", "x64",
"arm64" "arm64"
@@ -53,11 +53,11 @@
"uuid": "^9.0.0" "uuid": "^9.0.0"
}, },
"optionalDependencies": { "optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.4.10", "@lancedb/vectordb-darwin-arm64": "0.4.7",
"@lancedb/vectordb-darwin-x64": "0.4.10", "@lancedb/vectordb-darwin-x64": "0.4.7",
"@lancedb/vectordb-linux-arm64-gnu": "0.4.10", "@lancedb/vectordb-linux-arm64-gnu": "0.4.7",
"@lancedb/vectordb-linux-x64-gnu": "0.4.10", "@lancedb/vectordb-linux-x64-gnu": "0.4.7",
"@lancedb/vectordb-win32-x64-msvc": "0.4.10" "@lancedb/vectordb-win32-x64-msvc": "0.4.7"
} }
}, },
"node_modules/@75lb/deep-merge": { "node_modules/@75lb/deep-merge": {
@@ -328,6 +328,66 @@
"@jridgewell/sourcemap-codec": "^1.4.10" "@jridgewell/sourcemap-codec": "^1.4.10"
} }
}, },
"node_modules/@lancedb/vectordb-darwin-arm64": {
"version": "0.4.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.4.7.tgz",
"integrity": "sha512-kACOIytgjBfX8NRwjPKe311XRN3lbSN13B7avT5htMd3kYm3AnnMag9tZhlwoO7lIuvGaXhy7mApygJrjhfJ4g==",
"cpu": [
"arm64"
],
"optional": true,
"os": [
"darwin"
]
},
"node_modules/@lancedb/vectordb-darwin-x64": {
"version": "0.4.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.4.7.tgz",
"integrity": "sha512-vb74iK5uPWCwz5E60r3yWp/R/HSg54/Z9AZWYckYXqsPv4w/nfbkM5iZhfRqqR/9uE6JClWJKOtjbk7b8CFRFg==",
"cpu": [
"x64"
],
"optional": true,
"os": [
"darwin"
]
},
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
"version": "0.4.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.4.7.tgz",
"integrity": "sha512-jHp7THm6S9sB8RaCxGoZXLAwGAUHnawUUilB1K3mvQsRdfB2bBs0f7wDehW+PDhr+Iog4LshaWbcnoQEUJWR+Q==",
"cpu": [
"arm64"
],
"optional": true,
"os": [
"linux"
]
},
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
"version": "0.4.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.4.7.tgz",
"integrity": "sha512-LKbVe6Wrp/AGqCCjKliNDmYoeTNgY/wfb2DTLjrx41Jko/04ywLrJ6xSEAn3XD5RDCO5u3fyUdXHHHv5a3VAAQ==",
"cpu": [
"x64"
],
"optional": true,
"os": [
"linux"
]
},
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
"version": "0.4.7",
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.4.7.tgz",
"integrity": "sha512-C5ln4+wafeY1Sm4PeV0Ios9lUaQVVip5Mjl9XU7ngioSEMEuXI/XMVfIdVfDPppVNXPeQxg33wLA272uw88D1Q==",
"cpu": [
"x64"
],
"optional": true,
"os": [
"win32"
]
},
"node_modules/@neon-rs/cli": { "node_modules/@neon-rs/cli": {
"version": "0.0.160", "version": "0.0.160",
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz", "resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",

View File

@@ -1,6 +1,6 @@
{ {
"name": "vectordb", "name": "vectordb",
"version": "0.4.10", "version": "0.4.7",
"description": " Serverless, low-latency vector database for AI applications", "description": " Serverless, low-latency vector database for AI applications",
"main": "dist/index.js", "main": "dist/index.js",
"types": "dist/index.d.ts", "types": "dist/index.d.ts",
@@ -85,10 +85,10 @@
} }
}, },
"optionalDependencies": { "optionalDependencies": {
"@lancedb/vectordb-darwin-arm64": "0.4.10", "@lancedb/vectordb-darwin-arm64": "0.4.7",
"@lancedb/vectordb-darwin-x64": "0.4.10", "@lancedb/vectordb-darwin-x64": "0.4.7",
"@lancedb/vectordb-linux-arm64-gnu": "0.4.10", "@lancedb/vectordb-linux-arm64-gnu": "0.4.7",
"@lancedb/vectordb-linux-x64-gnu": "0.4.10", "@lancedb/vectordb-linux-x64-gnu": "0.4.7",
"@lancedb/vectordb-win32-x64-msvc": "0.4.10" "@lancedb/vectordb-win32-x64-msvc": "0.4.7"
} }
} }

View File

@@ -14,6 +14,8 @@
import { import {
Field, Field,
type FixedSizeListBuilder,
Float32,
makeBuilder, makeBuilder,
RecordBatchFileWriter, RecordBatchFileWriter,
Utf8, Utf8,
@@ -24,19 +26,14 @@ import {
Table as ArrowTable, Table as ArrowTable,
RecordBatchStreamWriter, RecordBatchStreamWriter,
List, List,
Float64,
RecordBatch, RecordBatch,
makeData, makeData,
Struct, Struct,
type Float, type Float
DataType,
Binary,
Float32
} from 'apache-arrow' } from 'apache-arrow'
import { type EmbeddingFunction } from './index' import { type EmbeddingFunction } from './index'
/*
* Options to control how a column should be converted to a vector array
*/
export class VectorColumnOptions { export class VectorColumnOptions {
/** Vector column type. */ /** Vector column type. */
type: Float = new Float32() type: Float = new Float32()
@@ -48,50 +45,14 @@ export class VectorColumnOptions {
/** Options to control the makeArrowTable call. */ /** Options to control the makeArrowTable call. */
export class MakeArrowTableOptions { export class MakeArrowTableOptions {
/* /** Provided schema. */
* Schema of the data.
*
* If this is not provided then the data type will be inferred from the
* JS type. Integer numbers will become int64, floating point numbers
* will become float64 and arrays will become variable sized lists with
* the data type inferred from the first element in the array.
*
* The schema must be specified if there are no records (e.g. to make
* an empty table)
*/
schema?: Schema schema?: Schema
/* /** Vector columns */
* Mapping from vector column name to expected type
*
* Lance expects vector columns to be fixed size list arrays (i.e. tensors)
* However, `makeArrowTable` will not infer this by default (it creates
* variable size list arrays). This field can be used to indicate that a column
* should be treated as a vector column and converted to a fixed size list.
*
* The keys should be the names of the vector columns. The value specifies the
* expected data type of the vector columns.
*
* If `schema` is provided then this field is ignored.
*
* By default, the column named "vector" will be assumed to be a float32
* vector column.
*/
vectorColumns: Record<string, VectorColumnOptions> = { vectorColumns: Record<string, VectorColumnOptions> = {
vector: new VectorColumnOptions() vector: new VectorColumnOptions()
} }
/**
* If true then string columns will be encoded with dictionary encoding
*
* Set this to true if your string columns tend to repeat the same values
* often. For more precise control use the `schema` property to specify the
* data type for individual columns.
*
* If `schema` is provided then this property is ignored.
*/
dictionaryEncodeStrings: boolean = false
constructor (values?: Partial<MakeArrowTableOptions>) { constructor (values?: Partial<MakeArrowTableOptions>) {
Object.assign(this, values) Object.assign(this, values)
} }
@@ -101,29 +62,8 @@ export class MakeArrowTableOptions {
* An enhanced version of the {@link makeTable} function from Apache Arrow * An enhanced version of the {@link makeTable} function from Apache Arrow
* that supports nested fields and embeddings columns. * that supports nested fields and embeddings columns.
* *
* This function converts an array of Record<String, any> (row-major JS objects)
* to an Arrow Table (a columnar structure)
*
* Note that it currently does not support nulls. * Note that it currently does not support nulls.
* *
* If a schema is provided then it will be used to determine the resulting array
* types. Fields will also be reordered to fit the order defined by the schema.
*
* If a schema is not provided then the types will be inferred and the field order
* will be controlled by the order of properties in the first record.
*
* If the input is empty then a schema must be provided to create an empty table.
*
* When a schema is not specified then data types will be inferred. The inference
* rules are as follows:
*
* - boolean => Bool
* - number => Float64
* - String => Utf8
* - Buffer => Binary
* - Record<String, any> => Struct
* - Array<any> => List
*
* @param data input data * @param data input data
* @param options options to control the makeArrowTable call. * @param options options to control the makeArrowTable call.
* *
@@ -146,10 +86,8 @@ export class MakeArrowTableOptions {
* ], { schema }); * ], { schema });
* ``` * ```
* *
* By default it assumes that the column named `vector` is a vector column * It guesses the vector columns if the schema is not provided. For example,
* and it will be converted into a fixed size list array of type float32. * by default it assumes that the column named `vector` is a vector column.
* The `vectorColumns` option can be used to support other vector column
* names and data types.
* *
* ```ts * ```ts
* *
@@ -196,304 +134,211 @@ export function makeArrowTable (
data: Array<Record<string, any>>, data: Array<Record<string, any>>,
options?: Partial<MakeArrowTableOptions> options?: Partial<MakeArrowTableOptions>
): ArrowTable { ): ArrowTable {
if (data.length === 0 && (options?.schema === undefined || options?.schema === null)) { if (data.length === 0) {
throw new Error('At least one record or a schema needs to be provided') throw new Error('At least one record needs to be provided')
} }
const opt = new MakeArrowTableOptions(options !== undefined ? options : {}) const opt = new MakeArrowTableOptions(options !== undefined ? options : {})
const columns: Record<string, Vector> = {} const columns: Record<string, Vector> = {}
// TODO: sample dataset to find missing columns // TODO: sample dataset to find missing columns
// Prefer the field ordering of the schema, if present const columnNames = Object.keys(data[0])
const columnNames = ((options?.schema) != null) ? (options?.schema?.names as string[]) : Object.keys(data[0])
for (const colName of columnNames) { for (const colName of columnNames) {
if (data.length !== 0 && !Object.prototype.hasOwnProperty.call(data[0], colName)) { const values = data.map((datum) => datum[colName])
// The field is present in the schema, but not in the data, skip it let vector: Vector
continue
}
// Extract a single column from the records (transpose from row-major to col-major)
let values = data.map((datum) => datum[colName])
// By default (type === undefined) arrow will infer the type from the JS type
let type
if (opt.schema !== undefined) { if (opt.schema !== undefined) {
// If there is a schema provided, then use that for the type instead // Explicit schema is provided, highest priority
type = opt.schema?.fields.filter((f) => f.name === colName)[0]?.type vector = vectorFromArray(
if (DataType.isInt(type) && type.bitWidth === 64) { values,
// wrap in BigInt to avoid bug: https://github.com/apache/arrow/issues/40051 opt.schema?.fields.filter((f) => f.name === colName)[0]?.type
values = values.map((v) => { )
if (v === null) {
return v
}
return BigInt(v)
})
}
} else { } else {
// Otherwise, check to see if this column is one of the vector columns
// defined by opt.vectorColumns and, if so, use the fixed size list type
const vectorColumnOptions = opt.vectorColumns[colName] const vectorColumnOptions = opt.vectorColumns[colName]
if (vectorColumnOptions !== undefined) { if (vectorColumnOptions !== undefined) {
type = newVectorType(values[0].length, vectorColumnOptions.type) const fslType = new FixedSizeList(
} values[0].length,
} new Field('item', vectorColumnOptions.type, false)
)
try { vector = vectorFromArray(values, fslType)
// Convert an Array of JS values to an arrow vector
columns[colName] = makeVector(values, type, opt.dictionaryEncodeStrings)
} catch (error: unknown) {
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
throw Error(`Could not convert column "${colName}" to Arrow: ${error}`)
}
}
if (opt.schema != null) {
// `new ArrowTable(columns)` infers a schema which may sometimes have
// incorrect nullability (it assumes nullable=true if there are 0 rows)
//
// `new ArrowTable(schema, columns)` will also fail because it will create a
// batch with an inferred schema and then complain that the batch schema
// does not match the provided schema.
//
// To work around this we first create a table with the wrong schema and
// then patch the schema of the batches so we can use
// `new ArrowTable(schema, batches)` which does not do any schema inference
const firstTable = new ArrowTable(columns)
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const batchesFixed = firstTable.batches.map(batch => new RecordBatch(opt.schema!, batch.data))
return new ArrowTable(opt.schema, batchesFixed)
} else {
return new ArrowTable(columns)
}
}
/**
* Create an empty Arrow table with the provided schema
*/
export function makeEmptyTable (schema: Schema): ArrowTable {
return makeArrowTable([], { schema })
}
// Helper function to convert Array<Array<any>> to a variable sized list array
function makeListVector (lists: any[][]): Vector<any> {
if (lists.length === 0 || lists[0].length === 0) {
throw Error('Cannot infer list vector from empty array or empty list')
}
const sampleList = lists[0]
let inferredType
try {
const sampleVector = makeVector(sampleList)
inferredType = sampleVector.type
} catch (error: unknown) {
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
throw Error(`Cannot infer list vector. Cannot infer inner type: ${error}`)
}
const listBuilder = makeBuilder({
type: new List(new Field('item', inferredType, true))
})
for (const list of lists) {
listBuilder.append(list)
}
return listBuilder.finish().toVector()
}
// Helper function to convert an Array of JS values to an Arrow Vector
function makeVector (values: any[], type?: DataType, stringAsDictionary?: boolean): Vector<any> {
if (type !== undefined) {
// No need for inference, let Arrow create it
return vectorFromArray(values, type)
}
if (values.length === 0) {
throw Error('makeVector requires at least one value or the type must be specfied')
}
const sampleValue = values.find(val => val !== null && val !== undefined)
if (sampleValue === undefined) {
throw Error('makeVector cannot infer the type if all values are null or undefined')
}
if (Array.isArray(sampleValue)) {
// Default Arrow inference doesn't handle list types
return makeListVector(values)
} else if (Buffer.isBuffer(sampleValue)) {
// Default Arrow inference doesn't handle Buffer
return vectorFromArray(values, new Binary())
} else if (!(stringAsDictionary ?? false) && (typeof sampleValue === 'string' || sampleValue instanceof String)) {
// If the type is string then don't use Arrow's default inference unless dictionaries are requested
// because it will always use dictionary encoding for strings
return vectorFromArray(values, new Utf8())
} else {
// Convert a JS array of values to an arrow vector
return vectorFromArray(values)
}
}
async function applyEmbeddings<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<ArrowTable> {
if (embeddings == null) {
return table
}
// Convert from ArrowTable to Record<String, Vector>
const colEntries = [...Array(table.numCols).keys()].map((_, idx) => {
const name = table.schema.fields[idx].name
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const vec = table.getChildAt(idx)!
return [name, vec]
})
const newColumns = Object.fromEntries(colEntries)
const sourceColumn = newColumns[embeddings.sourceColumn]
const destColumn = embeddings.destColumn ?? 'vector'
const innerDestType = embeddings.embeddingDataType ?? new Float32()
if (sourceColumn === undefined) {
throw new Error(`Cannot apply embedding function because the source column '${embeddings.sourceColumn}' was not present in the data`)
}
if (table.numRows === 0) {
if (Object.prototype.hasOwnProperty.call(newColumns, destColumn)) {
// We have an empty table and it already has the embedding column so no work needs to be done
// Note: we don't return an error like we did below because this is a common occurrence. For example,
// if we call convertToTable with 0 records and a schema that includes the embedding
return table
}
if (embeddings.embeddingDimension !== undefined) {
const destType = newVectorType(embeddings.embeddingDimension, innerDestType)
newColumns[destColumn] = makeVector([], destType)
} else if (schema != null) {
const destField = schema.fields.find(f => f.name === destColumn)
if (destField != null) {
newColumns[destColumn] = makeVector([], destField.type)
} else { } else {
throw new Error(`Attempt to apply embeddings to an empty table failed because schema was missing embedding column '${destColumn}'`) // Normal case
vector = vectorFromArray(values)
} }
} else {
throw new Error('Attempt to apply embeddings to an empty table when the embeddings function does not specify `embeddingDimension`')
} }
} else { columns[colName] = vector
if (Object.prototype.hasOwnProperty.call(newColumns, destColumn)) {
throw new Error(`Attempt to apply embeddings to table failed because column ${destColumn} already existed`)
}
if (table.batches.length > 1) {
throw new Error('Internal error: `makeArrowTable` unexpectedly created a table with more than one batch')
}
const values = sourceColumn.toArray()
const vectors = await embeddings.embed(values as T[])
if (vectors.length !== values.length) {
throw new Error('Embedding function did not return an embedding for each input element')
}
const destType = newVectorType(vectors[0].length, innerDestType)
newColumns[destColumn] = makeVector(vectors, destType)
} }
const newTable = new ArrowTable(newColumns) return new ArrowTable(columns)
if (schema != null) {
if (schema.fields.find(f => f.name === destColumn) === undefined) {
throw new Error(`When using embedding functions and specifying a schema the schema should include the embedding column but the column ${destColumn} was missing`)
}
return alignTable(newTable, schema)
}
return newTable
} }
/* // Converts an Array of records into an Arrow Table, optionally applying an embeddings function to it.
* Convert an Array of records into an Arrow Table, optionally applying an
* embeddings function to it.
*
* This function calls `makeArrowTable` first to create the Arrow Table.
* Any provided `makeTableOptions` (e.g. a schema) will be passed on to
* that call.
*
* The embedding function will be passed a column of values (based on the
* `sourceColumn` of the embedding function) and expects to receive back
* number[][] which will be converted into a fixed size list column. By
* default this will be a fixed size list of Float32 but that can be
* customized by the `embeddingDataType` property of the embedding function.
*
* If a schema is provided in `makeTableOptions` then it should include the
* embedding columns. If no schema is provded then embedding columns will
* be placed at the end of the table, after all of the input columns.
*/
export async function convertToTable<T> ( export async function convertToTable<T> (
data: Array<Record<string, unknown>>, data: Array<Record<string, unknown>>,
embeddings?: EmbeddingFunction<T>, embeddings?: EmbeddingFunction<T>
makeTableOptions?: Partial<MakeArrowTableOptions>
): Promise<ArrowTable> { ): Promise<ArrowTable> {
const table = makeArrowTable(data, makeTableOptions) if (data.length === 0) {
return await applyEmbeddings(table, embeddings, makeTableOptions?.schema) throw new Error('At least one record needs to be provided')
}
const columns = Object.keys(data[0])
const records: Record<string, Vector> = {}
for (const columnsKey of columns) {
if (columnsKey === 'vector') {
const vectorSize = (data[0].vector as any[]).length
const listBuilder = newVectorBuilder(vectorSize)
for (const datum of data) {
if ((datum[columnsKey] as any[]).length !== vectorSize) {
throw new Error(`Invalid vector size, expected ${vectorSize}`)
}
listBuilder.append(datum[columnsKey])
}
records[columnsKey] = listBuilder.finish().toVector()
} else {
const values = []
for (const datum of data) {
values.push(datum[columnsKey])
}
if (columnsKey === embeddings?.sourceColumn) {
const vectors = await embeddings.embed(values as T[])
records.vector = vectorFromArray(
vectors,
newVectorType(vectors[0].length)
)
}
if (typeof values[0] === 'string') {
// `vectorFromArray` converts strings into dictionary vectors, forcing it back to a string column
records[columnsKey] = vectorFromArray(values, new Utf8())
} else if (Array.isArray(values[0])) {
const elementType = getElementType(values[0])
let innerType
if (elementType === 'string') {
innerType = new Utf8()
} else if (elementType === 'number') {
innerType = new Float64()
} else {
// TODO: pass in schema if it exists, else keep going to the next element
throw new Error(`Unsupported array element type ${elementType}`)
}
const listBuilder = makeBuilder({
type: new List(new Field('item', innerType, true))
})
for (const value of values) {
listBuilder.append(value)
}
records[columnsKey] = listBuilder.finish().toVector()
} else {
// TODO if this is a struct field then recursively align the subfields
records[columnsKey] = vectorFromArray(values)
}
}
}
return new ArrowTable(records)
}
function getElementType (arr: any[]): string {
if (arr.length === 0) {
return 'undefined'
}
return typeof arr[0]
}
// Creates a new Arrow ListBuilder that stores a Vector column
function newVectorBuilder (dim: number): FixedSizeListBuilder<Float32> {
return makeBuilder({
type: newVectorType(dim)
})
} }
// Creates the Arrow Type for a Vector column with dimension `dim` // Creates the Arrow Type for a Vector column with dimension `dim`
function newVectorType <T extends Float> (dim: number, innerType: T): FixedSizeList<T> { function newVectorType (dim: number): FixedSizeList<Float32> {
// Somewhere we always default to have the elements nullable, so we need to set it to true // Somewhere we always default to have the elements nullable, so we need to set it to true
// otherwise we often get schema mismatches because the stored data always has schema with nullable elements // otherwise we often get schema mismatches because the stored data always has schema with nullable elements
const children = new Field<T>('item', innerType, true) const children = new Field<Float32>('item', new Float32(), true)
return new FixedSizeList(dim, children) return new FixedSizeList(dim, children)
} }
/** // Converts an Array of records into Arrow IPC format
* Serialize an Array of records into a buffer using the Arrow IPC File serialization
*
* This function will call `convertToTable` and pass on `embeddings` and `schema`
*
* `schema` is required if data is empty
*/
export async function fromRecordsToBuffer<T> ( export async function fromRecordsToBuffer<T> (
data: Array<Record<string, unknown>>, data: Array<Record<string, unknown>>,
embeddings?: EmbeddingFunction<T>, embeddings?: EmbeddingFunction<T>,
schema?: Schema schema?: Schema
): Promise<Buffer> { ): Promise<Buffer> {
const table = await convertToTable(data, embeddings, { schema }) let table = await convertToTable(data, embeddings)
if (schema !== undefined) {
table = alignTable(table, schema)
}
const writer = RecordBatchFileWriter.writeAll(table) const writer = RecordBatchFileWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array()) return Buffer.from(await writer.toUint8Array())
} }
/** // Converts an Array of records into Arrow IPC stream format
* Serialize an Array of records into a buffer using the Arrow IPC Stream serialization
*
* This function will call `convertToTable` and pass on `embeddings` and `schema`
*
* `schema` is required if data is empty
*/
export async function fromRecordsToStreamBuffer<T> ( export async function fromRecordsToStreamBuffer<T> (
data: Array<Record<string, unknown>>, data: Array<Record<string, unknown>>,
embeddings?: EmbeddingFunction<T>, embeddings?: EmbeddingFunction<T>,
schema?: Schema schema?: Schema
): Promise<Buffer> { ): Promise<Buffer> {
const table = await convertToTable(data, embeddings, { schema }) let table = await convertToTable(data, embeddings)
if (schema !== undefined) {
table = alignTable(table, schema)
}
const writer = RecordBatchStreamWriter.writeAll(table) const writer = RecordBatchStreamWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array()) return Buffer.from(await writer.toUint8Array())
} }
/** // Converts an Arrow Table into Arrow IPC format
* Serialize an Arrow Table into a buffer using the Arrow IPC File serialization
*
* This function will apply `embeddings` to the table in a manner similar to
* `convertToTable`.
*
* `schema` is required if the table is empty
*/
export async function fromTableToBuffer<T> ( export async function fromTableToBuffer<T> (
table: ArrowTable, table: ArrowTable,
embeddings?: EmbeddingFunction<T>, embeddings?: EmbeddingFunction<T>,
schema?: Schema schema?: Schema
): Promise<Buffer> { ): Promise<Buffer> {
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema) if (embeddings !== undefined) {
const writer = RecordBatchFileWriter.writeAll(tableWithEmbeddings) const source = table.getChild(embeddings.sourceColumn)
if (source === null) {
throw new Error(
`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`
)
}
const vectors = await embeddings.embed(source.toArray() as T[])
const column = vectorFromArray(vectors, newVectorType(vectors[0].length))
table = table.assign(new ArrowTable({ vector: column }))
}
if (schema !== undefined) {
table = alignTable(table, schema)
}
const writer = RecordBatchFileWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array()) return Buffer.from(await writer.toUint8Array())
} }
/** // Converts an Arrow Table into Arrow IPC stream format
* Serialize an Arrow Table into a buffer using the Arrow IPC Stream serialization
*
* This function will apply `embeddings` to the table in a manner similar to
* `convertToTable`.
*
* `schema` is required if the table is empty
*/
export async function fromTableToStreamBuffer<T> ( export async function fromTableToStreamBuffer<T> (
table: ArrowTable, table: ArrowTable,
embeddings?: EmbeddingFunction<T>, embeddings?: EmbeddingFunction<T>,
schema?: Schema schema?: Schema
): Promise<Buffer> { ): Promise<Buffer> {
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema) if (embeddings !== undefined) {
const writer = RecordBatchStreamWriter.writeAll(tableWithEmbeddings) const source = table.getChild(embeddings.sourceColumn)
if (source === null) {
throw new Error(
`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`
)
}
const vectors = await embeddings.embed(source.toArray() as T[])
const column = vectorFromArray(vectors, newVectorType(vectors[0].length))
table = table.assign(new ArrowTable({ vector: column }))
}
if (schema !== undefined) {
table = alignTable(table, schema)
}
const writer = RecordBatchStreamWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array()) return Buffer.from(await writer.toUint8Array())
} }

View File

@@ -12,53 +12,18 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
import { type Float } from 'apache-arrow'
/** /**
* An embedding function that automatically creates vector representation for a given column. * An embedding function that automatically creates vector representation for a given column.
*/ */
export interface EmbeddingFunction<T> { export interface EmbeddingFunction<T> {
/** /**
* The name of the column that will be used as input for the Embedding Function. * The name of the column that will be used as input for the Embedding Function.
*/ */
sourceColumn: string sourceColumn: string
/** /**
* The data type of the embedding * Creates a vector representation for the given values.
* */
* The embedding function should return `number`. This will be converted into
* an Arrow float array. By default this will be Float32 but this property can
* be used to control the conversion.
*/
embeddingDataType?: Float
/**
* The dimension of the embedding
*
* This is optional, normally this can be determined by looking at the results of
* `embed`. If this is not specified, and there is an attempt to apply the embedding
* to an empty table, then that process will fail.
*/
embeddingDimension?: number
/**
* The name of the column that will contain the embedding
*
* By default this is "vector"
*/
destColumn?: string
/**
* Should the source column be excluded from the resulting table
*
* By default the source column is included. Set this to true and
* only the embedding will be stored.
*/
excludeSource?: boolean
/**
* Creates a vector representation for the given values.
*/
embed: (data: T[]) => Promise<number[][]> embed: (data: T[]) => Promise<number[][]>
} }

View File

@@ -37,7 +37,6 @@ const {
tableCountRows, tableCountRows,
tableDelete, tableDelete,
tableUpdate, tableUpdate,
tableMergeInsert,
tableCleanupOldVersions, tableCleanupOldVersions,
tableCompactFiles, tableCompactFiles,
tableListIndices, tableListIndices,
@@ -49,7 +48,7 @@ const {
export { Query } export { Query }
export type { EmbeddingFunction } export type { EmbeddingFunction }
export { OpenAIEmbeddingFunction } from './embedding/openai' export { OpenAIEmbeddingFunction } from './embedding/openai'
export { convertToTable, makeArrowTable, type MakeArrowTableOptions } from './arrow' export { makeArrowTable, type MakeArrowTableOptions } from './arrow'
const defaultAwsRegion = 'us-west-2' const defaultAwsRegion = 'us-west-2'
@@ -372,7 +371,7 @@ export interface Table<T = number[]> {
/** /**
* Returns the number of rows in this table. * Returns the number of rows in this table.
*/ */
countRows: (filter?: string) => Promise<number> countRows: () => Promise<number>
/** /**
* Delete rows from this table. * Delete rows from this table.
@@ -441,38 +440,6 @@ export interface Table<T = number[]> {
*/ */
update: (args: UpdateArgs | UpdateSqlArgs) => Promise<void> update: (args: UpdateArgs | UpdateSqlArgs) => Promise<void>
/**
* Runs a "merge insert" operation on the table
*
* This operation can add rows, update rows, and remove rows all in a single
* transaction. It is a very generic tool that can be used to create
* behaviors like "insert if not exists", "update or insert (i.e. upsert)",
* or even replace a portion of existing data with new data (e.g. replace
* all data where month="january")
*
* The merge insert operation works by combining new data from a
* **source table** with existing data in a **target table** by using a
* join. There are three categories of records.
*
* "Matched" records are records that exist in both the source table and
* the target table. "Not matched" records exist only in the source table
* (e.g. these are new data) "Not matched by source" records exist only
* in the target table (this is old data)
*
* The MergeInsertArgs can be used to customize what should happen for
* each category of data.
*
* Please note that the data may appear to be reordered as part of this
* operation. This is because updated rows will be deleted from the
* dataset and then reinserted at the end with the new values.
*
* @param on a column to join on. This is how records from the source
* table and target table are matched.
* @param data the new data to insert
* @param args parameters controlling how the operation should behave
*/
mergeInsert: (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs) => Promise<void>
/** /**
* List the indicies on this table. * List the indicies on this table.
*/ */
@@ -516,47 +483,6 @@ export interface UpdateSqlArgs {
valuesSql: Record<string, string> valuesSql: Record<string, string>
} }
export interface MergeInsertArgs {
/**
* If true then rows that exist in both the source table (new data) and
* the target table (old data) will be updated, replacing the old row
* with the corresponding matching row.
*
* If there are multiple matches then the behavior is undefined.
* Currently this causes multiple copies of the row to be created
* but that behavior is subject to change.
*
* Optionally, a filter can be specified. This should be an SQL
* filter where fields with the prefix "target." refer to fields
* in the target table (old data) and fields with the prefix
* "source." refer to fields in the source table (new data). For
* example, the filter "target.lastUpdated < source.lastUpdated" will
* only update matched rows when the incoming `lastUpdated` value is
* newer.
*
* Rows that do not match the filter will not be updated. Rows that
* do not match the filter do become "not matched" rows.
*/
whenMatchedUpdateAll?: string | boolean
/**
* If true then rows that exist only in the source table (new data)
* will be inserted into the target table.
*/
whenNotMatchedInsertAll?: boolean
/**
* If true then rows that exist only in the target table (old data)
* will be deleted.
*
* If this is a string then it will be treated as an SQL filter and
* only rows that both do not match any row in the source table and
* match the given filter will be deleted.
*
* This can be used to replace a selection of existing data with
* new data.
*/
whenNotMatchedBySourceDelete?: string | boolean
}
export interface VectorIndex { export interface VectorIndex {
columns: string[] columns: string[]
name: string name: string
@@ -851,8 +777,8 @@ export class LocalTable<T = number[]> implements Table<T> {
/** /**
* Returns the number of rows in this table. * Returns the number of rows in this table.
*/ */
async countRows (filter?: string): Promise<number> { async countRows (): Promise<number> {
return tableCountRows.call(this._tbl, filter) return tableCountRows.call(this._tbl)
} }
/** /**
@@ -895,46 +821,6 @@ export class LocalTable<T = number[]> implements Table<T> {
}) })
} }
async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
let whenMatchedUpdateAll = false
let whenMatchedUpdateAllFilt = null
if (args.whenMatchedUpdateAll !== undefined && args.whenMatchedUpdateAll !== null) {
whenMatchedUpdateAll = true
if (args.whenMatchedUpdateAll !== true) {
whenMatchedUpdateAllFilt = args.whenMatchedUpdateAll
}
}
const whenNotMatchedInsertAll = args.whenNotMatchedInsertAll ?? false
let whenNotMatchedBySourceDelete = false
let whenNotMatchedBySourceDeleteFilt = null
if (args.whenNotMatchedBySourceDelete !== undefined && args.whenNotMatchedBySourceDelete !== null) {
whenNotMatchedBySourceDelete = true
if (args.whenNotMatchedBySourceDelete !== true) {
whenNotMatchedBySourceDeleteFilt = args.whenNotMatchedBySourceDelete
}
}
const schema = await this.schema
let tbl: ArrowTable
if (data instanceof ArrowTable) {
tbl = data
} else {
tbl = makeArrowTable(data, { schema })
}
const buffer = await fromTableToBuffer(tbl, this._embeddings, schema)
this._tbl = await tableMergeInsert.call(
this._tbl,
on,
whenMatchedUpdateAll,
whenMatchedUpdateAllFilt,
whenNotMatchedInsertAll,
whenNotMatchedBySourceDelete,
whenNotMatchedBySourceDeleteFilt,
buffer
)
}
/** /**
* Clean up old versions of the table, freeing disk space. * Clean up old versions of the table, freeing disk space.
* *

View File

@@ -24,8 +24,7 @@ import {
type IndexStats, type IndexStats,
type UpdateArgs, type UpdateArgs,
type UpdateSqlArgs, type UpdateSqlArgs,
makeArrowTable, makeArrowTable
type MergeInsertArgs
} from '../index' } from '../index'
import { Query } from '../query' import { Query } from '../query'
@@ -275,55 +274,6 @@ export class RemoteTable<T = number[]> implements Table<T> {
throw new Error('Not implemented') throw new Error('Not implemented')
} }
async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
let tbl: ArrowTable
if (data instanceof ArrowTable) {
tbl = data
} else {
tbl = makeArrowTable(data, await this.schema)
}
const queryParams: any = {
on
}
if (args.whenMatchedUpdateAll !== false && args.whenMatchedUpdateAll !== null && args.whenMatchedUpdateAll !== undefined) {
queryParams.when_matched_update_all = 'true'
if (typeof args.whenMatchedUpdateAll === 'string') {
queryParams.when_matched_update_all_filt = args.whenMatchedUpdateAll
}
} else {
queryParams.when_matched_update_all = 'false'
}
if (args.whenNotMatchedInsertAll ?? false) {
queryParams.when_not_matched_insert_all = 'true'
} else {
queryParams.when_not_matched_insert_all = 'false'
}
if (args.whenNotMatchedBySourceDelete !== false && args.whenNotMatchedBySourceDelete !== null && args.whenNotMatchedBySourceDelete !== undefined) {
queryParams.when_not_matched_by_source_delete = 'true'
if (typeof args.whenNotMatchedBySourceDelete === 'string') {
queryParams.when_not_matched_by_source_delete_filt = args.whenNotMatchedBySourceDelete
}
} else {
queryParams.when_not_matched_by_source_delete = 'false'
}
const buffer = await fromTableToStreamBuffer(tbl, this._embeddings)
const res = await this._client.post(
`/v1/table/${this._name}/merge_insert/`,
buffer,
queryParams,
'application/vnd.apache.arrow.stream'
)
if (res.status !== 200) {
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
)
}
}
async add (data: Array<Record<string, unknown>> | ArrowTable): Promise<number> { async add (data: Array<Record<string, unknown>> | ArrowTable): Promise<number> {
let tbl: ArrowTable let tbl: ArrowTable
if (data instanceof ArrowTable) { if (data instanceof ArrowTable) {

View File

@@ -13,10 +13,9 @@
// limitations under the License. // limitations under the License.
import { describe } from 'mocha' import { describe } from 'mocha'
import { assert, expect, use as chaiUse } from 'chai' import { assert } from 'chai'
import * as chaiAsPromised from 'chai-as-promised'
import { convertToTable, fromTableToBuffer, makeArrowTable, makeEmptyTable } from '../arrow' import { fromTableToBuffer, makeArrowTable } from '../arrow'
import { import {
Field, Field,
FixedSizeList, FixedSizeList,
@@ -25,79 +24,21 @@ import {
Int32, Int32,
tableFromIPC, tableFromIPC,
Schema, Schema,
Float64, Float64
type Table,
Binary,
Bool,
Utf8,
Struct,
List,
DataType,
Dictionary,
Int64
} from 'apache-arrow' } from 'apache-arrow'
import { type EmbeddingFunction } from '../embedding/embedding_function'
chaiUse(chaiAsPromised) describe('Apache Arrow tables', function () {
it('customized schema', async function () {
function sampleRecords (): Array<Record<string, any>> {
return [
{
binary: Buffer.alloc(5),
boolean: false,
number: 7,
string: 'hello',
struct: { x: 0, y: 0 },
list: ['anime', 'action', 'comedy']
}
]
}
// Helper method to verify various ways to create a table
async function checkTableCreation (tableCreationMethod: (records: any, recordsReversed: any, schema: Schema) => Promise<Table>): Promise<void> {
const records = sampleRecords()
const recordsReversed = [{
list: ['anime', 'action', 'comedy'],
struct: { x: 0, y: 0 },
string: 'hello',
number: 7,
boolean: false,
binary: Buffer.alloc(5)
}]
const schema = new Schema([
new Field('binary', new Binary(), false),
new Field('boolean', new Bool(), false),
new Field('number', new Float64(), false),
new Field('string', new Utf8(), false),
new Field('struct', new Struct([
new Field('x', new Float64(), false),
new Field('y', new Float64(), false)
])),
new Field('list', new List(new Field('item', new Utf8(), false)), false)
])
const table = await tableCreationMethod(records, recordsReversed, schema)
schema.fields.forEach((field, idx) => {
const actualField = table.schema.fields[idx]
assert.isFalse(actualField.nullable)
assert.equal(table.getChild(field.name)?.type.toString(), field.type.toString())
assert.equal(table.getChildAt(idx)?.type.toString(), field.type.toString())
})
}
describe('The function makeArrowTable', function () {
it('will use data types from a provided schema instead of inference', async function () {
const schema = new Schema([ const schema = new Schema([
new Field('a', new Int32()), new Field('a', new Int32()),
new Field('b', new Float32()), new Field('b', new Float32()),
new Field('c', new FixedSizeList(3, new Field('item', new Float16()))), new Field('c', new FixedSizeList(3, new Field('item', new Float16())))
new Field('d', new Int64())
]) ])
const table = makeArrowTable( const table = makeArrowTable(
[ [
{ a: 1, b: 2, c: [1, 2, 3], d: 9 }, { a: 1, b: 2, c: [1, 2, 3] },
{ a: 4, b: 5, c: [4, 5, 6], d: 10 }, { a: 4, b: 5, c: [4, 5, 6] },
{ a: 7, b: 8, c: [7, 8, 9], d: null } { a: 7, b: 8, c: [7, 8, 9] }
], ],
{ schema } { schema }
) )
@@ -111,13 +52,13 @@ describe('The function makeArrowTable', function () {
assert.deepEqual(actualSchema, schema) assert.deepEqual(actualSchema, schema)
}) })
it('will assume the column `vector` is FixedSizeList<Float32> by default', async function () { it('default vector column', async function () {
const schema = new Schema([ const schema = new Schema([
new Field('a', new Float64()), new Field('a', new Float64()),
new Field('b', new Float64()), new Field('b', new Float64()),
new Field( new Field(
'vector', 'vector',
new FixedSizeList(3, new Field('item', new Float32(), true)) new FixedSizeList(3, new Field('item', new Float32()))
) )
]) ])
const table = makeArrowTable([ const table = makeArrowTable([
@@ -135,12 +76,12 @@ describe('The function makeArrowTable', function () {
assert.deepEqual(actualSchema, schema) assert.deepEqual(actualSchema, schema)
}) })
it('can support multiple vector columns', async function () { it('2 vector columns', async function () {
const schema = new Schema([ const schema = new Schema([
new Field('a', new Float64()), new Field('a', new Float64()),
new Field('b', new Float64()), new Field('b', new Float64()),
new Field('vec1', new FixedSizeList(3, new Field('item', new Float16(), true))), new Field('vec1', new FixedSizeList(3, new Field('item', new Float16()))),
new Field('vec2', new FixedSizeList(3, new Field('item', new Float16(), true))) new Field('vec2', new FixedSizeList(3, new Field('item', new Float16())))
]) ])
const table = makeArrowTable( const table = makeArrowTable(
[ [
@@ -164,157 +105,4 @@ describe('The function makeArrowTable', function () {
const actualSchema = actual.schema const actualSchema = actual.schema
assert.deepEqual(actualSchema, schema) assert.deepEqual(actualSchema, schema)
}) })
it('will allow different vector column types', async function () {
const table = makeArrowTable(
[
{ fp16: [1], fp32: [1], fp64: [1] }
],
{
vectorColumns: {
fp16: { type: new Float16() },
fp32: { type: new Float32() },
fp64: { type: new Float64() }
}
}
)
assert.equal(table.getChild('fp16')?.type.children[0].type.toString(), new Float16().toString())
assert.equal(table.getChild('fp32')?.type.children[0].type.toString(), new Float32().toString())
assert.equal(table.getChild('fp64')?.type.children[0].type.toString(), new Float64().toString())
})
it('will use dictionary encoded strings if asked', async function () {
const table = makeArrowTable([{ str: 'hello' }])
assert.isTrue(DataType.isUtf8(table.getChild('str')?.type))
const tableWithDict = makeArrowTable([{ str: 'hello' }], { dictionaryEncodeStrings: true })
assert.isTrue(DataType.isDictionary(tableWithDict.getChild('str')?.type))
const schema = new Schema([
new Field('str', new Dictionary(new Utf8(), new Int32()))
])
const tableWithDict2 = makeArrowTable([{ str: 'hello' }], { schema })
assert.isTrue(DataType.isDictionary(tableWithDict2.getChild('str')?.type))
})
it('will infer data types correctly', async function () {
await checkTableCreation(async (records) => makeArrowTable(records))
})
it('will allow a schema to be provided', async function () {
await checkTableCreation(async (records, _, schema) => makeArrowTable(records, { schema }))
})
it('will use the field order of any provided schema', async function () {
await checkTableCreation(async (_, recordsReversed, schema) => makeArrowTable(recordsReversed, { schema }))
})
it('will make an empty table', async function () {
await checkTableCreation(async (_, __, schema) => makeArrowTable([], { schema }))
})
})
class DummyEmbedding implements EmbeddingFunction<string> {
public readonly sourceColumn = 'string'
public readonly embeddingDimension = 2
public readonly embeddingDataType = new Float16()
async embed (data: string[]): Promise<number[][]> {
return data.map(
() => [0.0, 0.0]
)
}
}
class DummyEmbeddingWithNoDimension implements EmbeddingFunction<string> {
public readonly sourceColumn = 'string'
async embed (data: string[]): Promise<number[][]> {
return data.map(
() => [0.0, 0.0]
)
}
}
describe('convertToTable', function () {
it('will infer data types correctly', async function () {
await checkTableCreation(async (records) => await convertToTable(records))
})
it('will allow a schema to be provided', async function () {
await checkTableCreation(async (records, _, schema) => await convertToTable(records, undefined, { schema }))
})
it('will use the field order of any provided schema', async function () {
await checkTableCreation(async (_, recordsReversed, schema) => await convertToTable(recordsReversed, undefined, { schema }))
})
it('will make an empty table', async function () {
await checkTableCreation(async (_, __, schema) => await convertToTable([], undefined, { schema }))
})
it('will apply embeddings', async function () {
const records = sampleRecords()
const table = await convertToTable(records, new DummyEmbedding())
assert.isTrue(DataType.isFixedSizeList(table.getChild('vector')?.type))
assert.equal(table.getChild('vector')?.type.children[0].type.toString(), new Float16().toString())
})
it('will fail if missing the embedding source column', async function () {
return await expect(convertToTable([{ id: 1 }], new DummyEmbedding())).to.be.rejectedWith("'string' was not present")
})
it('use embeddingDimension if embedding missing from table', async function () {
const schema = new Schema([
new Field('string', new Utf8(), false)
])
// Simulate getting an empty Arrow table (minus embedding) from some other source
// In other words, we aren't starting with records
const table = makeEmptyTable(schema)
// If the embedding specifies the dimension we are fine
await fromTableToBuffer(table, new DummyEmbedding())
// We can also supply a schema and should be ok
const schemaWithEmbedding = new Schema([
new Field('string', new Utf8(), false),
new Field('vector', new FixedSizeList(2, new Field('item', new Float16(), false)), false)
])
await fromTableToBuffer(table, new DummyEmbeddingWithNoDimension(), schemaWithEmbedding)
// Otherwise we will get an error
return await expect(fromTableToBuffer(table, new DummyEmbeddingWithNoDimension())).to.be.rejectedWith('does not specify `embeddingDimension`')
})
it('will apply embeddings to an empty table', async function () {
const schema = new Schema([
new Field('string', new Utf8(), false),
new Field('vector', new FixedSizeList(2, new Field('item', new Float16(), false)), false)
])
const table = await convertToTable([], new DummyEmbedding(), { schema })
assert.isTrue(DataType.isFixedSizeList(table.getChild('vector')?.type))
assert.equal(table.getChild('vector')?.type.children[0].type.toString(), new Float16().toString())
})
it('will complain if embeddings present but schema missing embedding column', async function () {
const schema = new Schema([
new Field('string', new Utf8(), false)
])
return await expect(convertToTable([], new DummyEmbedding(), { schema })).to.be.rejectedWith('column vector was missing')
})
it('will provide a nice error if run twice', async function () {
const records = sampleRecords()
const table = await convertToTable(records, new DummyEmbedding())
// fromTableToBuffer will try and apply the embeddings again
return await expect(fromTableToBuffer(table, new DummyEmbedding())).to.be.rejectedWith('already existed')
})
})
describe('makeEmptyTable', function () {
it('will make an empty table', async function () {
await checkTableCreation(async (_, __, schema) => makeEmptyTable(schema))
})
}) })

View File

@@ -294,7 +294,6 @@ describe('LanceDB client', function () {
}) })
assert.equal(table.name, 'vectors') assert.equal(table.name, 'vectors')
assert.equal(await table.countRows(), 10) assert.equal(await table.countRows(), 10)
assert.equal(await table.countRows('vector IS NULL'), 0)
assert.deepEqual(await con.tableNames(), ['vectors']) assert.deepEqual(await con.tableNames(), ['vectors'])
}) })
@@ -370,7 +369,6 @@ describe('LanceDB client', function () {
const table = await con.createTable('f16', data) const table = await con.createTable('f16', data)
assert.equal(table.name, 'f16') assert.equal(table.name, 'f16')
assert.equal(await table.countRows(), total) assert.equal(await table.countRows(), total)
assert.equal(await table.countRows('id < 5'), 5)
assert.deepEqual(await con.tableNames(), ['f16']) assert.deepEqual(await con.tableNames(), ['f16'])
assert.deepEqual(await table.schema, schema) assert.deepEqual(await table.schema, schema)
@@ -533,54 +531,6 @@ describe('LanceDB client', function () {
assert.equal(await table.countRows(), 2) assert.equal(await table.countRows(), 2)
}) })
it('can merge insert records into the table', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const data = [{ id: 1, age: 1 }, { id: 2, age: 1 }]
const table = await con.createTable('my_table', data)
// insert if not exists
let newData = [{ id: 2, age: 2 }, { id: 3, age: 2 }]
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true
})
assert.equal(await table.countRows(), 3)
assert.equal(await table.countRows('age = 2'), 1)
// conditional update
newData = [{ id: 2, age: 3 }, { id: 3, age: 3 }]
await table.mergeInsert('id', newData, {
whenMatchedUpdateAll: 'target.age = 1'
})
assert.equal(await table.countRows(), 3)
assert.equal(await table.countRows('age = 1'), 1)
assert.equal(await table.countRows('age = 3'), 1)
newData = [{ id: 3, age: 4 }, { id: 4, age: 4 }]
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true,
whenMatchedUpdateAll: true
})
assert.equal(await table.countRows(), 4)
assert.equal((await table.filter('age = 4').execute()).length, 2)
newData = [{ id: 5, age: 5 }]
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true,
whenMatchedUpdateAll: true,
whenNotMatchedBySourceDelete: 'age < 4'
})
assert.equal(await table.countRows(), 3)
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true,
whenMatchedUpdateAll: true,
whenNotMatchedBySourceDelete: true
})
assert.equal(await table.countRows(), 1)
})
it('can update records in the table', async function () { it('can update records in the table', async function () {
const uri = await createTestDB() const uri = await createTestDB()
const con = await lancedb.connect(uri) const con = await lancedb.connect(uri)

View File

@@ -9,6 +9,6 @@
"declaration": true, "declaration": true,
"outDir": "./dist", "outDir": "./dist",
"strict": true, "strict": true,
"sourceMap": true, // "esModuleInterop": true,
} }
} }

View File

@@ -1,12 +1,9 @@
[package] [package]
name = "vectordb-nodejs" name = "vectordb-nodejs"
edition.workspace = true edition = "2021"
version = "0.0.0" version = "0.0.0"
license.workspace = true license.workspace = true
description.workspace = true
repository.workspace = true repository.workspace = true
keywords.workspace = true
categories.workspace = true
[lib] [lib]
crate-type = ["cdylib"] crate-type = ["cdylib"]
@@ -17,14 +14,15 @@ futures.workspace = true
lance-linalg.workspace = true lance-linalg.workspace = true
lance.workspace = true lance.workspace = true
vectordb = { path = "../rust/vectordb" } vectordb = { path = "../rust/vectordb" }
napi = { version = "2.15", default-features = false, features = [ napi = { version = "2.14", default-features = false, features = [
"napi7", "napi7",
"async" "async"
] } ] }
napi-derive = "2" napi-derive = "2.14"
# Prevent dynamic linking of lzma, which comes from datafusion
lzma-sys = { version = "*", features = ["static"] }
[build-dependencies] [build-dependencies]
napi-build = "2.1" napi-build = "2.1"
[profile.release]
lto = true
strip = "symbols"

View File

@@ -14,7 +14,6 @@
import { makeArrowTable, toBuffer } from "../vectordb/arrow"; import { makeArrowTable, toBuffer } from "../vectordb/arrow";
import { import {
Int64,
Field, Field,
FixedSizeList, FixedSizeList,
Float16, Float16,
@@ -105,16 +104,3 @@ test("2 vector columns", function () {
const actualSchema = actual.schema; const actualSchema = actual.schema;
expect(actualSchema.toString()).toEqual(schema.toString()); expect(actualSchema.toString()).toEqual(schema.toString());
}); });
test("handles int64", function() {
// https://github.com/lancedb/lancedb/issues/960
const schema = new Schema([
new Field("x", new Int64(), true)
]);
const table = makeArrowTable([
{ x: 1 },
{ x: 2 },
{ x: 3 }
], { schema });
expect(table.schema).toEqual(schema);
})

View File

@@ -2,6 +2,4 @@
module.exports = { module.exports = {
preset: 'ts-jest', preset: 'ts-jest',
testEnvironment: 'node', testEnvironment: 'node',
moduleDirectories: ["node_modules", "./dist"], };
moduleFileExtensions: ["js", "ts"],
};

View File

@@ -57,8 +57,8 @@ impl Table {
} }
#[napi] #[napi]
pub async fn count_rows(&self, filter: Option<String>) -> napi::Result<usize> { pub async fn count_rows(&self) -> napi::Result<usize> {
self.table.count_rows(filter).await.map_err(|e| { self.table.count_rows().await.map_err(|e| {
napi::Error::from_reason(format!( napi::Error::from_reason(format!(
"Failed to count rows in table {}: {}", "Failed to count rows in table {}: {}",
self.table, e self.table, e

View File

@@ -13,7 +13,6 @@
// limitations under the License. // limitations under the License.
import { import {
Int64,
Field, Field,
FixedSizeList, FixedSizeList,
Float, Float,
@@ -24,7 +23,6 @@ import {
Vector, Vector,
vectorFromArray, vectorFromArray,
tableToIPC, tableToIPC,
DataType,
} from "apache-arrow"; } from "apache-arrow";
/** Data type accepted by NodeJS SDK */ /** Data type accepted by NodeJS SDK */
@@ -139,18 +137,15 @@ export function makeArrowTable(
const columnNames = Object.keys(data[0]); const columnNames = Object.keys(data[0]);
for (const colName of columnNames) { for (const colName of columnNames) {
// eslint-disable-next-line @typescript-eslint/no-unsafe-return // eslint-disable-next-line @typescript-eslint/no-unsafe-return
let values = data.map((datum) => datum[colName]); const values = data.map((datum) => datum[colName]);
let vector: Vector; let vector: Vector;
if (opt.schema !== undefined) { if (opt.schema !== undefined) {
// Explicit schema is provided, highest priority // Explicit schema is provided, highest priority
const fieldType: DataType | undefined = opt.schema.fields.filter((f) => f.name === colName)[0]?.type as DataType; vector = vectorFromArray(
if (fieldType instanceof Int64) { values,
// wrap in BigInt to avoid bug: https://github.com/apache/arrow/issues/40051 opt.schema?.fields.filter((f) => f.name === colName)[0]?.type
// eslint-disable-next-line @typescript-eslint/no-unsafe-argument );
values = values.map((v) => BigInt(v));
}
vector = vectorFromArray(values, fieldType);
} else { } else {
const vectorColumnOptions = opt.vectorColumns[colName]; const vectorColumnOptions = opt.vectorColumns[colName];
if (vectorColumnOptions !== undefined) { if (vectorColumnOptions !== undefined) {

View File

@@ -73,7 +73,7 @@ export class Table {
/** Return Schema as empty Arrow IPC file. */ /** Return Schema as empty Arrow IPC file. */
schema(): Buffer schema(): Buffer
add(buf: Buffer): Promise<void> add(buf: Buffer): Promise<void>
countRows(filter?: string): Promise<bigint> countRows(): Promise<bigint>
delete(predicate: string): Promise<void> delete(predicate: string): Promise<void>
createIndex(): IndexBuilder createIndex(): IndexBuilder
query(): Query query(): Query

View File

@@ -50,8 +50,8 @@ export class Table {
} }
/** Count the total number of rows in the dataset. */ /** Count the total number of rows in the dataset. */
async countRows(filter?: string): Promise<bigint> { async countRows(): Promise<bigint> {
return await this.inner.countRows(filter); return await this.inner.countRows();
} }
/** Delete the rows that satisfy the predicate. */ /** Delete the rows that satisfy the predicate. */

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 0.5.5 current_version = 0.5.1
commit = True commit = True
message = [python] Bump version: {current_version} → {new_version} message = [python] Bump version: {current_version} → {new_version}
tag = True tag = True

View File

@@ -42,12 +42,6 @@ To run the unit tests:
pytest pytest
``` ```
To run the doc tests:
```bash
pytest --doctest-modules lancedb
```
To run linter and automatically fix all errors: To run linter and automatically fix all errors:
```bash ```bash

View File

@@ -13,7 +13,6 @@
import importlib.metadata import importlib.metadata
import os import os
from datetime import timedelta
from typing import Optional from typing import Optional
__version__ = importlib.metadata.version("lancedb") __version__ = importlib.metadata.version("lancedb")
@@ -31,7 +30,6 @@ def connect(
api_key: Optional[str] = None, api_key: Optional[str] = None,
region: str = "us-east-1", region: str = "us-east-1",
host_override: Optional[str] = None, host_override: Optional[str] = None,
read_consistency_interval: Optional[timedelta] = None,
) -> DBConnection: ) -> DBConnection:
"""Connect to a LanceDB database. """Connect to a LanceDB database.
@@ -47,18 +45,6 @@ def connect(
The region to use for LanceDB Cloud. The region to use for LanceDB Cloud.
host_override: str, optional host_override: str, optional
The override url for LanceDB Cloud. The override url for LanceDB Cloud.
read_consistency_interval: timedelta, default None
(For LanceDB OSS only)
The interval at which to check for updates to the table from other
processes. If None, then consistency is not checked. For performance
reasons, this is the default. For strong consistency, set this to
zero seconds. Then every read will check for updates from other
processes. As a compromise, you can set this to a non-zero timedelta
for eventual consistency. If more than that interval has passed since
the last check, then the table will be checked for updates. Note: this
consistency only applies to read operations. Write operations are
always consistent.
Examples Examples
-------- --------
@@ -87,4 +73,4 @@ def connect(
if api_key is None: if api_key is None:
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}") raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
return RemoteDBConnection(uri, api_key, region, host_override) return RemoteDBConnection(uri, api_key, region, host_override)
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval) return LanceDBConnection(uri)

View File

@@ -16,9 +16,9 @@ from typing import Iterable, List, Union
import numpy as np import numpy as np
import pyarrow as pa import pyarrow as pa
from .util import safe_import_pandas from .util import safe_import
pd = safe_import_pandas() pd = safe_import("pandas")
DATA = Union[List[dict], dict, "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]] DATA = Union[List[dict], dict, "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]]
VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray] VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray]

View File

@@ -16,9 +16,9 @@ import deprecation
from . import __version__ from . import __version__
from .exceptions import MissingColumnError, MissingValueError from .exceptions import MissingColumnError, MissingValueError
from .util import safe_import_pandas from .util import safe_import
pd = safe_import_pandas() pd = safe_import("pandas")
def contextualize(raw_df: "pd.DataFrame") -> Contextualizer: def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:

View File

@@ -26,8 +26,6 @@ from .table import LanceTable, Table
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
if TYPE_CHECKING: if TYPE_CHECKING:
from datetime import timedelta
from .common import DATA, URI from .common import DATA, URI
from .embeddings import EmbeddingFunctionConfig from .embeddings import EmbeddingFunctionConfig
from .pydantic import LanceModel from .pydantic import LanceModel
@@ -120,7 +118,7 @@ class DBConnection(EnforceOverrides):
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7}, >>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}] ... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
>>> db.create_table("my_table", data) >>> db.create_table("my_table", data)
LanceTable(connection=..., name="my_table") LanceTable(my_table)
>>> db["my_table"].head() >>> db["my_table"].head()
pyarrow.Table pyarrow.Table
vector: fixed_size_list<item: float>[2] vector: fixed_size_list<item: float>[2]
@@ -141,7 +139,7 @@ class DBConnection(EnforceOverrides):
... "long": [-122.7, -74.1] ... "long": [-122.7, -74.1]
... }) ... })
>>> db.create_table("table2", data) >>> db.create_table("table2", data)
LanceTable(connection=..., name="table2") LanceTable(table2)
>>> db["table2"].head() >>> db["table2"].head()
pyarrow.Table pyarrow.Table
vector: fixed_size_list<item: float>[2] vector: fixed_size_list<item: float>[2]
@@ -163,7 +161,7 @@ class DBConnection(EnforceOverrides):
... pa.field("long", pa.float32()) ... pa.field("long", pa.float32())
... ]) ... ])
>>> db.create_table("table3", data, schema = custom_schema) >>> db.create_table("table3", data, schema = custom_schema)
LanceTable(connection=..., name="table3") LanceTable(table3)
>>> db["table3"].head() >>> db["table3"].head()
pyarrow.Table pyarrow.Table
vector: fixed_size_list<item: float>[2] vector: fixed_size_list<item: float>[2]
@@ -197,7 +195,7 @@ class DBConnection(EnforceOverrides):
... pa.field("price", pa.float32()), ... pa.field("price", pa.float32()),
... ]) ... ])
>>> db.create_table("table4", make_batches(), schema=schema) >>> db.create_table("table4", make_batches(), schema=schema)
LanceTable(connection=..., name="table4") LanceTable(table4)
""" """
raise NotImplementedError raise NotImplementedError
@@ -245,16 +243,6 @@ class LanceDBConnection(DBConnection):
---------- ----------
uri: str or Path uri: str or Path
The root uri of the database. The root uri of the database.
read_consistency_interval: timedelta, default None
The interval at which to check for updates to the table from other
processes. If None, then consistency is not checked. For performance
reasons, this is the default. For strong consistency, set this to
zero seconds. Then every read will check for updates from other
processes. As a compromise, you can set this to a non-zero timedelta
for eventual consistency. If more than that interval has passed since
the last check, then the table will be checked for updates. Note: this
consistency only applies to read operations. Write operations are
always consistent.
Examples Examples
-------- --------
@@ -262,24 +250,22 @@ class LanceDBConnection(DBConnection):
>>> db = lancedb.connect("./.lancedb") >>> db = lancedb.connect("./.lancedb")
>>> db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2}, >>> db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2},
... {"vector": [0.5, 1.3], "b": 4}]) ... {"vector": [0.5, 1.3], "b": 4}])
LanceTable(connection=..., name="my_table") LanceTable(my_table)
>>> db.create_table("another_table", data=[{"vector": [0.4, 0.4], "b": 6}]) >>> db.create_table("another_table", data=[{"vector": [0.4, 0.4], "b": 6}])
LanceTable(connection=..., name="another_table") LanceTable(another_table)
>>> sorted(db.table_names()) >>> sorted(db.table_names())
['another_table', 'my_table'] ['another_table', 'my_table']
>>> len(db) >>> len(db)
2 2
>>> db["my_table"] >>> db["my_table"]
LanceTable(connection=..., name="my_table") LanceTable(my_table)
>>> "my_table" in db >>> "my_table" in db
True True
>>> db.drop_table("my_table") >>> db.drop_table("my_table")
>>> db.drop_table("another_table") >>> db.drop_table("another_table")
""" """
def __init__( def __init__(self, uri: URI):
self, uri: URI, *, read_consistency_interval: Optional[timedelta] = None
):
if not isinstance(uri, Path): if not isinstance(uri, Path):
scheme = get_uri_scheme(uri) scheme = get_uri_scheme(uri)
is_local = isinstance(uri, Path) or scheme == "file" is_local = isinstance(uri, Path) or scheme == "file"
@@ -291,14 +277,6 @@ class LanceDBConnection(DBConnection):
self._uri = str(uri) self._uri = str(uri)
self._entered = False self._entered = False
self.read_consistency_interval = read_consistency_interval
def __repr__(self) -> str:
val = f"{self.__class__.__name__}({self._uri}"
if self.read_consistency_interval is not None:
val += f", read_consistency_interval={repr(self.read_consistency_interval)}"
val += ")"
return val
@property @property
def uri(self) -> str: def uri(self) -> str:

View File

@@ -10,6 +10,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Union from typing import List, Union
@@ -90,6 +91,25 @@ class EmbeddingFunction(BaseModel, ABC):
texts = texts.combine_chunks().to_pylist() texts = texts.combine_chunks().to_pylist()
return texts return texts
@classmethod
def safe_import(cls, module: str, mitigation=None):
"""
Import the specified module. If the module is not installed,
raise an ImportError with a helpful message.
Parameters
----------
module : str
The name of the module to import
mitigation : Optional[str]
The package(s) to install to mitigate the error.
If not provided then the module name will be used.
"""
try:
return importlib.import_module(module)
except ImportError:
raise ImportError(f"Please install {mitigation or module}")
def safe_model_dump(self): def safe_model_dump(self):
from ..pydantic import PYDANTIC_VERSION from ..pydantic import PYDANTIC_VERSION

View File

@@ -19,7 +19,6 @@ import numpy as np
from lancedb.pydantic import PYDANTIC_VERSION from lancedb.pydantic import PYDANTIC_VERSION
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction from .base import TextEmbeddingFunction
from .registry import register from .registry import register
from .utils import TEXT from .utils import TEXT
@@ -184,8 +183,8 @@ class BedRockText(TextEmbeddingFunction):
boto3.client boto3.client
The boto3 client for Amazon Bedrock service The boto3 client for Amazon Bedrock service
""" """
botocore = attempt_import_or_raise("botocore") botocore = self.safe_import("botocore")
boto3 = attempt_import_or_raise("boto3") boto3 = self.safe_import("boto3")
session_kwargs = {"region_name": self.region} session_kwargs = {"region_name": self.region}
client_kwargs = {**session_kwargs} client_kwargs = {**session_kwargs}

View File

@@ -16,7 +16,6 @@ from typing import ClassVar, List, Union
import numpy as np import numpy as np
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction from .base import TextEmbeddingFunction
from .registry import register from .registry import register
from .utils import api_key_not_found_help from .utils import api_key_not_found_help
@@ -85,7 +84,7 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
return [emb for emb in rs.embeddings] return [emb for emb in rs.embeddings]
def _init_client(self): def _init_client(self):
cohere = attempt_import_or_raise("cohere") cohere = self.safe_import("cohere")
if CohereEmbeddingFunction.client is None: if CohereEmbeddingFunction.client is None:
if os.environ.get("COHERE_API_KEY") is None: if os.environ.get("COHERE_API_KEY") is None:
api_key_not_found_help("cohere") api_key_not_found_help("cohere")

View File

@@ -19,7 +19,6 @@ import numpy as np
from lancedb.pydantic import PYDANTIC_VERSION from lancedb.pydantic import PYDANTIC_VERSION
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction from .base import TextEmbeddingFunction
from .registry import register from .registry import register
from .utils import TEXT, api_key_not_found_help from .utils import TEXT, api_key_not_found_help
@@ -135,7 +134,7 @@ class GeminiText(TextEmbeddingFunction):
@cached_property @cached_property
def client(self): def client(self):
genai = attempt_import_or_raise("google.generativeai", "google.generativeai") genai = self.safe_import("google.generativeai", "google.generativeai")
if not os.environ.get("GOOGLE_API_KEY"): if not os.environ.get("GOOGLE_API_KEY"):
api_key_not_found_help("google") api_key_not_found_help("google")

View File

@@ -14,7 +14,6 @@ from typing import List, Union
import numpy as np import numpy as np
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction from .base import TextEmbeddingFunction
from .registry import register from .registry import register
from .utils import weak_lru from .utils import weak_lru
@@ -123,7 +122,7 @@ class GteEmbeddings(TextEmbeddingFunction):
return Model() return Model()
else: else:
sentence_transformers = attempt_import_or_raise( sentence_transformers = self.safe_import(
"sentence_transformers", "sentence-transformers" "sentence_transformers", "sentence-transformers"
) )
return sentence_transformers.SentenceTransformer( return sentence_transformers.SentenceTransformer(

View File

@@ -14,7 +14,6 @@ from typing import List
import numpy as np import numpy as np
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction from .base import TextEmbeddingFunction
from .registry import register from .registry import register
from .utils import TEXT, weak_lru from .utils import TEXT, weak_lru
@@ -103,9 +102,9 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
# convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly # convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly
source_instruction: str = "represent the document for retrieval" source_instruction: str = "represent the document for retrieval"
query_instruction: ( query_instruction: str = (
str "represent the document for retrieving the most similar documents"
) = "represent the document for retrieving the most similar documents" )
@weak_lru(maxsize=1) @weak_lru(maxsize=1)
def ndims(self): def ndims(self):
@@ -132,10 +131,10 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
@weak_lru(maxsize=1) @weak_lru(maxsize=1)
def get_model(self): def get_model(self):
instructor_embedding = attempt_import_or_raise( instructor_embedding = self.safe_import(
"InstructorEmbedding", "InstructorEmbedding" "InstructorEmbedding", "InstructorEmbedding"
) )
torch = attempt_import_or_raise("torch", "torch") torch = self.safe_import("torch", "torch")
model = instructor_embedding.INSTRUCTOR(self.name) model = instructor_embedding.INSTRUCTOR(self.name)
if self.quantize: if self.quantize:

View File

@@ -21,7 +21,6 @@ import pyarrow as pa
from pydantic import PrivateAttr from pydantic import PrivateAttr
from tqdm import tqdm from tqdm import tqdm
from ..util import attempt_import_or_raise
from .base import EmbeddingFunction from .base import EmbeddingFunction
from .registry import register from .registry import register
from .utils import IMAGES, url_retrieve from .utils import IMAGES, url_retrieve
@@ -51,7 +50,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
open_clip = attempt_import_or_raise("open_clip", "open-clip") open_clip = self.safe_import("open_clip", "open-clip")
model, _, preprocess = open_clip.create_model_and_transforms( model, _, preprocess = open_clip.create_model_and_transforms(
self.name, pretrained=self.pretrained self.name, pretrained=self.pretrained
) )
@@ -79,14 +78,14 @@ class OpenClipEmbeddings(EmbeddingFunction):
if isinstance(query, str): if isinstance(query, str):
return [self.generate_text_embeddings(query)] return [self.generate_text_embeddings(query)]
else: else:
PIL = attempt_import_or_raise("PIL", "pillow") PIL = self.safe_import("PIL", "pillow")
if isinstance(query, PIL.Image.Image): if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)] return [self.generate_image_embedding(query)]
else: else:
raise TypeError("OpenClip supports str or PIL Image as query") raise TypeError("OpenClip supports str or PIL Image as query")
def generate_text_embeddings(self, text: str) -> np.ndarray: def generate_text_embeddings(self, text: str) -> np.ndarray:
torch = attempt_import_or_raise("torch") torch = self.safe_import("torch")
text = self.sanitize_input(text) text = self.sanitize_input(text)
text = self._tokenizer(text) text = self._tokenizer(text)
text.to(self.device) text.to(self.device)
@@ -145,7 +144,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
The image to embed. If the image is a str, it is treated as a uri. The image to embed. If the image is a str, it is treated as a uri.
If the image is bytes, it is treated as the raw image bytes. If the image is bytes, it is treated as the raw image bytes.
""" """
torch = attempt_import_or_raise("torch") torch = self.safe_import("torch")
# TODO handle retry and errors for https # TODO handle retry and errors for https
image = self._to_pil(image) image = self._to_pil(image)
image = self._preprocess(image).unsqueeze(0) image = self._preprocess(image).unsqueeze(0)
@@ -153,7 +152,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
return self._encode_and_normalize_image(image) return self._encode_and_normalize_image(image)
def _to_pil(self, image: Union[str, bytes]): def _to_pil(self, image: Union[str, bytes]):
PIL = attempt_import_or_raise("PIL", "pillow") PIL = self.safe_import("PIL", "pillow")
if isinstance(image, bytes): if isinstance(image, bytes):
return PIL.Image.open(io.BytesIO(image)) return PIL.Image.open(io.BytesIO(image))
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL.Image.Image):

View File

@@ -12,11 +12,10 @@
# limitations under the License. # limitations under the License.
import os import os
from functools import cached_property from functools import cached_property
from typing import List, Optional, Union from typing import List, Union
import numpy as np import numpy as np
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction from .base import TextEmbeddingFunction
from .registry import register from .registry import register
from .utils import api_key_not_found_help from .utils import api_key_not_found_help
@@ -31,21 +30,10 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
""" """
name: str = "text-embedding-ada-002" name: str = "text-embedding-ada-002"
dim: Optional[int] = None
def ndims(self): def ndims(self):
return self._ndims # TODO don't hardcode this
return 1536
@cached_property
def _ndims(self):
if self.name == "text-embedding-ada-002":
return 1536
elif self.name == "text-embedding-3-large":
return self.dim or 3072
elif self.name == "text-embedding-3-small":
return self.dim or 1536
else:
raise ValueError(f"Unknown model name {self.name}")
def generate_embeddings( def generate_embeddings(
self, texts: Union[List[str], np.ndarray] self, texts: Union[List[str], np.ndarray]
@@ -59,17 +47,12 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
The texts to embed The texts to embed
""" """
# TODO retry, rate limit, token limit # TODO retry, rate limit, token limit
if self.name == "text-embedding-ada-002": rs = self._openai_client.embeddings.create(input=texts, model=self.name)
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
else:
rs = self._openai_client.embeddings.create(
input=texts, model=self.name, dimensions=self.ndims()
)
return [v.embedding for v in rs.data] return [v.embedding for v in rs.data]
@cached_property @cached_property
def _openai_client(self): def _openai_client(self):
openai = attempt_import_or_raise("openai") openai = self.safe_import("openai")
if not os.environ.get("OPENAI_API_KEY"): if not os.environ.get("OPENAI_API_KEY"):
api_key_not_found_help("openai") api_key_not_found_help("openai")

View File

@@ -14,7 +14,6 @@ from typing import List, Union
import numpy as np import numpy as np
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction from .base import TextEmbeddingFunction
from .registry import register from .registry import register
from .utils import weak_lru from .utils import weak_lru
@@ -76,7 +75,7 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
TODO: use lru_cache instead with a reasonable/configurable maxsize TODO: use lru_cache instead with a reasonable/configurable maxsize
""" """
sentence_transformers = attempt_import_or_raise( sentence_transformers = self.safe_import(
"sentence_transformers", "sentence-transformers" "sentence_transformers", "sentence-transformers"
) )
return sentence_transformers.SentenceTransformer(self.name, device=self.device) return sentence_transformers.SentenceTransformer(self.name, device=self.device)

View File

@@ -26,10 +26,10 @@ import pyarrow as pa
from lance.vector import vec_to_table from lance.vector import vec_to_table
from retry import retry from retry import retry
from ..util import safe_import_pandas from ..util import safe_import
from ..utils.general import LOGGER from ..utils.general import LOGGER
pd = safe_import_pandas() pd = safe_import("pandas")
DATA = Union[pa.Table, "pd.DataFrame"] DATA = Union[pa.Table, "pd.DataFrame"]
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray] TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]

View File

@@ -12,7 +12,7 @@
# limitations under the License. # limitations under the License.
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, Iterable, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from .common import DATA from .common import DATA
@@ -25,21 +25,18 @@ class LanceMergeInsertBuilder(object):
more context more context
""" """
def __init__(self, table: "Table", on: List[str]): # noqa: F821 def __init__(self, table: "Table", on: Iterable[str]): # noqa: F821
# Do not put a docstring here. This method should be hidden # Do not put a docstring here. This method should be hidden
# from API docs. Users should use merge_insert to create # from API docs. Users should use merge_insert to create
# this object. # this object.
self._table = table self._table = table
self._on = on self._on = on
self._when_matched_update_all = False self._when_matched_update_all = False
self._when_matched_update_all_condition = None
self._when_not_matched_insert_all = False self._when_not_matched_insert_all = False
self._when_not_matched_by_source_delete = False self._when_not_matched_by_source_delete = False
self._when_not_matched_by_source_condition = None self._when_not_matched_by_source_condition = None
def when_matched_update_all( def when_matched_update_all(self) -> LanceMergeInsertBuilder:
self, *, where: Optional[str] = None
) -> LanceMergeInsertBuilder:
""" """
Rows that exist in both the source table (new data) and Rows that exist in both the source table (new data) and
the target table (old data) will be updated, replacing the target table (old data) will be updated, replacing
@@ -50,7 +47,6 @@ class LanceMergeInsertBuilder(object):
but that behavior is subject to change. but that behavior is subject to change.
""" """
self._when_matched_update_all = True self._when_matched_update_all = True
self._when_matched_update_all_condition = where
return self return self
def when_not_matched_insert_all(self) -> LanceMergeInsertBuilder: def when_not_matched_insert_all(self) -> LanceMergeInsertBuilder:
@@ -81,27 +77,10 @@ class LanceMergeInsertBuilder(object):
self._when_not_matched_by_source_condition = condition self._when_not_matched_by_source_condition = condition
return self return self
def execute( def execute(self, new_data: DATA):
self,
new_data: DATA,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
):
""" """
Executes the merge insert operation Executes the merge insert operation
Nothing is returned but the [`Table`][lancedb.table.Table] is updated Nothing is returned but the [`Table`][lancedb.table.Table] is updated
Parameters
----------
new_data: DATA
New records which will be matched against the existing records
to potentially insert or update into the table. This parameter
can be anything you use for [`add`][lancedb.table.Table.add]
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
""" """
self._table._do_merge(self, new_data, on_bad_vectors, fill_value) self._table._do_merge(self, new_data)

View File

@@ -304,7 +304,7 @@ class LanceModel(pydantic.BaseModel):
... name: str ... name: str
... vector: Vector(2) ... vector: Vector(2)
... ...
>>> db = lancedb.connect("./example") >>> db = lancedb.connect("/tmp")
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema()) >>> table = db.create_table("test", schema=TestModel.to_arrow_schema())
>>> table.add([ >>> table.add([
... TestModel(name="test", vector=[1.0, 2.0]) ... TestModel(name="test", vector=[1.0, 2.0])

View File

@@ -24,10 +24,10 @@ import pyarrow as pa
import pydantic import pydantic
from . import __version__ from . import __version__
from .common import VEC from .common import VEC, VECTOR_COLUMN_NAME
from .rerankers.base import Reranker from .rerankers.base import Reranker
from .rerankers.linear_combination import LinearCombinationReranker from .rerankers.linear_combination import LinearCombinationReranker
from .util import safe_import_pandas from .util import safe_import
if TYPE_CHECKING: if TYPE_CHECKING:
import PIL import PIL
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
from .pydantic import LanceModel from .pydantic import LanceModel
from .table import Table from .table import Table
pd = safe_import_pandas() pd = safe_import("pandas")
class Query(pydantic.BaseModel): class Query(pydantic.BaseModel):
@@ -75,7 +75,7 @@ class Query(pydantic.BaseModel):
tuning advice. tuning advice.
""" """
vector_column: Optional[str] = None vector_column: str = VECTOR_COLUMN_NAME
# vector to search for # vector to search for
vector: Union[List[float], List[List[float]]] vector: Union[List[float], List[List[float]]]
@@ -403,7 +403,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self, self,
table: "Table", table: "Table",
query: Union[np.ndarray, list, "PIL.Image.Image"], query: Union[np.ndarray, list, "PIL.Image.Image"],
vector_column: str, vector_column: str = VECTOR_COLUMN_NAME,
): ):
super().__init__(table) super().__init__(table)
self._query = query self._query = query
@@ -626,6 +626,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
def __init__(self, table: "Table", query: str, vector_column: str): def __init__(self, table: "Table", query: str, vector_column: str):
super().__init__(table) super().__init__(table)
self._validate_fts_index() self._validate_fts_index()
self._query = query
vector_query, fts_query = self._validate_query(query) vector_query, fts_query = self._validate_query(query)
self._fts_query = LanceFtsQueryBuilder(table, fts_query) self._fts_query = LanceFtsQueryBuilder(table, fts_query)
vector_query = self._query_to_vector(table, vector_query, vector_column) vector_query = self._query_to_vector(table, vector_query, vector_column)
@@ -678,18 +679,12 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
# rerankers might need to preserve this score to support `return_score="all"` # rerankers might need to preserve this score to support `return_score="all"`
fts_results = self._normalize_scores(fts_results, "score") fts_results = self._normalize_scores(fts_results, "score")
results = self._reranker.rerank_hybrid( results = self._reranker.rerank_hybrid(self, vector_results, fts_results)
self._fts_query._query, vector_results, fts_results
)
if not isinstance(results, pa.Table): # Enforce type if not isinstance(results, pa.Table): # Enforce type
raise TypeError( raise TypeError(
f"rerank_hybrid must return a pyarrow.Table, got {type(results)}" f"rerank_hybrid must return a pyarrow.Table, got {type(results)}"
) )
# apply limit after reranking
results = results.slice(length=self._limit)
if not self._with_row_id: if not self._with_row_id:
results = results.drop(["_rowid"]) results = results.drop(["_rowid"])
return results return results
@@ -781,8 +776,6 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
""" """
self._vector_query.limit(limit) self._vector_query.limit(limit)
self._fts_query.limit(limit) self._fts_query.limit(limit)
self._limit = limit
return self return self
def select(self, columns: list) -> LanceHybridQueryBuilder: def select(self, columns: list) -> LanceHybridQueryBuilder:

View File

@@ -13,8 +13,6 @@
import functools import functools
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
from urllib.parse import urljoin from urllib.parse import urljoin
@@ -22,8 +20,6 @@ import attrs
import pyarrow as pa import pyarrow as pa
import requests import requests
from pydantic import BaseModel from pydantic import BaseModel
from requests.adapters import HTTPAdapter
from urllib3 import Retry
from lancedb.common import Credential from lancedb.common import Credential
from lancedb.remote import VectorQuery, VectorQueryResult from lancedb.remote import VectorQuery, VectorQueryResult
@@ -61,10 +57,6 @@ class RestfulLanceDBClient:
@functools.cached_property @functools.cached_property
def session(self) -> requests.Session: def session(self) -> requests.Session:
sess = requests.Session() sess = requests.Session()
retry_adapter_instance = retry_adapter(retry_adapter_options())
sess.mount(urljoin(self.url, "/v1/table/"), retry_adapter_instance)
adapter_class = LanceDBClientHTTPAdapterFactory() adapter_class = LanceDBClientHTTPAdapterFactory()
sess.mount("https://", adapter_class()) sess.mount("https://", adapter_class())
return sess return sess
@@ -178,72 +170,3 @@ class RestfulLanceDBClient:
"""Query a table.""" """Query a table."""
tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc) tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc)
return VectorQueryResult(tbl) return VectorQueryResult(tbl)
def mount_retry_adapter_for_table(self, table_name: str) -> None:
"""
Adds an http adapter to session that will retry retryable requests to the table.
"""
retry_options = retry_adapter_options(methods=["GET", "POST"])
retry_adapter_instance = retry_adapter(retry_options)
session = self.session
session.mount(
urljoin(self.url, f"/v1/table/{table_name}/query/"), retry_adapter_instance
)
session.mount(
urljoin(self.url, f"/v1/table/{table_name}/describe/"),
retry_adapter_instance,
)
session.mount(
urljoin(self.url, f"/v1/table/{table_name}/index/list/"),
retry_adapter_instance,
)
def retry_adapter_options(methods=["GET"]) -> Dict[str, Any]:
return {
"retries": int(os.environ.get("LANCE_CLIENT_MAX_RETRIES", "3")),
"connect_retries": int(os.environ.get("LANCE_CLIENT_CONNECT_RETRIES", "3")),
"read_retries": int(os.environ.get("LANCE_CLIENT_READ_RETRIES", "3")),
"backoff_factor": float(
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_FACTOR", "0.25")
),
"backoff_jitter": float(
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_JITTER", "0.25")
),
"statuses": [
int(i.strip())
for i in os.environ.get(
"LANCE_CLIENT_RETRY_STATUSES", "429, 500, 502, 503"
).split(",")
],
"methods": methods,
}
def retry_adapter(options: Dict[str, Any]) -> HTTPAdapter:
total_retries = options["retries"]
connect_retries = options["connect_retries"]
read_retries = options["read_retries"]
backoff_factor = options["backoff_factor"]
backoff_jitter = options["backoff_jitter"]
statuses = options["statuses"]
methods = frozenset(options["methods"])
logging.debug(
f"Setting up retry adapter with {total_retries} retries," # noqa G003
+ f"connect retries {connect_retries}, read retries {read_retries},"
+ f"backoff factor {backoff_factor}, statuses {statuses}, "
+ f"methods {methods}"
)
return HTTPAdapter(
max_retries=Retry(
total=total_retries,
connect=connect_retries,
read=read_retries,
backoff_factor=backoff_factor,
backoff_jitter=backoff_jitter,
status_forcelist=statuses,
allowed_methods=methods,
)
)

View File

@@ -95,8 +95,6 @@ class RemoteDBConnection(DBConnection):
""" """
from .table import RemoteTable from .table import RemoteTable
self._client.mount_retry_adapter_for_table(name)
# check if table exists # check if table exists
try: try:
self._client.post(f"/v1/table/{name}/describe/") self._client.post(f"/v1/table/{name}/describe/")
@@ -118,7 +116,6 @@ class RemoteDBConnection(DBConnection):
schema: Optional[Union[pa.Schema, LanceModel]] = None, schema: Optional[Union[pa.Schema, LanceModel]] = None,
on_bad_vectors: str = "error", on_bad_vectors: str = "error",
fill_value: float = 0.0, fill_value: float = 0.0,
mode: Optional[str] = None,
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None, embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
) -> Table: ) -> Table:
"""Create a [Table][lancedb.table.Table] in the database. """Create a [Table][lancedb.table.Table] in the database.
@@ -216,13 +213,11 @@ class RemoteDBConnection(DBConnection):
if data is None and schema is None: if data is None and schema is None:
raise ValueError("Either data or schema must be provided.") raise ValueError("Either data or schema must be provided.")
if embedding_functions is not None: if embedding_functions is not None:
logging.warning( raise NotImplementedError(
"embedding_functions is not yet supported on LanceDB Cloud." "embedding_functions is not supported for remote databases."
"Please vote https://github.com/lancedb/lancedb/issues/626 " "Please vote https://github.com/lancedb/lancedb/issues/626 "
"for this feature." "for this feature."
) )
if mode is not None:
logging.warning("mode is not yet supported on LanceDB Cloud.")
if inspect.isclass(schema) and issubclass(schema, LanceModel): if inspect.isclass(schema) and issubclass(schema, LanceModel):
# convert LanceModel to pyarrow schema # convert LanceModel to pyarrow schema

View File

@@ -11,7 +11,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import uuid import uuid
from functools import cached_property from functools import cached_property
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
@@ -20,11 +19,10 @@ import pyarrow as pa
from lance import json_to_schema from lance import json_to_schema
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
from lancedb.merge import LanceMergeInsertBuilder
from ..query import LanceVectorQueryBuilder from ..query import LanceVectorQueryBuilder
from ..table import Query, Table, _sanitize_data from ..table import Query, Table, _sanitize_data
from ..util import inf_vector_column_query, value_to_sql from ..util import value_to_sql
from .arrow import to_ipc_binary from .arrow import to_ipc_binary
from .client import ARROW_STREAM_CONTENT_TYPE from .client import ARROW_STREAM_CONTENT_TYPE
from .db import RemoteDBConnection from .db import RemoteDBConnection
@@ -38,9 +36,6 @@ class RemoteTable(Table):
def __repr__(self) -> str: def __repr__(self) -> str:
return f"RemoteTable({self._conn.db_name}.{self._name})" return f"RemoteTable({self._conn.db_name}.{self._name})"
def __len__(self) -> int:
self.count_rows(None)
@cached_property @cached_property
def schema(self) -> pa.Schema: def schema(self) -> pa.Schema:
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#) """The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
@@ -58,17 +53,17 @@ class RemoteTable(Table):
return resp["version"] return resp["version"]
def to_arrow(self) -> pa.Table: def to_arrow(self) -> pa.Table:
"""to_arrow() is not yet supported on LanceDB cloud.""" """to_arrow() is not supported on the LanceDB cloud"""
raise NotImplementedError("to_arrow() is not yet supported on LanceDB cloud.") raise NotImplementedError("to_arrow() is not supported on the LanceDB cloud")
def to_pandas(self): def to_pandas(self):
"""to_pandas() is not yet supported on LanceDB cloud.""" """to_pandas() is not supported on the LanceDB cloud"""
return NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.") return NotImplementedError("to_pandas() is not supported on the LanceDB cloud")
def create_scalar_index(self, *args, **kwargs): def create_scalar_index(self, *args, **kwargs):
"""Creates a scalar index""" """Creates a scalar index"""
return NotImplementedError( return NotImplementedError(
"create_scalar_index() is not yet supported on LanceDB cloud." "create_scalar_index() is not supported on the LanceDB cloud"
) )
def create_index( def create_index(
@@ -76,10 +71,6 @@ class RemoteTable(Table):
metric="L2", metric="L2",
vector_column_name: str = VECTOR_COLUMN_NAME, vector_column_name: str = VECTOR_COLUMN_NAME,
index_cache_size: Optional[int] = None, index_cache_size: Optional[int] = None,
num_partitions: Optional[int] = None,
num_sub_vectors: Optional[int] = None,
replace: Optional[bool] = None,
accelerator: Optional[str] = None,
): ):
"""Create an index on the table. """Create an index on the table.
Currently, the only parameters that matter are Currently, the only parameters that matter are
@@ -113,28 +104,6 @@ class RemoteTable(Table):
... ) ... )
>>> table.create_index("L2", "vector") # doctest: +SKIP >>> table.create_index("L2", "vector") # doctest: +SKIP
""" """
if num_partitions is not None:
logging.warning(
"num_partitions is not supported on LanceDB cloud."
"This parameter will be tuned automatically."
)
if num_sub_vectors is not None:
logging.warning(
"num_sub_vectors is not supported on LanceDB cloud."
"This parameter will be tuned automatically."
)
if accelerator is not None:
logging.warning(
"GPU accelerator is not yet supported on LanceDB cloud."
"If you have 100M+ vectors to index,"
"please contact us at contact@lancedb.com"
)
if replace is not None:
logging.warning(
"replace is not supported on LanceDB cloud."
"Existing indexes will always be replaced."
)
index_type = "vector" index_type = "vector"
data = { data = {
@@ -198,9 +167,7 @@ class RemoteTable(Table):
) )
def search( def search(
self, self, query: Union[VEC, str], vector_column_name: str = VECTOR_COLUMN_NAME
query: Union[VEC, str],
vector_column_name: Optional[str] = None,
) -> LanceVectorQueryBuilder: ) -> LanceVectorQueryBuilder:
"""Create a search query to find the nearest neighbors """Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search] of the given query vector. We currently support [vector search][search]
@@ -219,7 +186,7 @@ class RemoteTable(Table):
... ] ... ]
>>> table = db.create_table("my_table", data) # doctest: +SKIP >>> table = db.create_table("my_table", data) # doctest: +SKIP
>>> query = [0.4, 1.4, 2.4] >>> query = [0.4, 1.4, 2.4]
>>> (table.search(query) # doctest: +SKIP >>> (table.search(query, vector_column_name="vector") # doctest: +SKIP
... .where("original_width > 1000", prefilter=True) # doctest: +SKIP ... .where("original_width > 1000", prefilter=True) # doctest: +SKIP
... .select(["caption", "original_width"]) # doctest: +SKIP ... .select(["caption", "original_width"]) # doctest: +SKIP
... .limit(2) # doctest: +SKIP ... .limit(2) # doctest: +SKIP
@@ -238,14 +205,9 @@ class RemoteTable(Table):
- If None then the select/where/limit clauses are applied to filter - If None then the select/where/limit clauses are applied to filter
the table the table
vector_column_name: str, optional vector_column_name: str
The name of the vector column to search. The name of the vector column to search.
*default "vector"*
- If not specified then the vector column is inferred from
the table schema
- If the table has multiple vector columns then the *vector_column_name*
needs to be specified. Otherwise, an error is raised.
Returns Returns
------- -------
@@ -260,8 +222,6 @@ class RemoteTable(Table):
- and also the "_distance" column which is the distance between the query - and also the "_distance" column which is the distance between the query
vector and the returned vector. vector and the returned vector.
""" """
if vector_column_name is None:
vector_column_name = inf_vector_column_query(self.schema)
return LanceVectorQueryBuilder(self, query, vector_column_name) return LanceVectorQueryBuilder(self, query, vector_column_name)
def _execute_query(self, query: Query) -> pa.Table: def _execute_query(self, query: Query) -> pa.Table:
@@ -284,50 +244,9 @@ class RemoteTable(Table):
result = self._conn._client.query(self._name, query) result = self._conn._client.query(self._name, query)
return result.to_arrow() return result.to_arrow()
def _do_merge( def _do_merge(self, *_args):
self, """_do_merge() is not supported on the LanceDB cloud yet"""
merge: LanceMergeInsertBuilder, return NotImplementedError("_do_merge() is not supported on the LanceDB cloud")
new_data: DATA,
on_bad_vectors: str,
fill_value: float,
):
data = _sanitize_data(
new_data,
self.schema,
metadata=None,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
payload = to_ipc_binary(data)
params = {}
if len(merge._on) != 1:
raise ValueError(
"RemoteTable only supports a single on key in merge_insert"
)
params["on"] = merge._on[0]
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
if merge._when_matched_update_all_condition is not None:
params[
"when_matched_update_all_filt"
] = merge._when_matched_update_all_condition
params["when_not_matched_insert_all"] = str(
merge._when_not_matched_insert_all
).lower()
params["when_not_matched_by_source_delete"] = str(
merge._when_not_matched_by_source_delete
).lower()
if merge._when_not_matched_by_source_condition is not None:
params[
"when_not_matched_by_source_delete_filt"
] = merge._when_not_matched_by_source_condition
self._conn._client.post(
f"/v1/table/{self._name}/merge_insert/",
data=payload,
params=params,
content_type=ARROW_STREAM_CONTENT_TYPE,
)
def delete(self, predicate: str): def delete(self, predicate: str):
"""Delete rows from the table. """Delete rows from the table.
@@ -440,25 +359,6 @@ class RemoteTable(Table):
payload = {"predicate": where, "updates": updates} payload = {"predicate": where, "updates": updates}
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload) self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
def cleanup_old_versions(self, *_):
"""cleanup_old_versions() is not supported on the LanceDB cloud"""
raise NotImplementedError(
"cleanup_old_versions() is not supported on the LanceDB cloud"
)
def compact_files(self, *_):
"""compact_files() is not supported on the LanceDB cloud"""
raise NotImplementedError(
"compact_files() is not supported on the LanceDB cloud"
)
def count_rows(self, filter: Optional[str] = None) -> int:
# payload = {"filter": filter}
# self._conn._client.post(f"/v1/table/{self._name}/count_rows/", data=payload)
return NotImplementedError(
"count_rows() is not yet supported on the LanceDB cloud"
)
def add_index(tbl: pa.Table, i: int) -> pa.Table: def add_index(tbl: pa.Table, i: int) -> pa.Table:
return tbl.add_column( return tbl.add_column(

View File

@@ -1,15 +1,11 @@
from .base import Reranker from .base import Reranker
from .cohere import CohereReranker from .cohere import CohereReranker
from .colbert import ColbertReranker
from .cross_encoder import CrossEncoderReranker from .cross_encoder import CrossEncoderReranker
from .linear_combination import LinearCombinationReranker from .linear_combination import LinearCombinationReranker
from .openai import OpenaiReranker
__all__ = [ __all__ = [
"Reranker", "Reranker",
"CrossEncoderReranker", "CrossEncoderReranker",
"CohereReranker", "CohereReranker",
"LinearCombinationReranker", "LinearCombinationReranker",
"OpenaiReranker",
"ColbertReranker",
] ]

View File

@@ -1,8 +1,12 @@
import typing
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import numpy as np import numpy as np
import pyarrow as pa import pyarrow as pa
if typing.TYPE_CHECKING:
import lancedb
class Reranker(ABC): class Reranker(ABC):
def __init__(self, return_score: str = "relevance"): def __init__(self, return_score: str = "relevance"):
@@ -26,7 +30,7 @@ class Reranker(ABC):
@abstractmethod @abstractmethod
def rerank_hybrid( def rerank_hybrid(
query: str, query_builder: "lancedb.HybridQueryBuilder",
vector_results: pa.Table, vector_results: pa.Table,
fts_results: pa.Table, fts_results: pa.Table,
): ):
@@ -37,8 +41,8 @@ class Reranker(ABC):
Parameters Parameters
---------- ----------
query : str query_builder : "lancedb.HybridQueryBuilder"
The input query The query builder object that was used to generate the results
vector_results : pa.Table vector_results : pa.Table
The results from the vector search The results from the vector search
fts_results : pa.Table fts_results : pa.Table
@@ -46,6 +50,36 @@ class Reranker(ABC):
""" """
pass pass
def rerank_vector(
query_builder: "lancedb.VectorQueryBuilder", vector_results: pa.Table
):
"""
Rerank function receives the individual results from the vector search.
This isn't mandatory to implement
Parameters
----------
query_builder : "lancedb.VectorQueryBuilder"
The query builder object that was used to generate the results
vector_results : pa.Table
The results from the vector search
"""
raise NotImplementedError("Vector Reranking is not implemented")
def rerank_fts(query_builder: "lancedb.FTSQueryBuilder", fts_results: pa.Table):
"""
Rerank function receives the individual results from the FTS search.
This isn't mandatory to implement
Parameters
----------
query_builder : "lancedb.FTSQueryBuilder"
The query builder object that was used to generate the results
fts_results : pa.Table
The results from the FTS search
"""
raise NotImplementedError("FTS Reranking is not implemented")
def merge_results(self, vector_results: pa.Table, fts_results: pa.Table): def merge_results(self, vector_results: pa.Table, fts_results: pa.Table):
""" """
Merge the results from the vector and FTS search. This is a vanilla merging Merge the results from the vector and FTS search. This is a vanilla merging

View File

@@ -1,12 +1,16 @@
import os import os
import typing
from functools import cached_property from functools import cached_property
from typing import Union from typing import Union
import pyarrow as pa import pyarrow as pa
from ..util import attempt_import_or_raise from ..util import safe_import
from .base import Reranker from .base import Reranker
if typing.TYPE_CHECKING:
import lancedb
class CohereReranker(Reranker): class CohereReranker(Reranker):
""" """
@@ -41,7 +45,7 @@ class CohereReranker(Reranker):
@cached_property @cached_property
def _client(self): def _client(self):
cohere = attempt_import_or_raise("cohere") cohere = safe_import("cohere")
if os.environ.get("COHERE_API_KEY") is None and self.api_key is None: if os.environ.get("COHERE_API_KEY") is None and self.api_key is None:
raise ValueError( raise ValueError(
"COHERE_API_KEY not set. Either set it in your environment or \ "COHERE_API_KEY not set. Either set it in your environment or \
@@ -51,14 +55,14 @@ class CohereReranker(Reranker):
def rerank_hybrid( def rerank_hybrid(
self, self,
query: str, query_builder: "lancedb.HybridQueryBuilder",
vector_results: pa.Table, vector_results: pa.Table,
fts_results: pa.Table, fts_results: pa.Table,
): ):
combined_results = self.merge_results(vector_results, fts_results) combined_results = self.merge_results(vector_results, fts_results)
docs = combined_results[self.column].to_pylist() docs = combined_results[self.column].to_pylist()
results = self._client.rerank( results = self._client.rerank(
query=query, query=query_builder._query,
documents=docs, documents=docs,
top_n=self.top_n, top_n=self.top_n,
model=self.model_name, model=self.model_name,

View File

@@ -1,109 +0,0 @@
from functools import cached_property
import pyarrow as pa
from ..util import attempt_import_or_raise
from .base import Reranker
class ColbertReranker(Reranker):
"""
Reranks the results using the ColBERT model.
Parameters
----------
model_name : str, default "colbert-ir/colbertv2.0"
The name of the cross encoder model to use.
column : str, default "text"
The name of the column to use as input to the cross encoder model.
return_score : str, default "relevance"
options are "relevance" or "all". Only "relevance" is supported for now.
"""
def __init__(
self,
model_name: str = "colbert-ir/colbertv2.0",
column: str = "text",
return_score="relevance",
):
super().__init__(return_score)
self.model_name = model_name
self.column = column
self.torch = attempt_import_or_raise(
"torch"
) # import here for faster ops later
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
docs = combined_results[self.column].to_pylist()
tokenizer, model = self._model
# Encode the query
query_encoding = tokenizer(query, return_tensors="pt")
query_embedding = model(**query_encoding).last_hidden_state.mean(dim=1)
scores = []
# Get score for each document
for document in docs:
document_encoding = tokenizer(
document, return_tensors="pt", truncation=True, max_length=512
)
document_embedding = model(**document_encoding).last_hidden_state
# Calculate MaxSim score
score = self.maxsim(query_embedding.unsqueeze(0), document_embedding)
scores.append(score.item())
# replace the self.column column with the docs
combined_results = combined_results.drop(self.column)
combined_results = combined_results.append_column(
self.column, pa.array(docs, type=pa.string())
)
# add the scores
combined_results = combined_results.append_column(
"_relevance_score", pa.array(scores, type=pa.float32())
)
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
elif self.score == "all":
raise NotImplementedError(
"OpenAI Reranker does not support score='all' yet"
)
combined_results = combined_results.sort_by(
[("_relevance_score", "descending")]
)
return combined_results
@cached_property
def _model(self):
transformers = attempt_import_or_raise("transformers")
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
model = transformers.AutoModel.from_pretrained(self.model_name)
return tokenizer, model
def maxsim(self, query_embedding, document_embedding):
# Expand dimensions for broadcasting
# Query: [batch, length, size] -> [batch, query, 1, size]
# Document: [batch, length, size] -> [batch, 1, length, size]
expanded_query = query_embedding.unsqueeze(2)
expanded_doc = document_embedding.unsqueeze(1)
# Compute cosine similarity across the embedding dimension
sim_matrix = self.torch.nn.functional.cosine_similarity(
expanded_query, expanded_doc, dim=-1
)
# Take the maximum similarity for each query token (across all document tokens)
# sim_matrix shape: [batch_size, query_length, doc_length]
max_sim_scores, _ = self.torch.max(sim_matrix, dim=2)
# Average these maximum scores across all query tokens
avg_max_sim = self.torch.mean(max_sim_scores, dim=1)
return avg_max_sim

View File

@@ -1,11 +1,15 @@
import typing
from functools import cached_property from functools import cached_property
from typing import Union from typing import Union
import pyarrow as pa import pyarrow as pa
from ..util import attempt_import_or_raise from ..util import safe_import
from .base import Reranker from .base import Reranker
if typing.TYPE_CHECKING:
import lancedb
class CrossEncoderReranker(Reranker): class CrossEncoderReranker(Reranker):
""" """
@@ -32,7 +36,7 @@ class CrossEncoderReranker(Reranker):
return_score="relevance", return_score="relevance",
): ):
super().__init__(return_score) super().__init__(return_score)
torch = attempt_import_or_raise("torch") torch = safe_import("torch")
self.model_name = model_name self.model_name = model_name
self.column = column self.column = column
self.device = device self.device = device
@@ -41,20 +45,20 @@ class CrossEncoderReranker(Reranker):
@cached_property @cached_property
def model(self): def model(self):
sbert = attempt_import_or_raise("sentence_transformers") sbert = safe_import("sentence_transformers")
cross_encoder = sbert.CrossEncoder(self.model_name) cross_encoder = sbert.CrossEncoder(self.model_name)
return cross_encoder return cross_encoder
def rerank_hybrid( def rerank_hybrid(
self, self,
query: str, query_builder: "lancedb.HybridQueryBuilder",
vector_results: pa.Table, vector_results: pa.Table,
fts_results: pa.Table, fts_results: pa.Table,
): ):
combined_results = self.merge_results(vector_results, fts_results) combined_results = self.merge_results(vector_results, fts_results)
passages = combined_results[self.column].to_pylist() passages = combined_results[self.column].to_pylist()
cross_inp = [[query, passage] for passage in passages] cross_inp = [[query_builder._query, passage] for passage in passages]
cross_scores = self.model.predict(cross_inp) cross_scores = self.model.predict(cross_inp)
combined_results = combined_results.append_column( combined_results = combined_results.append_column(
"_relevance_score", pa.array(cross_scores, type=pa.float32()) "_relevance_score", pa.array(cross_scores, type=pa.float32())

View File

@@ -36,7 +36,7 @@ class LinearCombinationReranker(Reranker):
def rerank_hybrid( def rerank_hybrid(
self, self,
query: str, # noqa: F821 query_builder: "lancedb.HybridQueryBuilder", # noqa: F821
vector_results: pa.Table, vector_results: pa.Table,
fts_results: pa.Table, fts_results: pa.Table,
): ):

View File

@@ -1,104 +0,0 @@
import json
import os
from functools import cached_property
from typing import Optional
import pyarrow as pa
from ..util import attempt_import_or_raise
from .base import Reranker
class OpenaiReranker(Reranker):
"""
Reranks the results using the OpenAI API.
WARNING: This is a prompt based reranker that uses chat model that is
not a dedicated reranker API. This should be treated as experimental.
Parameters
----------
model_name : str, default "gpt-4-turbo-preview"
The name of the cross encoder model to use.
column : str, default "text"
The name of the column to use as input to the cross encoder model.
return_score : str, default "relevance"
options are "relevance" or "all". Only "relevance" is supported for now.
api_key : str, default None
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
"""
def __init__(
self,
model_name: str = "gpt-4-turbo-preview",
column: str = "text",
return_score="relevance",
api_key: Optional[str] = None,
):
super().__init__(return_score)
self.model_name = model_name
self.column = column
self.api_key = api_key
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
docs = combined_results[self.column].to_pylist()
response = self._client.chat.completions.create(
model=self.model_name,
response_format={"type": "json_object"},
temperature=0,
messages=[
{
"role": "system",
"content": "You are an expert relevance ranker. Given a list of\
documents and a query, your job is to determine the relevance\
each document is for answering the query. Your output is JSON,\
which is a list of documents. Each document has two fields,\
content and relevance_score. relevance_score is from 0.0 to\
1.0 indicating the relevance of the text to the given query.\
Make sure to include all documents in the response.",
},
{"role": "user", "content": f"Query: {query} Docs: {docs}"},
],
)
results = json.loads(response.choices[0].message.content)["documents"]
docs, scores = list(
zip(*[(result["content"], result["relevance_score"]) for result in results])
) # tuples
# replace the self.column column with the docs
combined_results = combined_results.drop(self.column)
combined_results = combined_results.append_column(
self.column, pa.array(docs, type=pa.string())
)
# add the scores
combined_results = combined_results.append_column(
"_relevance_score", pa.array(scores, type=pa.float32())
)
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
elif self.score == "all":
raise NotImplementedError(
"OpenAI Reranker does not support score='all' yet"
)
combined_results = combined_results.sort_by(
[("_relevance_score", "descending")]
)
return combined_results
@cached_property
def _client(self):
openai = attempt_import_or_raise(
"openai"
) # TODO: force version or handle versions < 1.0
if os.environ.get("OPENAI_API_KEY") is None and self.api_key is None:
raise ValueError(
"OPENAI_API_KEY not set. Either set it in your environment or \
pass it as `api_key` argument to the CohereReranker."
)
return openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY") or self.api_key)

View File

@@ -14,10 +14,7 @@
from __future__ import annotations from __future__ import annotations
import inspect import inspect
import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import timedelta
from functools import cached_property from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
@@ -36,23 +33,23 @@ from .pydantic import LanceModel, model_to_dict
from .query import LanceQueryBuilder, Query from .query import LanceQueryBuilder, Query
from .util import ( from .util import (
fs_from_uri, fs_from_uri,
inf_vector_column_query,
join_uri, join_uri,
safe_import_pandas, safe_import,
safe_import_polars,
value_to_sql, value_to_sql,
) )
from .utils.events import register_event from .utils.events import register_event
if TYPE_CHECKING: if TYPE_CHECKING:
from datetime import timedelta
import PIL import PIL
from lance.dataset import CleanupStats, ReaderLike from lance.dataset import CleanupStats, ReaderLike
from .db import LanceDBConnection from .db import LanceDBConnection
pd = safe_import_pandas() pd = safe_import("pandas")
pl = safe_import_polars() pl = safe_import("polars")
def _sanitize_data( def _sanitize_data(
@@ -178,18 +175,6 @@ class Table(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def count_rows(self, filter: Optional[str] = None) -> int:
"""
Count the number of rows in the table.
Parameters
----------
filter: str, optional
A SQL where clause to filter the rows to count.
"""
raise NotImplementedError
def to_pandas(self) -> "pd.DataFrame": def to_pandas(self) -> "pd.DataFrame":
"""Return the table as a pandas DataFrame. """Return the table as a pandas DataFrame.
@@ -313,7 +298,7 @@ class Table(ABC):
import lance import lance
dataset = lance.dataset("./images.lance") dataset = lance.dataset("/tmp/images.lance")
dataset.create_scalar_index("category") dataset.create_scalar_index("category")
""" """
raise NotImplementedError raise NotImplementedError
@@ -406,15 +391,13 @@ class Table(ABC):
2 3 y 2 3 y
3 4 z 3 4 z
""" """
on = [on] if isinstance(on, str) else list(on.iter())
return LanceMergeInsertBuilder(self, on) return LanceMergeInsertBuilder(self, on)
@abstractmethod @abstractmethod
def search( def search(
self, self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None, vector_column_name: str = VECTOR_COLUMN_NAME,
query_type: str = "auto", query_type: str = "auto",
) -> LanceQueryBuilder: ) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors """Create a search query to find the nearest neighbors
@@ -434,7 +417,7 @@ class Table(ABC):
... ] ... ]
>>> table = db.create_table("my_table", data) >>> table = db.create_table("my_table", data)
>>> query = [0.4, 1.4, 2.4] >>> query = [0.4, 1.4, 2.4]
>>> (table.search(query) >>> (table.search(query, vector_column_name="vector")
... .where("original_width > 1000", prefilter=True) ... .where("original_width > 1000", prefilter=True)
... .select(["caption", "original_width"]) ... .select(["caption", "original_width"])
... .limit(2) ... .limit(2)
@@ -453,19 +436,12 @@ class Table(ABC):
- If None then the select/where/limit clauses are applied to filter - If None then the select/where/limit clauses are applied to filter
the table the table
vector_column_name: str, optional vector_column_name: str
The name of the vector column to search. The name of the vector column to search.
*default "vector"*
The vector column needs to be a pyarrow fixed size list type
- If not specified then the vector column is inferred from
the table schema
- If the table has multiple vector columns then the *vector_column_name*
needs to be specified. Otherwise, an error is raised.
query_type: str query_type: str
*default "auto"*. *default "auto"*.
Acceptable types are: "vector", "fts", "hybrid", or "auto" Acceptable types are: "vector", "fts", or "auto"
- If "auto" then the query type is inferred from the query; - If "auto" then the query type is inferred from the query;
@@ -502,8 +478,8 @@ class Table(ABC):
self, self,
merge: LanceMergeInsertBuilder, merge: LanceMergeInsertBuilder,
new_data: DATA, new_data: DATA,
on_bad_vectors: str, *,
fill_value: float, schema: Optional[pa.Schema] = None,
): ):
pass pass
@@ -614,192 +590,24 @@ class Table(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def cleanup_old_versions(
self,
older_than: Optional[timedelta] = None,
*,
delete_unverified: bool = False,
) -> CleanupStats:
"""
Clean up old versions of the table, freeing disk space.
Note: This function is not available in LanceDb Cloud (since LanceDb
Cloud manages cleanup for you automatically)
Parameters
----------
older_than: timedelta, default None
The minimum age of the version to delete. If None, then this defaults
to two weeks.
delete_unverified: bool, default False
Because they may be part of an in-progress transaction, files newer
than 7 days old are not deleted by default. If you are sure that
there are no in-progress transactions, then you can set this to True
to delete all files older than `older_than`.
Returns
-------
CleanupStats
The stats of the cleanup operation, including how many bytes were
freed.
"""
@abstractmethod
def compact_files(self, *args, **kwargs):
"""
Run the compaction process on the table.
Note: This function is not available in LanceDb Cloud (since LanceDb
Cloud manages compaction for you automatically)
This can be run after making several small appends to optimize the table
for faster reads.
Arguments are passed onto :meth:`lance.dataset.DatasetOptimizer.compact_files`.
For most cases, the default should be fine.
"""
class _LanceDatasetRef(ABC):
@property
@abstractmethod
def dataset(self) -> LanceDataset:
pass
@property
@abstractmethod
def dataset_mut(self) -> LanceDataset:
pass
@dataclass
class _LanceLatestDatasetRef(_LanceDatasetRef):
"""Reference to the latest version of a LanceDataset."""
uri: str
read_consistency_interval: Optional[timedelta] = None
last_consistency_check: Optional[float] = None
_dataset: Optional[LanceDataset] = None
@property
def dataset(self) -> LanceDataset:
if not self._dataset:
self._dataset = lance.dataset(self.uri)
self.last_consistency_check = time.monotonic()
elif self.read_consistency_interval is not None:
now = time.monotonic()
diff = timedelta(seconds=now - self.last_consistency_check)
if (
self.last_consistency_check is None
or diff > self.read_consistency_interval
):
self._dataset = self._dataset.checkout_version(
self._dataset.latest_version
)
self.last_consistency_check = time.monotonic()
return self._dataset
@dataset.setter
def dataset(self, value: LanceDataset):
self._dataset = value
self.last_consistency_check = time.monotonic()
@property
def dataset_mut(self) -> LanceDataset:
return self.dataset
@dataclass
class _LanceTimeTravelRef(_LanceDatasetRef):
uri: str
version: int
_dataset: Optional[LanceDataset] = None
@property
def dataset(self) -> LanceDataset:
if not self._dataset:
self._dataset = lance.dataset(self.uri, version=self.version)
return self._dataset
@dataset.setter
def dataset(self, value: LanceDataset):
self._dataset = value
self.version = value.version
@property
def dataset_mut(self) -> LanceDataset:
raise ValueError(
"Cannot mutate table reference fixed at version "
f"{self.version}. Call checkout_latest() to get a mutable "
"table reference."
)
class LanceTable(Table): class LanceTable(Table):
""" """
A table in a LanceDB database. A table in a LanceDB database.
This can be opened in two modes: standard and time-travel.
Standard mode is the default. In this mode, the table is mutable and tracks
the latest version of the table. The level of read consistency is controlled
by the `read_consistency_interval` parameter on the connection.
Time-travel mode is activated by specifying a version number. In this mode,
the table is immutable and fixed to a specific version. This is useful for
querying historical versions of the table.
""" """
def __init__( def __init__(self, connection: "LanceDBConnection", name: str, version: int = None):
self,
connection: "LanceDBConnection",
name: str,
version: Optional[int] = None,
):
self._conn = connection self._conn = connection
self.name = name self.name = name
self._version = version
if version is not None: def _reset_dataset(self, version=None):
self._ref = _LanceTimeTravelRef( try:
uri=self._dataset_uri, if "_dataset" in self.__dict__:
version=version, del self.__dict__["_dataset"]
) self._version = version
else: except AttributeError:
self._ref = _LanceLatestDatasetRef( pass
uri=self._dataset_uri,
read_consistency_interval=connection.read_consistency_interval,
)
@classmethod
def open(cls, db, name, **kwargs):
tbl = cls(db, name, **kwargs)
fs, path = fs_from_uri(tbl._dataset_uri)
file_info = fs.get_file_info(path)
if file_info.type != pa.fs.FileType.Directory:
raise FileNotFoundError(
f"Table {name} does not exist."
f"Please first call db.create_table({name}, data)"
)
register_event("open_table")
return tbl
@property
def _dataset_uri(self) -> str:
return join_uri(self._conn.uri, f"{self.name}.lance")
@property
def _dataset(self) -> LanceDataset:
return self._ref.dataset
@property
def _dataset_mut(self) -> LanceDataset:
return self._ref.dataset_mut
def to_lance(self) -> LanceDataset:
"""Return the LanceDataset backing this table."""
return self._dataset
@property @property
def schema(self) -> pa.Schema: def schema(self) -> pa.Schema:
@@ -827,9 +635,6 @@ class LanceTable(Table):
keep writing to the dataset starting from an old version, then use keep writing to the dataset starting from an old version, then use
the `restore` function. the `restore` function.
Calling this method will set the table into time-travel mode. If you
wish to return to standard mode, call `checkout_latest`.
Parameters Parameters
---------- ----------
version : int version : int
@@ -854,13 +659,15 @@ class LanceTable(Table):
vector type vector type
0 [1.1, 0.9] vector 0 [1.1, 0.9] vector
""" """
max_ver = self._dataset.latest_version max_ver = max([v["version"] for v in self._dataset.versions()])
if version < 1 or version > max_ver: if version < 1 or version > max_ver:
raise ValueError(f"Invalid version {version}") raise ValueError(f"Invalid version {version}")
self._reset_dataset(version=version)
try: try:
ds = self._dataset.checkout_version(version) # Accessing the property updates the cached value
except IOError as e: _ = self._dataset
except Exception as e:
if "not found" in str(e): if "not found" in str(e):
raise ValueError( raise ValueError(
f"Version {version} no longer exists. Was it cleaned up?" f"Version {version} no longer exists. Was it cleaned up?"
@@ -868,27 +675,6 @@ class LanceTable(Table):
else: else:
raise e raise e
self._ref = _LanceTimeTravelRef(
uri=self._dataset_uri,
version=version,
)
# We've already loaded the version so we can populate it directly.
self._ref.dataset = ds
def checkout_latest(self):
"""Checkout the latest version of the table. This is an in-place operation.
The table will be set back into standard mode, and will track the latest
version of the table.
"""
self.checkout(self._dataset.latest_version)
ds = self._ref.dataset
self._ref = _LanceLatestDatasetRef(
uri=self._dataset_uri,
read_consistency_interval=self._conn.read_consistency_interval,
)
self._ref.dataset = ds
def restore(self, version: int = None): def restore(self, version: int = None):
"""Restore a version of the table. This is an in-place operation. """Restore a version of the table. This is an in-place operation.
@@ -923,7 +709,7 @@ class LanceTable(Table):
>>> len(table.list_versions()) >>> len(table.list_versions())
4 4
""" """
max_ver = self._dataset.latest_version max_ver = max([v["version"] for v in self._dataset.versions()])
if version is None: if version is None:
version = self.version version = self.version
elif version < 1 or version > max_ver: elif version < 1 or version > max_ver:
@@ -931,30 +717,29 @@ class LanceTable(Table):
else: else:
self.checkout(version) self.checkout(version)
ds = self._dataset if version == max_ver:
# no-op if restoring the latest version
return
# no-op if restoring the latest version self._dataset.restore()
if version != max_ver: self._reset_dataset()
ds.restore()
self._ref = _LanceLatestDatasetRef(
uri=self._dataset_uri,
read_consistency_interval=self._conn.read_consistency_interval,
)
self._ref.dataset = ds
def count_rows(self, filter: Optional[str] = None) -> int: def count_rows(self, filter: Optional[str] = None) -> int:
"""
Count the number of rows in the table.
Parameters
----------
filter: str, optional
A SQL where clause to filter the rows to count.
"""
return self._dataset.count_rows(filter) return self._dataset.count_rows(filter)
def __len__(self): def __len__(self):
return self.count_rows() return self.count_rows()
def __repr__(self) -> str: def __repr__(self) -> str:
val = f'{self.__class__.__name__}(connection={self._conn!r}, name="{self.name}"' return f"LanceTable({self.name})"
if isinstance(self._ref, _LanceTimeTravelRef):
val += f", version={self._ref.version}"
val += ")"
return val
def __str__(self) -> str: def __str__(self) -> str:
return self.__repr__() return self.__repr__()
@@ -1004,6 +789,10 @@ class LanceTable(Table):
self.to_lance(), allow_pyarrow_filter=False, batch_size=batch_size self.to_lance(), allow_pyarrow_filter=False, batch_size=batch_size
) )
@property
def _dataset_uri(self) -> str:
return join_uri(self._conn.uri, f"{self.name}.lance")
def create_index( def create_index(
self, self,
metric="L2", metric="L2",
@@ -1015,7 +804,7 @@ class LanceTable(Table):
index_cache_size: Optional[int] = None, index_cache_size: Optional[int] = None,
): ):
"""Create an index on the table.""" """Create an index on the table."""
self._dataset_mut.create_index( self._dataset.create_index(
column=vector_column_name, column=vector_column_name,
index_type="IVF_PQ", index_type="IVF_PQ",
metric=metric, metric=metric,
@@ -1025,12 +814,11 @@ class LanceTable(Table):
accelerator=accelerator, accelerator=accelerator,
index_cache_size=index_cache_size, index_cache_size=index_cache_size,
) )
self._reset_dataset()
register_event("create_index") register_event("create_index")
def create_scalar_index(self, column: str, *, replace: bool = True): def create_scalar_index(self, column: str, *, replace: bool = True):
self._dataset_mut.create_scalar_index( self._dataset.create_scalar_index(column, index_type="BTREE", replace=replace)
column, index_type="BTREE", replace=replace
)
def create_fts_index( def create_fts_index(
self, self,
@@ -1073,6 +861,14 @@ class LanceTable(Table):
def _get_fts_index_path(self): def _get_fts_index_path(self):
return join_uri(self._dataset_uri, "_indices", "tantivy") return join_uri(self._dataset_uri, "_indices", "tantivy")
@cached_property
def _dataset(self) -> LanceDataset:
return lance.dataset(self._dataset_uri, version=self._version)
def to_lance(self) -> LanceDataset:
"""Return the LanceDataset backing this table."""
return self._dataset
def add( def add(
self, self,
data: DATA, data: DATA,
@@ -1111,11 +907,8 @@ class LanceTable(Table):
on_bad_vectors=on_bad_vectors, on_bad_vectors=on_bad_vectors,
fill_value=fill_value, fill_value=fill_value,
) )
# Access the dataset_mut property to ensure that the dataset is mutable. lance.write_dataset(data, self._dataset_uri, schema=self.schema, mode=mode)
self._ref.dataset_mut self._reset_dataset()
self._ref.dataset = lance.write_dataset(
data, self._dataset_uri, schema=self.schema, mode=mode
)
register_event("add") register_event("add")
def merge( def merge(
@@ -1176,9 +969,10 @@ class LanceTable(Table):
other_table = other_table.to_lance() other_table = other_table.to_lance()
if isinstance(other_table, LanceDataset): if isinstance(other_table, LanceDataset):
other_table = other_table.to_table() other_table = other_table.to_table()
self._ref.dataset = self._dataset_mut.merge( self._dataset.merge(
other_table, left_on=left_on, right_on=right_on, schema=schema other_table, left_on=left_on, right_on=right_on, schema=schema
) )
self._reset_dataset()
register_event("merge") register_event("merge")
@cached_property @cached_property
@@ -1199,7 +993,7 @@ class LanceTable(Table):
def search( def search(
self, self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None, query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None, vector_column_name: str = VECTOR_COLUMN_NAME,
query_type: str = "auto", query_type: str = "auto",
) -> LanceQueryBuilder: ) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors """Create a search query to find the nearest neighbors
@@ -1217,7 +1011,7 @@ class LanceTable(Table):
... ] ... ]
>>> table = db.create_table("my_table", data) >>> table = db.create_table("my_table", data)
>>> query = [0.4, 1.4, 2.4] >>> query = [0.4, 1.4, 2.4]
>>> (table.search(query) >>> (table.search(query, vector_column_name="vector")
... .where("original_width > 1000", prefilter=True) ... .where("original_width > 1000", prefilter=True)
... .select(["caption", "original_width"]) ... .select(["caption", "original_width"])
... .limit(2) ... .limit(2)
@@ -1236,17 +1030,8 @@ class LanceTable(Table):
- If None then the select/[where][sql]/limit clauses are applied - If None then the select/[where][sql]/limit clauses are applied
to filter the table to filter the table
vector_column_name: str, optional vector_column_name: str, default "vector"
The name of the vector column to search. The name of the vector column to search.
The vector column needs to be a pyarrow fixed size list type
*default "vector"*
- If not specified then the vector column is inferred from
the table schema
- If the table has multiple vector columns then the *vector_column_name*
needs to be specified. Otherwise, an error is raised.
query_type: str, default "auto" query_type: str, default "auto"
"vector", "fts", or "auto" "vector", "fts", or "auto"
If "auto" then the query type is inferred from the query; If "auto" then the query type is inferred from the query;
@@ -1264,8 +1049,6 @@ class LanceTable(Table):
and also the "_distance" column which is the distance between the query and also the "_distance" column which is the distance between the query
vector and the returned vector. vector and the returned vector.
""" """
if vector_column_name is None and query is not None:
vector_column_name = inf_vector_column_query(self.schema)
register_event("search_table") register_event("search_table")
return LanceQueryBuilder.create( return LanceQueryBuilder.create(
self, query, query_type, vector_column_name=vector_column_name self, query, query_type, vector_column_name=vector_column_name
@@ -1392,8 +1175,22 @@ class LanceTable(Table):
register_event("create_table") register_event("create_table")
return new_table return new_table
@classmethod
def open(cls, db, name):
tbl = cls(db, name)
fs, path = fs_from_uri(tbl._dataset_uri)
file_info = fs.get_file_info(path)
if file_info.type != pa.fs.FileType.Directory:
raise FileNotFoundError(
f"Table {name} does not exist."
f"Please first call db.create_table({name}, data)"
)
register_event("open_table")
return tbl
def delete(self, where: str): def delete(self, where: str):
self._dataset_mut.delete(where) self._dataset.delete(where)
def update( def update(
self, self,
@@ -1447,12 +1244,12 @@ class LanceTable(Table):
if values is not None: if values is not None:
values_sql = {k: value_to_sql(v) for k, v in values.items()} values_sql = {k: value_to_sql(v) for k, v in values.items()}
self._dataset_mut.update(values_sql, where) self.to_lance().update(values_sql, where)
self._reset_dataset()
register_event("update") register_event("update")
def _execute_query(self, query: Query) -> pa.Table: def _execute_query(self, query: Query) -> pa.Table:
ds = self.to_lance() ds = self.to_lance()
return ds.to_table( return ds.to_table(
columns=query.columns, columns=query.columns,
filter=query.filter, filter=query.filter,
@@ -1468,30 +1265,17 @@ class LanceTable(Table):
with_row_id=query.with_row_id, with_row_id=query.with_row_id,
) )
def _do_merge( def _do_merge(self, merge: LanceMergeInsertBuilder, new_data: DATA, *, schema=None):
self,
merge: LanceMergeInsertBuilder,
new_data: DATA,
on_bad_vectors: str,
fill_value: float,
):
new_data = _sanitize_data(
new_data,
self.schema,
metadata=self.schema.metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
ds = self.to_lance() ds = self.to_lance()
builder = ds.merge_insert(merge._on) builder = ds.merge_insert(merge._on)
if merge._when_matched_update_all: if merge._when_matched_update_all:
builder.when_matched_update_all(merge._when_matched_update_all_condition) builder.when_matched_update_all()
if merge._when_not_matched_insert_all: if merge._when_not_matched_insert_all:
builder.when_not_matched_insert_all() builder.when_not_matched_insert_all()
if merge._when_not_matched_by_source_delete: if merge._when_not_matched_by_source_delete:
cond = merge._when_not_matched_by_source_condition cond = merge._when_not_matched_by_source_condition
builder.when_not_matched_by_source_delete(cond) builder.when_not_matched_by_source_delete(cond)
builder.execute(new_data) builder.execute(new_data, schema=schema)
def cleanup_old_versions( def cleanup_old_versions(
self, self,
@@ -1530,9 +1314,8 @@ class LanceTable(Table):
This can be run after making several small appends to optimize the table This can be run after making several small appends to optimize the table
for faster reads. for faster reads.
Arguments are passed onto `lance.dataset.DatasetOptimizer.compact_files`. Arguments are passed onto :meth:`lance.dataset.DatasetOptimizer.compact_files`.
(see Lance documentation for more details) For most cases, the default For most cases, the default should be fine.
should be fine.
""" """
return self.to_lance().optimize.compact_files(*args, **kwargs) return self.to_lance().optimize.compact_files(*args, **kwargs)

View File

@@ -20,7 +20,6 @@ from typing import Tuple, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import numpy as np import numpy as np
import pyarrow as pa
import pyarrow.fs as pa_fs import pyarrow.fs as pa_fs
@@ -116,7 +115,7 @@ def join_uri(base: Union[str, pathlib.Path], *parts: str) -> str:
return "/".join([p.rstrip("/") for p in [base, *parts]]) return "/".join([p.rstrip("/") for p in [base, *parts]])
def attempt_import_or_raise(module: str, mitigation=None): def safe_import(module: str, mitigation=None):
""" """
Import the specified module. If the module is not installed, Import the specified module. If the module is not installed,
raise an ImportError with a helpful message. raise an ImportError with a helpful message.
@@ -135,62 +134,6 @@ def attempt_import_or_raise(module: str, mitigation=None):
raise ImportError(f"Please install {mitigation or module}") raise ImportError(f"Please install {mitigation or module}")
def safe_import_pandas():
try:
import pandas as pd
return pd
except ImportError:
return None
def safe_import_polars():
try:
import polars as pl
return pl
except ImportError:
return None
def inf_vector_column_query(schema: pa.Schema) -> str:
"""
Get the vector column name
Parameters
----------
schema : pa.Schema
The schema of the vector column.
Returns
-------
str: the vector column name.
"""
vector_col_name = ""
vector_col_count = 0
for field_name in schema.names:
field = schema.field(field_name)
if pa.types.is_fixed_size_list(field.type) and pa.types.is_floating(
field.type.value_type
):
vector_col_count += 1
if vector_col_count > 1:
raise ValueError(
"Schema has more than one vector column. "
"Please specify the vector column name "
"for vector search"
)
break
elif vector_col_count == 1:
vector_col_name = field_name
if vector_col_count == 0:
raise ValueError(
"There is no vector column in the data. "
"Please specify the vector column name for vector search"
)
return vector_col_name
@singledispatch @singledispatch
def value_to_sql(value): def value_to_sql(value):
raise NotImplementedError("SQL conversion is not implemented for this type") raise NotImplementedError("SQL conversion is not implemented for this type")

View File

@@ -1,9 +1,9 @@
[project] [project]
name = "lancedb" name = "lancedb"
version = "0.5.5" version = "0.5.1"
dependencies = [ dependencies = [
"deprecation", "deprecation",
"pylance==0.9.15", "pylance==0.9.11",
"ratelimiter~=1.0", "ratelimiter~=1.0",
"retry>=0.9.2", "retry>=0.9.2",
"tqdm>=4.27.0", "tqdm>=4.27.0",
@@ -48,7 +48,7 @@ classifiers = [
repository = "https://github.com/lancedb/lancedb" repository = "https://github.com/lancedb/lancedb"
[project.optional-dependencies] [project.optional-dependencies]
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars>=0.19"] tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars"]
dev = ["ruff", "pre-commit"] dev = ["ruff", "pre-commit"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"] clip = ["torch", "pillow", "open-clip"]

View File

@@ -88,7 +88,6 @@ def test_embedding_function(tmp_path):
assert np.allclose(actual, expected) assert np.allclose(actual, expected)
@pytest.mark.slow
def test_embedding_function_rate_limit(tmp_path): def test_embedding_function_rate_limit(tmp_path):
def _get_schema_from_model(model): def _get_schema_from_model(model):
class Schema(LanceModel): class Schema(LanceModel):

View File

@@ -23,6 +23,11 @@ import lancedb
from lancedb.embeddings import get_registry from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector from lancedb.pydantic import LanceModel, Vector
try:
if importlib.util.find_spec("mlx.core") is not None:
_mlx = True
except ImportError:
_mlx = None
# These are integration tests for embedding functions. # These are integration tests for embedding functions.
# They are slow because they require downloading models # They are slow because they require downloading models
# or connection to external api # or connection to external api
@@ -69,14 +74,10 @@ def test_basic_text_embeddings(alias, tmp_path):
) )
query = "greetings" query = "greetings"
actual = ( actual = table.search(query).limit(1).to_pydantic(Words)[0]
table.search(query, vector_column_name="vector").limit(1).to_pydantic(Words)[0]
)
vec = func.compute_query_embeddings(query)[0] vec = func.compute_query_embeddings(query)[0]
expected = ( expected = table.search(vec).limit(1).to_pydantic(Words)[0]
table.search(vec, vector_column_name="vector").limit(1).to_pydantic(Words)[0]
)
assert actual.text == expected.text assert actual.text == expected.text
assert actual.text == "hello world" assert actual.text == "hello world"
assert not np.allclose(actual.vector, actual.vector2) assert not np.allclose(actual.vector, actual.vector2)
@@ -120,11 +121,7 @@ def test_openclip(tmp_path):
) )
# text search # text search
actual = ( actual = table.search("man's best friend").limit(1).to_pydantic(Images)[0]
table.search("man's best friend", vector_column_name="vector")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == "dog" assert actual.label == "dog"
frombytes = ( frombytes = (
table.search("man's best friend", vector_column_name="vec_from_bytes") table.search("man's best friend", vector_column_name="vec_from_bytes")
@@ -138,11 +135,7 @@ def test_openclip(tmp_path):
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg" query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
image_bytes = requests.get(query_image_uri).content image_bytes = requests.get(query_image_uri).content
query_image = Image.open(io.BytesIO(image_bytes)) query_image = Image.open(io.BytesIO(image_bytes))
actual = ( actual = table.search(query_image).limit(1).to_pydantic(Images)[0]
table.search(query_image, vector_column_name="vector")
.limit(1)
.to_pydantic(Images)[0]
)
assert actual.label == "dog" assert actual.label == "dog"
other = ( other = (
table.search(query_image, vector_column_name="vec_from_bytes") table.search(query_image, vector_column_name="vec_from_bytes")
@@ -217,13 +210,6 @@ def test_gemini_embedding(tmp_path):
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world" assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
try:
if importlib.util.find_spec("mlx.core") is not None:
_mlx = True
except ImportError:
_mlx = None
@pytest.mark.skipif( @pytest.mark.skipif(
_mlx is None, _mlx is None,
reason="mlx tests only required for apple users.", reason="mlx tests only required for apple users.",
@@ -280,49 +266,3 @@ def test_bedrock_embedding(tmp_path):
tbl.add(df) tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims() assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
)
def test_openai_embedding(tmp_path):
def _get_table(model):
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
return tbl
model = get_registry().get("openai").create(max_retries=0)
tbl = _get_table(model)
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
model = (
get_registry()
.get("openai")
.create(max_retries=0, name="text-embedding-3-large")
)
tbl = _get_table(model)
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
model = (
get_registry()
.get("openai")
.create(max_retries=0, name="text-embedding-3-large", dim=1024)
)
tbl = _get_table(model)
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"

View File

@@ -29,14 +29,10 @@ class FakeLanceDBClient:
def post(self, path: str): def post(self, path: str):
pass pass
def mount_retry_adapter_for_table(self, table_name: str):
pass
def test_remote_db(): def test_remote_db():
conn = lancedb.connect("db://client-will-be-injected", api_key="fake") conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
setattr(conn, "_client", FakeLanceDBClient()) setattr(conn, "_client", FakeLanceDBClient())
table = conn["test"] table = conn["test"]
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
table.search([1.0, 2.0]).to_pandas() table.search([1.0, 2.0]).to_pandas()

View File

@@ -7,12 +7,7 @@ import lancedb
from lancedb.conftest import MockTextEmbeddingFunction # noqa from lancedb.conftest import MockTextEmbeddingFunction # noqa
from lancedb.embeddings import EmbeddingFunctionRegistry from lancedb.embeddings import EmbeddingFunctionRegistry
from lancedb.pydantic import LanceModel, Vector from lancedb.pydantic import LanceModel, Vector
from lancedb.rerankers import ( from lancedb.rerankers import CohereReranker, CrossEncoderReranker
CohereReranker,
ColbertReranker,
CrossEncoderReranker,
OpenaiReranker,
)
from lancedb.table import LanceTable from lancedb.table import LanceTable
@@ -80,6 +75,7 @@ def get_test_table(tmp_path):
return table, MyTable return table, MyTable
## These tests are pretty loose, we should also check for correctness
def test_linear_combination(tmp_path): def test_linear_combination(tmp_path):
table, schema = get_test_table(tmp_path) table, schema = get_test_table(tmp_path)
# The default reranker # The default reranker
@@ -99,19 +95,14 @@ def test_linear_combination(tmp_path):
assert result1 == result3 # 2 & 3 should be the same as they use score as score assert result1 == result3 # 2 & 3 should be the same as they use score as score
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = ( result = (
table.search((query_vector, query)) table.search("Our father who art in heaven", query_type="hybrid")
.limit(30) .limit(50)
.rerank(normalize="score") .rerank(normalize="score")
.to_arrow() .to_arrow()
) )
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _relevance_score column of the results returned by the reranker " "The _score column of the results returned by the reranker "
"represents the relevance of the result to the query & should " "represents the relevance of the result to the query & should "
"be descending." "be descending."
) )
@@ -131,24 +122,19 @@ def test_cohere_reranker(tmp_path):
) )
result2 = ( result2 = (
table.search("Our father who art in heaven", query_type="hybrid") table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=CohereReranker()) .rerank(normalize="rank", reranker=CohereReranker())
.to_pydantic(schema) .to_pydantic(schema)
) )
assert result1 == result2 assert result1 == result2
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = ( result = (
table.search((query_vector, query)) table.search("Our father who art in heaven", query_type="hybrid")
.limit(30) .limit(50)
.rerank(reranker=CohereReranker()) .rerank(reranker=CohereReranker())
.to_arrow() .to_arrow()
) )
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _relevance_score column of the results returned by the reranker " "The _score column of the results returned by the reranker "
"represents the relevance of the result to the query & should " "represents the relevance of the result to the query & should "
"be descending." "be descending."
) )
@@ -164,96 +150,19 @@ def test_cross_encoder_reranker(tmp_path):
) )
result2 = ( result2 = (
table.search("Our father who art in heaven", query_type="hybrid") table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=CrossEncoderReranker()) .rerank(normalize="rank", reranker=CrossEncoderReranker())
.to_pydantic(schema) .to_pydantic(schema)
) )
assert result1 == result2 assert result1 == result2
# test explicit hybrid query
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = ( result = (
table.search((query_vector, query), query_type="hybrid") table.search("Our father who art in heaven", query_type="hybrid")
.limit(30) .limit(50)
.rerank(reranker=CrossEncoderReranker()) .rerank(reranker=CrossEncoderReranker())
.to_arrow() .to_arrow()
) )
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), ( assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _relevance_score column of the results returned by the reranker " "The _score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
def test_colbert_reranker(tmp_path):
pytest.importorskip("transformers")
table, schema = get_test_table(tmp_path)
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=ColbertReranker())
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=ColbertReranker())
.to_pydantic(schema)
)
assert result1 == result2
# test explicit hybrid query
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
.limit(30)
.rerank(reranker=ColbertReranker())
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
)
def test_openai_reranker(tmp_path):
pytest.importorskip("openai")
table, schema = get_test_table(tmp_path)
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=OpenaiReranker())
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=OpenaiReranker())
.to_pydantic(schema)
)
assert result1 == result2
# test explicit hybrid query
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
.limit(30)
.rerank(reranker=OpenaiReranker())
.to_arrow()
)
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should " "represents the relevance of the result to the query & should "
"be descending." "be descending."
) )

View File

@@ -12,10 +12,8 @@
# limitations under the License. # limitations under the License.
import functools import functools
from copy import copy
from datetime import date, datetime, timedelta from datetime import date, datetime, timedelta
from pathlib import Path from pathlib import Path
from time import sleep
from typing import List from typing import List
from unittest.mock import PropertyMock, patch from unittest.mock import PropertyMock, patch
@@ -27,7 +25,6 @@ import pyarrow as pa
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
import lancedb
from lancedb.conftest import MockTextEmbeddingFunction from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.db import LanceDBConnection from lancedb.db import LanceDBConnection
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
@@ -38,7 +35,6 @@ from lancedb.table import LanceTable
class MockDB: class MockDB:
def __init__(self, uri: Path): def __init__(self, uri: Path):
self.uri = uri self.uri = uri
self.read_consistency_interval = None
@functools.cached_property @functools.cached_property
def is_managed_remote(self) -> bool: def is_managed_remote(self) -> bool:
@@ -271,38 +267,39 @@ def test_versioning(db):
def test_create_index_method(): def test_create_index_method():
with patch.object( with patch.object(LanceTable, "_reset_dataset", return_value=None):
LanceTable, "_dataset_mut", new_callable=PropertyMock with patch.object(
) as mock_dataset: LanceTable, "_dataset", new_callable=PropertyMock
# Setup mock responses ) as mock_dataset:
mock_dataset.return_value.create_index.return_value = None # Setup mock responses
mock_dataset.return_value.create_index.return_value = None
# Create a LanceTable object # Create a LanceTable object
connection = LanceDBConnection(uri="mock.uri") connection = LanceDBConnection(uri="mock.uri")
table = LanceTable(connection, "test_table") table = LanceTable(connection, "test_table")
# Call the create_index method # Call the create_index method
table.create_index( table.create_index(
metric="L2", metric="L2",
num_partitions=256, num_partitions=256,
num_sub_vectors=96, num_sub_vectors=96,
vector_column_name="vector", vector_column_name="vector",
replace=True, replace=True,
index_cache_size=256, index_cache_size=256,
) )
# Check that the _dataset.create_index method was called # Check that the _dataset.create_index method was called
# with the right parameters # with the right parameters
mock_dataset.return_value.create_index.assert_called_once_with( mock_dataset.return_value.create_index.assert_called_once_with(
column="vector", column="vector",
index_type="IVF_PQ", index_type="IVF_PQ",
metric="L2", metric="L2",
num_partitions=256, num_partitions=256,
num_sub_vectors=96, num_sub_vectors=96,
replace=True, replace=True,
accelerator=None, accelerator=None,
index_cache_size=256, index_cache_size=256,
) )
def test_add_with_nans(db): def test_add_with_nans(db):
@@ -513,15 +510,8 @@ def test_merge_insert(db):
).when_matched_update_all().when_not_matched_insert_all().execute(new_data) ).when_matched_update_all().when_not_matched_insert_all().execute(new_data)
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "x", "y", "z"]}) expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "x", "y", "z"]})
assert table.to_arrow().sort_by("a") == expected # These `sort_by` calls can be removed once lance#1892
# is merged (it fixes the ordering)
table.restore(version)
# conditional update
table.merge_insert("a").when_matched_update_all(where="target.b = 'b'").execute(
new_data
)
expected = pa.table({"a": [1, 2, 3], "b": ["a", "x", "c"]})
assert table.to_arrow().sort_by("a") == expected assert table.to_arrow().sort_by("a") == expected
table.restore(version) table.restore(version)
@@ -710,59 +700,6 @@ def test_empty_query(db):
assert len(df) == 100 assert len(df) == 100
def test_search_with_schema_inf_single_vector(db):
class MyTable(LanceModel):
text: str
vector_col: Vector(10)
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
)
v1 = np.random.randn(10)
v2 = np.random.randn(10)
data = [
{"vector_col": v1, "text": "foo"},
{"vector_col": v2, "text": "bar"},
]
df = pd.DataFrame(data)
table.add(df)
q = np.random.randn(10)
result1 = table.search(q, vector_column_name="vector_col").limit(1).to_pandas()
result2 = table.search(q).limit(1).to_pandas()
assert result1["text"].iloc[0] == result2["text"].iloc[0]
def test_search_with_schema_inf_multiple_vector(db):
class MyTable(LanceModel):
text: str
vector1: Vector(10)
vector2: Vector(10)
table = LanceTable.create(
db,
"my_table",
schema=MyTable,
)
v1 = np.random.randn(10)
v2 = np.random.randn(10)
data = [
{"vector1": v1, "vector2": v2, "text": "foo"},
{"vector1": v2, "vector2": v1, "text": "bar"},
]
df = pd.DataFrame(data)
table.add(df)
q = np.random.randn(10)
with pytest.raises(ValueError):
table.search(q).limit(1).to_pandas()
def test_compact_cleanup(db): def test_compact_cleanup(db):
table = LanceTable.create( table = LanceTable.create(
db, db,
@@ -855,48 +792,3 @@ def test_hybrid_search(db):
"Our father who art in heaven", query_type="hybrid" "Our father who art in heaven", query_type="hybrid"
).to_pydantic(MyTable) ).to_pydantic(MyTable)
assert result1 == result3 assert result1 == result3
@pytest.mark.parametrize(
"consistency_interval", [None, timedelta(seconds=0), timedelta(seconds=0.1)]
)
def test_consistency(tmp_path, consistency_interval):
db = lancedb.connect(tmp_path)
table = LanceTable.create(db, "my_table", data=[{"id": 0}])
db2 = lancedb.connect(tmp_path, read_consistency_interval=consistency_interval)
table2 = db2.open_table("my_table")
assert table2.version == table.version
table.add([{"id": 1}])
if consistency_interval is None:
assert table2.version == table.version - 1
table2.checkout_latest()
assert table2.version == table.version
elif consistency_interval == timedelta(seconds=0):
assert table2.version == table.version
else:
# (consistency_interval == timedelta(seconds=0.1)
assert table2.version == table.version - 1
sleep(0.1)
assert table2.version == table.version
def test_restore_consistency(tmp_path):
db = lancedb.connect(tmp_path)
table = LanceTable.create(db, "my_table", data=[{"id": 0}])
db2 = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0))
table2 = db2.open_table("my_table")
assert table2.version == table.version
# If we call checkout, it should lose consistency
table_fixed = copy(table2)
table_fixed.checkout(table.version)
# But if we call checkout_latest, it should be consistent again
table_ref_latest = copy(table_fixed)
table_ref_latest.checkout_latest()
table.add([{"id": 2}])
assert table_fixed.version == table.version - 1
assert table_ref_latest.version == table.version

View File

@@ -1,12 +1,9 @@
[package] [package]
name = "vectordb-node" name = "vectordb-node"
version = "0.4.10" version = "0.4.7"
description = "Serverless, low-latency vector database for AI applications" description = "Serverless, low-latency vector database for AI applications"
license.workspace = true license = "Apache-2.0"
edition.workspace = true edition = "2018"
repository.workspace = true
keywords.workspace = true
categories.workspace = true
exclude = ["index.node"] exclude = ["index.node"]
[lib] [lib]
@@ -31,6 +28,3 @@ object_store = { workspace = true, features = ["aws"] }
snafu = { workspace = true } snafu = { workspace = true }
async-trait = "0" async-trait = "0"
env_logger = "0" env_logger = "0"
# Prevent dynamic linking of lzma, which comes from datafusion
lzma-sys = { version = "*", features = ["static"] }

View File

@@ -22,7 +22,7 @@ use arrow_schema::SchemaRef;
use crate::error::Result; use crate::error::Result;
pub fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBatch>, SchemaRef)> { pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBatch>, SchemaRef)> {
let mut batches: Vec<RecordBatch> = Vec::new(); let mut batches: Vec<RecordBatch> = Vec::new();
let file_reader = FileReader::try_new(Cursor::new(slice), None)?; let file_reader = FileReader::try_new(Cursor::new(slice), None)?;
let schema = file_reader.schema(); let schema = file_reader.schema();
@@ -33,7 +33,7 @@ pub fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBatch>, S
Ok((batches, schema)) Ok((batches, schema))
} }
pub fn record_batch_to_buffer(batches: Vec<RecordBatch>) -> Result<Vec<u8>> { pub(crate) fn record_batch_to_buffer(batches: Vec<RecordBatch>) -> Result<Vec<u8>> {
if batches.is_empty() { if batches.is_empty() {
return Ok(Vec::new()); return Ok(Vec::new());
} }

View File

@@ -17,7 +17,10 @@ use neon::types::buffer::TypedArray;
use crate::error::ResultExt; use crate::error::ResultExt;
pub fn vec_str_to_array<'a, C: Context<'a>>(vec: &[String], cx: &mut C) -> JsResult<'a, JsArray> { pub(crate) fn vec_str_to_array<'a, C: Context<'a>>(
vec: &Vec<String>,
cx: &mut C,
) -> JsResult<'a, JsArray> {
let a = JsArray::new(cx, vec.len() as u32); let a = JsArray::new(cx, vec.len() as u32);
for (i, s) in vec.iter().enumerate() { for (i, s) in vec.iter().enumerate() {
let v = cx.string(s); let v = cx.string(s);
@@ -26,7 +29,7 @@ pub fn vec_str_to_array<'a, C: Context<'a>>(vec: &[String], cx: &mut C) -> JsRes
Ok(a) Ok(a)
} }
pub fn js_array_to_vec(array: &JsArray, cx: &mut FunctionContext) -> Vec<f32> { pub(crate) fn js_array_to_vec(array: &JsArray, cx: &mut FunctionContext) -> Vec<f32> {
let mut query_vec: Vec<f32> = Vec::new(); let mut query_vec: Vec<f32> = Vec::new();
for i in 0..array.len(cx) { for i in 0..array.len(cx) {
let entry: Handle<JsNumber> = array.get(cx, i).unwrap(); let entry: Handle<JsNumber> = array.get(cx, i).unwrap();
@@ -36,7 +39,7 @@ pub fn js_array_to_vec(array: &JsArray, cx: &mut FunctionContext) -> Vec<f32> {
} }
// Creates a new JsBuffer from a rust buffer with a special logic for electron // Creates a new JsBuffer from a rust buffer with a special logic for electron
pub fn new_js_buffer<'a>( pub(crate) fn new_js_buffer<'a>(
buffer: Vec<u8>, buffer: Vec<u8>,
cx: &mut TaskContext<'a>, cx: &mut TaskContext<'a>,
is_electron: bool, is_electron: bool,

View File

@@ -18,6 +18,7 @@ use neon::prelude::NeonResult;
use snafu::Snafu; use snafu::Snafu;
#[derive(Debug, Snafu)] #[derive(Debug, Snafu)]
#[snafu(visibility(pub(crate)))]
pub enum Error { pub enum Error {
#[snafu(display("column '{name}' is missing"))] #[snafu(display("column '{name}' is missing"))]
MissingColumn { name: String }, MissingColumn { name: String },

View File

@@ -21,7 +21,7 @@ use neon::{
use crate::{error::ResultExt, runtime, table::JsTable}; use crate::{error::ResultExt, runtime, table::JsTable};
use vectordb::Table; use vectordb::Table;
pub fn table_create_scalar_index(mut cx: FunctionContext) -> JsResult<JsPromise> { pub(crate) fn table_create_scalar_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?; let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let column = cx.argument::<JsString>(0)?.value(&mut cx); let column = cx.argument::<JsString>(0)?.value(&mut cx);
let replace = cx.argument::<JsBoolean>(1)?.value(&mut cx); let replace = cx.argument::<JsBoolean>(1)?.value(&mut cx);

View File

@@ -24,7 +24,7 @@ use crate::neon_ext::js_object_ext::JsObjectExt;
use crate::runtime; use crate::runtime;
use crate::table::JsTable; use crate::table::JsTable;
pub fn table_create_vector_index(mut cx: FunctionContext) -> JsResult<JsPromise> { pub(crate) fn table_create_vector_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?; let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let index_params = cx.argument::<JsObject>(0)?; let index_params = cx.argument::<JsObject>(0)?;

View File

@@ -260,7 +260,6 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> {
cx.export_function("tableCountRows", JsTable::js_count_rows)?; cx.export_function("tableCountRows", JsTable::js_count_rows)?;
cx.export_function("tableDelete", JsTable::js_delete)?; cx.export_function("tableDelete", JsTable::js_delete)?;
cx.export_function("tableUpdate", JsTable::js_update)?; cx.export_function("tableUpdate", JsTable::js_update)?;
cx.export_function("tableMergeInsert", JsTable::js_merge_insert)?;
cx.export_function("tableCleanupOldVersions", JsTable::js_cleanup)?; cx.export_function("tableCleanupOldVersions", JsTable::js_cleanup)?;
cx.export_function("tableCompactFiles", JsTable::js_compact)?; cx.export_function("tableCompactFiles", JsTable::js_compact)?;
cx.export_function("tableListIndices", JsTable::js_list_indices)?; cx.export_function("tableListIndices", JsTable::js_list_indices)?;

View File

@@ -13,7 +13,7 @@ use crate::neon_ext::js_object_ext::JsObjectExt;
use crate::table::JsTable; use crate::table::JsTable;
use crate::{convert, runtime}; use crate::{convert, runtime};
pub struct JsQuery {} pub(crate) struct JsQuery {}
impl JsQuery { impl JsQuery {
pub(crate) fn js_search(mut cx: FunctionContext) -> JsResult<JsPromise> { pub(crate) fn js_search(mut cx: FunctionContext) -> JsResult<JsPromise> {

View File

@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::ops::Deref;
use arrow_array::{RecordBatch, RecordBatchIterator}; use arrow_array::{RecordBatch, RecordBatchIterator};
use lance::dataset::optimize::CompactionOptions; use lance::dataset::optimize::CompactionOptions;
use lance::dataset::{WriteMode, WriteParams}; use lance::dataset::{WriteMode, WriteParams};
@@ -28,7 +26,7 @@ use vectordb::TableRef;
use crate::error::ResultExt; use crate::error::ResultExt;
use crate::{convert, get_aws_credential_provider, get_aws_region, runtime, JsDatabase}; use crate::{convert, get_aws_credential_provider, get_aws_region, runtime, JsDatabase};
pub struct JsTable { pub(crate) struct JsTable {
pub table: TableRef, pub table: TableRef,
} }
@@ -36,7 +34,7 @@ impl Finalize for JsTable {}
impl From<TableRef> for JsTable { impl From<TableRef> for JsTable {
fn from(table: TableRef) -> Self { fn from(table: TableRef) -> Self {
Self { table } JsTable { table }
} }
} }
@@ -85,14 +83,14 @@ impl JsTable {
deferred.settle_with(&channel, move |mut cx| { deferred.settle_with(&channel, move |mut cx| {
let table = table_rst.or_throw(&mut cx)?; let table = table_rst.or_throw(&mut cx)?;
Ok(cx.boxed(Self::from(table))) Ok(cx.boxed(JsTable::from(table)))
}); });
}); });
Ok(promise) Ok(promise)
} }
pub(crate) fn js_add(mut cx: FunctionContext) -> JsResult<JsPromise> { pub(crate) fn js_add(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?; let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let buffer = cx.argument::<JsBuffer>(0)?; let buffer = cx.argument::<JsBuffer>(0)?;
let write_mode = cx.argument::<JsString>(1)?.value(&mut cx); let write_mode = cx.argument::<JsString>(1)?.value(&mut cx);
let (batches, schema) = let (batches, schema) =
@@ -125,34 +123,21 @@ impl JsTable {
deferred.settle_with(&channel, move |mut cx| { deferred.settle_with(&channel, move |mut cx| {
add_result.or_throw(&mut cx)?; add_result.or_throw(&mut cx)?;
Ok(cx.boxed(Self::from(table))) Ok(cx.boxed(JsTable::from(table)))
}); });
}); });
Ok(promise) Ok(promise)
} }
pub(crate) fn js_count_rows(mut cx: FunctionContext) -> JsResult<JsPromise> { pub(crate) fn js_count_rows(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?; let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let filter = cx
.argument_opt(0)
.and_then(|filt| {
if filt.is_a::<JsUndefined, _>(&mut cx) || filt.is_a::<JsNull, _>(&mut cx) {
None
} else {
Some(
filt.downcast_or_throw::<JsString, _>(&mut cx)
.map(|js_filt| js_filt.deref().value(&mut cx)),
)
}
})
.transpose()?;
let rt = runtime(&mut cx)?; let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise(); let (deferred, promise) = cx.promise();
let channel = cx.channel(); let channel = cx.channel();
let table = js_table.table.clone(); let table = js_table.table.clone();
rt.spawn(async move { rt.spawn(async move {
let num_rows_result = table.count_rows(filter).await; let num_rows_result = table.count_rows().await;
deferred.settle_with(&channel, move |mut cx| { deferred.settle_with(&channel, move |mut cx| {
let num_rows = num_rows_result.or_throw(&mut cx)?; let num_rows = num_rows_result.or_throw(&mut cx)?;
@@ -163,7 +148,7 @@ impl JsTable {
} }
pub(crate) fn js_delete(mut cx: FunctionContext) -> JsResult<JsPromise> { pub(crate) fn js_delete(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?; let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let rt = runtime(&mut cx)?; let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise(); let (deferred, promise) = cx.promise();
let predicate = cx.argument::<JsString>(0)?.value(&mut cx); let predicate = cx.argument::<JsString>(0)?.value(&mut cx);
@@ -175,67 +160,14 @@ impl JsTable {
deferred.settle_with(&channel, move |mut cx| { deferred.settle_with(&channel, move |mut cx| {
delete_result.or_throw(&mut cx)?; delete_result.or_throw(&mut cx)?;
Ok(cx.boxed(Self::from(table))) Ok(cx.boxed(JsTable::from(table)))
})
});
Ok(promise)
}
pub(crate) fn js_merge_insert(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise();
let channel = cx.channel();
let table = js_table.table.clone();
let key = cx.argument::<JsString>(0)?.value(&mut cx);
let mut builder = table.merge_insert(&[&key]);
if cx.argument::<JsBoolean>(1)?.value(&mut cx) {
let filter = cx.argument_opt(2).unwrap();
if filter.is_a::<JsNull, _>(&mut cx) {
builder.when_matched_update_all(None);
} else {
let filter = filter
.downcast_or_throw::<JsString, _>(&mut cx)?
.deref()
.value(&mut cx);
builder.when_matched_update_all(Some(filter));
}
}
if cx.argument::<JsBoolean>(3)?.value(&mut cx) {
builder.when_not_matched_insert_all();
}
if cx.argument::<JsBoolean>(4)?.value(&mut cx) {
let filter = cx.argument_opt(5).unwrap();
if filter.is_a::<JsNull, _>(&mut cx) {
builder.when_not_matched_by_source_delete(None);
} else {
let filter = filter
.downcast_or_throw::<JsString, _>(&mut cx)?
.deref()
.value(&mut cx);
builder.when_not_matched_by_source_delete(Some(filter));
}
}
let buffer = cx.argument::<JsBuffer>(6)?;
let (batches, schema) =
arrow_buffer_to_record_batch(buffer.as_slice(&cx)).or_throw(&mut cx)?;
rt.spawn(async move {
let new_data = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
let merge_insert_result = builder.execute(Box::new(new_data)).await;
deferred.settle_with(&channel, move |mut cx| {
merge_insert_result.or_throw(&mut cx)?;
Ok(cx.boxed(Self::from(table)))
}) })
}); });
Ok(promise) Ok(promise)
} }
pub(crate) fn js_update(mut cx: FunctionContext) -> JsResult<JsPromise> { pub(crate) fn js_update(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?; let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let table = js_table.table.clone(); let table = js_table.table.clone();
let rt = runtime(&mut cx)?; let rt = runtime(&mut cx)?;
@@ -294,7 +226,7 @@ impl JsTable {
.await; .await;
deferred.settle_with(&channel, move |mut cx| { deferred.settle_with(&channel, move |mut cx| {
update_result.or_throw(&mut cx)?; update_result.or_throw(&mut cx)?;
Ok(cx.boxed(Self::from(table))) Ok(cx.boxed(JsTable::from(table)))
}) })
}); });
@@ -302,7 +234,7 @@ impl JsTable {
} }
pub(crate) fn js_cleanup(mut cx: FunctionContext) -> JsResult<JsPromise> { pub(crate) fn js_cleanup(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?; let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let rt = runtime(&mut cx)?; let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise(); let (deferred, promise) = cx.promise();
let table = js_table.table.clone(); let table = js_table.table.clone();
@@ -340,7 +272,7 @@ impl JsTable {
let old_versions = cx.number(prune_stats.old_versions as f64); let old_versions = cx.number(prune_stats.old_versions as f64);
output_metrics.set(&mut cx, "oldVersions", old_versions)?; output_metrics.set(&mut cx, "oldVersions", old_versions)?;
let output_table = cx.boxed(Self::from(table)); let output_table = cx.boxed(JsTable::from(table));
let output = JsObject::new(&mut cx); let output = JsObject::new(&mut cx);
output.set(&mut cx, "metrics", output_metrics)?; output.set(&mut cx, "metrics", output_metrics)?;
@@ -353,7 +285,7 @@ impl JsTable {
} }
pub(crate) fn js_compact(mut cx: FunctionContext) -> JsResult<JsPromise> { pub(crate) fn js_compact(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?; let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let rt = runtime(&mut cx)?; let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise(); let (deferred, promise) = cx.promise();
let table = js_table.table.clone(); let table = js_table.table.clone();
@@ -412,7 +344,7 @@ impl JsTable {
let files_added = cx.number(stats.files_added as f64); let files_added = cx.number(stats.files_added as f64);
output_metrics.set(&mut cx, "filesAdded", files_added)?; output_metrics.set(&mut cx, "filesAdded", files_added)?;
let output_table = cx.boxed(Self::from(table)); let output_table = cx.boxed(JsTable::from(table));
let output = JsObject::new(&mut cx); let output = JsObject::new(&mut cx);
output.set(&mut cx, "metrics", output_metrics)?; output.set(&mut cx, "metrics", output_metrics)?;
@@ -425,7 +357,7 @@ impl JsTable {
} }
pub(crate) fn js_list_indices(mut cx: FunctionContext) -> JsResult<JsPromise> { pub(crate) fn js_list_indices(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?; let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let rt = runtime(&mut cx)?; let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise(); let (deferred, promise) = cx.promise();
// let predicate = cx.argument::<JsString>(0)?.value(&mut cx); // let predicate = cx.argument::<JsString>(0)?.value(&mut cx);
@@ -464,7 +396,7 @@ impl JsTable {
} }
pub(crate) fn js_index_stats(mut cx: FunctionContext) -> JsResult<JsPromise> { pub(crate) fn js_index_stats(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?; let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let rt = runtime(&mut cx)?; let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise(); let (deferred, promise) = cx.promise();
let index_uuid = cx.argument::<JsString>(0)?.value(&mut cx); let index_uuid = cx.argument::<JsString>(0)?.value(&mut cx);
@@ -512,7 +444,7 @@ impl JsTable {
} }
pub(crate) fn js_schema(mut cx: FunctionContext) -> JsResult<JsPromise> { pub(crate) fn js_schema(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?; let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let rt = runtime(&mut cx)?; let rt = runtime(&mut cx)?;
let (deferred, promise) = cx.promise(); let (deferred, promise) = cx.promise();
let channel = cx.channel(); let channel = cx.channel();

View File

@@ -1,12 +1,12 @@
[package] [package]
name = "vectordb" name = "vectordb"
version = "0.4.10" version = "0.4.7"
edition.workspace = true edition = "2021"
description = "LanceDB: A serverless, low-latency vector database for AI applications" description = "LanceDB: A serverless, low-latency vector database for AI applications"
license.workspace = true license = "Apache-2.0"
repository.workspace = true repository = "https://github.com/lancedb/lancedb"
keywords.workspace = true keywords = ["lancedb", "lance", "database", "search"]
categories.workspace = true categories = ["database-implementations"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]

View File

@@ -188,12 +188,12 @@ impl Database {
/// # Returns /// # Returns
/// ///
/// * A [Database] object. /// * A [Database] object.
pub async fn connect(uri: &str) -> Result<Self> { pub async fn connect(uri: &str) -> Result<Database> {
let options = ConnectOptions::new(uri); let options = ConnectOptions::new(uri);
Self::connect_with_options(&options).await Self::connect_with_options(&options).await
} }
pub async fn connect_with_options(options: &ConnectOptions) -> Result<Self> { pub async fn connect_with_options(options: &ConnectOptions) -> Result<Database> {
let uri = &options.uri; let uri = &options.uri;
let parse_res = url::Url::parse(uri); let parse_res = url::Url::parse(uri);
@@ -276,7 +276,7 @@ impl Database {
None => None, None => None,
}; };
Ok(Self { Ok(Database {
uri: table_base_uri, uri: table_base_uri,
query_string, query_string,
base_path, base_path,
@@ -288,7 +288,7 @@ impl Database {
} }
} }
async fn open_path(path: &str) -> Result<Self> { async fn open_path(path: &str) -> Result<Database> {
let (object_store, base_path) = ObjectStore::from_uri(path).await?; let (object_store, base_path) = ObjectStore::from_uri(path).await?;
if object_store.is_local() { if object_store.is_local() {
Self::try_create_dir(path).context(CreateDirSnafu { path })?; Self::try_create_dir(path).context(CreateDirSnafu { path })?;
@@ -422,11 +422,13 @@ mod tests {
let tmp_dir = tempdir().unwrap(); let tmp_dir = tempdir().unwrap();
let uri = std::fs::canonicalize(tmp_dir.path().to_str().unwrap()).unwrap(); let uri = std::fs::canonicalize(tmp_dir.path().to_str().unwrap()).unwrap();
let mut relative_anacestors = vec![];
let current_dir = std::env::current_dir().unwrap(); let current_dir = std::env::current_dir().unwrap();
let ancestors = current_dir.ancestors(); let mut ancestors = current_dir.ancestors();
let relative_ancestors = vec![".."; ancestors.count()]; while let Some(_) = ancestors.next() {
relative_anacestors.push("..");
let relative_root = std::path::PathBuf::from(relative_ancestors.join("/")); }
let relative_root = std::path::PathBuf::from(relative_anacestors.join("/"));
let relative_uri = relative_root.join(&uri); let relative_uri = relative_root.join(&uri);
let db = Database::connect(relative_uri.to_str().unwrap()) let db = Database::connect(relative_uri.to_str().unwrap())

View File

@@ -69,7 +69,7 @@ pub struct IndexBuilder {
impl IndexBuilder { impl IndexBuilder {
pub(crate) fn new(table: Arc<dyn Table>, columns: &[&str]) -> Self { pub(crate) fn new(table: Arc<dyn Table>, columns: &[&str]) -> Self {
Self { IndexBuilder {
table, table,
columns: columns.iter().map(|c| c.to_string()).collect(), columns: columns.iter().map(|c| c.to_string()).collect(),
name: None, name: None,
@@ -197,7 +197,7 @@ impl IndexBuilder {
let num_partitions = if let Some(n) = self.num_partitions { let num_partitions = if let Some(n) = self.num_partitions {
n n
} else { } else {
suggested_num_partitions(self.table.count_rows(None).await?) suggested_num_partitions(self.table.count_rows().await?)
}; };
let num_sub_vectors: u32 = if let Some(n) = self.num_sub_vectors { let num_sub_vectors: u32 = if let Some(n) = self.num_sub_vectors {
n n

View File

@@ -23,13 +23,13 @@ pub struct VectorIndex {
} }
impl VectorIndex { impl VectorIndex {
pub fn new_from_format(manifest: &Manifest, index: &Index) -> Self { pub fn new_from_format(manifest: &Manifest, index: &Index) -> VectorIndex {
let fields = index let fields = index
.fields .fields
.iter() .iter()
.map(|i| manifest.schema.fields[*i as usize].name.clone()) .map(|i| manifest.schema.fields[*i as usize].name.clone())
.collect(); .collect();
Self { VectorIndex {
columns: fields, columns: fields,
index_name: index.name.clone(), index_name: index.name.clone(),
index_uuid: index.uuid.to_string(), index_uuid: index.uuid.to_string(),

View File

@@ -357,14 +357,12 @@ mod test {
let db = Database::connect(dir1.to_str().unwrap()).await.unwrap(); let db = Database::connect(dir1.to_str().unwrap()).await.unwrap();
let mut param = WriteParams::default(); let mut param = WriteParams::default();
let store_params = ObjectStoreParams { let mut store_params = ObjectStoreParams::default();
object_store_wrapper: Some(object_store_wrapper), store_params.object_store_wrapper = Some(object_store_wrapper);
..Default::default()
};
param.store_params = Some(store_params); param.store_params = Some(store_params);
let mut datagen = BatchGenerator::new(); let mut datagen = BatchGenerator::new();
datagen = datagen.col(Box::<IncrementingInt32>::default()); datagen = datagen.col(Box::new(IncrementingInt32::default()));
datagen = datagen.col(Box::new(RandomVector::default().named("vector".into()))); datagen = datagen.col(Box::new(RandomVector::default().named("vector".into())));
let res = db let res = db
@@ -374,7 +372,7 @@ mod test {
// leave this here for easy debugging // leave this here for easy debugging
let t = res.unwrap(); let t = res.unwrap();
assert_eq!(t.count_rows(None).await.unwrap(), 100); assert_eq!(t.count_rows().await.unwrap(), 100);
let q = t let q = t
.search(&[0.1, 0.1, 0.1, 0.1]) .search(&[0.1, 0.1, 0.1, 0.1])

View File

@@ -62,7 +62,7 @@ impl Query {
/// * `dataset` - Lance dataset. /// * `dataset` - Lance dataset.
/// ///
pub(crate) fn new(dataset: Arc<Dataset>) -> Self { pub(crate) fn new(dataset: Arc<Dataset>) -> Self {
Self { Query {
dataset, dataset,
query_vector: None, query_vector: None,
column: None, column: None,
@@ -257,7 +257,7 @@ mod tests {
assert_eq!(query.query_vector.unwrap(), new_vector); assert_eq!(query.query_vector.unwrap(), new_vector);
assert_eq!(query.limit.unwrap(), 100); assert_eq!(query.limit.unwrap(), 100);
assert_eq!(query.nprobes, 1000); assert_eq!(query.nprobes, 1000);
assert!(query.use_index); assert_eq!(query.use_index, true);
assert_eq!(query.metric_type, Some(MetricType::Cosine)); assert_eq!(query.metric_type, Some(MetricType::Cosine));
assert_eq!(query.refine_factor, Some(999)); assert_eq!(query.refine_factor, Some(999));
} }

View File

@@ -19,7 +19,6 @@ use std::sync::{Arc, Mutex};
use arrow_array::RecordBatchReader; use arrow_array::RecordBatchReader;
use arrow_schema::{Schema, SchemaRef}; use arrow_schema::{Schema, SchemaRef};
use async_trait::async_trait;
use chrono::Duration; use chrono::Duration;
use lance::dataset::builder::DatasetBuilder; use lance::dataset::builder::DatasetBuilder;
use lance::dataset::cleanup::RemovalStats; use lance::dataset::cleanup::RemovalStats;
@@ -27,8 +26,7 @@ use lance::dataset::optimize::{
compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions, compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions,
}; };
pub use lance::dataset::ReadParams; pub use lance::dataset::ReadParams;
use lance::dataset::{Dataset, UpdateBuilder, WhenMatched, WriteParams}; use lance::dataset::{Dataset, UpdateBuilder, WriteParams};
use lance::dataset::{MergeInsertBuilder as LanceMergeInsertBuilder, WhenNotMatchedBySource};
use lance::io::WrappingObjectStore; use lance::io::WrappingObjectStore;
use lance_index::{optimize::OptimizeOptions, DatasetIndexExt}; use lance_index::{optimize::OptimizeOptions, DatasetIndexExt};
use log::info; use log::info;
@@ -40,10 +38,6 @@ use crate::query::Query;
use crate::utils::{PatchReadParam, PatchWriteParam}; use crate::utils::{PatchReadParam, PatchWriteParam};
use crate::WriteMode; use crate::WriteMode;
use self::merge::{MergeInsert, MergeInsertBuilder};
pub mod merge;
/// Optimize the dataset. /// Optimize the dataset.
/// ///
/// Similar to `VACUUM` in PostgreSQL, it offers different options to /// Similar to `VACUUM` in PostgreSQL, it offers different options to
@@ -102,11 +96,7 @@ pub trait Table: std::fmt::Display + Send + Sync {
fn schema(&self) -> SchemaRef; fn schema(&self) -> SchemaRef;
/// Count the number of rows in this dataset. /// Count the number of rows in this dataset.
/// async fn count_rows(&self) -> Result<usize>;
/// # Arguments
///
/// * `filter` if present, only count rows matching the filter
async fn count_rows(&self, filter: Option<String>) -> Result<usize>;
/// Insert new records into this Table /// Insert new records into this Table
/// ///
@@ -180,71 +170,6 @@ pub trait Table: std::fmt::Display + Send + Sync {
/// ``` /// ```
fn create_index(&self, column: &[&str]) -> IndexBuilder; fn create_index(&self, column: &[&str]) -> IndexBuilder;
/// Create a builder for a merge insert operation
///
/// This operation can add rows, update rows, and remove rows all in a single
/// transaction. It is a very generic tool that can be used to create
/// behaviors like "insert if not exists", "update or insert (i.e. upsert)",
/// or even replace a portion of existing data with new data (e.g. replace
/// all data where month="january")
///
/// The merge insert operation works by combining new data from a
/// **source table** with existing data in a **target table** by using a
/// join. There are three categories of records.
///
/// "Matched" records are records that exist in both the source table and
/// the target table. "Not matched" records exist only in the source table
/// (e.g. these are new data) "Not matched by source" records exist only
/// in the target table (this is old data)
///
/// The builder returned by this method can be used to customize what
/// should happen for each category of data.
///
/// Please note that the data may appear to be reordered as part of this
/// operation. This is because updated rows will be deleted from the
/// dataset and then reinserted at the end with the new values.
///
/// # Arguments
///
/// * `on` One or more columns to join on. This is how records from the
/// source table and target table are matched. Typically this is some
/// kind of key or id column.
///
/// # Examples
///
/// ```no_run
/// # use std::sync::Arc;
/// # use vectordb::connection::{Database, Connection};
/// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch,
/// # RecordBatchIterator, Int32Array};
/// # use arrow_schema::{Schema, Field, DataType};
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let tmpdir = tempfile::tempdir().unwrap();
/// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap();
/// # let tbl = db.open_table("idx_test").await.unwrap();
/// # let schema = Arc::new(Schema::new(vec![
/// # Field::new("id", DataType::Int32, false),
/// # Field::new("vector", DataType::FixedSizeList(
/// # Arc::new(Field::new("item", DataType::Float32, true)), 128), true),
/// # ]));
/// let new_data = RecordBatchIterator::new(vec![
/// RecordBatch::try_new(schema.clone(),
/// vec![
/// Arc::new(Int32Array::from_iter_values(0..10)),
/// Arc::new(FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
/// (0..10).map(|_| Some(vec![Some(1.0); 128])), 128)),
/// ]).unwrap()
/// ].into_iter().map(Ok),
/// schema.clone());
/// // Perform an upsert operation
/// let mut merge_insert = tbl.merge_insert(&["id"]);
/// merge_insert.when_matched_update_all(None)
/// .when_not_matched_insert_all();
/// merge_insert.execute(Box::new(new_data)).await.unwrap();
/// # });
/// ```
fn merge_insert(&self, on: &[&str]) -> MergeInsertBuilder;
/// Search the table with a given query vector. /// Search the table with a given query vector.
/// ///
/// This is a convenience method for preparing an ANN query. /// This is a convenience method for preparing an ANN query.
@@ -389,7 +314,7 @@ impl NativeTable {
message: e.to_string(), message: e.to_string(),
}, },
})?; })?;
Ok(Self { Ok(NativeTable {
name: name.to_string(), name: name.to_string(),
uri: uri.to_string(), uri: uri.to_string(),
dataset: Arc::new(Mutex::new(dataset)), dataset: Arc::new(Mutex::new(dataset)),
@@ -431,7 +356,7 @@ impl NativeTable {
message: e.to_string(), message: e.to_string(),
}, },
})?; })?;
Ok(Self { Ok(NativeTable {
name: name.to_string(), name: name.to_string(),
uri: uri.to_string(), uri: uri.to_string(),
dataset: Arc::new(Mutex::new(dataset)), dataset: Arc::new(Mutex::new(dataset)),
@@ -505,7 +430,7 @@ impl NativeTable {
message: e.to_string(), message: e.to_string(),
}, },
})?; })?;
Ok(Self { Ok(NativeTable {
name: name.to_string(), name: name.to_string(),
uri: uri.to_string(), uri: uri.to_string(),
dataset: Arc::new(Mutex::new(dataset)), dataset: Arc::new(Mutex::new(dataset)),
@@ -668,45 +593,6 @@ impl NativeTable {
} }
} }
#[async_trait]
impl MergeInsert for NativeTable {
async fn do_merge_insert(
&self,
params: MergeInsertBuilder,
new_data: Box<dyn RecordBatchReader + Send>,
) -> Result<()> {
let dataset = Arc::new(self.clone_inner_dataset());
let mut builder = LanceMergeInsertBuilder::try_new(dataset.clone(), params.on)?;
match (
params.when_matched_update_all,
params.when_matched_update_all_filt,
) {
(false, _) => builder.when_matched(WhenMatched::DoNothing),
(true, None) => builder.when_matched(WhenMatched::UpdateAll),
(true, Some(filt)) => builder.when_matched(WhenMatched::update_if(&dataset, &filt)?),
};
if params.when_not_matched_insert_all {
builder.when_not_matched(lance::dataset::WhenNotMatched::InsertAll);
} else {
builder.when_not_matched(lance::dataset::WhenNotMatched::DoNothing);
}
if params.when_not_matched_by_source_delete {
let behavior = if let Some(filter) = params.when_not_matched_by_source_delete_filt {
WhenNotMatchedBySource::delete_if(dataset.as_ref(), &filter)?
} else {
WhenNotMatchedBySource::Delete
};
builder.when_not_matched_by_source(behavior);
} else {
builder.when_not_matched_by_source(WhenNotMatchedBySource::Keep);
}
let job = builder.try_build()?;
let new_dataset = job.execute_reader(new_data).await?;
self.reset_dataset((*new_dataset).clone());
Ok(())
}
}
#[async_trait::async_trait] #[async_trait::async_trait]
impl Table for NativeTable { impl Table for NativeTable {
fn as_any(&self) -> &dyn std::any::Any { fn as_any(&self) -> &dyn std::any::Any {
@@ -726,15 +612,9 @@ impl Table for NativeTable {
Arc::new(Schema::from(&lance_schema)) Arc::new(Schema::from(&lance_schema))
} }
async fn count_rows(&self, filter: Option<String>) -> Result<usize> { async fn count_rows(&self) -> Result<usize> {
let dataset = { self.dataset.lock().expect("lock poison").clone() }; let dataset = { self.dataset.lock().expect("lock poison").clone() };
if let Some(filter) = filter { Ok(dataset.count_rows().await?)
let mut scanner = dataset.scan();
scanner.filter(&filter)?;
Ok(scanner.count_rows().await? as usize)
} else {
Ok(dataset.count_rows().await?)
}
} }
async fn add( async fn add(
@@ -757,11 +637,6 @@ impl Table for NativeTable {
Ok(()) Ok(())
} }
fn merge_insert(&self, on: &[&str]) -> MergeInsertBuilder {
let on = Vec::from_iter(on.iter().map(|key| key.to_string()));
MergeInsertBuilder::new(Arc::new(self.clone()), on)
}
fn create_index(&self, columns: &[&str]) -> IndexBuilder { fn create_index(&self, columns: &[&str]) -> IndexBuilder {
IndexBuilder::new(Arc::new(self.clone()), columns) IndexBuilder::new(Arc::new(self.clone()), columns)
} }
@@ -827,7 +702,6 @@ impl Table for NativeTable {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::iter;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
@@ -888,35 +762,18 @@ mod tests {
let batches = make_test_batches(); let batches = make_test_batches();
let _ = batches.schema().clone(); let _ = batches.schema().clone();
NativeTable::create(uri, "test", batches, None, None) NativeTable::create(&uri, "test", batches, None, None)
.await .await
.unwrap(); .unwrap();
let batches = make_test_batches(); let batches = make_test_batches();
let result = NativeTable::create(uri, "test", batches, None, None).await; let result = NativeTable::create(&uri, "test", batches, None, None).await;
assert!(matches!( assert!(matches!(
result.unwrap_err(), result.unwrap_err(),
Error::TableAlreadyExists { .. } Error::TableAlreadyExists { .. }
)); ));
} }
#[tokio::test]
async fn test_count_rows() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let batches = make_test_batches();
let table = NativeTable::create(uri, "test", batches, None, None)
.await
.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 10);
assert_eq!(
table.count_rows(Some("i >= 5".to_string())).await.unwrap(),
5
);
}
#[tokio::test] #[tokio::test]
async fn test_add() { async fn test_add() {
let tmp_dir = tempdir().unwrap(); let tmp_dir = tempdir().unwrap();
@@ -924,10 +781,10 @@ mod tests {
let batches = make_test_batches(); let batches = make_test_batches();
let schema = batches.schema().clone(); let schema = batches.schema().clone();
let table = NativeTable::create(uri, "test", batches, None, None) let table = NativeTable::create(&uri, "test", batches, None, None)
.await .await
.unwrap(); .unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 10); assert_eq!(table.count_rows().await.unwrap(), 10);
let new_batches = RecordBatchIterator::new( let new_batches = RecordBatchIterator::new(
vec![RecordBatch::try_new( vec![RecordBatch::try_new(
@@ -941,56 +798,10 @@ mod tests {
); );
table.add(Box::new(new_batches), None).await.unwrap(); table.add(Box::new(new_batches), None).await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 20); assert_eq!(table.count_rows().await.unwrap(), 20);
assert_eq!(table.name, "test"); assert_eq!(table.name, "test");
} }
#[tokio::test]
async fn test_merge_insert() {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
// Create a dataset with i=0..10
let batches = merge_insert_test_batches(0, 0);
let table = NativeTable::create(uri, "test", batches, None, None)
.await
.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 10);
// Create new data with i=5..15
let new_batches = Box::new(merge_insert_test_batches(5, 1));
// Perform a "insert if not exists"
let mut merge_insert_builder = table.merge_insert(&["i"]);
merge_insert_builder.when_not_matched_insert_all();
merge_insert_builder.execute(new_batches).await.unwrap();
// Only 5 rows should actually be inserted
assert_eq!(table.count_rows(None).await.unwrap(), 15);
// Create new data with i=15..25 (no id matches)
let new_batches = Box::new(merge_insert_test_batches(15, 2));
// Perform a "bulk update" (should not affect anything)
let mut merge_insert_builder = table.merge_insert(&["i"]);
merge_insert_builder.when_matched_update_all(None);
merge_insert_builder.execute(new_batches).await.unwrap();
// No new rows should have been inserted
assert_eq!(table.count_rows(None).await.unwrap(), 15);
assert_eq!(
table.count_rows(Some("age = 2".to_string())).await.unwrap(),
0
);
// Conditional update that only replaces the age=0 data
let new_batches = Box::new(merge_insert_test_batches(5, 3));
let mut merge_insert_builder = table.merge_insert(&["i"]);
merge_insert_builder.when_matched_update_all(Some("target.age = 0".to_string()));
merge_insert_builder.execute(new_batches).await.unwrap();
assert_eq!(
table.count_rows(Some("age = 3".to_string())).await.unwrap(),
5
);
}
#[tokio::test] #[tokio::test]
async fn test_add_overwrite() { async fn test_add_overwrite() {
let tmp_dir = tempdir().unwrap(); let tmp_dir = tempdir().unwrap();
@@ -1001,7 +812,7 @@ mod tests {
let table = NativeTable::create(uri, "test", batches, None, None) let table = NativeTable::create(uri, "test", batches, None, None)
.await .await
.unwrap(); .unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 10); assert_eq!(table.count_rows().await.unwrap(), 10);
let new_batches = RecordBatchIterator::new( let new_batches = RecordBatchIterator::new(
vec![RecordBatch::try_new( vec![RecordBatch::try_new(
@@ -1020,7 +831,7 @@ mod tests {
}; };
table.add(Box::new(new_batches), Some(param)).await.unwrap(); table.add(Box::new(new_batches), Some(param)).await.unwrap();
assert_eq!(table.count_rows(None).await.unwrap(), 10); assert_eq!(table.count_rows().await.unwrap(), 10);
assert_eq!(table.name, "test"); assert_eq!(table.name, "test");
} }
@@ -1149,8 +960,12 @@ mod tests {
Arc::new(LargeStringArray::from_iter_values(vec![ Arc::new(LargeStringArray::from_iter_values(vec![
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
])), ])),
Arc::new(Float32Array::from_iter_values((0..10).map(|i| i as f32))), Arc::new(Float32Array::from_iter_values(
Arc::new(Float64Array::from_iter_values((0..10).map(|i| i as f64))), (0..10).into_iter().map(|i| i as f32),
)),
Arc::new(Float64Array::from_iter_values(
(0..10).into_iter().map(|i| i as f64),
)),
Arc::new(Into::<BooleanArray>::into(vec![ Arc::new(Into::<BooleanArray>::into(vec![
true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false,
])), ])),
@@ -1159,14 +974,14 @@ mod tests {
Arc::new(TimestampMillisecondArray::from_iter_values(0..10)), Arc::new(TimestampMillisecondArray::from_iter_values(0..10)),
Arc::new( Arc::new(
create_fixed_size_list( create_fixed_size_list(
Float32Array::from_iter_values((0..20).map(|i| i as f32)), Float32Array::from_iter_values((0..20).into_iter().map(|i| i as f32)),
2, 2,
) )
.unwrap(), .unwrap(),
), ),
Arc::new( Arc::new(
create_fixed_size_list( create_fixed_size_list(
Float64Array::from_iter_values((0..20).map(|i| i as f64)), Float64Array::from_iter_values((0..20).into_iter().map(|i| i as f64)),
2, 2,
) )
.unwrap(), .unwrap(),
@@ -1303,7 +1118,7 @@ mod tests {
original: Arc<dyn object_store::ObjectStore>, original: Arc<dyn object_store::ObjectStore>,
) -> Arc<dyn object_store::ObjectStore> { ) -> Arc<dyn object_store::ObjectStore> {
self.called.store(true, Ordering::Relaxed); self.called.store(true, Ordering::Relaxed);
original return original;
} }
} }
@@ -1320,10 +1135,8 @@ mod tests {
let wrapper = Arc::new(NoOpCacheWrapper::default()); let wrapper = Arc::new(NoOpCacheWrapper::default());
let object_store_params = ObjectStoreParams { let mut object_store_params = ObjectStoreParams::default();
object_store_wrapper: Some(wrapper.clone()), object_store_params.object_store_wrapper = Some(wrapper.clone());
..Default::default()
};
let param = ReadParams { let param = ReadParams {
store_options: Some(object_store_params), store_options: Some(object_store_params),
..Default::default() ..Default::default()
@@ -1335,26 +1148,6 @@ mod tests {
assert!(wrapper.called()); assert!(wrapper.called());
} }
fn merge_insert_test_batches(
offset: i32,
age: i32,
) -> impl RecordBatchReader + Send + Sync + 'static {
let schema = Arc::new(Schema::new(vec![
Field::new("i", DataType::Int32, false),
Field::new("age", DataType::Int32, false),
]));
RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(offset..(offset + 10))),
Arc::new(Int32Array::from_iter_values(iter::repeat(age).take(10))),
],
)],
schema,
)
}
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static { fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)])); let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
RecordBatchIterator::new( RecordBatchIterator::new(
@@ -1420,7 +1213,7 @@ mod tests {
.unwrap(); .unwrap();
assert_eq!(table.load_indices().await.unwrap().len(), 1); assert_eq!(table.load_indices().await.unwrap().len(), 1);
assert_eq!(table.count_rows(None).await.unwrap(), 512); assert_eq!(table.count_rows().await.unwrap(), 512);
assert_eq!(table.name, "test"); assert_eq!(table.name, "test");
let indices = table.load_indices().await.unwrap(); let indices = table.load_indices().await.unwrap();

View File

@@ -1,111 +0,0 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use arrow_array::RecordBatchReader;
use async_trait::async_trait;
use crate::Result;
#[async_trait]
pub(super) trait MergeInsert: Send + Sync {
async fn do_merge_insert(
&self,
params: MergeInsertBuilder,
new_data: Box<dyn RecordBatchReader + Send>,
) -> Result<()>;
}
/// A builder used to create and run a merge insert operation
///
/// See [`super::Table::merge_insert`] for more context
pub struct MergeInsertBuilder {
table: Arc<dyn MergeInsert>,
pub(super) on: Vec<String>,
pub(super) when_matched_update_all: bool,
pub(super) when_matched_update_all_filt: Option<String>,
pub(super) when_not_matched_insert_all: bool,
pub(super) when_not_matched_by_source_delete: bool,
pub(super) when_not_matched_by_source_delete_filt: Option<String>,
}
impl MergeInsertBuilder {
pub(super) fn new(table: Arc<dyn MergeInsert>, on: Vec<String>) -> Self {
Self {
table,
on,
when_matched_update_all: false,
when_matched_update_all_filt: None,
when_not_matched_insert_all: false,
when_not_matched_by_source_delete: false,
when_not_matched_by_source_delete_filt: None,
}
}
/// Rows that exist in both the source table (new data) and
/// the target table (old data) will be updated, replacing
/// the old row with the corresponding matching row.
///
/// If there are multiple matches then the behavior is undefined.
/// Currently this causes multiple copies of the row to be created
/// but that behavior is subject to change.
///
/// An optional condition may be specified. If it is, then only
/// matched rows that satisfy the condtion will be updated. Any
/// rows that do not satisfy the condition will be left as they
/// are. Failing to satisfy the condition does not cause a
/// "matched row" to become a "not matched" row.
///
/// The condition should be an SQL string. Use the prefix
/// target. to refer to rows in the target table (old data)
/// and the prefix source. to refer to rows in the source
/// table (new data).
///
/// For example, "target.last_update < source.last_update"
pub fn when_matched_update_all(&mut self, condition: Option<String>) -> &mut Self {
self.when_matched_update_all = true;
self.when_matched_update_all_filt = condition;
self
}
/// Rows that exist only in the source table (new data) should
/// be inserted into the target table.
pub fn when_not_matched_insert_all(&mut self) -> &mut Self {
self.when_not_matched_insert_all = true;
self
}
/// Rows that exist only in the target table (old data) will be
/// deleted. An optional condition can be provided to limit what
/// data is deleted.
///
/// # Arguments
///
/// * `condition` - If None then all such rows will be deleted.
/// Otherwise the condition will be used as an SQL filter to
/// limit what rows are deleted.
pub fn when_not_matched_by_source_delete(&mut self, filter: Option<String>) -> &mut Self {
self.when_not_matched_by_source_delete = true;
self.when_not_matched_by_source_delete_filt = filter;
self
}
/// Executes the merge insert operation
///
/// Nothing is returned but the [`super::Table`] is updated
pub async fn execute(self, new_data: Box<dyn RecordBatchReader + Send>) -> Result<()> {
self.table.clone().do_merge_insert(self, new_data).await
}
}