mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
Compare commits
57 Commits
v0.4.7
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
62c5117def | ||
|
|
22c196b3e3 | ||
|
|
1f4ac71fa3 | ||
|
|
b5aad2d856 | ||
|
|
ca6f55b160 | ||
|
|
6f8cf1e068 | ||
|
|
e0277383a5 | ||
|
|
d6b408e26f | ||
|
|
2447372c1f | ||
|
|
f0298d8372 | ||
|
|
54693e6bec | ||
|
|
73b2977bff | ||
|
|
aec85f7875 | ||
|
|
51f92ecb3d | ||
|
|
5b60412d66 | ||
|
|
53d63966a9 | ||
|
|
5ba87575e7 | ||
|
|
cc5f2136a6 | ||
|
|
78e5fb5451 | ||
|
|
8104c5c18e | ||
|
|
4fbabdeec3 | ||
|
|
eb31d95fef | ||
|
|
3169c36525 | ||
|
|
1b990983b3 | ||
|
|
0c21f91c16 | ||
|
|
7e50c239eb | ||
|
|
24e8043150 | ||
|
|
990440385d | ||
|
|
a693a9d897 | ||
|
|
82936c77ef | ||
|
|
dddcddcaf9 | ||
|
|
a9727eb318 | ||
|
|
48d55bf952 | ||
|
|
d2e71c8b08 | ||
|
|
f53aace89c | ||
|
|
d982ee934a | ||
|
|
57605a2d86 | ||
|
|
738511c5f2 | ||
|
|
0b0f42537e | ||
|
|
e412194008 | ||
|
|
a9088224c5 | ||
|
|
688c57a0d8 | ||
|
|
12a98deded | ||
|
|
e4bb042918 | ||
|
|
04e1662681 | ||
|
|
ce2242e06d | ||
|
|
778339388a | ||
|
|
7f8637a0b4 | ||
|
|
09cd08222d | ||
|
|
a248d7feec | ||
|
|
cc9473a94a | ||
|
|
d77e95a4f4 | ||
|
|
62f053ac92 | ||
|
|
34e10caad2 | ||
|
|
f5726e2d0c | ||
|
|
12b4fb42fc | ||
|
|
1328cd46f1 |
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.4.7
|
||||
current_version = 0.4.10
|
||||
commit = True
|
||||
message = Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
40
.cargo/config.toml
Normal file
40
.cargo/config.toml
Normal file
@@ -0,0 +1,40 @@
|
||||
[profile.release]
|
||||
lto = "fat"
|
||||
codegen-units = 1
|
||||
|
||||
[profile.release-with-debug]
|
||||
inherits = "release"
|
||||
debug = true
|
||||
# Prioritize compile time over runtime performance
|
||||
codegen-units = 16
|
||||
lto = "thin"
|
||||
|
||||
[target.'cfg(all())']
|
||||
rustflags = [
|
||||
"-Wclippy::all",
|
||||
"-Wclippy::style",
|
||||
"-Wclippy::fallible_impl_from",
|
||||
"-Wclippy::manual_let_else",
|
||||
"-Wclippy::redundant_pub_crate",
|
||||
"-Wclippy::string_add_assign",
|
||||
"-Wclippy::string_add",
|
||||
"-Wclippy::string_lit_as_bytes",
|
||||
"-Wclippy::string_to_string",
|
||||
"-Wclippy::use_self",
|
||||
"-Dclippy::cargo",
|
||||
"-Dclippy::dbg_macro",
|
||||
# not too much we can do to avoid multiple crate versions
|
||||
"-Aclippy::multiple-crate-versions",
|
||||
"-Aclippy::wildcard_dependencies",
|
||||
]
|
||||
|
||||
[target.x86_64-unknown-linux-gnu]
|
||||
rustflags = ["-C", "target-cpu=haswell", "-C", "target-feature=+avx2,+fma,+f16c"]
|
||||
|
||||
[target.aarch64-apple-darwin]
|
||||
rustflags = ["-C", "target-cpu=apple-m1", "-C", "target-feature=+neon,+fp16,+fhm,+dotprod"]
|
||||
|
||||
# Not all Windows systems have the C runtime installed, so this avoids library
|
||||
# not found errors on systems that are missing it.
|
||||
[target.x86_64-pc-windows-msvc]
|
||||
rustflags = ["-Ctarget-feature=+crt-static"]
|
||||
2
.github/workflows/docs.yml
vendored
2
.github/workflows/docs.yml
vendored
@@ -29,7 +29,7 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
cache: "pip"
|
||||
|
||||
11
.github/workflows/docs_test.yml
vendored
11
.github/workflows/docs_test.yml
vendored
@@ -29,7 +29,7 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.11
|
||||
cache: "pip"
|
||||
@@ -49,6 +49,9 @@ jobs:
|
||||
test-node:
|
||||
name: Test doc nodejs code
|
||||
runs-on: "ubuntu-latest"
|
||||
timeout-minutes: 45
|
||||
strategy:
|
||||
fail-fast: false
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -66,6 +69,12 @@ jobs:
|
||||
uses: swatinem/rust-cache@v2
|
||||
- name: Install node dependencies
|
||||
run: |
|
||||
sudo swapoff -a
|
||||
sudo fallocate -l 8G /swapfile
|
||||
sudo chmod 600 /swapfile
|
||||
sudo mkswap /swapfile
|
||||
sudo swapon /swapfile
|
||||
sudo swapon --show
|
||||
cd node
|
||||
npm ci
|
||||
npm run build-release
|
||||
|
||||
6
.github/workflows/make-release-commit.yml
vendored
6
.github/workflows/make-release-commit.yml
vendored
@@ -37,10 +37,10 @@ jobs:
|
||||
run: |
|
||||
git config user.name 'Lance Release'
|
||||
git config user.email 'lance-dev@lancedb.com'
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v4
|
||||
- name: Set up Python 3.11
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
python-version: "3.11"
|
||||
- name: Bump version, create tag and commit
|
||||
run: |
|
||||
pip install bump2version
|
||||
|
||||
17
.github/workflows/npm-publish.yml
vendored
17
.github/workflows/npm-publish.yml
vendored
@@ -80,10 +80,25 @@ jobs:
|
||||
- arch: x86_64
|
||||
runner: ubuntu-latest
|
||||
- arch: aarch64
|
||||
runner: buildjet-4vcpu-ubuntu-2204-arm
|
||||
# For successful fat LTO builds, we need a large runner to avoid OOM errors.
|
||||
runner: buildjet-16vcpu-ubuntu-2204-arm
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
# Buildjet aarch64 runners have only 1.5 GB RAM per core, vs 3.5 GB per core for
|
||||
# x86_64 runners. To avoid OOM errors on ARM, we create a swap file.
|
||||
- name: Configure aarch64 build
|
||||
if: ${{ matrix.config.arch == 'aarch64' }}
|
||||
run: |
|
||||
free -h
|
||||
sudo fallocate -l 16G /swapfile
|
||||
sudo chmod 600 /swapfile
|
||||
sudo mkswap /swapfile
|
||||
sudo swapon /swapfile
|
||||
echo "/swapfile swap swap defaults 0 0" >> sudo /etc/fstab
|
||||
# print info
|
||||
swapon --show
|
||||
free -h
|
||||
- name: Build Linux Artifacts
|
||||
run: |
|
||||
bash ci/build_linux_artifacts.sh ${{ matrix.config.arch }}
|
||||
|
||||
2
.github/workflows/pypi-publish.yml
vendored
2
.github/workflows/pypi-publish.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Build distribution
|
||||
|
||||
@@ -37,10 +37,10 @@ jobs:
|
||||
run: |
|
||||
git config user.name 'Lance Release'
|
||||
git config user.email 'lance-dev@lancedb.com'
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
python-version: "3.11"
|
||||
- name: Bump version, create tag and commit
|
||||
working-directory: python
|
||||
run: |
|
||||
|
||||
6
.github/workflows/python.yml
vendored
6
.github/workflows/python.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
lfs: true
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.${{ matrix.python-minor-version }}
|
||||
- name: Install lancedb
|
||||
@@ -69,7 +69,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
lfs: true
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install lancedb
|
||||
@@ -92,7 +92,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
lfs: true
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.9
|
||||
- name: Install lancedb
|
||||
|
||||
13
Cargo.toml
13
Cargo.toml
@@ -6,15 +6,18 @@ resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
edition = "2021"
|
||||
authors = ["Lance Devs <dev@lancedb.com>"]
|
||||
authors = ["LanceDB Devs <dev@lancedb.com>"]
|
||||
license = "Apache-2.0"
|
||||
repository = "https://github.com/lancedb/lancedb"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
keywords = ["lancedb", "lance", "database", "vector", "search"]
|
||||
categories = ["database-implementations"]
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.9.10", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.9.10" }
|
||||
lance-linalg = { "version" = "=0.9.10" }
|
||||
lance-testing = { "version" = "=0.9.10" }
|
||||
lance = { "version" = "=0.9.18", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.9.18" }
|
||||
lance-linalg = { "version" = "=0.9.18" }
|
||||
lance-testing = { "version" = "=0.9.18" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "50.0", optional = false }
|
||||
arrow-array = "50.0"
|
||||
|
||||
@@ -13,7 +13,9 @@ docker build \
|
||||
.
|
||||
popd
|
||||
|
||||
# We turn on memory swap to avoid OOM killer
|
||||
docker run \
|
||||
-v $(pwd):/io -w /io \
|
||||
--memory-swap=-1 \
|
||||
lancedb-node-manylinux \
|
||||
bash ci/manylinux_node/build.sh $ARCH
|
||||
|
||||
@@ -90,16 +90,18 @@ nav:
|
||||
- Building an ANN index: ann_indexes.md
|
||||
- Vector Search: search.md
|
||||
- Full-text search: fts.md
|
||||
- Hybrid search: hybrid_search.md
|
||||
- Hybrid search:
|
||||
- Overview: hybrid_search/hybrid_search.md
|
||||
- Comparing Rerankers: hybrid_search/eval.md
|
||||
- Airbnb financial data example: notebooks/hybrid_search.ipynb
|
||||
- Filtering: sql.md
|
||||
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||
- Configuring Storage: guides/storage.md
|
||||
- 🧬 Managing embeddings:
|
||||
- Overview: embeddings/index.md
|
||||
- Explicit management: embeddings/embedding_explicit.md
|
||||
- Implicit management: embeddings/embedding_functions.md
|
||||
- Available Functions: embeddings/default_embedding_functions.md
|
||||
- Custom Embedding Functions: embeddings/api.md
|
||||
- Embedding functions: embeddings/embedding_functions.md
|
||||
- Available models: embeddings/default_embedding_functions.md
|
||||
- User-defined embedding functions: embeddings/custom_embedding_function.md
|
||||
- "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb
|
||||
- "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb
|
||||
- 🔌 Integrations:
|
||||
@@ -152,16 +154,18 @@ nav:
|
||||
- Building an ANN index: ann_indexes.md
|
||||
- Vector Search: search.md
|
||||
- Full-text search: fts.md
|
||||
- Hybrid search: hybrid_search.md
|
||||
- Hybrid search:
|
||||
- Overview: hybrid_search/hybrid_search.md
|
||||
- Comparing Rerankers: hybrid_search/eval.md
|
||||
- Airbnb financial data example: notebooks/hybrid_search.ipynb
|
||||
- Filtering: sql.md
|
||||
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||
- Configuring Storage: guides/storage.md
|
||||
- Managing Embeddings:
|
||||
- Overview: embeddings/index.md
|
||||
- Explicit management: embeddings/embedding_explicit.md
|
||||
- Implicit management: embeddings/embedding_functions.md
|
||||
- Available Functions: embeddings/default_embedding_functions.md
|
||||
- Custom Embedding Functions: embeddings/api.md
|
||||
- Embedding functions: embeddings/embedding_functions.md
|
||||
- Available models: embeddings/default_embedding_functions.md
|
||||
- User-defined embedding functions: embeddings/custom_embedding_function.md
|
||||
- "Example: Multi-lingual semantic search": notebooks/multi_lingual_example.ipynb
|
||||
- "Example: MultiModal CLIP Embeddings": notebooks/DisappearingEmbeddingFunction.ipynb
|
||||
- Integrations:
|
||||
@@ -202,6 +206,7 @@ extra_css:
|
||||
|
||||
extra_javascript:
|
||||
- "extra_js/init_ask_ai_widget.js"
|
||||
- "extra_js/meta_tag.js"
|
||||
|
||||
extra:
|
||||
analytics:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// --8<-- [start:import]
|
||||
import * as lancedb from "vectordb";
|
||||
import { Schema, Field, Float32, FixedSizeList, Int32 } from "apache-arrow";
|
||||
import { Schema, Field, Float32, FixedSizeList, Int32, Float16 } from "apache-arrow";
|
||||
// --8<-- [end:import]
|
||||
import * as fs from "fs";
|
||||
import { Table as ArrowTable, Utf8 } from "apache-arrow";
|
||||
@@ -8,6 +8,7 @@ import { Table as ArrowTable, Utf8 } from "apache-arrow";
|
||||
const example = async () => {
|
||||
fs.rmSync("data/sample-lancedb", { recursive: true, force: true });
|
||||
// --8<-- [start:open_db]
|
||||
const lancedb = require("vectordb");
|
||||
const uri = "data/sample-lancedb";
|
||||
const db = await lancedb.connect(uri);
|
||||
// --8<-- [end:open_db]
|
||||
@@ -48,6 +49,27 @@ const example = async () => {
|
||||
const empty_tbl = await db.createTable({ name: "empty_table", schema });
|
||||
// --8<-- [end:create_empty_table]
|
||||
|
||||
// --8<-- [start:create_f16_table]
|
||||
const dim = 16
|
||||
const total = 10
|
||||
const f16_schema = new Schema([
|
||||
new Field('id', new Int32()),
|
||||
new Field(
|
||||
'vector',
|
||||
new FixedSizeList(dim, new Field('item', new Float16(), true)),
|
||||
false
|
||||
)
|
||||
])
|
||||
const data = lancedb.makeArrowTable(
|
||||
Array.from(Array(total), (_, i) => ({
|
||||
id: i,
|
||||
vector: Array.from(Array(dim), Math.random)
|
||||
})),
|
||||
{ f16_schema }
|
||||
)
|
||||
const table = await db.createTable('f16_tbl', data)
|
||||
// --8<-- [end:create_f16_table]
|
||||
|
||||
// --8<-- [start:search]
|
||||
const query = await tbl.search([100, 100]).limit(2).execute();
|
||||
// --8<-- [end:search]
|
||||
|
||||
@@ -17,6 +17,7 @@ Let's implement `SentenceTransformerEmbeddings` class. All you need to do is imp
|
||||
|
||||
```python
|
||||
from lancedb.embeddings import register
|
||||
from lancedb.util import attempt_import_or_raise
|
||||
|
||||
@register("sentence-transformers")
|
||||
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
@@ -81,7 +82,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
open_clip = self.safe_import("open_clip", "open-clip") # EmbeddingFunction util to import external libs and raise if not found
|
||||
open_clip = attempt_import_or_raise("open_clip", "open-clip") # EmbeddingFunction util to import external libs and raise if not found
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||
self.name, pretrained=self.pretrained
|
||||
)
|
||||
@@ -109,14 +110,14 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
||||
if isinstance(query, str):
|
||||
return [self.generate_text_embeddings(query)]
|
||||
else:
|
||||
PIL = self.safe_import("PIL", "pillow")
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(query, PIL.Image.Image):
|
||||
return [self.generate_image_embedding(query)]
|
||||
else:
|
||||
raise TypeError("OpenClip supports str or PIL Image as query")
|
||||
|
||||
def generate_text_embeddings(self, text: str) -> np.ndarray:
|
||||
torch = self.safe_import("torch")
|
||||
torch = attempt_import_or_raise("torch")
|
||||
text = self.sanitize_input(text)
|
||||
text = self._tokenizer(text)
|
||||
text.to(self.device)
|
||||
@@ -175,7 +176,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
||||
The image to embed. If the image is a str, it is treated as a uri.
|
||||
If the image is bytes, it is treated as the raw image bytes.
|
||||
"""
|
||||
torch = self.safe_import("torch")
|
||||
torch = attempt_import_or_raise("torch")
|
||||
# TODO handle retry and errors for https
|
||||
image = self._to_pil(image)
|
||||
image = self._preprocess(image).unsqueeze(0)
|
||||
@@ -183,7 +184,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
||||
return self._encode_and_normalize_image(image)
|
||||
|
||||
def _to_pil(self, image: Union[str, bytes]):
|
||||
PIL = self.safe_import("PIL", "pillow")
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(image, bytes):
|
||||
return PIL.Image.open(io.BytesIO(image))
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
@@ -9,6 +9,9 @@ Contains the text embedding functions registered by default.
|
||||
### Sentence transformers
|
||||
Allows you to set parameters when registering a `sentence-transformers` object.
|
||||
|
||||
!!! info
|
||||
Sentence transformer embeddings are normalized by default. It is recommended to use normalized embeddings for similarity search.
|
||||
|
||||
| Parameter | Type | Default Value | Description |
|
||||
|---|---|---|---|
|
||||
| `name` | `str` | `all-MiniLM-L6-v2` | The name of the model |
|
||||
|
||||
@@ -1,141 +0,0 @@
|
||||
In this workflow, you define your own embedding function and pass it as a callable to LanceDB, invoking it in your code to generate the embeddings. Let's look at some examples.
|
||||
|
||||
### Hugging Face
|
||||
|
||||
!!! note
|
||||
Currently, the Hugging Face method is only supported in the Python SDK.
|
||||
|
||||
=== "Python"
|
||||
The most popular open source option is to use the [sentence-transformers](https://www.sbert.net/)
|
||||
library, which can be installed via pip.
|
||||
|
||||
```bash
|
||||
pip install sentence-transformers
|
||||
```
|
||||
|
||||
The example below shows how to use the `paraphrase-albert-small-v2` model to generate embeddings
|
||||
for a given document.
|
||||
|
||||
```python
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
name="paraphrase-albert-small-v2"
|
||||
model = SentenceTransformer(name)
|
||||
|
||||
# used for both training and querying
|
||||
def embed_func(batch):
|
||||
return [model.encode(sentence) for sentence in batch]
|
||||
```
|
||||
|
||||
### OpenAI
|
||||
|
||||
Another popular alternative is to use an external API like OpenAI's [embeddings API](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings).
|
||||
|
||||
=== "Python"
|
||||
```python
|
||||
import openai
|
||||
import os
|
||||
|
||||
# Configuring the environment variable OPENAI_API_KEY
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
# OR set the key here as a variable
|
||||
openai.api_key = "sk-..."
|
||||
|
||||
# verify that the API key is working
|
||||
assert len(openai.Model.list()["data"]) > 0
|
||||
|
||||
def embed_func(c):
|
||||
rs = openai.Embedding.create(input=c, engine="text-embedding-ada-002")
|
||||
return [record["embedding"] for record in rs["data"]]
|
||||
```
|
||||
|
||||
=== "JavaScript"
|
||||
```javascript
|
||||
const lancedb = require("vectordb");
|
||||
|
||||
// You need to provide an OpenAI API key
|
||||
const apiKey = "sk-..."
|
||||
// The embedding function will create embeddings for the 'text' column
|
||||
const embedding = new lancedb.OpenAIEmbeddingFunction('text', apiKey)
|
||||
```
|
||||
|
||||
## Applying an embedding function to data
|
||||
|
||||
=== "Python"
|
||||
Using an embedding function, you can apply it to raw data
|
||||
to generate embeddings for each record.
|
||||
|
||||
Say you have a pandas DataFrame with a `text` column that you want embedded,
|
||||
you can use the `with_embeddings` function to generate embeddings and add them to
|
||||
an existing table.
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
from lancedb.embeddings import with_embeddings
|
||||
|
||||
df = pd.DataFrame(
|
||||
[
|
||||
{"text": "pepperoni"},
|
||||
{"text": "pineapple"}
|
||||
]
|
||||
)
|
||||
data = with_embeddings(embed_func, df)
|
||||
|
||||
# The output is used to create / append to a table
|
||||
# db.create_table("my_table", data=data)
|
||||
```
|
||||
|
||||
If your data is in a different column, you can specify the `column` kwarg to `with_embeddings`.
|
||||
|
||||
By default, LanceDB calls the function with batches of 1000 rows. This can be configured
|
||||
using the `batch_size` parameter to `with_embeddings`.
|
||||
|
||||
LanceDB automatically wraps the function with retry and rate-limit logic to ensure the OpenAI
|
||||
API call is reliable.
|
||||
|
||||
=== "JavaScript"
|
||||
Using an embedding function, you can apply it to raw data
|
||||
to generate embeddings for each record.
|
||||
|
||||
Simply pass the embedding function created above and LanceDB will use it to generate
|
||||
embeddings for your data.
|
||||
|
||||
```javascript
|
||||
const db = await lancedb.connect("data/sample-lancedb");
|
||||
const data = [
|
||||
{ text: "pepperoni"},
|
||||
{ text: "pineapple"}
|
||||
]
|
||||
|
||||
const table = await db.createTable("vectors", data, embedding)
|
||||
```
|
||||
|
||||
## Querying using an embedding function
|
||||
|
||||
!!! warning
|
||||
At query time, you **must** use the same embedding function you used to vectorize your data.
|
||||
If you use a different embedding function, the embeddings will not reside in the same vector
|
||||
space and the results will be nonsensical.
|
||||
|
||||
=== "Python"
|
||||
```python
|
||||
query = "What's the best pizza topping?"
|
||||
query_vector = embed_func([query])[0]
|
||||
results = (
|
||||
tbl.search(query_vector)
|
||||
.limit(10)
|
||||
.to_pandas()
|
||||
)
|
||||
```
|
||||
|
||||
The above snippet returns a pandas DataFrame with the 10 closest vectors to the query.
|
||||
|
||||
=== "JavaScript"
|
||||
```javascript
|
||||
const results = await table
|
||||
.search("What's the best pizza topping?")
|
||||
.limit(10)
|
||||
.execute()
|
||||
```
|
||||
|
||||
The above snippet returns an array of records with the top 10 nearest neighbors to the query.
|
||||
@@ -3,61 +3,126 @@ Representing multi-modal data as vector embeddings is becoming a standard practi
|
||||
For this purpose, LanceDB introduces an **embedding functions API**, that allow you simply set up once, during the configuration stage of your project. After this, the table remembers it, effectively making the embedding functions *disappear in the background* so you don't have to worry about manually passing callables, and instead, simply focus on the rest of your data engineering pipeline.
|
||||
|
||||
!!! warning
|
||||
Using the implicit embeddings management approach means that you can forget about the manually passing around embedding
|
||||
functions in your code, as long as you don't intend to change it at a later time. If your embedding function changes,
|
||||
you'll have to re-configure your table with the new embedding function and regenerate the embeddings.
|
||||
Using the embedding function registry means that you don't have to explicitly generate the embeddings yourself.
|
||||
However, if your embedding function changes, you'll have to re-configure your table with the new embedding function
|
||||
and regenerate the embeddings. In the future, we plan to support the ability to change the embedding function via
|
||||
table metadata and have LanceDB automatically take care of regenerating the embeddings.
|
||||
|
||||
|
||||
## 1. Define the embedding function
|
||||
We have some pre-defined embedding functions in the global registry, with more coming soon. Here's let's an implementation of CLIP as example.
|
||||
```
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
clip = registry.get("open-clip").create()
|
||||
|
||||
```
|
||||
You can also define your own embedding function by implementing the `EmbeddingFunction` abstract base interface. It subclasses Pydantic Model which can be utilized to write complex schemas simply as we'll see next!
|
||||
=== "Python"
|
||||
In the LanceDB python SDK, we define a global embedding function registry with
|
||||
many different embedding models and even more coming soon.
|
||||
Here's let's an implementation of CLIP as example.
|
||||
|
||||
```python
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
clip = registry.get("open-clip").create()
|
||||
```
|
||||
|
||||
You can also define your own embedding function by implementing the `EmbeddingFunction`
|
||||
abstract base interface. It subclasses Pydantic Model which can be utilized to write complex schemas simply as we'll see next!
|
||||
|
||||
=== "JavaScript""
|
||||
In the TypeScript SDK, the choices are more limited. For now, only the OpenAI
|
||||
embedding function is available.
|
||||
|
||||
```javascript
|
||||
const lancedb = require("vectordb");
|
||||
|
||||
// You need to provide an OpenAI API key
|
||||
const apiKey = "sk-..."
|
||||
// The embedding function will create embeddings for the 'text' column
|
||||
const embedding = new lancedb.OpenAIEmbeddingFunction('text', apiKey)
|
||||
```
|
||||
|
||||
## 2. Define the data model or schema
|
||||
The embedding function defined above abstracts away all the details about the models and dimensions required to define the schema. You can simply set a field as **source** or **vector** column. Here's how:
|
||||
|
||||
```python
|
||||
class Pets(LanceModel):
|
||||
=== "Python"
|
||||
The embedding function defined above abstracts away all the details about the models and dimensions required to define the schema. You can simply set a field as **source** or **vector** column. Here's how:
|
||||
|
||||
```python
|
||||
class Pets(LanceModel):
|
||||
vector: Vector(clip.ndims) = clip.VectorField()
|
||||
image_uri: str = clip.SourceField()
|
||||
```
|
||||
```
|
||||
|
||||
`VectorField` tells LanceDB to use the clip embedding function to generate query embeddings for the `vector` column and `SourceField` ensures that when adding data, we automatically use the specified embedding function to encode `image_uri`.
|
||||
`VectorField` tells LanceDB to use the clip embedding function to generate query embeddings for the `vector` column and `SourceField` ensures that when adding data, we automatically use the specified embedding function to encode `image_uri`.
|
||||
|
||||
## 3. Create LanceDB table
|
||||
Now that we have chosen/defined our embedding function and the schema, we can create the table:
|
||||
=== "JavaScript"
|
||||
|
||||
```python
|
||||
db = lancedb.connect("~/lancedb")
|
||||
table = db.create_table("pets", schema=Pets)
|
||||
For the TypeScript SDK, a schema can be inferred from input data, or an explicit
|
||||
Arrow schema can be provided.
|
||||
|
||||
```
|
||||
## 3. Create table and add data
|
||||
|
||||
That's it! We've provided all the information needed to embed the source and query inputs. We can now forget about the model and dimension details and start to build our VectorDB pipeline.
|
||||
Now that we have chosen/defined our embedding function and the schema,
|
||||
we can create the table and ingest data without needing to explicitly generate
|
||||
the embeddings at all:
|
||||
|
||||
## 4. Ingest lots of data and query your table
|
||||
Any new or incoming data can just be added and it'll be vectorized automatically.
|
||||
=== "Python"
|
||||
```python
|
||||
db = lancedb.connect("~/lancedb")
|
||||
table = db.create_table("pets", schema=Pets)
|
||||
|
||||
```python
|
||||
table.add([{"image_uri": u} for u in uris])
|
||||
```
|
||||
table.add([{"image_uri": u} for u in uris])
|
||||
```
|
||||
|
||||
Our OpenCLIP query embedding function supports querying via both text and images:
|
||||
=== "JavaScript"
|
||||
|
||||
```python
|
||||
result = table.search("dog")
|
||||
```
|
||||
```javascript
|
||||
const db = await lancedb.connect("data/sample-lancedb");
|
||||
const data = [
|
||||
{ text: "pepperoni"},
|
||||
{ text: "pineapple"}
|
||||
]
|
||||
|
||||
Let's query an image:
|
||||
const table = await db.createTable("vectors", data, embedding)
|
||||
```
|
||||
|
||||
```python
|
||||
p = Path("path/to/images/samoyed_100.jpg")
|
||||
query_image = Image.open(p)
|
||||
table.search(query_image)
|
||||
```
|
||||
## 4. Querying your table
|
||||
Not only can you forget about the embeddings during ingestion, you also don't
|
||||
need to worry about it when you query the table:
|
||||
|
||||
=== "Python"
|
||||
|
||||
Our OpenCLIP query embedding function supports querying via both text and images:
|
||||
|
||||
```python
|
||||
results = (
|
||||
table.search("dog")
|
||||
.limit(10)
|
||||
.to_pandas()
|
||||
)
|
||||
```
|
||||
|
||||
Or we can search using an image:
|
||||
|
||||
```python
|
||||
p = Path("path/to/images/samoyed_100.jpg")
|
||||
query_image = Image.open(p)
|
||||
results = (
|
||||
table.search(query_image)
|
||||
.limit(10)
|
||||
.to_pandas()
|
||||
)
|
||||
```
|
||||
|
||||
Both of the above snippet returns a pandas DataFrame with the 10 closest vectors to the query.
|
||||
|
||||
=== "JavaScript"
|
||||
|
||||
```javascript
|
||||
const results = await table
|
||||
.search("What's the best pizza topping?")
|
||||
.limit(10)
|
||||
.execute()
|
||||
```
|
||||
|
||||
The above snippet returns an array of records with the top 10 nearest neighbors to the query.
|
||||
|
||||
---
|
||||
|
||||
@@ -100,4 +165,5 @@ rs[2].image
|
||||
|
||||

|
||||
|
||||
Now that you have the basic idea about implicit management via embedding functions, let's dive deeper into a [custom API](./api.md) that you can use to implement your own embedding functions.
|
||||
Now that you have the basic idea about LanceDB embedding functions and the embedding function registry,
|
||||
let's dive deeper into defining your own [custom functions](./custom_embedding_function.md).
|
||||
@@ -1,8 +1,14 @@
|
||||
Due to the nature of vector embeddings, they can be used to represent any kind of data, from text to images to audio. This makes them a very powerful tool for machine learning practitioners. However, there's no one-size-fits-all solution for generating embeddings - there are many different libraries and APIs (both commercial and open source) that can be used to generate embeddings from structured/unstructured data.
|
||||
Due to the nature of vector embeddings, they can be used to represent any kind of data, from text to images to audio.
|
||||
This makes them a very powerful tool for machine learning practitioners.
|
||||
However, there's no one-size-fits-all solution for generating embeddings - there are many different libraries and APIs
|
||||
(both commercial and open source) that can be used to generate embeddings from structured/unstructured data.
|
||||
|
||||
LanceDB supports 2 methods of vectorizing your raw data into embeddings.
|
||||
LanceDB supports 3 methods of working with embeddings.
|
||||
|
||||
1. **Explicit**: By manually calling LanceDB's `with_embedding` function to vectorize your data via an `embed_func` of your choice
|
||||
2. **Implicit**: Allow LanceDB to embed the data and queries in the background as they come in, by using the table's `EmbeddingRegistry` information
|
||||
1. You can manually generate embeddings for the data and queries. This is done outside of LanceDB.
|
||||
2. You can use the built-in [embedding functions](./embedding_functions.md) to embed the data and queries in the background.
|
||||
3. For python users, you can define your own [custom embedding function](./custom_embedding_function.md)
|
||||
that extends the default embedding functions.
|
||||
|
||||
See the [explicit](embedding_explicit.md) and [implicit](embedding_functions.md) embedding sections for more details.
|
||||
For python users, there is also a legacy [with_embeddings API](./legacy.md).
|
||||
It is retained for compatibility and will be removed in a future version.
|
||||
99
docs/src/embeddings/legacy.md
Normal file
99
docs/src/embeddings/legacy.md
Normal file
@@ -0,0 +1,99 @@
|
||||
The legacy `with_embeddings` API is for Python only and is deprecated.
|
||||
|
||||
### Hugging Face
|
||||
|
||||
The most popular open source option is to use the [sentence-transformers](https://www.sbert.net/)
|
||||
library, which can be installed via pip.
|
||||
|
||||
```bash
|
||||
pip install sentence-transformers
|
||||
```
|
||||
|
||||
The example below shows how to use the `paraphrase-albert-small-v2` model to generate embeddings
|
||||
for a given document.
|
||||
|
||||
```python
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
name="paraphrase-albert-small-v2"
|
||||
model = SentenceTransformer(name)
|
||||
|
||||
# used for both training and querying
|
||||
def embed_func(batch):
|
||||
return [model.encode(sentence) for sentence in batch]
|
||||
```
|
||||
|
||||
|
||||
### OpenAI
|
||||
|
||||
Another popular alternative is to use an external API like OpenAI's [embeddings API](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings).
|
||||
|
||||
```python
|
||||
import openai
|
||||
import os
|
||||
|
||||
# Configuring the environment variable OPENAI_API_KEY
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
# OR set the key here as a variable
|
||||
openai.api_key = "sk-..."
|
||||
|
||||
client = openai.OpenAI()
|
||||
|
||||
def embed_func(c):
|
||||
rs = client.embeddings.create(input=c, model="text-embedding-ada-002")
|
||||
return [record.embedding for record in rs["data"]]
|
||||
```
|
||||
|
||||
|
||||
## Applying an embedding function to data
|
||||
|
||||
Using an embedding function, you can apply it to raw data
|
||||
to generate embeddings for each record.
|
||||
|
||||
Say you have a pandas DataFrame with a `text` column that you want embedded,
|
||||
you can use the `with_embeddings` function to generate embeddings and add them to
|
||||
an existing table.
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
from lancedb.embeddings import with_embeddings
|
||||
|
||||
df = pd.DataFrame(
|
||||
[
|
||||
{"text": "pepperoni"},
|
||||
{"text": "pineapple"}
|
||||
]
|
||||
)
|
||||
data = with_embeddings(embed_func, df)
|
||||
|
||||
# The output is used to create / append to a table
|
||||
tbl = db.create_table("my_table", data=data)
|
||||
```
|
||||
|
||||
If your data is in a different column, you can specify the `column` kwarg to `with_embeddings`.
|
||||
|
||||
By default, LanceDB calls the function with batches of 1000 rows. This can be configured
|
||||
using the `batch_size` parameter to `with_embeddings`.
|
||||
|
||||
LanceDB automatically wraps the function with retry and rate-limit logic to ensure the OpenAI
|
||||
API call is reliable.
|
||||
|
||||
## Querying using an embedding function
|
||||
|
||||
!!! warning
|
||||
At query time, you **must** use the same embedding function you used to vectorize your data.
|
||||
If you use a different embedding function, the embeddings will not reside in the same vector
|
||||
space and the results will be nonsensical.
|
||||
|
||||
=== "Python"
|
||||
```python
|
||||
query = "What's the best pizza topping?"
|
||||
query_vector = embed_func([query])[0]
|
||||
results = (
|
||||
tbl.search(query_vector)
|
||||
.limit(10)
|
||||
.to_pandas()
|
||||
)
|
||||
```
|
||||
|
||||
The above snippet returns a pandas DataFrame with the 10 closest vectors to the query.
|
||||
@@ -1,6 +1,5 @@
|
||||
import pickle
|
||||
import re
|
||||
import sys
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
6
docs/src/extra_js/meta_tag.js
Normal file
6
docs/src/extra_js/meta_tag.js
Normal file
@@ -0,0 +1,6 @@
|
||||
window.addEventListener('load', function() {
|
||||
var meta = document.createElement('meta');
|
||||
meta.setAttribute('property', 'og:image');
|
||||
meta.setAttribute('content', '/assets/lancedb_and_lance.png');
|
||||
document.head.appendChild(meta);
|
||||
});
|
||||
@@ -69,3 +69,19 @@ MinIO supports an S3 compatible API. In order to connect to a MinIO instance, yo
|
||||
- Set the envvar `AWS_ENDPOINT` to the URL of your MinIO API
|
||||
- Set the envvars `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` with your MinIO credential
|
||||
- Call `lancedb.connect("s3://minio_bucket_name")`
|
||||
|
||||
### Where can I find benchmarks for LanceDB?
|
||||
|
||||
Refer to this [post](https://blog.lancedb.com/benchmarking-lancedb-92b01032874a) for recent benchmarks.
|
||||
|
||||
### How much data can LanceDB practically manage without effecting performance?
|
||||
|
||||
We target good performance on ~10-50 billion rows and ~10-30 TB of data.
|
||||
|
||||
### Does LanceDB support concurrent operations?
|
||||
|
||||
LanceDB can handle concurrent reads very well, and can scale horizontally. The main constraint is how well the [storage layer](https://lancedb.github.io/lancedb/concepts/storage/) you've chosen scales. For writes, we support concurrent writing, though too many concurrent writers can lead to failing writes as there is a limited number of times a writer retries a commit
|
||||
|
||||
!!! info "Multiprocessing with LanceDB"
|
||||
|
||||
For multiprocessing you should probably not use ```fork``` as lance is multi-threaded internally and ```fork``` and multi-thread do not work well.[Refer to this discussion](https://discuss.python.org/t/concerns-regarding-deprecation-of-fork-with-alive-threads/33555)
|
||||
|
||||
@@ -16,9 +16,22 @@ This guide will show how to create tables, insert data into them, and update the
|
||||
db = lancedb.connect("./.lancedb")
|
||||
```
|
||||
|
||||
=== "Javascript"
|
||||
|
||||
Initialize a VectorDB connection and create a table using one of the many methods listed below.
|
||||
|
||||
```javascript
|
||||
const lancedb = require("vectordb");
|
||||
|
||||
const uri = "data/sample-lancedb";
|
||||
const db = await lancedb.connect(uri);
|
||||
```
|
||||
|
||||
LanceDB allows ingesting data from various sources - `dict`, `list[dict]`, `pd.DataFrame`, `pa.Table` or a `Iterator[pa.RecordBatch]`. Let's take a look at some of the these.
|
||||
|
||||
### From list of tuples or dictionaries
|
||||
### From list of tuples or dictionaries
|
||||
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
import lancedb
|
||||
@@ -32,7 +45,6 @@ This guide will show how to create tables, insert data into them, and update the
|
||||
|
||||
db["my_table"].head()
|
||||
```
|
||||
|
||||
!!! info "Note"
|
||||
If the table already exists, LanceDB will raise an error by default.
|
||||
|
||||
@@ -51,6 +63,27 @@ This guide will show how to create tables, insert data into them, and update the
|
||||
db.create_table("name", data, mode="overwrite")
|
||||
```
|
||||
|
||||
=== "Javascript"
|
||||
You can create a LanceDB table in JavaScript using an array of JSON records as follows.
|
||||
|
||||
```javascript
|
||||
const tb = await db.createTable("my_table", [{
|
||||
"vector": [3.1, 4.1],
|
||||
"item": "foo",
|
||||
"price": 10.0
|
||||
}, {
|
||||
"vector": [5.9, 26.5],
|
||||
"item": "bar",
|
||||
"price": 20.0
|
||||
}]);
|
||||
```
|
||||
!!! info "Note"
|
||||
If the table already exists, LanceDB will raise an error by default. If you want to overwrite the table, you need to specify the `WriteMode` in the createTable function.
|
||||
|
||||
```javascript
|
||||
const table = await con.createTable(tableName, data, { writeMode: WriteMode.Overwrite })
|
||||
```
|
||||
|
||||
### From a Pandas DataFrame
|
||||
|
||||
```python
|
||||
@@ -69,6 +102,8 @@ This guide will show how to create tables, insert data into them, and update the
|
||||
!!! info "Note"
|
||||
Data is converted to Arrow before being written to disk. For maximum control over how data is saved, either provide the PyArrow schema to convert to or else provide a PyArrow Table directly.
|
||||
|
||||
The **`vector`** column needs to be a [Vector](../python/pydantic.md#vector-field) (defined as [pyarrow.FixedSizeList](https://arrow.apache.org/docs/python/generated/pyarrow.list_.html)) type.
|
||||
|
||||
```python
|
||||
custom_schema = pa.schema([
|
||||
pa.field("vector", pa.list_(pa.float32(), 4)),
|
||||
@@ -79,7 +114,7 @@ This guide will show how to create tables, insert data into them, and update the
|
||||
table = db.create_table("my_table", data, schema=custom_schema)
|
||||
```
|
||||
|
||||
### From a Polars DataFrame
|
||||
### From a Polars DataFrame
|
||||
|
||||
LanceDB supports [Polars](https://pola.rs/), a modern, fast DataFrame library
|
||||
written in Rust. Just like in Pandas, the Polars integration is enabled by PyArrow
|
||||
@@ -97,26 +132,44 @@ This guide will show how to create tables, insert data into them, and update the
|
||||
table = db.create_table("pl_table", data=data)
|
||||
```
|
||||
|
||||
### From PyArrow Tables
|
||||
You can also create LanceDB tables directly from PyArrow tables
|
||||
### From an Arrow Table
|
||||
=== "Python"
|
||||
You can also create LanceDB tables directly from Arrow tables.
|
||||
LanceDB supports float16 data type!
|
||||
|
||||
```python
|
||||
table = pa.Table.from_arrays(
|
||||
import pyarrows as pa
|
||||
import numpy as np
|
||||
|
||||
dim = 16
|
||||
total = 2
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.array([[3.1, 4.1, 5.1, 6.1], [5.9, 26.5, 4.7, 32.8]],
|
||||
pa.list_(pa.float32(), 4)),
|
||||
pa.array(["foo", "bar"]),
|
||||
pa.array([10.0, 20.0]),
|
||||
],
|
||||
["vector", "item", "price"],
|
||||
pa.field("vector", pa.list_(pa.float16(), dim)),
|
||||
pa.field("text", pa.string())
|
||||
]
|
||||
)
|
||||
data = pa.Table.from_arrays(
|
||||
[
|
||||
pa.array([np.random.randn(dim).astype(np.float16) for _ in range(total)],
|
||||
pa.list_(pa.float16(), dim)),
|
||||
pa.array(["foo", "bar"])
|
||||
],
|
||||
["vector", "text"],
|
||||
)
|
||||
tbl = db.create_table("f16_tbl", data, schema=schema)
|
||||
```
|
||||
|
||||
db = lancedb.connect("db")
|
||||
=== "Javascript"
|
||||
You can also create LanceDB tables directly from Arrow tables.
|
||||
LanceDB supports Float16 data type!
|
||||
|
||||
tbl = db.create_table("my_table", table)
|
||||
```javascript
|
||||
--8<-- "docs/src/basic_legacy.ts:create_f16_table"
|
||||
```
|
||||
|
||||
### From Pydantic Models
|
||||
|
||||
When you create an empty table without data, you must specify the table schema.
|
||||
LanceDB supports creating tables by specifying a PyArrow schema or a specialized
|
||||
Pydantic model called `LanceModel`.
|
||||
@@ -261,37 +314,6 @@ This guide will show how to create tables, insert data into them, and update the
|
||||
|
||||
You can also use iterators of other types like Pandas DataFrame or Pylists directly in the above example.
|
||||
|
||||
=== "JavaScript"
|
||||
Initialize a VectorDB connection and create a table using one of the many methods listed below.
|
||||
|
||||
```javascript
|
||||
const lancedb = require("vectordb");
|
||||
|
||||
const uri = "data/sample-lancedb";
|
||||
const db = await lancedb.connect(uri);
|
||||
```
|
||||
|
||||
You can create a LanceDB table in JavaScript using an array of JSON records as follows.
|
||||
|
||||
```javascript
|
||||
const tb = await db.createTable("my_table", [{
|
||||
"vector": [3.1, 4.1],
|
||||
"item": "foo",
|
||||
"price": 10.0
|
||||
}, {
|
||||
"vector": [5.9, 26.5],
|
||||
"item": "bar",
|
||||
"price": 20.0
|
||||
}]);
|
||||
```
|
||||
|
||||
!!! info "Note"
|
||||
If the table already exists, LanceDB will raise an error by default. If you want to overwrite the table, you need to specify the `WriteMode` in the createTable function.
|
||||
|
||||
```javascript
|
||||
const table = await con.createTable(tableName, data, { writeMode: WriteMode.Overwrite })
|
||||
```
|
||||
|
||||
## Open existing tables
|
||||
|
||||
=== "Python"
|
||||
|
||||
49
docs/src/hybrid_search/eval.md
Normal file
49
docs/src/hybrid_search/eval.md
Normal file
@@ -0,0 +1,49 @@
|
||||
# Hybrid Search
|
||||
|
||||
Hybrid Search is a broad (often misused) term. It can mean anything from combining multiple methods for searching, to applying ranking methods to better sort the results. In this blog, we use the definition of "hybrid search" to mean using a combination of keyword-based and vector search.
|
||||
|
||||
## The challenge of (re)ranking search results
|
||||
Once you have a group of the most relevant search results from multiple search sources, you'd likely standardize the score and rank them accordingly. This process can also be seen as another independent step - reranking.
|
||||
There are two approaches for reranking search results from multiple sources.
|
||||
* <b>Score-based</b>: Calculate final relevance scores based on a weighted linear combination of individual search algorithm scores. Example - Weighted linear combination of semantic search & keyword-based search results.
|
||||
* <b>Relevance-based</b>: Discards the existing scores and calculates the relevance of each search result - query pair. Example - Cross Encoder models
|
||||
|
||||
Even though there are many strategies for reranking search results, none works for all cases. Moreover, evaluating them itself is a challenge. Also, reranking can be dataset, application specific so it's hard to generalize.
|
||||
|
||||
### Example evaluation of hybrid search with Reranking
|
||||
|
||||
Here's some evaluation numbers from experiment comparing these re-rankers on about 800 queries. It is modified version of an evaluation script from [llama-index](https://github.com/run-llama/finetune-embedding/blob/main/evaluate.ipynb) that measures hit-rate at top-k.
|
||||
|
||||
<b> With OpenAI ada2 embedding </b>
|
||||
|
||||
Vector Search baseline - `0.64`
|
||||
|
||||
| Reranker | Top-3 | Top-5 | Top-10 |
|
||||
| --- | --- | --- | --- |
|
||||
| Linear Combination | `0.73` | `0.74` | `0.85` |
|
||||
| Cross Encoder | `0.71` | `0.70` | `0.77` |
|
||||
| Cohere | `0.81` | `0.81` | `0.85` |
|
||||
| ColBERT | `0.68` | `0.68` | `0.73` |
|
||||
|
||||
<p>
|
||||
<img src="https://github.com/AyushExel/assets/assets/15766192/d57b1780-ef27-414c-a5c3-73bee7808a45">
|
||||
</p>
|
||||
|
||||
<b> With OpenAI embedding-v3-small </b>
|
||||
|
||||
Vector Search baseline - `0.59`
|
||||
|
||||
| Reranker | Top-3 | Top-5 | Top-10 |
|
||||
| --- | --- | --- | --- |
|
||||
| Linear Combination | `0.68` | `0.70` | `0.84` |
|
||||
| Cross Encoder | `0.72` | `0.72` | `0.79` |
|
||||
| Cohere | `0.79` | `0.79` | `0.84` |
|
||||
| ColBERT | `0.70` | `0.70` | `0.76` |
|
||||
|
||||
<p>
|
||||
<img src="https://github.com/AyushExel/assets/assets/15766192/259adfd2-6ec6-4df6-a77d-1456598970dd">
|
||||
</p>
|
||||
|
||||
### Conclusion
|
||||
|
||||
The results show that the reranking methods are able to improve the search results. However, the improvement is not consistent across all rerankers. The choice of reranker depends on the dataset and the application. It is also important to note that the reranking methods are not a replacement for the search methods. They are complementary and should be used together to get the best results. The speed to recall tradeoff is also an important factor to consider when choosing the reranker.
|
||||
@@ -1,22 +1,29 @@
|
||||
# Hybrid Search
|
||||
|
||||
LanceDB supports both semantic and keyword-based search. In real world applications, it is often useful to combine these two approaches to get the best best results. For example, you may want to search for a document that is semantically similar to a query document, but also contains a specific keyword. This is an example of *hybrid search*, a search algorithm that combines multiple search techniques.
|
||||
LanceDB supports both semantic and keyword-based search (also termed full-text search, or FTS). In real world applications, it is often useful to combine these two approaches to get the best best results. For example, you may want to search for a document that is semantically similar to a query document, but also contains a specific keyword. This is an example of *hybrid search*, a search algorithm that combines multiple search techniques.
|
||||
|
||||
## Hybrid search in LanceDB
|
||||
You can perform hybrid search in LanceDB by combining the results of semantic and full-text search via a reranking algorithm of your choice. LanceDB provides multiple rerankers out of the box. However, you can always write a custom reranker if your use case need more sophisticated logic .
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
import lancedb
|
||||
import openai
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydanatic import LanceModel, Vector
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
|
||||
# Ingest embedding function in LanceDB table
|
||||
# Configuring the environment variable OPENAI_API_KEY
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
# OR set the key here as a variable
|
||||
openai.api_key = "sk-..."
|
||||
embeddings = get_registry().get("openai").create()
|
||||
|
||||
class Documents(LanceModel):
|
||||
vector: Vector(embeddings.ndims) = embeddings.VectorField()
|
||||
vector: Vector(embeddings.ndims()) = embeddings.VectorField()
|
||||
text: str = embeddings.SourceField()
|
||||
|
||||
table = db.create_table("documents", schema=Documents)
|
||||
@@ -31,17 +38,19 @@ data = [
|
||||
# ingest docs with auto-vectorization
|
||||
table.add(data)
|
||||
|
||||
# Create a fts index before the hybrid search
|
||||
table.create_fts_index("text")
|
||||
# hybrid search with default re-ranker
|
||||
results = table.search("flower moon", query_type="hybrid").to_pandas()
|
||||
```
|
||||
|
||||
By default, LanceDB uses `LinearCombinationReranker(weights=0.7)` to combine and rerank the results of semantic and full-text search. You can customize the hyperparameters as needed or write your own custom reranker. Here's how you can use any of the available rerankers:
|
||||
By default, LanceDB uses `LinearCombinationReranker(weight=0.7)` to combine and rerank the results of semantic and full-text search. You can customize the hyperparameters as needed or write your own custom reranker. Here's how you can use any of the available rerankers:
|
||||
|
||||
|
||||
### `rerank()` arguments
|
||||
* `normalize`: `str`, default `"score"`:
|
||||
The method to normalize the scores. Can be "rank" or "score". If "rank", the scores are converted to ranks and then normalized. If "score", the scores are normalized directly.
|
||||
* `reranker`: `Reranker`, default `LinearCombinationReranker(weights=0.7)`.
|
||||
* `reranker`: `Reranker`, default `LinearCombinationReranker(weight=0.7)`.
|
||||
The reranker to use. If not specified, the default reranker is used.
|
||||
|
||||
|
||||
@@ -55,12 +64,12 @@ This is the default re-ranker used by LanceDB. It combines the results of semant
|
||||
```python
|
||||
from lancedb.rerankers import LinearCombinationReranker
|
||||
|
||||
reranker = LinearCombinationReranker(weights=0.3) # Use 0.3 as the weight for vector search
|
||||
reranker = LinearCombinationReranker(weight=0.3) # Use 0.3 as the weight for vector search
|
||||
|
||||
results = table.search("rebel", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||
```
|
||||
|
||||
Arguments
|
||||
### Arguments
|
||||
----------------
|
||||
* `weight`: `float`, default `0.7`:
|
||||
The weight to use for the semantic search score. The weight for the full-text search score is `1 - weights`.
|
||||
@@ -82,9 +91,9 @@ reranker = CohereReranker()
|
||||
results = table.search("vampire weekend", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||
```
|
||||
|
||||
Arguments
|
||||
### Arguments
|
||||
----------------
|
||||
* `model_name`` : str, default `"rerank-english-v2.0"``
|
||||
* `model_name` : str, default `"rerank-english-v2.0"`
|
||||
The name of the cross encoder model to use. Available cohere models are:
|
||||
- rerank-english-v2.0
|
||||
- rerank-multilingual-v2.0
|
||||
@@ -108,7 +117,7 @@ results = table.search("harmony hall", query_type="hybrid").rerank(reranker=rera
|
||||
```
|
||||
|
||||
|
||||
Arguments
|
||||
### Arguments
|
||||
----------------
|
||||
* `model` : str, default `"cross-encoder/ms-marco-TinyBERT-L-6"`
|
||||
The name of the cross encoder model to use. Available cross encoder models can be found [here](https://www.sbert.net/docs/pretrained_cross-encoders.html)
|
||||
@@ -121,6 +130,61 @@ Arguments
|
||||
Only returns `_relevance_score`. Does not support `return_score = "all"`.
|
||||
|
||||
|
||||
### ColBERT Reranker
|
||||
This reranker uses the ColBERT model to combine the results of semantic and full-text search. You can use it by passing `ColbertrReranker()` to the `rerank()` method.
|
||||
|
||||
ColBERT reranker model calculates relevance of given docs against the query and don't take existing fts and vector search scores into account, so it currently only supports `return_score="relevance"`. By default, it looks for `text` column to rerank the results. But you can specify the column name to use as input to the cross encoder model as described below.
|
||||
|
||||
```python
|
||||
from lancedb.rerankers import ColbertReranker
|
||||
|
||||
reranker = ColbertReranker()
|
||||
|
||||
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||
```
|
||||
|
||||
### Arguments
|
||||
----------------
|
||||
* `model_name` : `str`, default `"colbert-ir/colbertv2.0"`
|
||||
The name of the cross encoder model to use.
|
||||
* `column` : `str`, default `"text"`
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
* `return_score` : `str`, default `"relevance"`
|
||||
options are `"relevance"` or `"all"`. Only `"relevance"` is supported for now.
|
||||
|
||||
!!! Note
|
||||
Only returns `_relevance_score`. Does not support `return_score = "all"`.
|
||||
|
||||
### OpenAI Reranker
|
||||
This reranker uses the OpenAI API to combine the results of semantic and full-text search. You can use it by passing `OpenaiReranker()` to the `rerank()` method.
|
||||
|
||||
!!! Note
|
||||
This prompts chat model to rerank results which is not a dedicated reranker model. This should be treated as experimental.
|
||||
|
||||
!!! Tip
|
||||
- You might run out of token limit so set the search `limits` based on your token limit.
|
||||
- It is recommended to use gpt-4-turbo-preview, the default model, older models might lead to undesired behaviour
|
||||
|
||||
```python
|
||||
from lancedb.rerankers import OpenaiReranker
|
||||
|
||||
reranker = OpenaiReranker()
|
||||
|
||||
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||
```
|
||||
|
||||
### Arguments
|
||||
----------------
|
||||
* `model_name` : `str`, default `"gpt-4-turbo-preview"`
|
||||
The name of the cross encoder model to use.
|
||||
* `column` : `str`, default `"text"`
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
* `return_score` : `str`, default `"relevance"`
|
||||
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||
* `api_key` : `str`, default `None`
|
||||
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
|
||||
|
||||
|
||||
## Building Custom Rerankers
|
||||
You can build your own custom reranker by subclassing the `Reranker` class and implementing the `rerank_hybrid()` method. Here's an example of a custom reranker that combines the results of semantic and full-text search using a linear combination of the scores.
|
||||
|
||||
@@ -137,7 +201,7 @@ class MyReranker(Reranker):
|
||||
self.param1 = param1
|
||||
self.param2 = param2
|
||||
|
||||
def rerank_hybrid(self, vector_results: pa.Table, fts_results: pa.Table):
|
||||
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table):
|
||||
# Use the built-in merging function
|
||||
combined_result = self.merge_results(vector_results, fts_results)
|
||||
|
||||
@@ -149,24 +213,30 @@ class MyReranker(Reranker):
|
||||
|
||||
```
|
||||
|
||||
You can also accept additional arguments like a filter along with fts and vector search results
|
||||
### Example of a Custom Reranker
|
||||
For the sake of simplicity let's build custom reranker that just enchances the Cohere Reranker by accepting a filter query, and accept other CohereReranker params as kwags.
|
||||
|
||||
```python
|
||||
|
||||
from lancedb.rerankers import Reranker
|
||||
import pyarrow as pa
|
||||
from typing import List, Union
|
||||
import pandas as pd
|
||||
from lancedb.rerankers import CohereReranker
|
||||
|
||||
class MyReranker(Reranker):
|
||||
...
|
||||
class MofidifiedCohereReranker(CohereReranker):
|
||||
def __init__(self, filters: Union[str, List[str]], **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
filters = filters if isinstance(filters, list) else [filters]
|
||||
self.filters = filters
|
||||
|
||||
def rerank_hybrid(self, vector_results: pa.Table, fts_results: pa.Table, filter: str):
|
||||
# Use the built-in merging function
|
||||
combined_result = self.merge_results(vector_results, fts_results)
|
||||
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table)-> pa.Table:
|
||||
combined_result = super().rerank_hybrid(query, vector_results, fts_results)
|
||||
df = combined_result.to_pandas()
|
||||
for filter in self.filters:
|
||||
df = df.query("not text.str.contains(@filter)")
|
||||
|
||||
# Do something with the combined results & filter
|
||||
# ...
|
||||
|
||||
# Return the combined results
|
||||
return combined_result
|
||||
return pa.Table.from_pandas(df)
|
||||
|
||||
```
|
||||
|
||||
!!! tip
|
||||
The `vector_results` and `fts_results` are pyarrow tables. You can convert them to pandas dataframes using `to_pandas()` method and perform any operations you want. After you are done, you can convert the dataframe back to pyarrow table using `pa.Table.from_pandas()` method and return it.
|
||||
@@ -290,7 +290,7 @@
|
||||
"from lancedb.pydantic import LanceModel, Vector\n",
|
||||
"\n",
|
||||
"class Pets(LanceModel):\n",
|
||||
" vector: Vector(clip.ndims) = clip.VectorField()\n",
|
||||
" vector: Vector(clip.ndims()) = clip.VectorField()\n",
|
||||
" image_uri: str = clip.SourceField()\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
@@ -360,7 +360,7 @@
|
||||
" table = db.create_table(\"pets\", schema=Pets)\n",
|
||||
" # use a sampling of 1000 images\n",
|
||||
" p = Path(\"~/Downloads/images\").expanduser()\n",
|
||||
" uris = [str(f) for f in p.iterdir()]\n",
|
||||
" uris = [str(f) for f in p.glob(\"*.jpg\")]\n",
|
||||
" uris = sample(uris, 1000)\n",
|
||||
" table.add(pd.DataFrame({\"image_uri\": uris}))"
|
||||
]
|
||||
@@ -543,7 +543,7 @@
|
||||
],
|
||||
"source": [
|
||||
"from PIL import Image\n",
|
||||
"p = Path(\"/Users/changshe/Downloads/images/samoyed_100.jpg\")\n",
|
||||
"p = Path(\"~/Downloads/images/samoyed_100.jpg\").expanduser()\n",
|
||||
"query_image = Image.open(p)\n",
|
||||
"query_image"
|
||||
]
|
||||
|
||||
@@ -23,10 +23,8 @@ from multiprocessing import Pool
|
||||
import lance
|
||||
import pyarrow as pa
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast
|
||||
|
||||
import lancedb
|
||||
|
||||
MODEL_ID = "openai/clip-vit-base-patch32"
|
||||
|
||||
|
||||
1122
docs/src/notebooks/hybrid_search.ipynb
Normal file
1122
docs/src/notebooks/hybrid_search.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
@@ -13,7 +13,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 50,
|
||||
"execution_count": 2,
|
||||
"id": "c1b4e34b-a49c-471d-a343-a5940bb5138a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -23,7 +23,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 3,
|
||||
"id": "4e5a8d07-d9a1-48c1-913a-8e0629289579",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -44,7 +44,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 4,
|
||||
"id": "5df12f66-8d99-43ad-8d0b-22189ec0a6b9",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -62,7 +62,7 @@
|
||||
"long: [[-122.7,-74.1]]"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -90,7 +90,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 5,
|
||||
"id": "f4d87ae9-0ccb-48eb-b31d-bb8f2370e47e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -108,7 +108,7 @@
|
||||
"long: [[-122.7,-74.1]]"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -135,10 +135,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 6,
|
||||
"id": "25f34bcf-fca0-4431-8601-eac95d1bd347",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[2024-01-31T18:59:33Z WARN lance::dataset] No existing dataset at /Users/qian/Work/LanceDB/lancedb/docs/src/notebooks/.lancedb/table3.lance, it will be created\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
@@ -148,7 +155,7 @@
|
||||
"long: float"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -171,45 +178,51 @@
|
||||
"id": "4df51925-7ca2-4005-9c72-38b3d26240c6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### From PyArrow Tables\n",
|
||||
"### From an Arrow Table\n",
|
||||
"\n",
|
||||
"You can also create LanceDB tables directly from pyarrow tables"
|
||||
"You can also create LanceDB tables directly from pyarrow tables. LanceDB supports float16 type."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 7,
|
||||
"id": "90a880f6-be43-4c9d-ba65-0b05197c0f6f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"vector: fixed_size_list<item: float>[2]\n",
|
||||
" child 0, item: float\n",
|
||||
"item: string\n",
|
||||
"price: double"
|
||||
"vector: fixed_size_list<item: halffloat>[16]\n",
|
||||
" child 0, item: halffloat\n",
|
||||
"text: string"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"table = pa.Table.from_arrays(\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"dim = 16\n",
|
||||
"total = 2\n",
|
||||
"schema = pa.schema(\n",
|
||||
" [\n",
|
||||
" pa.array([[3.1, 4.1], [5.9, 26.5]],\n",
|
||||
" pa.list_(pa.float32(), 2)),\n",
|
||||
" pa.array([\"foo\", \"bar\"]),\n",
|
||||
" pa.array([10.0, 20.0]),\n",
|
||||
" pa.field(\"vector\", pa.list_(pa.float16(), dim)),\n",
|
||||
" pa.field(\"text\", pa.string())\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"data = pa.Table.from_arrays(\n",
|
||||
" [\n",
|
||||
" pa.array([np.random.randn(dim).astype(np.float16) for _ in range(total)],\n",
|
||||
" pa.list_(pa.float16(), dim)),\n",
|
||||
" pa.array([\"foo\", \"bar\"])\n",
|
||||
" ],\n",
|
||||
" [\"vector\", \"item\", \"price\"],\n",
|
||||
" )\n",
|
||||
" [\"vector\", \"text\"],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"db = lancedb.connect(\"db\")\n",
|
||||
"\n",
|
||||
"tbl = db.create_table(\"test1\", table, mode=\"overwrite\")\n",
|
||||
"tbl = db.create_table(\"f16_tbl\", data, schema=schema)\n",
|
||||
"tbl.schema"
|
||||
]
|
||||
},
|
||||
@@ -225,7 +238,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 8,
|
||||
"id": "d81121d7-e4b7-447c-a48c-974b6ebb464a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -240,7 +253,7 @@
|
||||
"imdb_id: int64 not null"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -282,7 +295,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 9,
|
||||
"id": "bc247142-4e3c-41a2-b94c-8e00d2c2a508",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -292,7 +305,7 @@
|
||||
"LanceTable(table4)"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -333,7 +346,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 10,
|
||||
"id": "25ad3523-e0c9-4c28-b3df-38189c4e0e5f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -346,7 +359,7 @@
|
||||
"price: double not null"
|
||||
]
|
||||
},
|
||||
"execution_count": 16,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -385,7 +398,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 11,
|
||||
"id": "2814173a-eacc-4dd8-a64d-6312b44582cc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -411,7 +424,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": 12,
|
||||
"id": "df9e13c0-41f6-437f-9dfa-2fd71d3d9c45",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -421,7 +434,7 @@
|
||||
"['table6', 'table4', 'table5', 'movielens_small']"
|
||||
]
|
||||
},
|
||||
"execution_count": 18,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -432,7 +445,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 13,
|
||||
"id": "9343f5ad-6024-42ee-ac2f-6c1471df8679",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -541,7 +554,7 @@
|
||||
"9 [5.9, 26.5] bar 20.0"
|
||||
]
|
||||
},
|
||||
"execution_count": 20,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -564,7 +577,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 14,
|
||||
"id": "8a56250f-73a1-4c26-a6ad-5c7a0ce3a9ab",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -590,7 +603,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 15,
|
||||
"id": "030c7057-b98e-4e2f-be14-b8c1f927f83c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -621,7 +634,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 16,
|
||||
"id": "e7a17de2-08d2-41b7-bd05-f63d1045ab1f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -629,16 +642,16 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"32\n"
|
||||
"22\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"17"
|
||||
"12"
|
||||
]
|
||||
},
|
||||
"execution_count": 24,
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -661,7 +674,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"execution_count": 17,
|
||||
"id": "fe3310bd-08f4-4a22-a63b-b3127d22f9f7",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -681,25 +694,20 @@
|
||||
"8 [3.1, 4.1] foo 10.0\n",
|
||||
"9 [3.1, 4.1] foo 10.0\n",
|
||||
"10 [3.1, 4.1] foo 10.0\n",
|
||||
"11 [3.1, 4.1] foo 10.0\n",
|
||||
"12 [3.1, 4.1] foo 10.0\n",
|
||||
"13 [3.1, 4.1] foo 10.0\n",
|
||||
"14 [3.1, 4.1] foo 10.0\n",
|
||||
"15 [3.1, 4.1] foo 10.0\n",
|
||||
"16 [3.1, 4.1] foo 10.0\n"
|
||||
"11 [3.1, 4.1] foo 10.0\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "OSError",
|
||||
"evalue": "LanceError(IO): Error during planning: column foo does not exist",
|
||||
"evalue": "LanceError(IO): Error during planning: column foo does not exist, /Users/runner/work/lance/lance/rust/lance-core/src/error.rs:212:23",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[30], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m to_remove \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;28mstr\u001b[39m(v) \u001b[38;5;28;01mfor\u001b[39;00m v \u001b[38;5;129;01min\u001b[39;00m to_remove)\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(tbl\u001b[38;5;241m.\u001b[39mto_pandas())\n\u001b[0;32m----> 4\u001b[0m \u001b[43mtbl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdelete\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mitem IN (\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mto_remove\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m)\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m tbl\u001b[38;5;241m.\u001b[39mto_pandas()\n",
|
||||
"File \u001b[0;32m~/Documents/lancedb/lancedb/python/lancedb/table.py:610\u001b[0m, in \u001b[0;36mLanceTable.delete\u001b[0;34m(self, where)\u001b[0m\n\u001b[1;32m 609\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdelete\u001b[39m(\u001b[38;5;28mself\u001b[39m, where: \u001b[38;5;28mstr\u001b[39m):\n\u001b[0;32m--> 610\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdelete\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwhere\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Documents/lancedb/lancedb/env/lib/python3.11/site-packages/lance/dataset.py:489\u001b[0m, in \u001b[0;36mLanceDataset.delete\u001b[0;34m(self, predicate)\u001b[0m\n\u001b[1;32m 487\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(predicate, pa\u001b[38;5;241m.\u001b[39mcompute\u001b[38;5;241m.\u001b[39mExpression):\n\u001b[1;32m 488\u001b[0m predicate \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mstr\u001b[39m(predicate)\n\u001b[0;32m--> 489\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_ds\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdelete\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpredicate\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"\u001b[0;31mOSError\u001b[0m: LanceError(IO): Error during planning: column foo does not exist"
|
||||
"Cell \u001b[0;32mIn[17], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m to_remove \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;28mstr\u001b[39m(v) \u001b[38;5;28;01mfor\u001b[39;00m v \u001b[38;5;129;01min\u001b[39;00m to_remove)\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(tbl\u001b[38;5;241m.\u001b[39mto_pandas())\n\u001b[0;32m----> 4\u001b[0m \u001b[43mtbl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdelete\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mitem IN (\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mto_remove\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m)\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Work/LanceDB/lancedb/docs/doc-venv/lib/python3.11/site-packages/lancedb/table.py:872\u001b[0m, in \u001b[0;36mLanceTable.delete\u001b[0;34m(self, where)\u001b[0m\n\u001b[1;32m 871\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdelete\u001b[39m(\u001b[38;5;28mself\u001b[39m, where: \u001b[38;5;28mstr\u001b[39m):\n\u001b[0;32m--> 872\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdelete\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwhere\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Work/LanceDB/lancedb/docs/doc-venv/lib/python3.11/site-packages/lance/dataset.py:596\u001b[0m, in \u001b[0;36mLanceDataset.delete\u001b[0;34m(self, predicate)\u001b[0m\n\u001b[1;32m 594\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(predicate, pa\u001b[38;5;241m.\u001b[39mcompute\u001b[38;5;241m.\u001b[39mExpression):\n\u001b[1;32m 595\u001b[0m predicate \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mstr\u001b[39m(predicate)\n\u001b[0;32m--> 596\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_ds\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdelete\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpredicate\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"\u001b[0;31mOSError\u001b[0m: LanceError(IO): Error during planning: column foo does not exist, /Users/runner/work/lance/lance/rust/lance-core/src/error.rs:212:23"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -712,7 +720,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
"execution_count": null,
|
||||
"id": "87d5bc21-847f-4c81-b56e-f6dbe5d05aac",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -729,7 +737,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"execution_count": null,
|
||||
"id": "9cba4519-eb3a-4941-ab7e-873d762e750f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -742,7 +750,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 46,
|
||||
"execution_count": null,
|
||||
"id": "5bdc9801-d5ed-4871-92d0-88b27108e788",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -817,7 +825,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
"version": "3.11.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
# DuckDB
|
||||
|
||||
LanceDB is very well-integrated with [DuckDB](https://duckdb.org/), an in-process SQL OLAP database. This integration is done via [Arrow](https://duckdb.org/docs/guides/python/sql_on_arrow) .
|
||||
In Python, LanceDB tables can also be queried with [DuckDB](https://duckdb.org/), an in-process SQL OLAP database. This means you can write complex SQL queries to analyze your data in LanceDB.
|
||||
|
||||
This integration is done via [Apache Arrow](https://duckdb.org/docs/guides/python/sql_on_arrow), which provides zero-copy data sharing between LanceDB and DuckDB. DuckDB is capable of passing down column selections and basic filters to LanceDB, reducing the amount of data that needs to be scanned to perform your query. Finally, the integration allows streaming data from LanceDB tables, allowing you to aggregate tables that won't fit into memory. All of this uses the same mechanism described in DuckDB's blog post *[DuckDB quacks Arrow](https://duckdb.org/2021/12/03/duck-arrow.html)*.
|
||||
|
||||
|
||||
We can demonstrate this by first installing `duckdb` and `lancedb`.
|
||||
|
||||
@@ -19,14 +22,15 @@ data = [
|
||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}
|
||||
]
|
||||
table = db.create_table("pd_table", data=data)
|
||||
arrow_table = table.to_arrow()
|
||||
```
|
||||
|
||||
DuckDB can directly query the `pyarrow.Table` object:
|
||||
To query the table, first call `to_lance` to convert the table to a "dataset", which is an object that can be queried by DuckDB. Then all you need to do is reference that dataset by the same name in your SQL query.
|
||||
|
||||
```python
|
||||
import duckdb
|
||||
|
||||
arrow_table = table.to_lance()
|
||||
|
||||
duckdb.query("SELECT * FROM arrow_table")
|
||||
```
|
||||
|
||||
|
||||
@@ -58,6 +58,8 @@ pip install lancedb
|
||||
|
||||
::: lancedb.schema.vector
|
||||
|
||||
::: lancedb.merge.LanceMergeInsertBuilder
|
||||
|
||||
## Integrations
|
||||
|
||||
### Pydantic
|
||||
|
||||
@@ -14,7 +14,7 @@ excluded_globs = [
|
||||
"../src/concepts/*.md",
|
||||
"../src/ann_indexes.md",
|
||||
"../src/basic.md",
|
||||
"../src/hybrid_search.md",
|
||||
"../src/hybrid_search/hybrid_search.md",
|
||||
]
|
||||
|
||||
python_prefix = "py"
|
||||
|
||||
44
node/package-lock.json
generated
44
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.4.6",
|
||||
"version": "0.4.10",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "vectordb",
|
||||
"version": "0.4.6",
|
||||
"version": "0.4.10",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
@@ -53,11 +53,11 @@
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.4.6",
|
||||
"@lancedb/vectordb-darwin-x64": "0.4.6",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.6",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.6",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.6"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.4.10",
|
||||
"@lancedb/vectordb-darwin-x64": "0.4.10",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.10",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.10",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.10"
|
||||
}
|
||||
},
|
||||
"node_modules/@75lb/deep-merge": {
|
||||
@@ -329,9 +329,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||
"version": "0.4.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.4.6.tgz",
|
||||
"integrity": "sha512-p6w/BXBxgFHR87phxvfBPPbvz4wDGmG2guRSQPEriwrc8h/gQ3wuexHhyzi7SWcV2E25vyUO9QcFL3vYKhIJRg==",
|
||||
"version": "0.4.10",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.4.10.tgz",
|
||||
"integrity": "sha512-y/uHOGb0g15pvqv5tdTyZ6oN+0QVpBmZDzKFWW6pPbuSZjB2uPqcs+ti0RB+AUdmS21kavVQqaNsw/HLKEGrHA==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -341,9 +341,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||
"version": "0.4.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.4.6.tgz",
|
||||
"integrity": "sha512-7Fmg63Ky783ROpaQEL6I1uTrO//YDi4MgG0pjWAkDKsdHQ8QisFF8kd+JvjPh4PhMScC/rtB0SXpY/Y4zZvLfw==",
|
||||
"version": "0.4.10",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.4.10.tgz",
|
||||
"integrity": "sha512-XbfR58OkQpAe0xMSTrwJh9ZjGSzG9EZ7zwO6HfYem8PxcLYAcC6eWRWoSG/T0uObyrPTcYYyvHsp0eNQWYBFAQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -353,9 +353,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||
"version": "0.4.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.4.6.tgz",
|
||||
"integrity": "sha512-2wM+BKnjtZyKhiQPvldpfORH2JdKy6AuLFJ7AQtuyly57mkvgZRJeqK0DsRi/hyyZPRUOvWaDp/LfAxZvhLWgA==",
|
||||
"version": "0.4.10",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.4.10.tgz",
|
||||
"integrity": "sha512-x40WKH9b+KxorRmKr9G7fv8p5mMj8QJQvRMA0v6v+nbZHr2FLlAZV+9mvhHOnm4AGIkPP5335cUgv6Qz6hgwkQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -365,9 +365,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||
"version": "0.4.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.4.6.tgz",
|
||||
"integrity": "sha512-1BK9i3DnnFHyBVLxOfsIW2i800o9exDEHm5onikvfoa5Ot5tXwIwAw86+0HGsBm5YbJnKKxZmbAM6Pr9qfMKiQ==",
|
||||
"version": "0.4.10",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.4.10.tgz",
|
||||
"integrity": "sha512-CTGPpuzlqq2nVjUxI9gAJOT1oBANIovtIaFsOmBSnEAHgX7oeAxKy2b6L/kJzsgqSzvR5vfLwYcWFrr6ZmBxSA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -377,9 +377,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||
"version": "0.4.6",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.4.6.tgz",
|
||||
"integrity": "sha512-Fh/fw+HRf/LDZKCDQTvpWoacFfmLXGwQpcqxxlwIZ0vy45eCNYvnZrpjQBjej0uh3tEVC6OHh6Jhn7Pr9k8r2w==",
|
||||
"version": "0.4.10",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.4.10.tgz",
|
||||
"integrity": "sha512-Fd7r74coZyrKzkfXg4WthqOL+uKyJyPTia6imcrMNqKOlTGdKmHf02Qi2QxWZrFaabkRYo4Tpn5FeRJ3yYX8CA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.4.7",
|
||||
"version": "0.4.10",
|
||||
"description": " Serverless, low-latency vector database for AI applications",
|
||||
"main": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
@@ -17,7 +17,11 @@
|
||||
},
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "https://github.com/lancedb/lancedb/node"
|
||||
"url": "https://github.com/lancedb/lancedb.git"
|
||||
},
|
||||
"homepage": "https://lancedb.github.io/lancedb/",
|
||||
"bugs": {
|
||||
"url": "https://github.com/lancedb/lancedb/issues"
|
||||
},
|
||||
"keywords": [
|
||||
"data-format",
|
||||
@@ -81,10 +85,10 @@
|
||||
}
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.4.7",
|
||||
"@lancedb/vectordb-darwin-x64": "0.4.7",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.7",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.7",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.7"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.4.10",
|
||||
"@lancedb/vectordb-darwin-x64": "0.4.10",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.4.10",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.4.10",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.4.10"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,8 +14,6 @@
|
||||
|
||||
import {
|
||||
Field,
|
||||
type FixedSizeListBuilder,
|
||||
Float32,
|
||||
makeBuilder,
|
||||
RecordBatchFileWriter,
|
||||
Utf8,
|
||||
@@ -26,14 +24,19 @@ import {
|
||||
Table as ArrowTable,
|
||||
RecordBatchStreamWriter,
|
||||
List,
|
||||
Float64,
|
||||
RecordBatch,
|
||||
makeData,
|
||||
Struct,
|
||||
type Float
|
||||
type Float,
|
||||
DataType,
|
||||
Binary,
|
||||
Float32
|
||||
} from 'apache-arrow'
|
||||
import { type EmbeddingFunction } from './index'
|
||||
|
||||
/*
|
||||
* Options to control how a column should be converted to a vector array
|
||||
*/
|
||||
export class VectorColumnOptions {
|
||||
/** Vector column type. */
|
||||
type: Float = new Float32()
|
||||
@@ -45,14 +48,50 @@ export class VectorColumnOptions {
|
||||
|
||||
/** Options to control the makeArrowTable call. */
|
||||
export class MakeArrowTableOptions {
|
||||
/** Provided schema. */
|
||||
/*
|
||||
* Schema of the data.
|
||||
*
|
||||
* If this is not provided then the data type will be inferred from the
|
||||
* JS type. Integer numbers will become int64, floating point numbers
|
||||
* will become float64 and arrays will become variable sized lists with
|
||||
* the data type inferred from the first element in the array.
|
||||
*
|
||||
* The schema must be specified if there are no records (e.g. to make
|
||||
* an empty table)
|
||||
*/
|
||||
schema?: Schema
|
||||
|
||||
/** Vector columns */
|
||||
/*
|
||||
* Mapping from vector column name to expected type
|
||||
*
|
||||
* Lance expects vector columns to be fixed size list arrays (i.e. tensors)
|
||||
* However, `makeArrowTable` will not infer this by default (it creates
|
||||
* variable size list arrays). This field can be used to indicate that a column
|
||||
* should be treated as a vector column and converted to a fixed size list.
|
||||
*
|
||||
* The keys should be the names of the vector columns. The value specifies the
|
||||
* expected data type of the vector columns.
|
||||
*
|
||||
* If `schema` is provided then this field is ignored.
|
||||
*
|
||||
* By default, the column named "vector" will be assumed to be a float32
|
||||
* vector column.
|
||||
*/
|
||||
vectorColumns: Record<string, VectorColumnOptions> = {
|
||||
vector: new VectorColumnOptions()
|
||||
}
|
||||
|
||||
/**
|
||||
* If true then string columns will be encoded with dictionary encoding
|
||||
*
|
||||
* Set this to true if your string columns tend to repeat the same values
|
||||
* often. For more precise control use the `schema` property to specify the
|
||||
* data type for individual columns.
|
||||
*
|
||||
* If `schema` is provided then this property is ignored.
|
||||
*/
|
||||
dictionaryEncodeStrings: boolean = false
|
||||
|
||||
constructor (values?: Partial<MakeArrowTableOptions>) {
|
||||
Object.assign(this, values)
|
||||
}
|
||||
@@ -62,8 +101,29 @@ export class MakeArrowTableOptions {
|
||||
* An enhanced version of the {@link makeTable} function from Apache Arrow
|
||||
* that supports nested fields and embeddings columns.
|
||||
*
|
||||
* This function converts an array of Record<String, any> (row-major JS objects)
|
||||
* to an Arrow Table (a columnar structure)
|
||||
*
|
||||
* Note that it currently does not support nulls.
|
||||
*
|
||||
* If a schema is provided then it will be used to determine the resulting array
|
||||
* types. Fields will also be reordered to fit the order defined by the schema.
|
||||
*
|
||||
* If a schema is not provided then the types will be inferred and the field order
|
||||
* will be controlled by the order of properties in the first record.
|
||||
*
|
||||
* If the input is empty then a schema must be provided to create an empty table.
|
||||
*
|
||||
* When a schema is not specified then data types will be inferred. The inference
|
||||
* rules are as follows:
|
||||
*
|
||||
* - boolean => Bool
|
||||
* - number => Float64
|
||||
* - String => Utf8
|
||||
* - Buffer => Binary
|
||||
* - Record<String, any> => Struct
|
||||
* - Array<any> => List
|
||||
*
|
||||
* @param data input data
|
||||
* @param options options to control the makeArrowTable call.
|
||||
*
|
||||
@@ -86,8 +146,10 @@ export class MakeArrowTableOptions {
|
||||
* ], { schema });
|
||||
* ```
|
||||
*
|
||||
* It guesses the vector columns if the schema is not provided. For example,
|
||||
* by default it assumes that the column named `vector` is a vector column.
|
||||
* By default it assumes that the column named `vector` is a vector column
|
||||
* and it will be converted into a fixed size list array of type float32.
|
||||
* The `vectorColumns` option can be used to support other vector column
|
||||
* names and data types.
|
||||
*
|
||||
* ```ts
|
||||
*
|
||||
@@ -134,211 +196,304 @@ export function makeArrowTable (
|
||||
data: Array<Record<string, any>>,
|
||||
options?: Partial<MakeArrowTableOptions>
|
||||
): ArrowTable {
|
||||
if (data.length === 0) {
|
||||
throw new Error('At least one record needs to be provided')
|
||||
if (data.length === 0 && (options?.schema === undefined || options?.schema === null)) {
|
||||
throw new Error('At least one record or a schema needs to be provided')
|
||||
}
|
||||
|
||||
const opt = new MakeArrowTableOptions(options !== undefined ? options : {})
|
||||
const columns: Record<string, Vector> = {}
|
||||
// TODO: sample dataset to find missing columns
|
||||
const columnNames = Object.keys(data[0])
|
||||
// Prefer the field ordering of the schema, if present
|
||||
const columnNames = ((options?.schema) != null) ? (options?.schema?.names as string[]) : Object.keys(data[0])
|
||||
for (const colName of columnNames) {
|
||||
const values = data.map((datum) => datum[colName])
|
||||
let vector: Vector
|
||||
if (data.length !== 0 && !Object.prototype.hasOwnProperty.call(data[0], colName)) {
|
||||
// The field is present in the schema, but not in the data, skip it
|
||||
continue
|
||||
}
|
||||
// Extract a single column from the records (transpose from row-major to col-major)
|
||||
let values = data.map((datum) => datum[colName])
|
||||
|
||||
// By default (type === undefined) arrow will infer the type from the JS type
|
||||
let type
|
||||
if (opt.schema !== undefined) {
|
||||
// Explicit schema is provided, highest priority
|
||||
vector = vectorFromArray(
|
||||
values,
|
||||
opt.schema?.fields.filter((f) => f.name === colName)[0]?.type
|
||||
)
|
||||
// If there is a schema provided, then use that for the type instead
|
||||
type = opt.schema?.fields.filter((f) => f.name === colName)[0]?.type
|
||||
if (DataType.isInt(type) && type.bitWidth === 64) {
|
||||
// wrap in BigInt to avoid bug: https://github.com/apache/arrow/issues/40051
|
||||
values = values.map((v) => {
|
||||
if (v === null) {
|
||||
return v
|
||||
}
|
||||
return BigInt(v)
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// Otherwise, check to see if this column is one of the vector columns
|
||||
// defined by opt.vectorColumns and, if so, use the fixed size list type
|
||||
const vectorColumnOptions = opt.vectorColumns[colName]
|
||||
if (vectorColumnOptions !== undefined) {
|
||||
const fslType = new FixedSizeList(
|
||||
values[0].length,
|
||||
new Field('item', vectorColumnOptions.type, false)
|
||||
)
|
||||
vector = vectorFromArray(values, fslType)
|
||||
} else {
|
||||
// Normal case
|
||||
vector = vectorFromArray(values)
|
||||
type = newVectorType(values[0].length, vectorColumnOptions.type)
|
||||
}
|
||||
}
|
||||
columns[colName] = vector
|
||||
}
|
||||
|
||||
try {
|
||||
// Convert an Array of JS values to an arrow vector
|
||||
columns[colName] = makeVector(values, type, opt.dictionaryEncodeStrings)
|
||||
} catch (error: unknown) {
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
throw Error(`Could not convert column "${colName}" to Arrow: ${error}`)
|
||||
}
|
||||
}
|
||||
|
||||
if (opt.schema != null) {
|
||||
// `new ArrowTable(columns)` infers a schema which may sometimes have
|
||||
// incorrect nullability (it assumes nullable=true if there are 0 rows)
|
||||
//
|
||||
// `new ArrowTable(schema, columns)` will also fail because it will create a
|
||||
// batch with an inferred schema and then complain that the batch schema
|
||||
// does not match the provided schema.
|
||||
//
|
||||
// To work around this we first create a table with the wrong schema and
|
||||
// then patch the schema of the batches so we can use
|
||||
// `new ArrowTable(schema, batches)` which does not do any schema inference
|
||||
const firstTable = new ArrowTable(columns)
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
const batchesFixed = firstTable.batches.map(batch => new RecordBatch(opt.schema!, batch.data))
|
||||
return new ArrowTable(opt.schema, batchesFixed)
|
||||
} else {
|
||||
return new ArrowTable(columns)
|
||||
}
|
||||
}
|
||||
|
||||
// Converts an Array of records into an Arrow Table, optionally applying an embeddings function to it.
|
||||
/**
|
||||
* Create an empty Arrow table with the provided schema
|
||||
*/
|
||||
export function makeEmptyTable (schema: Schema): ArrowTable {
|
||||
return makeArrowTable([], { schema })
|
||||
}
|
||||
|
||||
// Helper function to convert Array<Array<any>> to a variable sized list array
|
||||
function makeListVector (lists: any[][]): Vector<any> {
|
||||
if (lists.length === 0 || lists[0].length === 0) {
|
||||
throw Error('Cannot infer list vector from empty array or empty list')
|
||||
}
|
||||
const sampleList = lists[0]
|
||||
let inferredType
|
||||
try {
|
||||
const sampleVector = makeVector(sampleList)
|
||||
inferredType = sampleVector.type
|
||||
} catch (error: unknown) {
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
throw Error(`Cannot infer list vector. Cannot infer inner type: ${error}`)
|
||||
}
|
||||
|
||||
const listBuilder = makeBuilder({
|
||||
type: new List(new Field('item', inferredType, true))
|
||||
})
|
||||
for (const list of lists) {
|
||||
listBuilder.append(list)
|
||||
}
|
||||
return listBuilder.finish().toVector()
|
||||
}
|
||||
|
||||
// Helper function to convert an Array of JS values to an Arrow Vector
|
||||
function makeVector (values: any[], type?: DataType, stringAsDictionary?: boolean): Vector<any> {
|
||||
if (type !== undefined) {
|
||||
// No need for inference, let Arrow create it
|
||||
return vectorFromArray(values, type)
|
||||
}
|
||||
if (values.length === 0) {
|
||||
throw Error('makeVector requires at least one value or the type must be specfied')
|
||||
}
|
||||
const sampleValue = values.find(val => val !== null && val !== undefined)
|
||||
if (sampleValue === undefined) {
|
||||
throw Error('makeVector cannot infer the type if all values are null or undefined')
|
||||
}
|
||||
if (Array.isArray(sampleValue)) {
|
||||
// Default Arrow inference doesn't handle list types
|
||||
return makeListVector(values)
|
||||
} else if (Buffer.isBuffer(sampleValue)) {
|
||||
// Default Arrow inference doesn't handle Buffer
|
||||
return vectorFromArray(values, new Binary())
|
||||
} else if (!(stringAsDictionary ?? false) && (typeof sampleValue === 'string' || sampleValue instanceof String)) {
|
||||
// If the type is string then don't use Arrow's default inference unless dictionaries are requested
|
||||
// because it will always use dictionary encoding for strings
|
||||
return vectorFromArray(values, new Utf8())
|
||||
} else {
|
||||
// Convert a JS array of values to an arrow vector
|
||||
return vectorFromArray(values)
|
||||
}
|
||||
}
|
||||
|
||||
async function applyEmbeddings<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<ArrowTable> {
|
||||
if (embeddings == null) {
|
||||
return table
|
||||
}
|
||||
|
||||
// Convert from ArrowTable to Record<String, Vector>
|
||||
const colEntries = [...Array(table.numCols).keys()].map((_, idx) => {
|
||||
const name = table.schema.fields[idx].name
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
const vec = table.getChildAt(idx)!
|
||||
return [name, vec]
|
||||
})
|
||||
const newColumns = Object.fromEntries(colEntries)
|
||||
|
||||
const sourceColumn = newColumns[embeddings.sourceColumn]
|
||||
const destColumn = embeddings.destColumn ?? 'vector'
|
||||
const innerDestType = embeddings.embeddingDataType ?? new Float32()
|
||||
if (sourceColumn === undefined) {
|
||||
throw new Error(`Cannot apply embedding function because the source column '${embeddings.sourceColumn}' was not present in the data`)
|
||||
}
|
||||
|
||||
if (table.numRows === 0) {
|
||||
if (Object.prototype.hasOwnProperty.call(newColumns, destColumn)) {
|
||||
// We have an empty table and it already has the embedding column so no work needs to be done
|
||||
// Note: we don't return an error like we did below because this is a common occurrence. For example,
|
||||
// if we call convertToTable with 0 records and a schema that includes the embedding
|
||||
return table
|
||||
}
|
||||
if (embeddings.embeddingDimension !== undefined) {
|
||||
const destType = newVectorType(embeddings.embeddingDimension, innerDestType)
|
||||
newColumns[destColumn] = makeVector([], destType)
|
||||
} else if (schema != null) {
|
||||
const destField = schema.fields.find(f => f.name === destColumn)
|
||||
if (destField != null) {
|
||||
newColumns[destColumn] = makeVector([], destField.type)
|
||||
} else {
|
||||
throw new Error(`Attempt to apply embeddings to an empty table failed because schema was missing embedding column '${destColumn}'`)
|
||||
}
|
||||
} else {
|
||||
throw new Error('Attempt to apply embeddings to an empty table when the embeddings function does not specify `embeddingDimension`')
|
||||
}
|
||||
} else {
|
||||
if (Object.prototype.hasOwnProperty.call(newColumns, destColumn)) {
|
||||
throw new Error(`Attempt to apply embeddings to table failed because column ${destColumn} already existed`)
|
||||
}
|
||||
if (table.batches.length > 1) {
|
||||
throw new Error('Internal error: `makeArrowTable` unexpectedly created a table with more than one batch')
|
||||
}
|
||||
const values = sourceColumn.toArray()
|
||||
const vectors = await embeddings.embed(values as T[])
|
||||
if (vectors.length !== values.length) {
|
||||
throw new Error('Embedding function did not return an embedding for each input element')
|
||||
}
|
||||
const destType = newVectorType(vectors[0].length, innerDestType)
|
||||
newColumns[destColumn] = makeVector(vectors, destType)
|
||||
}
|
||||
|
||||
const newTable = new ArrowTable(newColumns)
|
||||
if (schema != null) {
|
||||
if (schema.fields.find(f => f.name === destColumn) === undefined) {
|
||||
throw new Error(`When using embedding functions and specifying a schema the schema should include the embedding column but the column ${destColumn} was missing`)
|
||||
}
|
||||
return alignTable(newTable, schema)
|
||||
}
|
||||
return newTable
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert an Array of records into an Arrow Table, optionally applying an
|
||||
* embeddings function to it.
|
||||
*
|
||||
* This function calls `makeArrowTable` first to create the Arrow Table.
|
||||
* Any provided `makeTableOptions` (e.g. a schema) will be passed on to
|
||||
* that call.
|
||||
*
|
||||
* The embedding function will be passed a column of values (based on the
|
||||
* `sourceColumn` of the embedding function) and expects to receive back
|
||||
* number[][] which will be converted into a fixed size list column. By
|
||||
* default this will be a fixed size list of Float32 but that can be
|
||||
* customized by the `embeddingDataType` property of the embedding function.
|
||||
*
|
||||
* If a schema is provided in `makeTableOptions` then it should include the
|
||||
* embedding columns. If no schema is provded then embedding columns will
|
||||
* be placed at the end of the table, after all of the input columns.
|
||||
*/
|
||||
export async function convertToTable<T> (
|
||||
data: Array<Record<string, unknown>>,
|
||||
embeddings?: EmbeddingFunction<T>
|
||||
embeddings?: EmbeddingFunction<T>,
|
||||
makeTableOptions?: Partial<MakeArrowTableOptions>
|
||||
): Promise<ArrowTable> {
|
||||
if (data.length === 0) {
|
||||
throw new Error('At least one record needs to be provided')
|
||||
}
|
||||
|
||||
const columns = Object.keys(data[0])
|
||||
const records: Record<string, Vector> = {}
|
||||
|
||||
for (const columnsKey of columns) {
|
||||
if (columnsKey === 'vector') {
|
||||
const vectorSize = (data[0].vector as any[]).length
|
||||
const listBuilder = newVectorBuilder(vectorSize)
|
||||
for (const datum of data) {
|
||||
if ((datum[columnsKey] as any[]).length !== vectorSize) {
|
||||
throw new Error(`Invalid vector size, expected ${vectorSize}`)
|
||||
}
|
||||
|
||||
listBuilder.append(datum[columnsKey])
|
||||
}
|
||||
records[columnsKey] = listBuilder.finish().toVector()
|
||||
} else {
|
||||
const values = []
|
||||
for (const datum of data) {
|
||||
values.push(datum[columnsKey])
|
||||
}
|
||||
|
||||
if (columnsKey === embeddings?.sourceColumn) {
|
||||
const vectors = await embeddings.embed(values as T[])
|
||||
records.vector = vectorFromArray(
|
||||
vectors,
|
||||
newVectorType(vectors[0].length)
|
||||
)
|
||||
}
|
||||
|
||||
if (typeof values[0] === 'string') {
|
||||
// `vectorFromArray` converts strings into dictionary vectors, forcing it back to a string column
|
||||
records[columnsKey] = vectorFromArray(values, new Utf8())
|
||||
} else if (Array.isArray(values[0])) {
|
||||
const elementType = getElementType(values[0])
|
||||
let innerType
|
||||
if (elementType === 'string') {
|
||||
innerType = new Utf8()
|
||||
} else if (elementType === 'number') {
|
||||
innerType = new Float64()
|
||||
} else {
|
||||
// TODO: pass in schema if it exists, else keep going to the next element
|
||||
throw new Error(`Unsupported array element type ${elementType}`)
|
||||
}
|
||||
const listBuilder = makeBuilder({
|
||||
type: new List(new Field('item', innerType, true))
|
||||
})
|
||||
for (const value of values) {
|
||||
listBuilder.append(value)
|
||||
}
|
||||
records[columnsKey] = listBuilder.finish().toVector()
|
||||
} else {
|
||||
// TODO if this is a struct field then recursively align the subfields
|
||||
records[columnsKey] = vectorFromArray(values)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return new ArrowTable(records)
|
||||
}
|
||||
|
||||
function getElementType (arr: any[]): string {
|
||||
if (arr.length === 0) {
|
||||
return 'undefined'
|
||||
}
|
||||
|
||||
return typeof arr[0]
|
||||
}
|
||||
|
||||
// Creates a new Arrow ListBuilder that stores a Vector column
|
||||
function newVectorBuilder (dim: number): FixedSizeListBuilder<Float32> {
|
||||
return makeBuilder({
|
||||
type: newVectorType(dim)
|
||||
})
|
||||
const table = makeArrowTable(data, makeTableOptions)
|
||||
return await applyEmbeddings(table, embeddings, makeTableOptions?.schema)
|
||||
}
|
||||
|
||||
// Creates the Arrow Type for a Vector column with dimension `dim`
|
||||
function newVectorType (dim: number): FixedSizeList<Float32> {
|
||||
function newVectorType <T extends Float> (dim: number, innerType: T): FixedSizeList<T> {
|
||||
// Somewhere we always default to have the elements nullable, so we need to set it to true
|
||||
// otherwise we often get schema mismatches because the stored data always has schema with nullable elements
|
||||
const children = new Field<Float32>('item', new Float32(), true)
|
||||
const children = new Field<T>('item', innerType, true)
|
||||
return new FixedSizeList(dim, children)
|
||||
}
|
||||
|
||||
// Converts an Array of records into Arrow IPC format
|
||||
/**
|
||||
* Serialize an Array of records into a buffer using the Arrow IPC File serialization
|
||||
*
|
||||
* This function will call `convertToTable` and pass on `embeddings` and `schema`
|
||||
*
|
||||
* `schema` is required if data is empty
|
||||
*/
|
||||
export async function fromRecordsToBuffer<T> (
|
||||
data: Array<Record<string, unknown>>,
|
||||
embeddings?: EmbeddingFunction<T>,
|
||||
schema?: Schema
|
||||
): Promise<Buffer> {
|
||||
let table = await convertToTable(data, embeddings)
|
||||
if (schema !== undefined) {
|
||||
table = alignTable(table, schema)
|
||||
}
|
||||
const table = await convertToTable(data, embeddings, { schema })
|
||||
const writer = RecordBatchFileWriter.writeAll(table)
|
||||
return Buffer.from(await writer.toUint8Array())
|
||||
}
|
||||
|
||||
// Converts an Array of records into Arrow IPC stream format
|
||||
/**
|
||||
* Serialize an Array of records into a buffer using the Arrow IPC Stream serialization
|
||||
*
|
||||
* This function will call `convertToTable` and pass on `embeddings` and `schema`
|
||||
*
|
||||
* `schema` is required if data is empty
|
||||
*/
|
||||
export async function fromRecordsToStreamBuffer<T> (
|
||||
data: Array<Record<string, unknown>>,
|
||||
embeddings?: EmbeddingFunction<T>,
|
||||
schema?: Schema
|
||||
): Promise<Buffer> {
|
||||
let table = await convertToTable(data, embeddings)
|
||||
if (schema !== undefined) {
|
||||
table = alignTable(table, schema)
|
||||
}
|
||||
const table = await convertToTable(data, embeddings, { schema })
|
||||
const writer = RecordBatchStreamWriter.writeAll(table)
|
||||
return Buffer.from(await writer.toUint8Array())
|
||||
}
|
||||
|
||||
// Converts an Arrow Table into Arrow IPC format
|
||||
/**
|
||||
* Serialize an Arrow Table into a buffer using the Arrow IPC File serialization
|
||||
*
|
||||
* This function will apply `embeddings` to the table in a manner similar to
|
||||
* `convertToTable`.
|
||||
*
|
||||
* `schema` is required if the table is empty
|
||||
*/
|
||||
export async function fromTableToBuffer<T> (
|
||||
table: ArrowTable,
|
||||
embeddings?: EmbeddingFunction<T>,
|
||||
schema?: Schema
|
||||
): Promise<Buffer> {
|
||||
if (embeddings !== undefined) {
|
||||
const source = table.getChild(embeddings.sourceColumn)
|
||||
|
||||
if (source === null) {
|
||||
throw new Error(
|
||||
`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`
|
||||
)
|
||||
}
|
||||
|
||||
const vectors = await embeddings.embed(source.toArray() as T[])
|
||||
const column = vectorFromArray(vectors, newVectorType(vectors[0].length))
|
||||
table = table.assign(new ArrowTable({ vector: column }))
|
||||
}
|
||||
if (schema !== undefined) {
|
||||
table = alignTable(table, schema)
|
||||
}
|
||||
const writer = RecordBatchFileWriter.writeAll(table)
|
||||
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema)
|
||||
const writer = RecordBatchFileWriter.writeAll(tableWithEmbeddings)
|
||||
return Buffer.from(await writer.toUint8Array())
|
||||
}
|
||||
|
||||
// Converts an Arrow Table into Arrow IPC stream format
|
||||
/**
|
||||
* Serialize an Arrow Table into a buffer using the Arrow IPC Stream serialization
|
||||
*
|
||||
* This function will apply `embeddings` to the table in a manner similar to
|
||||
* `convertToTable`.
|
||||
*
|
||||
* `schema` is required if the table is empty
|
||||
*/
|
||||
export async function fromTableToStreamBuffer<T> (
|
||||
table: ArrowTable,
|
||||
embeddings?: EmbeddingFunction<T>,
|
||||
schema?: Schema
|
||||
): Promise<Buffer> {
|
||||
if (embeddings !== undefined) {
|
||||
const source = table.getChild(embeddings.sourceColumn)
|
||||
|
||||
if (source === null) {
|
||||
throw new Error(
|
||||
`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`
|
||||
)
|
||||
}
|
||||
|
||||
const vectors = await embeddings.embed(source.toArray() as T[])
|
||||
const column = vectorFromArray(vectors, newVectorType(vectors[0].length))
|
||||
table = table.assign(new ArrowTable({ vector: column }))
|
||||
}
|
||||
if (schema !== undefined) {
|
||||
table = alignTable(table, schema)
|
||||
}
|
||||
const writer = RecordBatchStreamWriter.writeAll(table)
|
||||
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema)
|
||||
const writer = RecordBatchStreamWriter.writeAll(tableWithEmbeddings)
|
||||
return Buffer.from(await writer.toUint8Array())
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import { type Float } from 'apache-arrow'
|
||||
|
||||
/**
|
||||
* An embedding function that automatically creates vector representation for a given column.
|
||||
*/
|
||||
@@ -21,6 +23,39 @@ export interface EmbeddingFunction<T> {
|
||||
*/
|
||||
sourceColumn: string
|
||||
|
||||
/**
|
||||
* The data type of the embedding
|
||||
*
|
||||
* The embedding function should return `number`. This will be converted into
|
||||
* an Arrow float array. By default this will be Float32 but this property can
|
||||
* be used to control the conversion.
|
||||
*/
|
||||
embeddingDataType?: Float
|
||||
|
||||
/**
|
||||
* The dimension of the embedding
|
||||
*
|
||||
* This is optional, normally this can be determined by looking at the results of
|
||||
* `embed`. If this is not specified, and there is an attempt to apply the embedding
|
||||
* to an empty table, then that process will fail.
|
||||
*/
|
||||
embeddingDimension?: number
|
||||
|
||||
/**
|
||||
* The name of the column that will contain the embedding
|
||||
*
|
||||
* By default this is "vector"
|
||||
*/
|
||||
destColumn?: string
|
||||
|
||||
/**
|
||||
* Should the source column be excluded from the resulting table
|
||||
*
|
||||
* By default the source column is included. Set this to true and
|
||||
* only the embedding will be stored.
|
||||
*/
|
||||
excludeSource?: boolean
|
||||
|
||||
/**
|
||||
* Creates a vector representation for the given values.
|
||||
*/
|
||||
|
||||
@@ -37,6 +37,7 @@ const {
|
||||
tableCountRows,
|
||||
tableDelete,
|
||||
tableUpdate,
|
||||
tableMergeInsert,
|
||||
tableCleanupOldVersions,
|
||||
tableCompactFiles,
|
||||
tableListIndices,
|
||||
@@ -48,7 +49,7 @@ const {
|
||||
export { Query }
|
||||
export type { EmbeddingFunction }
|
||||
export { OpenAIEmbeddingFunction } from './embedding/openai'
|
||||
export { makeArrowTable, type MakeArrowTableOptions } from './arrow'
|
||||
export { convertToTable, makeArrowTable, type MakeArrowTableOptions } from './arrow'
|
||||
|
||||
const defaultAwsRegion = 'us-west-2'
|
||||
|
||||
@@ -371,7 +372,7 @@ export interface Table<T = number[]> {
|
||||
/**
|
||||
* Returns the number of rows in this table.
|
||||
*/
|
||||
countRows: () => Promise<number>
|
||||
countRows: (filter?: string) => Promise<number>
|
||||
|
||||
/**
|
||||
* Delete rows from this table.
|
||||
@@ -440,6 +441,38 @@ export interface Table<T = number[]> {
|
||||
*/
|
||||
update: (args: UpdateArgs | UpdateSqlArgs) => Promise<void>
|
||||
|
||||
/**
|
||||
* Runs a "merge insert" operation on the table
|
||||
*
|
||||
* This operation can add rows, update rows, and remove rows all in a single
|
||||
* transaction. It is a very generic tool that can be used to create
|
||||
* behaviors like "insert if not exists", "update or insert (i.e. upsert)",
|
||||
* or even replace a portion of existing data with new data (e.g. replace
|
||||
* all data where month="january")
|
||||
*
|
||||
* The merge insert operation works by combining new data from a
|
||||
* **source table** with existing data in a **target table** by using a
|
||||
* join. There are three categories of records.
|
||||
*
|
||||
* "Matched" records are records that exist in both the source table and
|
||||
* the target table. "Not matched" records exist only in the source table
|
||||
* (e.g. these are new data) "Not matched by source" records exist only
|
||||
* in the target table (this is old data)
|
||||
*
|
||||
* The MergeInsertArgs can be used to customize what should happen for
|
||||
* each category of data.
|
||||
*
|
||||
* Please note that the data may appear to be reordered as part of this
|
||||
* operation. This is because updated rows will be deleted from the
|
||||
* dataset and then reinserted at the end with the new values.
|
||||
*
|
||||
* @param on a column to join on. This is how records from the source
|
||||
* table and target table are matched.
|
||||
* @param data the new data to insert
|
||||
* @param args parameters controlling how the operation should behave
|
||||
*/
|
||||
mergeInsert: (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs) => Promise<void>
|
||||
|
||||
/**
|
||||
* List the indicies on this table.
|
||||
*/
|
||||
@@ -483,6 +516,47 @@ export interface UpdateSqlArgs {
|
||||
valuesSql: Record<string, string>
|
||||
}
|
||||
|
||||
export interface MergeInsertArgs {
|
||||
/**
|
||||
* If true then rows that exist in both the source table (new data) and
|
||||
* the target table (old data) will be updated, replacing the old row
|
||||
* with the corresponding matching row.
|
||||
*
|
||||
* If there are multiple matches then the behavior is undefined.
|
||||
* Currently this causes multiple copies of the row to be created
|
||||
* but that behavior is subject to change.
|
||||
*
|
||||
* Optionally, a filter can be specified. This should be an SQL
|
||||
* filter where fields with the prefix "target." refer to fields
|
||||
* in the target table (old data) and fields with the prefix
|
||||
* "source." refer to fields in the source table (new data). For
|
||||
* example, the filter "target.lastUpdated < source.lastUpdated" will
|
||||
* only update matched rows when the incoming `lastUpdated` value is
|
||||
* newer.
|
||||
*
|
||||
* Rows that do not match the filter will not be updated. Rows that
|
||||
* do not match the filter do become "not matched" rows.
|
||||
*/
|
||||
whenMatchedUpdateAll?: string | boolean
|
||||
/**
|
||||
* If true then rows that exist only in the source table (new data)
|
||||
* will be inserted into the target table.
|
||||
*/
|
||||
whenNotMatchedInsertAll?: boolean
|
||||
/**
|
||||
* If true then rows that exist only in the target table (old data)
|
||||
* will be deleted.
|
||||
*
|
||||
* If this is a string then it will be treated as an SQL filter and
|
||||
* only rows that both do not match any row in the source table and
|
||||
* match the given filter will be deleted.
|
||||
*
|
||||
* This can be used to replace a selection of existing data with
|
||||
* new data.
|
||||
*/
|
||||
whenNotMatchedBySourceDelete?: string | boolean
|
||||
}
|
||||
|
||||
export interface VectorIndex {
|
||||
columns: string[]
|
||||
name: string
|
||||
@@ -777,8 +851,8 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
/**
|
||||
* Returns the number of rows in this table.
|
||||
*/
|
||||
async countRows (): Promise<number> {
|
||||
return tableCountRows.call(this._tbl)
|
||||
async countRows (filter?: string): Promise<number> {
|
||||
return tableCountRows.call(this._tbl, filter)
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -821,6 +895,46 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
})
|
||||
}
|
||||
|
||||
async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
|
||||
let whenMatchedUpdateAll = false
|
||||
let whenMatchedUpdateAllFilt = null
|
||||
if (args.whenMatchedUpdateAll !== undefined && args.whenMatchedUpdateAll !== null) {
|
||||
whenMatchedUpdateAll = true
|
||||
if (args.whenMatchedUpdateAll !== true) {
|
||||
whenMatchedUpdateAllFilt = args.whenMatchedUpdateAll
|
||||
}
|
||||
}
|
||||
const whenNotMatchedInsertAll = args.whenNotMatchedInsertAll ?? false
|
||||
let whenNotMatchedBySourceDelete = false
|
||||
let whenNotMatchedBySourceDeleteFilt = null
|
||||
if (args.whenNotMatchedBySourceDelete !== undefined && args.whenNotMatchedBySourceDelete !== null) {
|
||||
whenNotMatchedBySourceDelete = true
|
||||
if (args.whenNotMatchedBySourceDelete !== true) {
|
||||
whenNotMatchedBySourceDeleteFilt = args.whenNotMatchedBySourceDelete
|
||||
}
|
||||
}
|
||||
|
||||
const schema = await this.schema
|
||||
let tbl: ArrowTable
|
||||
if (data instanceof ArrowTable) {
|
||||
tbl = data
|
||||
} else {
|
||||
tbl = makeArrowTable(data, { schema })
|
||||
}
|
||||
const buffer = await fromTableToBuffer(tbl, this._embeddings, schema)
|
||||
|
||||
this._tbl = await tableMergeInsert.call(
|
||||
this._tbl,
|
||||
on,
|
||||
whenMatchedUpdateAll,
|
||||
whenMatchedUpdateAllFilt,
|
||||
whenNotMatchedInsertAll,
|
||||
whenNotMatchedBySourceDelete,
|
||||
whenNotMatchedBySourceDeleteFilt,
|
||||
buffer
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up old versions of the table, freeing disk space.
|
||||
*
|
||||
|
||||
@@ -24,7 +24,8 @@ import {
|
||||
type IndexStats,
|
||||
type UpdateArgs,
|
||||
type UpdateSqlArgs,
|
||||
makeArrowTable
|
||||
makeArrowTable,
|
||||
type MergeInsertArgs
|
||||
} from '../index'
|
||||
import { Query } from '../query'
|
||||
|
||||
@@ -274,6 +275,55 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
throw new Error('Not implemented')
|
||||
}
|
||||
|
||||
async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
|
||||
let tbl: ArrowTable
|
||||
if (data instanceof ArrowTable) {
|
||||
tbl = data
|
||||
} else {
|
||||
tbl = makeArrowTable(data, await this.schema)
|
||||
}
|
||||
|
||||
const queryParams: any = {
|
||||
on
|
||||
}
|
||||
if (args.whenMatchedUpdateAll !== false && args.whenMatchedUpdateAll !== null && args.whenMatchedUpdateAll !== undefined) {
|
||||
queryParams.when_matched_update_all = 'true'
|
||||
if (typeof args.whenMatchedUpdateAll === 'string') {
|
||||
queryParams.when_matched_update_all_filt = args.whenMatchedUpdateAll
|
||||
}
|
||||
} else {
|
||||
queryParams.when_matched_update_all = 'false'
|
||||
}
|
||||
if (args.whenNotMatchedInsertAll ?? false) {
|
||||
queryParams.when_not_matched_insert_all = 'true'
|
||||
} else {
|
||||
queryParams.when_not_matched_insert_all = 'false'
|
||||
}
|
||||
if (args.whenNotMatchedBySourceDelete !== false && args.whenNotMatchedBySourceDelete !== null && args.whenNotMatchedBySourceDelete !== undefined) {
|
||||
queryParams.when_not_matched_by_source_delete = 'true'
|
||||
if (typeof args.whenNotMatchedBySourceDelete === 'string') {
|
||||
queryParams.when_not_matched_by_source_delete_filt = args.whenNotMatchedBySourceDelete
|
||||
}
|
||||
} else {
|
||||
queryParams.when_not_matched_by_source_delete = 'false'
|
||||
}
|
||||
|
||||
const buffer = await fromTableToStreamBuffer(tbl, this._embeddings)
|
||||
const res = await this._client.post(
|
||||
`/v1/table/${this._name}/merge_insert/`,
|
||||
buffer,
|
||||
queryParams,
|
||||
'application/vnd.apache.arrow.stream'
|
||||
)
|
||||
if (res.status !== 200) {
|
||||
throw new Error(
|
||||
`Server Error, status: ${res.status}, ` +
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
`message: ${res.statusText}: ${res.data}`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
async add (data: Array<Record<string, unknown>> | ArrowTable): Promise<number> {
|
||||
let tbl: ArrowTable
|
||||
if (data instanceof ArrowTable) {
|
||||
|
||||
@@ -13,9 +13,10 @@
|
||||
// limitations under the License.
|
||||
|
||||
import { describe } from 'mocha'
|
||||
import { assert } from 'chai'
|
||||
import { assert, expect, use as chaiUse } from 'chai'
|
||||
import * as chaiAsPromised from 'chai-as-promised'
|
||||
|
||||
import { fromTableToBuffer, makeArrowTable } from '../arrow'
|
||||
import { convertToTable, fromTableToBuffer, makeArrowTable, makeEmptyTable } from '../arrow'
|
||||
import {
|
||||
Field,
|
||||
FixedSizeList,
|
||||
@@ -24,21 +25,79 @@ import {
|
||||
Int32,
|
||||
tableFromIPC,
|
||||
Schema,
|
||||
Float64
|
||||
Float64,
|
||||
type Table,
|
||||
Binary,
|
||||
Bool,
|
||||
Utf8,
|
||||
Struct,
|
||||
List,
|
||||
DataType,
|
||||
Dictionary,
|
||||
Int64
|
||||
} from 'apache-arrow'
|
||||
import { type EmbeddingFunction } from '../embedding/embedding_function'
|
||||
|
||||
describe('Apache Arrow tables', function () {
|
||||
it('customized schema', async function () {
|
||||
chaiUse(chaiAsPromised)
|
||||
|
||||
function sampleRecords (): Array<Record<string, any>> {
|
||||
return [
|
||||
{
|
||||
binary: Buffer.alloc(5),
|
||||
boolean: false,
|
||||
number: 7,
|
||||
string: 'hello',
|
||||
struct: { x: 0, y: 0 },
|
||||
list: ['anime', 'action', 'comedy']
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
// Helper method to verify various ways to create a table
|
||||
async function checkTableCreation (tableCreationMethod: (records: any, recordsReversed: any, schema: Schema) => Promise<Table>): Promise<void> {
|
||||
const records = sampleRecords()
|
||||
const recordsReversed = [{
|
||||
list: ['anime', 'action', 'comedy'],
|
||||
struct: { x: 0, y: 0 },
|
||||
string: 'hello',
|
||||
number: 7,
|
||||
boolean: false,
|
||||
binary: Buffer.alloc(5)
|
||||
}]
|
||||
const schema = new Schema([
|
||||
new Field('binary', new Binary(), false),
|
||||
new Field('boolean', new Bool(), false),
|
||||
new Field('number', new Float64(), false),
|
||||
new Field('string', new Utf8(), false),
|
||||
new Field('struct', new Struct([
|
||||
new Field('x', new Float64(), false),
|
||||
new Field('y', new Float64(), false)
|
||||
])),
|
||||
new Field('list', new List(new Field('item', new Utf8(), false)), false)
|
||||
])
|
||||
|
||||
const table = await tableCreationMethod(records, recordsReversed, schema)
|
||||
schema.fields.forEach((field, idx) => {
|
||||
const actualField = table.schema.fields[idx]
|
||||
assert.isFalse(actualField.nullable)
|
||||
assert.equal(table.getChild(field.name)?.type.toString(), field.type.toString())
|
||||
assert.equal(table.getChildAt(idx)?.type.toString(), field.type.toString())
|
||||
})
|
||||
}
|
||||
|
||||
describe('The function makeArrowTable', function () {
|
||||
it('will use data types from a provided schema instead of inference', async function () {
|
||||
const schema = new Schema([
|
||||
new Field('a', new Int32()),
|
||||
new Field('b', new Float32()),
|
||||
new Field('c', new FixedSizeList(3, new Field('item', new Float16())))
|
||||
new Field('c', new FixedSizeList(3, new Field('item', new Float16()))),
|
||||
new Field('d', new Int64())
|
||||
])
|
||||
const table = makeArrowTable(
|
||||
[
|
||||
{ a: 1, b: 2, c: [1, 2, 3] },
|
||||
{ a: 4, b: 5, c: [4, 5, 6] },
|
||||
{ a: 7, b: 8, c: [7, 8, 9] }
|
||||
{ a: 1, b: 2, c: [1, 2, 3], d: 9 },
|
||||
{ a: 4, b: 5, c: [4, 5, 6], d: 10 },
|
||||
{ a: 7, b: 8, c: [7, 8, 9], d: null }
|
||||
],
|
||||
{ schema }
|
||||
)
|
||||
@@ -52,13 +111,13 @@ describe('Apache Arrow tables', function () {
|
||||
assert.deepEqual(actualSchema, schema)
|
||||
})
|
||||
|
||||
it('default vector column', async function () {
|
||||
it('will assume the column `vector` is FixedSizeList<Float32> by default', async function () {
|
||||
const schema = new Schema([
|
||||
new Field('a', new Float64()),
|
||||
new Field('b', new Float64()),
|
||||
new Field(
|
||||
'vector',
|
||||
new FixedSizeList(3, new Field('item', new Float32()))
|
||||
new FixedSizeList(3, new Field('item', new Float32(), true))
|
||||
)
|
||||
])
|
||||
const table = makeArrowTable([
|
||||
@@ -76,12 +135,12 @@ describe('Apache Arrow tables', function () {
|
||||
assert.deepEqual(actualSchema, schema)
|
||||
})
|
||||
|
||||
it('2 vector columns', async function () {
|
||||
it('can support multiple vector columns', async function () {
|
||||
const schema = new Schema([
|
||||
new Field('a', new Float64()),
|
||||
new Field('b', new Float64()),
|
||||
new Field('vec1', new FixedSizeList(3, new Field('item', new Float16()))),
|
||||
new Field('vec2', new FixedSizeList(3, new Field('item', new Float16())))
|
||||
new Field('vec1', new FixedSizeList(3, new Field('item', new Float16(), true))),
|
||||
new Field('vec2', new FixedSizeList(3, new Field('item', new Float16(), true)))
|
||||
])
|
||||
const table = makeArrowTable(
|
||||
[
|
||||
@@ -105,4 +164,157 @@ describe('Apache Arrow tables', function () {
|
||||
const actualSchema = actual.schema
|
||||
assert.deepEqual(actualSchema, schema)
|
||||
})
|
||||
|
||||
it('will allow different vector column types', async function () {
|
||||
const table = makeArrowTable(
|
||||
[
|
||||
{ fp16: [1], fp32: [1], fp64: [1] }
|
||||
],
|
||||
{
|
||||
vectorColumns: {
|
||||
fp16: { type: new Float16() },
|
||||
fp32: { type: new Float32() },
|
||||
fp64: { type: new Float64() }
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
assert.equal(table.getChild('fp16')?.type.children[0].type.toString(), new Float16().toString())
|
||||
assert.equal(table.getChild('fp32')?.type.children[0].type.toString(), new Float32().toString())
|
||||
assert.equal(table.getChild('fp64')?.type.children[0].type.toString(), new Float64().toString())
|
||||
})
|
||||
|
||||
it('will use dictionary encoded strings if asked', async function () {
|
||||
const table = makeArrowTable([{ str: 'hello' }])
|
||||
assert.isTrue(DataType.isUtf8(table.getChild('str')?.type))
|
||||
|
||||
const tableWithDict = makeArrowTable([{ str: 'hello' }], { dictionaryEncodeStrings: true })
|
||||
assert.isTrue(DataType.isDictionary(tableWithDict.getChild('str')?.type))
|
||||
|
||||
const schema = new Schema([
|
||||
new Field('str', new Dictionary(new Utf8(), new Int32()))
|
||||
])
|
||||
|
||||
const tableWithDict2 = makeArrowTable([{ str: 'hello' }], { schema })
|
||||
assert.isTrue(DataType.isDictionary(tableWithDict2.getChild('str')?.type))
|
||||
})
|
||||
|
||||
it('will infer data types correctly', async function () {
|
||||
await checkTableCreation(async (records) => makeArrowTable(records))
|
||||
})
|
||||
|
||||
it('will allow a schema to be provided', async function () {
|
||||
await checkTableCreation(async (records, _, schema) => makeArrowTable(records, { schema }))
|
||||
})
|
||||
|
||||
it('will use the field order of any provided schema', async function () {
|
||||
await checkTableCreation(async (_, recordsReversed, schema) => makeArrowTable(recordsReversed, { schema }))
|
||||
})
|
||||
|
||||
it('will make an empty table', async function () {
|
||||
await checkTableCreation(async (_, __, schema) => makeArrowTable([], { schema }))
|
||||
})
|
||||
})
|
||||
|
||||
class DummyEmbedding implements EmbeddingFunction<string> {
|
||||
public readonly sourceColumn = 'string'
|
||||
public readonly embeddingDimension = 2
|
||||
public readonly embeddingDataType = new Float16()
|
||||
|
||||
async embed (data: string[]): Promise<number[][]> {
|
||||
return data.map(
|
||||
() => [0.0, 0.0]
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
class DummyEmbeddingWithNoDimension implements EmbeddingFunction<string> {
|
||||
public readonly sourceColumn = 'string'
|
||||
|
||||
async embed (data: string[]): Promise<number[][]> {
|
||||
return data.map(
|
||||
() => [0.0, 0.0]
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
describe('convertToTable', function () {
|
||||
it('will infer data types correctly', async function () {
|
||||
await checkTableCreation(async (records) => await convertToTable(records))
|
||||
})
|
||||
|
||||
it('will allow a schema to be provided', async function () {
|
||||
await checkTableCreation(async (records, _, schema) => await convertToTable(records, undefined, { schema }))
|
||||
})
|
||||
|
||||
it('will use the field order of any provided schema', async function () {
|
||||
await checkTableCreation(async (_, recordsReversed, schema) => await convertToTable(recordsReversed, undefined, { schema }))
|
||||
})
|
||||
|
||||
it('will make an empty table', async function () {
|
||||
await checkTableCreation(async (_, __, schema) => await convertToTable([], undefined, { schema }))
|
||||
})
|
||||
|
||||
it('will apply embeddings', async function () {
|
||||
const records = sampleRecords()
|
||||
const table = await convertToTable(records, new DummyEmbedding())
|
||||
assert.isTrue(DataType.isFixedSizeList(table.getChild('vector')?.type))
|
||||
assert.equal(table.getChild('vector')?.type.children[0].type.toString(), new Float16().toString())
|
||||
})
|
||||
|
||||
it('will fail if missing the embedding source column', async function () {
|
||||
return await expect(convertToTable([{ id: 1 }], new DummyEmbedding())).to.be.rejectedWith("'string' was not present")
|
||||
})
|
||||
|
||||
it('use embeddingDimension if embedding missing from table', async function () {
|
||||
const schema = new Schema([
|
||||
new Field('string', new Utf8(), false)
|
||||
])
|
||||
// Simulate getting an empty Arrow table (minus embedding) from some other source
|
||||
// In other words, we aren't starting with records
|
||||
const table = makeEmptyTable(schema)
|
||||
|
||||
// If the embedding specifies the dimension we are fine
|
||||
await fromTableToBuffer(table, new DummyEmbedding())
|
||||
|
||||
// We can also supply a schema and should be ok
|
||||
const schemaWithEmbedding = new Schema([
|
||||
new Field('string', new Utf8(), false),
|
||||
new Field('vector', new FixedSizeList(2, new Field('item', new Float16(), false)), false)
|
||||
])
|
||||
await fromTableToBuffer(table, new DummyEmbeddingWithNoDimension(), schemaWithEmbedding)
|
||||
|
||||
// Otherwise we will get an error
|
||||
return await expect(fromTableToBuffer(table, new DummyEmbeddingWithNoDimension())).to.be.rejectedWith('does not specify `embeddingDimension`')
|
||||
})
|
||||
|
||||
it('will apply embeddings to an empty table', async function () {
|
||||
const schema = new Schema([
|
||||
new Field('string', new Utf8(), false),
|
||||
new Field('vector', new FixedSizeList(2, new Field('item', new Float16(), false)), false)
|
||||
])
|
||||
const table = await convertToTable([], new DummyEmbedding(), { schema })
|
||||
assert.isTrue(DataType.isFixedSizeList(table.getChild('vector')?.type))
|
||||
assert.equal(table.getChild('vector')?.type.children[0].type.toString(), new Float16().toString())
|
||||
})
|
||||
|
||||
it('will complain if embeddings present but schema missing embedding column', async function () {
|
||||
const schema = new Schema([
|
||||
new Field('string', new Utf8(), false)
|
||||
])
|
||||
return await expect(convertToTable([], new DummyEmbedding(), { schema })).to.be.rejectedWith('column vector was missing')
|
||||
})
|
||||
|
||||
it('will provide a nice error if run twice', async function () {
|
||||
const records = sampleRecords()
|
||||
const table = await convertToTable(records, new DummyEmbedding())
|
||||
// fromTableToBuffer will try and apply the embeddings again
|
||||
return await expect(fromTableToBuffer(table, new DummyEmbedding())).to.be.rejectedWith('already existed')
|
||||
})
|
||||
})
|
||||
|
||||
describe('makeEmptyTable', function () {
|
||||
it('will make an empty table', async function () {
|
||||
await checkTableCreation(async (_, __, schema) => makeEmptyTable(schema))
|
||||
})
|
||||
})
|
||||
|
||||
@@ -294,6 +294,7 @@ describe('LanceDB client', function () {
|
||||
})
|
||||
assert.equal(table.name, 'vectors')
|
||||
assert.equal(await table.countRows(), 10)
|
||||
assert.equal(await table.countRows('vector IS NULL'), 0)
|
||||
assert.deepEqual(await con.tableNames(), ['vectors'])
|
||||
})
|
||||
|
||||
@@ -369,6 +370,7 @@ describe('LanceDB client', function () {
|
||||
const table = await con.createTable('f16', data)
|
||||
assert.equal(table.name, 'f16')
|
||||
assert.equal(await table.countRows(), total)
|
||||
assert.equal(await table.countRows('id < 5'), 5)
|
||||
assert.deepEqual(await con.tableNames(), ['f16'])
|
||||
assert.deepEqual(await table.schema, schema)
|
||||
|
||||
@@ -531,6 +533,54 @@ describe('LanceDB client', function () {
|
||||
assert.equal(await table.countRows(), 2)
|
||||
})
|
||||
|
||||
it('can merge insert records into the table', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
const data = [{ id: 1, age: 1 }, { id: 2, age: 1 }]
|
||||
const table = await con.createTable('my_table', data)
|
||||
|
||||
// insert if not exists
|
||||
let newData = [{ id: 2, age: 2 }, { id: 3, age: 2 }]
|
||||
await table.mergeInsert('id', newData, {
|
||||
whenNotMatchedInsertAll: true
|
||||
})
|
||||
assert.equal(await table.countRows(), 3)
|
||||
assert.equal(await table.countRows('age = 2'), 1)
|
||||
|
||||
// conditional update
|
||||
newData = [{ id: 2, age: 3 }, { id: 3, age: 3 }]
|
||||
await table.mergeInsert('id', newData, {
|
||||
whenMatchedUpdateAll: 'target.age = 1'
|
||||
})
|
||||
assert.equal(await table.countRows(), 3)
|
||||
assert.equal(await table.countRows('age = 1'), 1)
|
||||
assert.equal(await table.countRows('age = 3'), 1)
|
||||
|
||||
newData = [{ id: 3, age: 4 }, { id: 4, age: 4 }]
|
||||
await table.mergeInsert('id', newData, {
|
||||
whenNotMatchedInsertAll: true,
|
||||
whenMatchedUpdateAll: true
|
||||
})
|
||||
assert.equal(await table.countRows(), 4)
|
||||
assert.equal((await table.filter('age = 4').execute()).length, 2)
|
||||
|
||||
newData = [{ id: 5, age: 5 }]
|
||||
await table.mergeInsert('id', newData, {
|
||||
whenNotMatchedInsertAll: true,
|
||||
whenMatchedUpdateAll: true,
|
||||
whenNotMatchedBySourceDelete: 'age < 4'
|
||||
})
|
||||
assert.equal(await table.countRows(), 3)
|
||||
|
||||
await table.mergeInsert('id', newData, {
|
||||
whenNotMatchedInsertAll: true,
|
||||
whenMatchedUpdateAll: true,
|
||||
whenNotMatchedBySourceDelete: true
|
||||
})
|
||||
assert.equal(await table.countRows(), 1)
|
||||
})
|
||||
|
||||
it('can update records in the table', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
|
||||
@@ -9,6 +9,6 @@
|
||||
"declaration": true,
|
||||
"outDir": "./dist",
|
||||
"strict": true,
|
||||
// "esModuleInterop": true,
|
||||
"sourceMap": true,
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,12 @@
|
||||
[package]
|
||||
name = "vectordb-nodejs"
|
||||
edition = "2021"
|
||||
edition.workspace = true
|
||||
version = "0.0.0"
|
||||
license.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib"]
|
||||
@@ -14,15 +17,14 @@ futures.workspace = true
|
||||
lance-linalg.workspace = true
|
||||
lance.workspace = true
|
||||
vectordb = { path = "../rust/vectordb" }
|
||||
napi = { version = "2.14", default-features = false, features = [
|
||||
napi = { version = "2.15", default-features = false, features = [
|
||||
"napi7",
|
||||
"async"
|
||||
] }
|
||||
napi-derive = "2.14"
|
||||
napi-derive = "2"
|
||||
|
||||
# Prevent dynamic linking of lzma, which comes from datafusion
|
||||
lzma-sys = { version = "*", features = ["static"] }
|
||||
|
||||
[build-dependencies]
|
||||
napi-build = "2.1"
|
||||
|
||||
[profile.release]
|
||||
lto = true
|
||||
strip = "symbols"
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import { makeArrowTable, toBuffer } from "../vectordb/arrow";
|
||||
import {
|
||||
Int64,
|
||||
Field,
|
||||
FixedSizeList,
|
||||
Float16,
|
||||
@@ -104,3 +105,16 @@ test("2 vector columns", function () {
|
||||
const actualSchema = actual.schema;
|
||||
expect(actualSchema.toString()).toEqual(schema.toString());
|
||||
});
|
||||
|
||||
test("handles int64", function() {
|
||||
// https://github.com/lancedb/lancedb/issues/960
|
||||
const schema = new Schema([
|
||||
new Field("x", new Int64(), true)
|
||||
]);
|
||||
const table = makeArrowTable([
|
||||
{ x: 1 },
|
||||
{ x: 2 },
|
||||
{ x: 3 }
|
||||
], { schema });
|
||||
expect(table.schema).toEqual(schema);
|
||||
})
|
||||
@@ -2,4 +2,6 @@
|
||||
module.exports = {
|
||||
preset: 'ts-jest',
|
||||
testEnvironment: 'node',
|
||||
moduleDirectories: ["node_modules", "./dist"],
|
||||
moduleFileExtensions: ["js", "ts"],
|
||||
};
|
||||
@@ -57,8 +57,8 @@ impl Table {
|
||||
}
|
||||
|
||||
#[napi]
|
||||
pub async fn count_rows(&self) -> napi::Result<usize> {
|
||||
self.table.count_rows().await.map_err(|e| {
|
||||
pub async fn count_rows(&self, filter: Option<String>) -> napi::Result<usize> {
|
||||
self.table.count_rows(filter).await.map_err(|e| {
|
||||
napi::Error::from_reason(format!(
|
||||
"Failed to count rows in table {}: {}",
|
||||
self.table, e
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
import {
|
||||
Int64,
|
||||
Field,
|
||||
FixedSizeList,
|
||||
Float,
|
||||
@@ -23,6 +24,7 @@ import {
|
||||
Vector,
|
||||
vectorFromArray,
|
||||
tableToIPC,
|
||||
DataType,
|
||||
} from "apache-arrow";
|
||||
|
||||
/** Data type accepted by NodeJS SDK */
|
||||
@@ -137,15 +139,18 @@ export function makeArrowTable(
|
||||
const columnNames = Object.keys(data[0]);
|
||||
for (const colName of columnNames) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-return
|
||||
const values = data.map((datum) => datum[colName]);
|
||||
let values = data.map((datum) => datum[colName]);
|
||||
let vector: Vector;
|
||||
|
||||
if (opt.schema !== undefined) {
|
||||
// Explicit schema is provided, highest priority
|
||||
vector = vectorFromArray(
|
||||
values,
|
||||
opt.schema?.fields.filter((f) => f.name === colName)[0]?.type
|
||||
);
|
||||
const fieldType: DataType | undefined = opt.schema.fields.filter((f) => f.name === colName)[0]?.type as DataType;
|
||||
if (fieldType instanceof Int64) {
|
||||
// wrap in BigInt to avoid bug: https://github.com/apache/arrow/issues/40051
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-argument
|
||||
values = values.map((v) => BigInt(v));
|
||||
}
|
||||
vector = vectorFromArray(values, fieldType);
|
||||
} else {
|
||||
const vectorColumnOptions = opt.vectorColumns[colName];
|
||||
if (vectorColumnOptions !== undefined) {
|
||||
|
||||
2
nodejs/vectordb/native.d.ts
vendored
2
nodejs/vectordb/native.d.ts
vendored
@@ -73,7 +73,7 @@ export class Table {
|
||||
/** Return Schema as empty Arrow IPC file. */
|
||||
schema(): Buffer
|
||||
add(buf: Buffer): Promise<void>
|
||||
countRows(): Promise<bigint>
|
||||
countRows(filter?: string | undefined | null): Promise<bigint>
|
||||
delete(predicate: string): Promise<void>
|
||||
createIndex(): IndexBuilder
|
||||
query(): Query
|
||||
|
||||
@@ -50,8 +50,8 @@ export class Table {
|
||||
}
|
||||
|
||||
/** Count the total number of rows in the dataset. */
|
||||
async countRows(): Promise<bigint> {
|
||||
return await this.inner.countRows();
|
||||
async countRows(filter?: string): Promise<bigint> {
|
||||
return await this.inner.countRows(filter);
|
||||
}
|
||||
|
||||
/** Delete the rows that satisfy the predicate. */
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.5.1
|
||||
current_version = 0.5.6
|
||||
commit = True
|
||||
message = [python] Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
@@ -42,6 +42,12 @@ To run the unit tests:
|
||||
pytest
|
||||
```
|
||||
|
||||
To run the doc tests:
|
||||
|
||||
```bash
|
||||
pytest --doctest-modules lancedb
|
||||
```
|
||||
|
||||
To run linter and automatically fix all errors:
|
||||
|
||||
```bash
|
||||
|
||||
@@ -13,7 +13,9 @@
|
||||
|
||||
import importlib.metadata
|
||||
import os
|
||||
from typing import Optional
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from typing import Optional, Union
|
||||
|
||||
__version__ = importlib.metadata.version("lancedb")
|
||||
|
||||
@@ -30,6 +32,8 @@ def connect(
|
||||
api_key: Optional[str] = None,
|
||||
region: str = "us-east-1",
|
||||
host_override: Optional[str] = None,
|
||||
read_consistency_interval: Optional[timedelta] = None,
|
||||
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
|
||||
) -> DBConnection:
|
||||
"""Connect to a LanceDB database.
|
||||
|
||||
@@ -45,6 +49,25 @@ def connect(
|
||||
The region to use for LanceDB Cloud.
|
||||
host_override: str, optional
|
||||
The override url for LanceDB Cloud.
|
||||
read_consistency_interval: timedelta, default None
|
||||
(For LanceDB OSS only)
|
||||
The interval at which to check for updates to the table from other
|
||||
processes. If None, then consistency is not checked. For performance
|
||||
reasons, this is the default. For strong consistency, set this to
|
||||
zero seconds. Then every read will check for updates from other
|
||||
processes. As a compromise, you can set this to a non-zero timedelta
|
||||
for eventual consistency. If more than that interval has passed since
|
||||
the last check, then the table will be checked for updates. Note: this
|
||||
consistency only applies to read operations. Write operations are
|
||||
always consistent.
|
||||
request_thread_pool: int or ThreadPoolExecutor, optional
|
||||
The thread pool to use for making batch requests to the LanceDB Cloud API.
|
||||
If an integer, then a ThreadPoolExecutor will be created with that
|
||||
number of threads. If None, then a ThreadPoolExecutor will be created
|
||||
with the default number of threads. If a ThreadPoolExecutor, then that
|
||||
executor will be used for making requests. This is for LanceDB Cloud
|
||||
only and is only used when making batch requests (i.e., passing in
|
||||
multiple queries to the search method at once).
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -72,5 +95,9 @@ def connect(
|
||||
api_key = os.environ.get("LANCEDB_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError(f"api_key is required to connected LanceDB cloud: {uri}")
|
||||
return RemoteDBConnection(uri, api_key, region, host_override)
|
||||
return LanceDBConnection(uri)
|
||||
if isinstance(request_thread_pool, int):
|
||||
request_thread_pool = ThreadPoolExecutor(request_thread_pool)
|
||||
return RemoteDBConnection(
|
||||
uri, api_key, region, host_override, request_thread_pool=request_thread_pool
|
||||
)
|
||||
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)
|
||||
|
||||
@@ -16,9 +16,9 @@ from typing import Iterable, List, Union
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
from .util import safe_import
|
||||
from .util import safe_import_pandas
|
||||
|
||||
pd = safe_import("pandas")
|
||||
pd = safe_import_pandas()
|
||||
|
||||
DATA = Union[List[dict], dict, "pd.DataFrame", pa.Table, Iterable[pa.RecordBatch]]
|
||||
VEC = Union[list, np.ndarray, pa.Array, pa.ChunkedArray]
|
||||
|
||||
@@ -16,9 +16,9 @@ import deprecation
|
||||
|
||||
from . import __version__
|
||||
from .exceptions import MissingColumnError, MissingValueError
|
||||
from .util import safe_import
|
||||
from .util import safe_import_pandas
|
||||
|
||||
pd = safe_import("pandas")
|
||||
pd = safe_import_pandas()
|
||||
|
||||
|
||||
def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
|
||||
|
||||
@@ -26,6 +26,8 @@ from .table import LanceTable, Table
|
||||
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
|
||||
from .common import DATA, URI
|
||||
from .embeddings import EmbeddingFunctionConfig
|
||||
from .pydantic import LanceModel
|
||||
@@ -118,7 +120,7 @@ class DBConnection(EnforceOverrides):
|
||||
>>> data = [{"vector": [1.1, 1.2], "lat": 45.5, "long": -122.7},
|
||||
... {"vector": [0.2, 1.8], "lat": 40.1, "long": -74.1}]
|
||||
>>> db.create_table("my_table", data)
|
||||
LanceTable(my_table)
|
||||
LanceTable(connection=..., name="my_table")
|
||||
>>> db["my_table"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
@@ -139,7 +141,7 @@ class DBConnection(EnforceOverrides):
|
||||
... "long": [-122.7, -74.1]
|
||||
... })
|
||||
>>> db.create_table("table2", data)
|
||||
LanceTable(table2)
|
||||
LanceTable(connection=..., name="table2")
|
||||
>>> db["table2"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
@@ -161,7 +163,7 @@ class DBConnection(EnforceOverrides):
|
||||
... pa.field("long", pa.float32())
|
||||
... ])
|
||||
>>> db.create_table("table3", data, schema = custom_schema)
|
||||
LanceTable(table3)
|
||||
LanceTable(connection=..., name="table3")
|
||||
>>> db["table3"].head()
|
||||
pyarrow.Table
|
||||
vector: fixed_size_list<item: float>[2]
|
||||
@@ -195,7 +197,7 @@ class DBConnection(EnforceOverrides):
|
||||
... pa.field("price", pa.float32()),
|
||||
... ])
|
||||
>>> db.create_table("table4", make_batches(), schema=schema)
|
||||
LanceTable(table4)
|
||||
LanceTable(connection=..., name="table4")
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -243,6 +245,16 @@ class LanceDBConnection(DBConnection):
|
||||
----------
|
||||
uri: str or Path
|
||||
The root uri of the database.
|
||||
read_consistency_interval: timedelta, default None
|
||||
The interval at which to check for updates to the table from other
|
||||
processes. If None, then consistency is not checked. For performance
|
||||
reasons, this is the default. For strong consistency, set this to
|
||||
zero seconds. Then every read will check for updates from other
|
||||
processes. As a compromise, you can set this to a non-zero timedelta
|
||||
for eventual consistency. If more than that interval has passed since
|
||||
the last check, then the table will be checked for updates. Note: this
|
||||
consistency only applies to read operations. Write operations are
|
||||
always consistent.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -250,22 +262,24 @@ class LanceDBConnection(DBConnection):
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> db.create_table("my_table", data=[{"vector": [1.1, 1.2], "b": 2},
|
||||
... {"vector": [0.5, 1.3], "b": 4}])
|
||||
LanceTable(my_table)
|
||||
LanceTable(connection=..., name="my_table")
|
||||
>>> db.create_table("another_table", data=[{"vector": [0.4, 0.4], "b": 6}])
|
||||
LanceTable(another_table)
|
||||
LanceTable(connection=..., name="another_table")
|
||||
>>> sorted(db.table_names())
|
||||
['another_table', 'my_table']
|
||||
>>> len(db)
|
||||
2
|
||||
>>> db["my_table"]
|
||||
LanceTable(my_table)
|
||||
LanceTable(connection=..., name="my_table")
|
||||
>>> "my_table" in db
|
||||
True
|
||||
>>> db.drop_table("my_table")
|
||||
>>> db.drop_table("another_table")
|
||||
"""
|
||||
|
||||
def __init__(self, uri: URI):
|
||||
def __init__(
|
||||
self, uri: URI, *, read_consistency_interval: Optional[timedelta] = None
|
||||
):
|
||||
if not isinstance(uri, Path):
|
||||
scheme = get_uri_scheme(uri)
|
||||
is_local = isinstance(uri, Path) or scheme == "file"
|
||||
@@ -277,6 +291,14 @@ class LanceDBConnection(DBConnection):
|
||||
self._uri = str(uri)
|
||||
|
||||
self._entered = False
|
||||
self.read_consistency_interval = read_consistency_interval
|
||||
|
||||
def __repr__(self) -> str:
|
||||
val = f"{self.__class__.__name__}({self._uri}"
|
||||
if self.read_consistency_interval is not None:
|
||||
val += f", read_consistency_interval={repr(self.read_consistency_interval)}"
|
||||
val += ")"
|
||||
return val
|
||||
|
||||
@property
|
||||
def uri(self) -> str:
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
|
||||
@@ -91,25 +90,6 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
texts = texts.combine_chunks().to_pylist()
|
||||
return texts
|
||||
|
||||
@classmethod
|
||||
def safe_import(cls, module: str, mitigation=None):
|
||||
"""
|
||||
Import the specified module. If the module is not installed,
|
||||
raise an ImportError with a helpful message.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
module : str
|
||||
The name of the module to import
|
||||
mitigation : Optional[str]
|
||||
The package(s) to install to mitigate the error.
|
||||
If not provided then the module name will be used.
|
||||
"""
|
||||
try:
|
||||
return importlib.import_module(module)
|
||||
except ImportError:
|
||||
raise ImportError(f"Please install {mitigation or module}")
|
||||
|
||||
def safe_model_dump(self):
|
||||
from ..pydantic import PYDANTIC_VERSION
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import numpy as np
|
||||
|
||||
from lancedb.pydantic import PYDANTIC_VERSION
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import TEXT
|
||||
@@ -183,8 +184,8 @@ class BedRockText(TextEmbeddingFunction):
|
||||
boto3.client
|
||||
The boto3 client for Amazon Bedrock service
|
||||
"""
|
||||
botocore = self.safe_import("botocore")
|
||||
boto3 = self.safe_import("boto3")
|
||||
botocore = attempt_import_or_raise("botocore")
|
||||
boto3 = attempt_import_or_raise("boto3")
|
||||
|
||||
session_kwargs = {"region_name": self.region}
|
||||
client_kwargs = {**session_kwargs}
|
||||
|
||||
@@ -16,6 +16,7 @@ from typing import ClassVar, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import api_key_not_found_help
|
||||
@@ -84,7 +85,7 @@ class CohereEmbeddingFunction(TextEmbeddingFunction):
|
||||
return [emb for emb in rs.embeddings]
|
||||
|
||||
def _init_client(self):
|
||||
cohere = self.safe_import("cohere")
|
||||
cohere = attempt_import_or_raise("cohere")
|
||||
if CohereEmbeddingFunction.client is None:
|
||||
if os.environ.get("COHERE_API_KEY") is None:
|
||||
api_key_not_found_help("cohere")
|
||||
|
||||
@@ -19,6 +19,7 @@ import numpy as np
|
||||
|
||||
from lancedb.pydantic import PYDANTIC_VERSION
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import TEXT, api_key_not_found_help
|
||||
@@ -134,7 +135,7 @@ class GeminiText(TextEmbeddingFunction):
|
||||
|
||||
@cached_property
|
||||
def client(self):
|
||||
genai = self.safe_import("google.generativeai", "google.generativeai")
|
||||
genai = attempt_import_or_raise("google.generativeai", "google.generativeai")
|
||||
|
||||
if not os.environ.get("GOOGLE_API_KEY"):
|
||||
api_key_not_found_help("google")
|
||||
|
||||
@@ -14,6 +14,7 @@ from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import weak_lru
|
||||
@@ -122,7 +123,7 @@ class GteEmbeddings(TextEmbeddingFunction):
|
||||
|
||||
return Model()
|
||||
else:
|
||||
sentence_transformers = self.safe_import(
|
||||
sentence_transformers = attempt_import_or_raise(
|
||||
"sentence_transformers", "sentence-transformers"
|
||||
)
|
||||
return sentence_transformers.SentenceTransformer(
|
||||
|
||||
@@ -14,6 +14,7 @@ from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import TEXT, weak_lru
|
||||
@@ -102,9 +103,9 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
||||
# convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly
|
||||
|
||||
source_instruction: str = "represent the document for retrieval"
|
||||
query_instruction: str = (
|
||||
"represent the document for retrieving the most similar documents"
|
||||
)
|
||||
query_instruction: (
|
||||
str
|
||||
) = "represent the document for retrieving the most similar documents"
|
||||
|
||||
@weak_lru(maxsize=1)
|
||||
def ndims(self):
|
||||
@@ -131,10 +132,10 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
||||
|
||||
@weak_lru(maxsize=1)
|
||||
def get_model(self):
|
||||
instructor_embedding = self.safe_import(
|
||||
instructor_embedding = attempt_import_or_raise(
|
||||
"InstructorEmbedding", "InstructorEmbedding"
|
||||
)
|
||||
torch = self.safe_import("torch", "torch")
|
||||
torch = attempt_import_or_raise("torch", "torch")
|
||||
|
||||
model = instructor_embedding.INSTRUCTOR(self.name)
|
||||
if self.quantize:
|
||||
|
||||
@@ -21,6 +21,7 @@ import pyarrow as pa
|
||||
from pydantic import PrivateAttr
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import EmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import IMAGES, url_retrieve
|
||||
@@ -50,7 +51,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
open_clip = self.safe_import("open_clip", "open-clip")
|
||||
open_clip = attempt_import_or_raise("open_clip", "open-clip")
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||
self.name, pretrained=self.pretrained
|
||||
)
|
||||
@@ -78,14 +79,14 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
||||
if isinstance(query, str):
|
||||
return [self.generate_text_embeddings(query)]
|
||||
else:
|
||||
PIL = self.safe_import("PIL", "pillow")
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(query, PIL.Image.Image):
|
||||
return [self.generate_image_embedding(query)]
|
||||
else:
|
||||
raise TypeError("OpenClip supports str or PIL Image as query")
|
||||
|
||||
def generate_text_embeddings(self, text: str) -> np.ndarray:
|
||||
torch = self.safe_import("torch")
|
||||
torch = attempt_import_or_raise("torch")
|
||||
text = self.sanitize_input(text)
|
||||
text = self._tokenizer(text)
|
||||
text.to(self.device)
|
||||
@@ -144,7 +145,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
||||
The image to embed. If the image is a str, it is treated as a uri.
|
||||
If the image is bytes, it is treated as the raw image bytes.
|
||||
"""
|
||||
torch = self.safe_import("torch")
|
||||
torch = attempt_import_or_raise("torch")
|
||||
# TODO handle retry and errors for https
|
||||
image = self._to_pil(image)
|
||||
image = self._preprocess(image).unsqueeze(0)
|
||||
@@ -152,7 +153,7 @@ class OpenClipEmbeddings(EmbeddingFunction):
|
||||
return self._encode_and_normalize_image(image)
|
||||
|
||||
def _to_pil(self, image: Union[str, bytes]):
|
||||
PIL = self.safe_import("PIL", "pillow")
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(image, bytes):
|
||||
return PIL.Image.open(io.BytesIO(image))
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
|
||||
@@ -12,10 +12,11 @@
|
||||
# limitations under the License.
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import List, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import api_key_not_found_help
|
||||
@@ -30,10 +31,21 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
|
||||
name: str = "text-embedding-ada-002"
|
||||
dim: Optional[int] = None
|
||||
|
||||
def ndims(self):
|
||||
# TODO don't hardcode this
|
||||
return self._ndims
|
||||
|
||||
@cached_property
|
||||
def _ndims(self):
|
||||
if self.name == "text-embedding-ada-002":
|
||||
return 1536
|
||||
elif self.name == "text-embedding-3-large":
|
||||
return self.dim or 3072
|
||||
elif self.name == "text-embedding-3-small":
|
||||
return self.dim or 1536
|
||||
else:
|
||||
raise ValueError(f"Unknown model name {self.name}")
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
@@ -47,12 +59,17 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
The texts to embed
|
||||
"""
|
||||
# TODO retry, rate limit, token limit
|
||||
if self.name == "text-embedding-ada-002":
|
||||
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
|
||||
else:
|
||||
rs = self._openai_client.embeddings.create(
|
||||
input=texts, model=self.name, dimensions=self.ndims()
|
||||
)
|
||||
return [v.embedding for v in rs.data]
|
||||
|
||||
@cached_property
|
||||
def _openai_client(self):
|
||||
openai = self.safe_import("openai")
|
||||
openai = attempt_import_or_raise("openai")
|
||||
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
api_key_not_found_help("openai")
|
||||
|
||||
@@ -14,6 +14,7 @@ from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import weak_lru
|
||||
@@ -75,7 +76,7 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
|
||||
TODO: use lru_cache instead with a reasonable/configurable maxsize
|
||||
"""
|
||||
sentence_transformers = self.safe_import(
|
||||
sentence_transformers = attempt_import_or_raise(
|
||||
"sentence_transformers", "sentence-transformers"
|
||||
)
|
||||
return sentence_transformers.SentenceTransformer(self.name, device=self.device)
|
||||
|
||||
@@ -26,10 +26,10 @@ import pyarrow as pa
|
||||
from lance.vector import vec_to_table
|
||||
from retry import retry
|
||||
|
||||
from ..util import safe_import
|
||||
from ..util import deprecated, safe_import_pandas
|
||||
from ..utils.general import LOGGER
|
||||
|
||||
pd = safe_import("pandas")
|
||||
pd = safe_import_pandas()
|
||||
|
||||
DATA = Union[pa.Table, "pd.DataFrame"]
|
||||
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
@@ -38,6 +38,7 @@ IMAGES = Union[
|
||||
]
|
||||
|
||||
|
||||
@deprecated
|
||||
def with_embeddings(
|
||||
func: Callable,
|
||||
data: DATA,
|
||||
|
||||
107
python/lancedb/merge.py
Normal file
107
python/lancedb/merge.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# Copyright 2023 LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .common import DATA
|
||||
|
||||
|
||||
class LanceMergeInsertBuilder(object):
|
||||
"""Builder for a LanceDB merge insert operation
|
||||
|
||||
See [`merge_insert`][lancedb.table.Table.merge_insert] for
|
||||
more context
|
||||
"""
|
||||
|
||||
def __init__(self, table: "Table", on: List[str]): # noqa: F821
|
||||
# Do not put a docstring here. This method should be hidden
|
||||
# from API docs. Users should use merge_insert to create
|
||||
# this object.
|
||||
self._table = table
|
||||
self._on = on
|
||||
self._when_matched_update_all = False
|
||||
self._when_matched_update_all_condition = None
|
||||
self._when_not_matched_insert_all = False
|
||||
self._when_not_matched_by_source_delete = False
|
||||
self._when_not_matched_by_source_condition = None
|
||||
|
||||
def when_matched_update_all(
|
||||
self, *, where: Optional[str] = None
|
||||
) -> LanceMergeInsertBuilder:
|
||||
"""
|
||||
Rows that exist in both the source table (new data) and
|
||||
the target table (old data) will be updated, replacing
|
||||
the old row with the corresponding matching row.
|
||||
|
||||
If there are multiple matches then the behavior is undefined.
|
||||
Currently this causes multiple copies of the row to be created
|
||||
but that behavior is subject to change.
|
||||
"""
|
||||
self._when_matched_update_all = True
|
||||
self._when_matched_update_all_condition = where
|
||||
return self
|
||||
|
||||
def when_not_matched_insert_all(self) -> LanceMergeInsertBuilder:
|
||||
"""
|
||||
Rows that exist only in the source table (new data) should
|
||||
be inserted into the target table.
|
||||
"""
|
||||
self._when_not_matched_insert_all = True
|
||||
return self
|
||||
|
||||
def when_not_matched_by_source_delete(
|
||||
self, condition: Optional[str] = None
|
||||
) -> LanceMergeInsertBuilder:
|
||||
"""
|
||||
Rows that exist only in the target table (old data) will be
|
||||
deleted. An optional condition can be provided to limit what
|
||||
data is deleted.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
condition: Optional[str], default None
|
||||
If None then all such rows will be deleted. Otherwise the
|
||||
condition will be used as an SQL filter to limit what rows
|
||||
are deleted.
|
||||
"""
|
||||
self._when_not_matched_by_source_delete = True
|
||||
if condition is not None:
|
||||
self._when_not_matched_by_source_condition = condition
|
||||
return self
|
||||
|
||||
def execute(
|
||||
self,
|
||||
new_data: DATA,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
):
|
||||
"""
|
||||
Executes the merge insert operation
|
||||
|
||||
Nothing is returned but the [`Table`][lancedb.table.Table] is updated
|
||||
|
||||
Parameters
|
||||
----------
|
||||
new_data: DATA
|
||||
New records which will be matched against the existing records
|
||||
to potentially insert or update into the table. This parameter
|
||||
can be anything you use for [`add`][lancedb.table.Table.add]
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
"""
|
||||
self._table._do_merge(self, new_data, on_bad_vectors, fill_value)
|
||||
@@ -304,7 +304,7 @@ class LanceModel(pydantic.BaseModel):
|
||||
... name: str
|
||||
... vector: Vector(2)
|
||||
...
|
||||
>>> db = lancedb.connect("/tmp")
|
||||
>>> db = lancedb.connect("./example")
|
||||
>>> table = db.create_table("test", schema=TestModel.to_arrow_schema())
|
||||
>>> table.add([
|
||||
... TestModel(name="test", vector=[1.0, 2.0])
|
||||
|
||||
@@ -24,10 +24,10 @@ import pyarrow as pa
|
||||
import pydantic
|
||||
|
||||
from . import __version__
|
||||
from .common import VEC, VECTOR_COLUMN_NAME
|
||||
from .common import VEC
|
||||
from .rerankers.base import Reranker
|
||||
from .rerankers.linear_combination import LinearCombinationReranker
|
||||
from .util import safe_import
|
||||
from .util import safe_import_pandas
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import PIL
|
||||
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
|
||||
from .pydantic import LanceModel
|
||||
from .table import Table
|
||||
|
||||
pd = safe_import("pandas")
|
||||
pd = safe_import_pandas()
|
||||
|
||||
|
||||
class Query(pydantic.BaseModel):
|
||||
@@ -75,7 +75,7 @@ class Query(pydantic.BaseModel):
|
||||
tuning advice.
|
||||
"""
|
||||
|
||||
vector_column: str = VECTOR_COLUMN_NAME
|
||||
vector_column: Optional[str] = None
|
||||
|
||||
# vector to search for
|
||||
vector: Union[List[float], List[List[float]]]
|
||||
@@ -403,7 +403,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self,
|
||||
table: "Table",
|
||||
query: Union[np.ndarray, list, "PIL.Image.Image"],
|
||||
vector_column: str = VECTOR_COLUMN_NAME,
|
||||
vector_column: str,
|
||||
):
|
||||
super().__init__(table)
|
||||
self._query = query
|
||||
@@ -626,7 +626,6 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
def __init__(self, table: "Table", query: str, vector_column: str):
|
||||
super().__init__(table)
|
||||
self._validate_fts_index()
|
||||
self._query = query
|
||||
vector_query, fts_query = self._validate_query(query)
|
||||
self._fts_query = LanceFtsQueryBuilder(table, fts_query)
|
||||
vector_query = self._query_to_vector(table, vector_query, vector_column)
|
||||
@@ -679,12 +678,18 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
# rerankers might need to preserve this score to support `return_score="all"`
|
||||
fts_results = self._normalize_scores(fts_results, "score")
|
||||
|
||||
results = self._reranker.rerank_hybrid(self, vector_results, fts_results)
|
||||
results = self._reranker.rerank_hybrid(
|
||||
self._fts_query._query, vector_results, fts_results
|
||||
)
|
||||
|
||||
if not isinstance(results, pa.Table): # Enforce type
|
||||
raise TypeError(
|
||||
f"rerank_hybrid must return a pyarrow.Table, got {type(results)}"
|
||||
)
|
||||
|
||||
# apply limit after reranking
|
||||
results = results.slice(length=self._limit)
|
||||
|
||||
if not self._with_row_id:
|
||||
results = results.drop(["_rowid"])
|
||||
return results
|
||||
@@ -776,6 +781,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
"""
|
||||
self._vector_query.limit(limit)
|
||||
self._fts_query.limit(limit)
|
||||
self._limit = limit
|
||||
|
||||
return self
|
||||
|
||||
def select(self, columns: list) -> LanceHybridQueryBuilder:
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from urllib.parse import urljoin
|
||||
|
||||
@@ -20,6 +22,8 @@ import attrs
|
||||
import pyarrow as pa
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3 import Retry
|
||||
|
||||
from lancedb.common import Credential
|
||||
from lancedb.remote import VectorQuery, VectorQueryResult
|
||||
@@ -57,6 +61,10 @@ class RestfulLanceDBClient:
|
||||
@functools.cached_property
|
||||
def session(self) -> requests.Session:
|
||||
sess = requests.Session()
|
||||
|
||||
retry_adapter_instance = retry_adapter(retry_adapter_options())
|
||||
sess.mount(urljoin(self.url, "/v1/table/"), retry_adapter_instance)
|
||||
|
||||
adapter_class = LanceDBClientHTTPAdapterFactory()
|
||||
sess.mount("https://", adapter_class())
|
||||
return sess
|
||||
@@ -170,3 +178,72 @@ class RestfulLanceDBClient:
|
||||
"""Query a table."""
|
||||
tbl = self.post(f"/v1/table/{table_name}/query/", query, deserialize=_read_ipc)
|
||||
return VectorQueryResult(tbl)
|
||||
|
||||
def mount_retry_adapter_for_table(self, table_name: str) -> None:
|
||||
"""
|
||||
Adds an http adapter to session that will retry retryable requests to the table.
|
||||
"""
|
||||
retry_options = retry_adapter_options(methods=["GET", "POST"])
|
||||
retry_adapter_instance = retry_adapter(retry_options)
|
||||
session = self.session
|
||||
|
||||
session.mount(
|
||||
urljoin(self.url, f"/v1/table/{table_name}/query/"), retry_adapter_instance
|
||||
)
|
||||
session.mount(
|
||||
urljoin(self.url, f"/v1/table/{table_name}/describe/"),
|
||||
retry_adapter_instance,
|
||||
)
|
||||
session.mount(
|
||||
urljoin(self.url, f"/v1/table/{table_name}/index/list/"),
|
||||
retry_adapter_instance,
|
||||
)
|
||||
|
||||
|
||||
def retry_adapter_options(methods=["GET"]) -> Dict[str, Any]:
|
||||
return {
|
||||
"retries": int(os.environ.get("LANCE_CLIENT_MAX_RETRIES", "3")),
|
||||
"connect_retries": int(os.environ.get("LANCE_CLIENT_CONNECT_RETRIES", "3")),
|
||||
"read_retries": int(os.environ.get("LANCE_CLIENT_READ_RETRIES", "3")),
|
||||
"backoff_factor": float(
|
||||
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_FACTOR", "0.25")
|
||||
),
|
||||
"backoff_jitter": float(
|
||||
os.environ.get("LANCE_CLIENT_RETRY_BACKOFF_JITTER", "0.25")
|
||||
),
|
||||
"statuses": [
|
||||
int(i.strip())
|
||||
for i in os.environ.get(
|
||||
"LANCE_CLIENT_RETRY_STATUSES", "429, 500, 502, 503"
|
||||
).split(",")
|
||||
],
|
||||
"methods": methods,
|
||||
}
|
||||
|
||||
|
||||
def retry_adapter(options: Dict[str, Any]) -> HTTPAdapter:
|
||||
total_retries = options["retries"]
|
||||
connect_retries = options["connect_retries"]
|
||||
read_retries = options["read_retries"]
|
||||
backoff_factor = options["backoff_factor"]
|
||||
backoff_jitter = options["backoff_jitter"]
|
||||
statuses = options["statuses"]
|
||||
methods = frozenset(options["methods"])
|
||||
logging.debug(
|
||||
f"Setting up retry adapter with {total_retries} retries," # noqa G003
|
||||
+ f"connect retries {connect_retries}, read retries {read_retries},"
|
||||
+ f"backoff factor {backoff_factor}, statuses {statuses}, "
|
||||
+ f"methods {methods}"
|
||||
)
|
||||
|
||||
return HTTPAdapter(
|
||||
max_retries=Retry(
|
||||
total=total_retries,
|
||||
connect=connect_retries,
|
||||
read=read_retries,
|
||||
backoff_factor=backoff_factor,
|
||||
backoff_jitter=backoff_jitter,
|
||||
status_forcelist=statuses,
|
||||
allowed_methods=methods,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
import inspect
|
||||
import logging
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Iterable, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -39,6 +40,7 @@ class RemoteDBConnection(DBConnection):
|
||||
api_key: str,
|
||||
region: str,
|
||||
host_override: Optional[str] = None,
|
||||
request_thread_pool: Optional[ThreadPoolExecutor] = None,
|
||||
):
|
||||
"""Connect to a remote LanceDB database."""
|
||||
parsed = urlparse(db_url)
|
||||
@@ -49,6 +51,7 @@ class RemoteDBConnection(DBConnection):
|
||||
self._client = RestfulLanceDBClient(
|
||||
self.db_name, region, api_key, host_override
|
||||
)
|
||||
self._request_thread_pool = request_thread_pool
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoteConnect(name={self.db_name})"
|
||||
@@ -95,6 +98,8 @@ class RemoteDBConnection(DBConnection):
|
||||
"""
|
||||
from .table import RemoteTable
|
||||
|
||||
self._client.mount_retry_adapter_for_table(name)
|
||||
|
||||
# check if table exists
|
||||
try:
|
||||
self._client.post(f"/v1/table/{name}/describe/")
|
||||
@@ -116,6 +121,7 @@ class RemoteDBConnection(DBConnection):
|
||||
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
mode: Optional[str] = None,
|
||||
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||
) -> Table:
|
||||
"""Create a [Table][lancedb.table.Table] in the database.
|
||||
@@ -213,11 +219,13 @@ class RemoteDBConnection(DBConnection):
|
||||
if data is None and schema is None:
|
||||
raise ValueError("Either data or schema must be provided.")
|
||||
if embedding_functions is not None:
|
||||
raise NotImplementedError(
|
||||
"embedding_functions is not supported for remote databases."
|
||||
logging.warning(
|
||||
"embedding_functions is not yet supported on LanceDB Cloud."
|
||||
"Please vote https://github.com/lancedb/lancedb/issues/626 "
|
||||
"for this feature."
|
||||
)
|
||||
if mode is not None:
|
||||
logging.warning("mode is not yet supported on LanceDB Cloud.")
|
||||
|
||||
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
||||
# convert LanceModel to pyarrow schema
|
||||
|
||||
@@ -11,7 +11,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from concurrent.futures import Future
|
||||
from functools import cached_property
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
@@ -19,10 +21,11 @@ import pyarrow as pa
|
||||
from lance import json_to_schema
|
||||
|
||||
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from lancedb.merge import LanceMergeInsertBuilder
|
||||
|
||||
from ..query import LanceVectorQueryBuilder
|
||||
from ..table import Query, Table, _sanitize_data
|
||||
from ..util import value_to_sql
|
||||
from ..util import inf_vector_column_query, value_to_sql
|
||||
from .arrow import to_ipc_binary
|
||||
from .client import ARROW_STREAM_CONTENT_TYPE
|
||||
from .db import RemoteDBConnection
|
||||
@@ -36,6 +39,9 @@ class RemoteTable(Table):
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoteTable({self._conn.db_name}.{self._name})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
self.count_rows(None)
|
||||
|
||||
@cached_property
|
||||
def schema(self) -> pa.Schema:
|
||||
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
|
||||
@@ -53,17 +59,17 @@ class RemoteTable(Table):
|
||||
return resp["version"]
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
"""to_arrow() is not supported on the LanceDB cloud"""
|
||||
raise NotImplementedError("to_arrow() is not supported on the LanceDB cloud")
|
||||
"""to_arrow() is not yet supported on LanceDB cloud."""
|
||||
raise NotImplementedError("to_arrow() is not yet supported on LanceDB cloud.")
|
||||
|
||||
def to_pandas(self):
|
||||
"""to_pandas() is not supported on the LanceDB cloud"""
|
||||
return NotImplementedError("to_pandas() is not supported on the LanceDB cloud")
|
||||
"""to_pandas() is not yet supported on LanceDB cloud."""
|
||||
return NotImplementedError("to_pandas() is not yet supported on LanceDB cloud.")
|
||||
|
||||
def create_scalar_index(self, *args, **kwargs):
|
||||
"""Creates a scalar index"""
|
||||
return NotImplementedError(
|
||||
"create_scalar_index() is not supported on the LanceDB cloud"
|
||||
"create_scalar_index() is not yet supported on LanceDB cloud."
|
||||
)
|
||||
|
||||
def create_index(
|
||||
@@ -71,6 +77,10 @@ class RemoteTable(Table):
|
||||
metric="L2",
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
index_cache_size: Optional[int] = None,
|
||||
num_partitions: Optional[int] = None,
|
||||
num_sub_vectors: Optional[int] = None,
|
||||
replace: Optional[bool] = None,
|
||||
accelerator: Optional[str] = None,
|
||||
):
|
||||
"""Create an index on the table.
|
||||
Currently, the only parameters that matter are
|
||||
@@ -104,6 +114,28 @@ class RemoteTable(Table):
|
||||
... )
|
||||
>>> table.create_index("L2", "vector") # doctest: +SKIP
|
||||
"""
|
||||
|
||||
if num_partitions is not None:
|
||||
logging.warning(
|
||||
"num_partitions is not supported on LanceDB cloud."
|
||||
"This parameter will be tuned automatically."
|
||||
)
|
||||
if num_sub_vectors is not None:
|
||||
logging.warning(
|
||||
"num_sub_vectors is not supported on LanceDB cloud."
|
||||
"This parameter will be tuned automatically."
|
||||
)
|
||||
if accelerator is not None:
|
||||
logging.warning(
|
||||
"GPU accelerator is not yet supported on LanceDB cloud."
|
||||
"If you have 100M+ vectors to index,"
|
||||
"please contact us at contact@lancedb.com"
|
||||
)
|
||||
if replace is not None:
|
||||
logging.warning(
|
||||
"replace is not supported on LanceDB cloud."
|
||||
"Existing indexes will always be replaced."
|
||||
)
|
||||
index_type = "vector"
|
||||
|
||||
data = {
|
||||
@@ -167,7 +199,9 @@ class RemoteTable(Table):
|
||||
)
|
||||
|
||||
def search(
|
||||
self, query: Union[VEC, str], vector_column_name: str = VECTOR_COLUMN_NAME
|
||||
self,
|
||||
query: Union[VEC, str],
|
||||
vector_column_name: Optional[str] = None,
|
||||
) -> LanceVectorQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector. We currently support [vector search][search]
|
||||
@@ -186,7 +220,7 @@ class RemoteTable(Table):
|
||||
... ]
|
||||
>>> table = db.create_table("my_table", data) # doctest: +SKIP
|
||||
>>> query = [0.4, 1.4, 2.4]
|
||||
>>> (table.search(query, vector_column_name="vector") # doctest: +SKIP
|
||||
>>> (table.search(query) # doctest: +SKIP
|
||||
... .where("original_width > 1000", prefilter=True) # doctest: +SKIP
|
||||
... .select(["caption", "original_width"]) # doctest: +SKIP
|
||||
... .limit(2) # doctest: +SKIP
|
||||
@@ -205,9 +239,14 @@ class RemoteTable(Table):
|
||||
|
||||
- If None then the select/where/limit clauses are applied to filter
|
||||
the table
|
||||
vector_column_name: str
|
||||
vector_column_name: str, optional
|
||||
The name of the vector column to search.
|
||||
*default "vector"*
|
||||
|
||||
- If not specified then the vector column is inferred from
|
||||
the table schema
|
||||
|
||||
- If the table has multiple vector columns then the *vector_column_name*
|
||||
needs to be specified. Otherwise, an error is raised.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -222,6 +261,8 @@ class RemoteTable(Table):
|
||||
- and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
if vector_column_name is None:
|
||||
vector_column_name = inf_vector_column_query(self.schema)
|
||||
return LanceVectorQueryBuilder(self, query, vector_column_name)
|
||||
|
||||
def _execute_query(self, query: Query) -> pa.Table:
|
||||
@@ -230,20 +271,78 @@ class RemoteTable(Table):
|
||||
and len(query.vector) > 0
|
||||
and not isinstance(query.vector[0], float)
|
||||
):
|
||||
if self._conn._request_thread_pool is None:
|
||||
|
||||
def submit(name, q):
|
||||
f = Future()
|
||||
f.set_result(self._conn._client.query(name, q))
|
||||
return f
|
||||
else:
|
||||
|
||||
def submit(name, q):
|
||||
return self._conn._request_thread_pool.submit(
|
||||
self._conn._client.query, name, q
|
||||
)
|
||||
|
||||
results = []
|
||||
for v in query.vector:
|
||||
v = list(v)
|
||||
q = query.copy()
|
||||
q.vector = v
|
||||
results.append(self._conn._client.query(self._name, q))
|
||||
results.append(submit(self._name, q))
|
||||
|
||||
return pa.concat_tables(
|
||||
[add_index(r.to_arrow(), i) for i, r in enumerate(results)]
|
||||
[add_index(r.result().to_arrow(), i) for i, r in enumerate(results)]
|
||||
)
|
||||
else:
|
||||
result = self._conn._client.query(self._name, query)
|
||||
return result.to_arrow()
|
||||
|
||||
def _do_merge(
|
||||
self,
|
||||
merge: LanceMergeInsertBuilder,
|
||||
new_data: DATA,
|
||||
on_bad_vectors: str,
|
||||
fill_value: float,
|
||||
):
|
||||
data = _sanitize_data(
|
||||
new_data,
|
||||
self.schema,
|
||||
metadata=None,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
payload = to_ipc_binary(data)
|
||||
|
||||
params = {}
|
||||
if len(merge._on) != 1:
|
||||
raise ValueError(
|
||||
"RemoteTable only supports a single on key in merge_insert"
|
||||
)
|
||||
params["on"] = merge._on[0]
|
||||
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
|
||||
if merge._when_matched_update_all_condition is not None:
|
||||
params[
|
||||
"when_matched_update_all_filt"
|
||||
] = merge._when_matched_update_all_condition
|
||||
params["when_not_matched_insert_all"] = str(
|
||||
merge._when_not_matched_insert_all
|
||||
).lower()
|
||||
params["when_not_matched_by_source_delete"] = str(
|
||||
merge._when_not_matched_by_source_delete
|
||||
).lower()
|
||||
if merge._when_not_matched_by_source_condition is not None:
|
||||
params[
|
||||
"when_not_matched_by_source_delete_filt"
|
||||
] = merge._when_not_matched_by_source_condition
|
||||
|
||||
self._conn._client.post(
|
||||
f"/v1/table/{self._name}/merge_insert/",
|
||||
data=payload,
|
||||
params=params,
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
)
|
||||
|
||||
def delete(self, predicate: str):
|
||||
"""Delete rows from the table.
|
||||
|
||||
@@ -355,6 +454,25 @@ class RemoteTable(Table):
|
||||
payload = {"predicate": where, "updates": updates}
|
||||
self._conn._client.post(f"/v1/table/{self._name}/update/", data=payload)
|
||||
|
||||
def cleanup_old_versions(self, *_):
|
||||
"""cleanup_old_versions() is not supported on the LanceDB cloud"""
|
||||
raise NotImplementedError(
|
||||
"cleanup_old_versions() is not supported on the LanceDB cloud"
|
||||
)
|
||||
|
||||
def compact_files(self, *_):
|
||||
"""compact_files() is not supported on the LanceDB cloud"""
|
||||
raise NotImplementedError(
|
||||
"compact_files() is not supported on the LanceDB cloud"
|
||||
)
|
||||
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
# payload = {"filter": filter}
|
||||
# self._conn._client.post(f"/v1/table/{self._name}/count_rows/", data=payload)
|
||||
return NotImplementedError(
|
||||
"count_rows() is not yet supported on the LanceDB cloud"
|
||||
)
|
||||
|
||||
|
||||
def add_index(tbl: pa.Table, i: int) -> pa.Table:
|
||||
return tbl.add_column(
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
from .base import Reranker
|
||||
from .cohere import CohereReranker
|
||||
from .colbert import ColbertReranker
|
||||
from .cross_encoder import CrossEncoderReranker
|
||||
from .linear_combination import LinearCombinationReranker
|
||||
from .openai import OpenaiReranker
|
||||
|
||||
__all__ = [
|
||||
"Reranker",
|
||||
"CrossEncoderReranker",
|
||||
"CohereReranker",
|
||||
"LinearCombinationReranker",
|
||||
"OpenaiReranker",
|
||||
"ColbertReranker",
|
||||
]
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import lancedb
|
||||
|
||||
|
||||
class Reranker(ABC):
|
||||
def __init__(self, return_score: str = "relevance"):
|
||||
@@ -30,7 +26,7 @@ class Reranker(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def rerank_hybrid(
|
||||
query_builder: "lancedb.HybridQueryBuilder",
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
@@ -41,8 +37,8 @@ class Reranker(ABC):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query_builder : "lancedb.HybridQueryBuilder"
|
||||
The query builder object that was used to generate the results
|
||||
query : str
|
||||
The input query
|
||||
vector_results : pa.Table
|
||||
The results from the vector search
|
||||
fts_results : pa.Table
|
||||
@@ -50,36 +46,6 @@ class Reranker(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def rerank_vector(
|
||||
query_builder: "lancedb.VectorQueryBuilder", vector_results: pa.Table
|
||||
):
|
||||
"""
|
||||
Rerank function receives the individual results from the vector search.
|
||||
This isn't mandatory to implement
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query_builder : "lancedb.VectorQueryBuilder"
|
||||
The query builder object that was used to generate the results
|
||||
vector_results : pa.Table
|
||||
The results from the vector search
|
||||
"""
|
||||
raise NotImplementedError("Vector Reranking is not implemented")
|
||||
|
||||
def rerank_fts(query_builder: "lancedb.FTSQueryBuilder", fts_results: pa.Table):
|
||||
"""
|
||||
Rerank function receives the individual results from the FTS search.
|
||||
This isn't mandatory to implement
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query_builder : "lancedb.FTSQueryBuilder"
|
||||
The query builder object that was used to generate the results
|
||||
fts_results : pa.Table
|
||||
The results from the FTS search
|
||||
"""
|
||||
raise NotImplementedError("FTS Reranking is not implemented")
|
||||
|
||||
def merge_results(self, vector_results: pa.Table, fts_results: pa.Table):
|
||||
"""
|
||||
Merge the results from the vector and FTS search. This is a vanilla merging
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
import os
|
||||
import typing
|
||||
from functools import cached_property
|
||||
from typing import Union
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import safe_import
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import Reranker
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import lancedb
|
||||
|
||||
|
||||
class CohereReranker(Reranker):
|
||||
"""
|
||||
@@ -45,7 +41,7 @@ class CohereReranker(Reranker):
|
||||
|
||||
@cached_property
|
||||
def _client(self):
|
||||
cohere = safe_import("cohere")
|
||||
cohere = attempt_import_or_raise("cohere")
|
||||
if os.environ.get("COHERE_API_KEY") is None and self.api_key is None:
|
||||
raise ValueError(
|
||||
"COHERE_API_KEY not set. Either set it in your environment or \
|
||||
@@ -55,14 +51,14 @@ class CohereReranker(Reranker):
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query_builder: "lancedb.HybridQueryBuilder",
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
docs = combined_results[self.column].to_pylist()
|
||||
results = self._client.rerank(
|
||||
query=query_builder._query,
|
||||
query=query,
|
||||
documents=docs,
|
||||
top_n=self.top_n,
|
||||
model=self.model_name,
|
||||
|
||||
109
python/lancedb/rerankers/colbert.py
Normal file
109
python/lancedb/rerankers/colbert.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from functools import cached_property
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import Reranker
|
||||
|
||||
|
||||
class ColbertReranker(Reranker):
|
||||
"""
|
||||
Reranks the results using the ColBERT model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str, default "colbert-ir/colbertv2.0"
|
||||
The name of the cross encoder model to use.
|
||||
column : str, default "text"
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
return_score : str, default "relevance"
|
||||
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "colbert-ir/colbertv2.0",
|
||||
column: str = "text",
|
||||
return_score="relevance",
|
||||
):
|
||||
super().__init__(return_score)
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
self.torch = attempt_import_or_raise(
|
||||
"torch"
|
||||
) # import here for faster ops later
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
docs = combined_results[self.column].to_pylist()
|
||||
|
||||
tokenizer, model = self._model
|
||||
|
||||
# Encode the query
|
||||
query_encoding = tokenizer(query, return_tensors="pt")
|
||||
query_embedding = model(**query_encoding).last_hidden_state.mean(dim=1)
|
||||
scores = []
|
||||
# Get score for each document
|
||||
for document in docs:
|
||||
document_encoding = tokenizer(
|
||||
document, return_tensors="pt", truncation=True, max_length=512
|
||||
)
|
||||
document_embedding = model(**document_encoding).last_hidden_state
|
||||
# Calculate MaxSim score
|
||||
score = self.maxsim(query_embedding.unsqueeze(0), document_embedding)
|
||||
scores.append(score.item())
|
||||
|
||||
# replace the self.column column with the docs
|
||||
combined_results = combined_results.drop(self.column)
|
||||
combined_results = combined_results.append_column(
|
||||
self.column, pa.array(docs, type=pa.string())
|
||||
)
|
||||
# add the scores
|
||||
combined_results = combined_results.append_column(
|
||||
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||
)
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"OpenAI Reranker does not support score='all' yet"
|
||||
)
|
||||
|
||||
combined_results = combined_results.sort_by(
|
||||
[("_relevance_score", "descending")]
|
||||
)
|
||||
|
||||
return combined_results
|
||||
|
||||
@cached_property
|
||||
def _model(self):
|
||||
transformers = attempt_import_or_raise("transformers")
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
|
||||
model = transformers.AutoModel.from_pretrained(self.model_name)
|
||||
|
||||
return tokenizer, model
|
||||
|
||||
def maxsim(self, query_embedding, document_embedding):
|
||||
# Expand dimensions for broadcasting
|
||||
# Query: [batch, length, size] -> [batch, query, 1, size]
|
||||
# Document: [batch, length, size] -> [batch, 1, length, size]
|
||||
expanded_query = query_embedding.unsqueeze(2)
|
||||
expanded_doc = document_embedding.unsqueeze(1)
|
||||
|
||||
# Compute cosine similarity across the embedding dimension
|
||||
sim_matrix = self.torch.nn.functional.cosine_similarity(
|
||||
expanded_query, expanded_doc, dim=-1
|
||||
)
|
||||
|
||||
# Take the maximum similarity for each query token (across all document tokens)
|
||||
# sim_matrix shape: [batch_size, query_length, doc_length]
|
||||
max_sim_scores, _ = self.torch.max(sim_matrix, dim=2)
|
||||
|
||||
# Average these maximum scores across all query tokens
|
||||
avg_max_sim = self.torch.mean(max_sim_scores, dim=1)
|
||||
return avg_max_sim
|
||||
@@ -1,15 +1,11 @@
|
||||
import typing
|
||||
from functools import cached_property
|
||||
from typing import Union
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import safe_import
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import Reranker
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import lancedb
|
||||
|
||||
|
||||
class CrossEncoderReranker(Reranker):
|
||||
"""
|
||||
@@ -36,7 +32,7 @@ class CrossEncoderReranker(Reranker):
|
||||
return_score="relevance",
|
||||
):
|
||||
super().__init__(return_score)
|
||||
torch = safe_import("torch")
|
||||
torch = attempt_import_or_raise("torch")
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
self.device = device
|
||||
@@ -45,20 +41,20 @@ class CrossEncoderReranker(Reranker):
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
sbert = safe_import("sentence_transformers")
|
||||
sbert = attempt_import_or_raise("sentence_transformers")
|
||||
cross_encoder = sbert.CrossEncoder(self.model_name)
|
||||
|
||||
return cross_encoder
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query_builder: "lancedb.HybridQueryBuilder",
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
passages = combined_results[self.column].to_pylist()
|
||||
cross_inp = [[query_builder._query, passage] for passage in passages]
|
||||
cross_inp = [[query, passage] for passage in passages]
|
||||
cross_scores = self.model.predict(cross_inp)
|
||||
combined_results = combined_results.append_column(
|
||||
"_relevance_score", pa.array(cross_scores, type=pa.float32())
|
||||
|
||||
@@ -36,7 +36,7 @@ class LinearCombinationReranker(Reranker):
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query_builder: "lancedb.HybridQueryBuilder", # noqa: F821
|
||||
query: str, # noqa: F821
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
|
||||
104
python/lancedb/rerankers/openai.py
Normal file
104
python/lancedb/rerankers/openai.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import json
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import Optional
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import Reranker
|
||||
|
||||
|
||||
class OpenaiReranker(Reranker):
|
||||
"""
|
||||
Reranks the results using the OpenAI API.
|
||||
WARNING: This is a prompt based reranker that uses chat model that is
|
||||
not a dedicated reranker API. This should be treated as experimental.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str, default "gpt-4-turbo-preview"
|
||||
The name of the cross encoder model to use.
|
||||
column : str, default "text"
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
return_score : str, default "relevance"
|
||||
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||
api_key : str, default None
|
||||
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "gpt-4-turbo-preview",
|
||||
column: str = "text",
|
||||
return_score="relevance",
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
super().__init__(return_score)
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
self.api_key = api_key
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
docs = combined_results[self.column].to_pylist()
|
||||
response = self._client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
response_format={"type": "json_object"},
|
||||
temperature=0,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are an expert relevance ranker. Given a list of\
|
||||
documents and a query, your job is to determine the relevance\
|
||||
each document is for answering the query. Your output is JSON,\
|
||||
which is a list of documents. Each document has two fields,\
|
||||
content and relevance_score. relevance_score is from 0.0 to\
|
||||
1.0 indicating the relevance of the text to the given query.\
|
||||
Make sure to include all documents in the response.",
|
||||
},
|
||||
{"role": "user", "content": f"Query: {query} Docs: {docs}"},
|
||||
],
|
||||
)
|
||||
results = json.loads(response.choices[0].message.content)["documents"]
|
||||
docs, scores = list(
|
||||
zip(*[(result["content"], result["relevance_score"]) for result in results])
|
||||
) # tuples
|
||||
# replace the self.column column with the docs
|
||||
combined_results = combined_results.drop(self.column)
|
||||
combined_results = combined_results.append_column(
|
||||
self.column, pa.array(docs, type=pa.string())
|
||||
)
|
||||
# add the scores
|
||||
combined_results = combined_results.append_column(
|
||||
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||
)
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"OpenAI Reranker does not support score='all' yet"
|
||||
)
|
||||
|
||||
combined_results = combined_results.sort_by(
|
||||
[("_relevance_score", "descending")]
|
||||
)
|
||||
|
||||
return combined_results
|
||||
|
||||
@cached_property
|
||||
def _client(self):
|
||||
openai = attempt_import_or_raise(
|
||||
"openai"
|
||||
) # TODO: force version or handle versions < 1.0
|
||||
if os.environ.get("OPENAI_API_KEY") is None and self.api_key is None:
|
||||
raise ValueError(
|
||||
"OPENAI_API_KEY not set. Either set it in your environment or \
|
||||
pass it as `api_key` argument to the CohereReranker."
|
||||
)
|
||||
return openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY") or self.api_key)
|
||||
@@ -14,7 +14,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
@@ -28,27 +31,28 @@ from lance.vector import vec_to_table
|
||||
|
||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from .merge import LanceMergeInsertBuilder
|
||||
from .pydantic import LanceModel, model_to_dict
|
||||
from .query import LanceQueryBuilder, Query
|
||||
from .util import (
|
||||
fs_from_uri,
|
||||
inf_vector_column_query,
|
||||
join_uri,
|
||||
safe_import,
|
||||
safe_import_pandas,
|
||||
safe_import_polars,
|
||||
value_to_sql,
|
||||
)
|
||||
from .utils.events import register_event
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
|
||||
import PIL
|
||||
from lance.dataset import CleanupStats, ReaderLike
|
||||
|
||||
from .db import LanceDBConnection
|
||||
|
||||
|
||||
pd = safe_import("pandas")
|
||||
pl = safe_import("polars")
|
||||
pd = safe_import_pandas()
|
||||
pl = safe_import_polars()
|
||||
|
||||
|
||||
def _sanitize_data(
|
||||
@@ -174,6 +178,18 @@ class Table(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
"""
|
||||
Count the number of rows in the table.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filter: str, optional
|
||||
A SQL where clause to filter the rows to count.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_pandas(self) -> "pd.DataFrame":
|
||||
"""Return the table as a pandas DataFrame.
|
||||
|
||||
@@ -297,7 +313,7 @@ class Table(ABC):
|
||||
|
||||
import lance
|
||||
|
||||
dataset = lance.dataset("/tmp/images.lance")
|
||||
dataset = lance.dataset("./images.lance")
|
||||
dataset.create_scalar_index("category")
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -334,11 +350,71 @@ class Table(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||
"""
|
||||
Returns a [`LanceMergeInsertBuilder`][lancedb.merge.LanceMergeInsertBuilder]
|
||||
that can be used to create a "merge insert" operation
|
||||
|
||||
This operation can add rows, update rows, and remove rows all in a single
|
||||
transaction. It is a very generic tool that can be used to create
|
||||
behaviors like "insert if not exists", "update or insert (i.e. upsert)",
|
||||
or even replace a portion of existing data with new data (e.g. replace
|
||||
all data where month="january")
|
||||
|
||||
The merge insert operation works by combining new data from a
|
||||
**source table** with existing data in a **target table** by using a
|
||||
join. There are three categories of records.
|
||||
|
||||
"Matched" records are records that exist in both the source table and
|
||||
the target table. "Not matched" records exist only in the source table
|
||||
(e.g. these are new data) "Not matched by source" records exist only
|
||||
in the target table (this is old data)
|
||||
|
||||
The builder returned by this method can be used to customize what
|
||||
should happen for each category of data.
|
||||
|
||||
Please note that the data may appear to be reordered as part of this
|
||||
operation. This is because updated rows will be deleted from the
|
||||
dataset and then reinserted at the end with the new values.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
on: Union[str, Iterable[str]]
|
||||
A column (or columns) to join on. This is how records from the
|
||||
source table and target table are matched. Typically this is some
|
||||
kind of key or id column.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> data = pa.table({"a": [2, 1, 3], "b": ["a", "b", "c"]})
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data)
|
||||
>>> new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]})
|
||||
>>> # Perform a "upsert" operation
|
||||
>>> table.merge_insert("a") \\
|
||||
... .when_matched_update_all() \\
|
||||
... .when_not_matched_insert_all() \\
|
||||
... .execute(new_data)
|
||||
>>> # The order of new rows is non-deterministic since we use
|
||||
>>> # a hash-join as part of this operation and so we sort here
|
||||
>>> table.to_arrow().sort_by("a").to_pandas()
|
||||
a b
|
||||
0 1 b
|
||||
1 2 x
|
||||
2 3 y
|
||||
3 4 z
|
||||
"""
|
||||
on = [on] if isinstance(on, str) else list(on.iter())
|
||||
|
||||
return LanceMergeInsertBuilder(self, on)
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: str = "auto",
|
||||
) -> LanceQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
@@ -358,7 +434,7 @@ class Table(ABC):
|
||||
... ]
|
||||
>>> table = db.create_table("my_table", data)
|
||||
>>> query = [0.4, 1.4, 2.4]
|
||||
>>> (table.search(query, vector_column_name="vector")
|
||||
>>> (table.search(query)
|
||||
... .where("original_width > 1000", prefilter=True)
|
||||
... .select(["caption", "original_width"])
|
||||
... .limit(2)
|
||||
@@ -377,12 +453,19 @@ class Table(ABC):
|
||||
|
||||
- If None then the select/where/limit clauses are applied to filter
|
||||
the table
|
||||
vector_column_name: str
|
||||
vector_column_name: str, optional
|
||||
The name of the vector column to search.
|
||||
*default "vector"*
|
||||
|
||||
The vector column needs to be a pyarrow fixed size list type
|
||||
|
||||
- If not specified then the vector column is inferred from
|
||||
the table schema
|
||||
|
||||
- If the table has multiple vector columns then the *vector_column_name*
|
||||
needs to be specified. Otherwise, an error is raised.
|
||||
query_type: str
|
||||
*default "auto"*.
|
||||
Acceptable types are: "vector", "fts", or "auto"
|
||||
Acceptable types are: "vector", "fts", "hybrid", or "auto"
|
||||
|
||||
- If "auto" then the query type is inferred from the query;
|
||||
|
||||
@@ -414,6 +497,16 @@ class Table(ABC):
|
||||
def _execute_query(self, query: Query) -> pa.Table:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _do_merge(
|
||||
self,
|
||||
merge: LanceMergeInsertBuilder,
|
||||
new_data: DATA,
|
||||
on_bad_vectors: str,
|
||||
fill_value: float,
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, where: str):
|
||||
"""Delete rows from the table.
|
||||
@@ -521,24 +614,192 @@ class Table(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def cleanup_old_versions(
|
||||
self,
|
||||
older_than: Optional[timedelta] = None,
|
||||
*,
|
||||
delete_unverified: bool = False,
|
||||
) -> CleanupStats:
|
||||
"""
|
||||
Clean up old versions of the table, freeing disk space.
|
||||
|
||||
Note: This function is not available in LanceDb Cloud (since LanceDb
|
||||
Cloud manages cleanup for you automatically)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
older_than: timedelta, default None
|
||||
The minimum age of the version to delete. If None, then this defaults
|
||||
to two weeks.
|
||||
delete_unverified: bool, default False
|
||||
Because they may be part of an in-progress transaction, files newer
|
||||
than 7 days old are not deleted by default. If you are sure that
|
||||
there are no in-progress transactions, then you can set this to True
|
||||
to delete all files older than `older_than`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
CleanupStats
|
||||
The stats of the cleanup operation, including how many bytes were
|
||||
freed.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def compact_files(self, *args, **kwargs):
|
||||
"""
|
||||
Run the compaction process on the table.
|
||||
|
||||
Note: This function is not available in LanceDb Cloud (since LanceDb
|
||||
Cloud manages compaction for you automatically)
|
||||
|
||||
This can be run after making several small appends to optimize the table
|
||||
for faster reads.
|
||||
|
||||
Arguments are passed onto :meth:`lance.dataset.DatasetOptimizer.compact_files`.
|
||||
For most cases, the default should be fine.
|
||||
"""
|
||||
|
||||
|
||||
class _LanceDatasetRef(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def dataset(self) -> LanceDataset:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dataset_mut(self) -> LanceDataset:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class _LanceLatestDatasetRef(_LanceDatasetRef):
|
||||
"""Reference to the latest version of a LanceDataset."""
|
||||
|
||||
uri: str
|
||||
read_consistency_interval: Optional[timedelta] = None
|
||||
last_consistency_check: Optional[float] = None
|
||||
_dataset: Optional[LanceDataset] = None
|
||||
|
||||
@property
|
||||
def dataset(self) -> LanceDataset:
|
||||
if not self._dataset:
|
||||
self._dataset = lance.dataset(self.uri)
|
||||
self.last_consistency_check = time.monotonic()
|
||||
elif self.read_consistency_interval is not None:
|
||||
now = time.monotonic()
|
||||
diff = timedelta(seconds=now - self.last_consistency_check)
|
||||
if (
|
||||
self.last_consistency_check is None
|
||||
or diff > self.read_consistency_interval
|
||||
):
|
||||
self._dataset = self._dataset.checkout_version(
|
||||
self._dataset.latest_version
|
||||
)
|
||||
self.last_consistency_check = time.monotonic()
|
||||
return self._dataset
|
||||
|
||||
@dataset.setter
|
||||
def dataset(self, value: LanceDataset):
|
||||
self._dataset = value
|
||||
self.last_consistency_check = time.monotonic()
|
||||
|
||||
@property
|
||||
def dataset_mut(self) -> LanceDataset:
|
||||
return self.dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class _LanceTimeTravelRef(_LanceDatasetRef):
|
||||
uri: str
|
||||
version: int
|
||||
_dataset: Optional[LanceDataset] = None
|
||||
|
||||
@property
|
||||
def dataset(self) -> LanceDataset:
|
||||
if not self._dataset:
|
||||
self._dataset = lance.dataset(self.uri, version=self.version)
|
||||
return self._dataset
|
||||
|
||||
@dataset.setter
|
||||
def dataset(self, value: LanceDataset):
|
||||
self._dataset = value
|
||||
self.version = value.version
|
||||
|
||||
@property
|
||||
def dataset_mut(self) -> LanceDataset:
|
||||
raise ValueError(
|
||||
"Cannot mutate table reference fixed at version "
|
||||
f"{self.version}. Call checkout_latest() to get a mutable "
|
||||
"table reference."
|
||||
)
|
||||
|
||||
|
||||
class LanceTable(Table):
|
||||
"""
|
||||
A table in a LanceDB database.
|
||||
|
||||
This can be opened in two modes: standard and time-travel.
|
||||
|
||||
Standard mode is the default. In this mode, the table is mutable and tracks
|
||||
the latest version of the table. The level of read consistency is controlled
|
||||
by the `read_consistency_interval` parameter on the connection.
|
||||
|
||||
Time-travel mode is activated by specifying a version number. In this mode,
|
||||
the table is immutable and fixed to a specific version. This is useful for
|
||||
querying historical versions of the table.
|
||||
"""
|
||||
|
||||
def __init__(self, connection: "LanceDBConnection", name: str, version: int = None):
|
||||
def __init__(
|
||||
self,
|
||||
connection: "LanceDBConnection",
|
||||
name: str,
|
||||
version: Optional[int] = None,
|
||||
):
|
||||
self._conn = connection
|
||||
self.name = name
|
||||
self._version = version
|
||||
|
||||
def _reset_dataset(self, version=None):
|
||||
try:
|
||||
if "_dataset" in self.__dict__:
|
||||
del self.__dict__["_dataset"]
|
||||
self._version = version
|
||||
except AttributeError:
|
||||
pass
|
||||
if version is not None:
|
||||
self._ref = _LanceTimeTravelRef(
|
||||
uri=self._dataset_uri,
|
||||
version=version,
|
||||
)
|
||||
else:
|
||||
self._ref = _LanceLatestDatasetRef(
|
||||
uri=self._dataset_uri,
|
||||
read_consistency_interval=connection.read_consistency_interval,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def open(cls, db, name, **kwargs):
|
||||
tbl = cls(db, name, **kwargs)
|
||||
fs, path = fs_from_uri(tbl._dataset_uri)
|
||||
file_info = fs.get_file_info(path)
|
||||
if file_info.type != pa.fs.FileType.Directory:
|
||||
raise FileNotFoundError(
|
||||
f"Table {name} does not exist."
|
||||
f"Please first call db.create_table({name}, data)"
|
||||
)
|
||||
register_event("open_table")
|
||||
|
||||
return tbl
|
||||
|
||||
@property
|
||||
def _dataset_uri(self) -> str:
|
||||
return join_uri(self._conn.uri, f"{self.name}.lance")
|
||||
|
||||
@property
|
||||
def _dataset(self) -> LanceDataset:
|
||||
return self._ref.dataset
|
||||
|
||||
@property
|
||||
def _dataset_mut(self) -> LanceDataset:
|
||||
return self._ref.dataset_mut
|
||||
|
||||
def to_lance(self) -> LanceDataset:
|
||||
"""Return the LanceDataset backing this table."""
|
||||
return self._dataset
|
||||
|
||||
@property
|
||||
def schema(self) -> pa.Schema:
|
||||
@@ -566,6 +827,9 @@ class LanceTable(Table):
|
||||
keep writing to the dataset starting from an old version, then use
|
||||
the `restore` function.
|
||||
|
||||
Calling this method will set the table into time-travel mode. If you
|
||||
wish to return to standard mode, call `checkout_latest`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
version : int
|
||||
@@ -590,15 +854,13 @@ class LanceTable(Table):
|
||||
vector type
|
||||
0 [1.1, 0.9] vector
|
||||
"""
|
||||
max_ver = max([v["version"] for v in self._dataset.versions()])
|
||||
max_ver = self._dataset.latest_version
|
||||
if version < 1 or version > max_ver:
|
||||
raise ValueError(f"Invalid version {version}")
|
||||
self._reset_dataset(version=version)
|
||||
|
||||
try:
|
||||
# Accessing the property updates the cached value
|
||||
_ = self._dataset
|
||||
except Exception as e:
|
||||
ds = self._dataset.checkout_version(version)
|
||||
except IOError as e:
|
||||
if "not found" in str(e):
|
||||
raise ValueError(
|
||||
f"Version {version} no longer exists. Was it cleaned up?"
|
||||
@@ -606,6 +868,27 @@ class LanceTable(Table):
|
||||
else:
|
||||
raise e
|
||||
|
||||
self._ref = _LanceTimeTravelRef(
|
||||
uri=self._dataset_uri,
|
||||
version=version,
|
||||
)
|
||||
# We've already loaded the version so we can populate it directly.
|
||||
self._ref.dataset = ds
|
||||
|
||||
def checkout_latest(self):
|
||||
"""Checkout the latest version of the table. This is an in-place operation.
|
||||
|
||||
The table will be set back into standard mode, and will track the latest
|
||||
version of the table.
|
||||
"""
|
||||
self.checkout(self._dataset.latest_version)
|
||||
ds = self._ref.dataset
|
||||
self._ref = _LanceLatestDatasetRef(
|
||||
uri=self._dataset_uri,
|
||||
read_consistency_interval=self._conn.read_consistency_interval,
|
||||
)
|
||||
self._ref.dataset = ds
|
||||
|
||||
def restore(self, version: int = None):
|
||||
"""Restore a version of the table. This is an in-place operation.
|
||||
|
||||
@@ -640,7 +923,7 @@ class LanceTable(Table):
|
||||
>>> len(table.list_versions())
|
||||
4
|
||||
"""
|
||||
max_ver = max([v["version"] for v in self._dataset.versions()])
|
||||
max_ver = self._dataset.latest_version
|
||||
if version is None:
|
||||
version = self.version
|
||||
elif version < 1 or version > max_ver:
|
||||
@@ -648,29 +931,30 @@ class LanceTable(Table):
|
||||
else:
|
||||
self.checkout(version)
|
||||
|
||||
if version == max_ver:
|
||||
# no-op if restoring the latest version
|
||||
return
|
||||
ds = self._dataset
|
||||
|
||||
self._dataset.restore()
|
||||
self._reset_dataset()
|
||||
# no-op if restoring the latest version
|
||||
if version != max_ver:
|
||||
ds.restore()
|
||||
|
||||
self._ref = _LanceLatestDatasetRef(
|
||||
uri=self._dataset_uri,
|
||||
read_consistency_interval=self._conn.read_consistency_interval,
|
||||
)
|
||||
self._ref.dataset = ds
|
||||
|
||||
def count_rows(self, filter: Optional[str] = None) -> int:
|
||||
"""
|
||||
Count the number of rows in the table.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filter: str, optional
|
||||
A SQL where clause to filter the rows to count.
|
||||
"""
|
||||
return self._dataset.count_rows(filter)
|
||||
|
||||
def __len__(self):
|
||||
return self.count_rows()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"LanceTable({self.name})"
|
||||
val = f'{self.__class__.__name__}(connection={self._conn!r}, name="{self.name}"'
|
||||
if isinstance(self._ref, _LanceTimeTravelRef):
|
||||
val += f", version={self._ref.version}"
|
||||
val += ")"
|
||||
return val
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__repr__()
|
||||
@@ -720,10 +1004,6 @@ class LanceTable(Table):
|
||||
self.to_lance(), allow_pyarrow_filter=False, batch_size=batch_size
|
||||
)
|
||||
|
||||
@property
|
||||
def _dataset_uri(self) -> str:
|
||||
return join_uri(self._conn.uri, f"{self.name}.lance")
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
metric="L2",
|
||||
@@ -735,7 +1015,7 @@ class LanceTable(Table):
|
||||
index_cache_size: Optional[int] = None,
|
||||
):
|
||||
"""Create an index on the table."""
|
||||
self._dataset.create_index(
|
||||
self._dataset_mut.create_index(
|
||||
column=vector_column_name,
|
||||
index_type="IVF_PQ",
|
||||
metric=metric,
|
||||
@@ -745,11 +1025,12 @@ class LanceTable(Table):
|
||||
accelerator=accelerator,
|
||||
index_cache_size=index_cache_size,
|
||||
)
|
||||
self._reset_dataset()
|
||||
register_event("create_index")
|
||||
|
||||
def create_scalar_index(self, column: str, *, replace: bool = True):
|
||||
self._dataset.create_scalar_index(column, index_type="BTREE", replace=replace)
|
||||
self._dataset_mut.create_scalar_index(
|
||||
column, index_type="BTREE", replace=replace
|
||||
)
|
||||
|
||||
def create_fts_index(
|
||||
self,
|
||||
@@ -792,14 +1073,6 @@ class LanceTable(Table):
|
||||
def _get_fts_index_path(self):
|
||||
return join_uri(self._dataset_uri, "_indices", "tantivy")
|
||||
|
||||
@cached_property
|
||||
def _dataset(self) -> LanceDataset:
|
||||
return lance.dataset(self._dataset_uri, version=self._version)
|
||||
|
||||
def to_lance(self) -> LanceDataset:
|
||||
"""Return the LanceDataset backing this table."""
|
||||
return self._dataset
|
||||
|
||||
def add(
|
||||
self,
|
||||
data: DATA,
|
||||
@@ -838,8 +1111,11 @@ class LanceTable(Table):
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
lance.write_dataset(data, self._dataset_uri, schema=self.schema, mode=mode)
|
||||
self._reset_dataset()
|
||||
# Access the dataset_mut property to ensure that the dataset is mutable.
|
||||
self._ref.dataset_mut
|
||||
self._ref.dataset = lance.write_dataset(
|
||||
data, self._dataset_uri, schema=self.schema, mode=mode
|
||||
)
|
||||
register_event("add")
|
||||
|
||||
def merge(
|
||||
@@ -900,10 +1176,9 @@ class LanceTable(Table):
|
||||
other_table = other_table.to_lance()
|
||||
if isinstance(other_table, LanceDataset):
|
||||
other_table = other_table.to_table()
|
||||
self._dataset.merge(
|
||||
self._ref.dataset = self._dataset_mut.merge(
|
||||
other_table, left_on=left_on, right_on=right_on, schema=schema
|
||||
)
|
||||
self._reset_dataset()
|
||||
register_event("merge")
|
||||
|
||||
@cached_property
|
||||
@@ -924,7 +1199,7 @@ class LanceTable(Table):
|
||||
def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: str = "auto",
|
||||
) -> LanceQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
@@ -942,7 +1217,7 @@ class LanceTable(Table):
|
||||
... ]
|
||||
>>> table = db.create_table("my_table", data)
|
||||
>>> query = [0.4, 1.4, 2.4]
|
||||
>>> (table.search(query, vector_column_name="vector")
|
||||
>>> (table.search(query)
|
||||
... .where("original_width > 1000", prefilter=True)
|
||||
... .select(["caption", "original_width"])
|
||||
... .limit(2)
|
||||
@@ -961,8 +1236,17 @@ class LanceTable(Table):
|
||||
|
||||
- If None then the select/[where][sql]/limit clauses are applied
|
||||
to filter the table
|
||||
vector_column_name: str, default "vector"
|
||||
vector_column_name: str, optional
|
||||
The name of the vector column to search.
|
||||
|
||||
The vector column needs to be a pyarrow fixed size list type
|
||||
*default "vector"*
|
||||
|
||||
- If not specified then the vector column is inferred from
|
||||
the table schema
|
||||
|
||||
- If the table has multiple vector columns then the *vector_column_name*
|
||||
needs to be specified. Otherwise, an error is raised.
|
||||
query_type: str, default "auto"
|
||||
"vector", "fts", or "auto"
|
||||
If "auto" then the query type is inferred from the query;
|
||||
@@ -980,6 +1264,8 @@ class LanceTable(Table):
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
if vector_column_name is None and query is not None:
|
||||
vector_column_name = inf_vector_column_query(self.schema)
|
||||
register_event("search_table")
|
||||
return LanceQueryBuilder.create(
|
||||
self, query, query_type, vector_column_name=vector_column_name
|
||||
@@ -1106,22 +1392,8 @@ class LanceTable(Table):
|
||||
register_event("create_table")
|
||||
return new_table
|
||||
|
||||
@classmethod
|
||||
def open(cls, db, name):
|
||||
tbl = cls(db, name)
|
||||
fs, path = fs_from_uri(tbl._dataset_uri)
|
||||
file_info = fs.get_file_info(path)
|
||||
if file_info.type != pa.fs.FileType.Directory:
|
||||
raise FileNotFoundError(
|
||||
f"Table {name} does not exist."
|
||||
f"Please first call db.create_table({name}, data)"
|
||||
)
|
||||
register_event("open_table")
|
||||
|
||||
return tbl
|
||||
|
||||
def delete(self, where: str):
|
||||
self._dataset.delete(where)
|
||||
self._dataset_mut.delete(where)
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -1175,12 +1447,12 @@ class LanceTable(Table):
|
||||
if values is not None:
|
||||
values_sql = {k: value_to_sql(v) for k, v in values.items()}
|
||||
|
||||
self.to_lance().update(values_sql, where)
|
||||
self._reset_dataset()
|
||||
self._dataset_mut.update(values_sql, where)
|
||||
register_event("update")
|
||||
|
||||
def _execute_query(self, query: Query) -> pa.Table:
|
||||
ds = self.to_lance()
|
||||
|
||||
return ds.to_table(
|
||||
columns=query.columns,
|
||||
filter=query.filter,
|
||||
@@ -1196,6 +1468,31 @@ class LanceTable(Table):
|
||||
with_row_id=query.with_row_id,
|
||||
)
|
||||
|
||||
def _do_merge(
|
||||
self,
|
||||
merge: LanceMergeInsertBuilder,
|
||||
new_data: DATA,
|
||||
on_bad_vectors: str,
|
||||
fill_value: float,
|
||||
):
|
||||
new_data = _sanitize_data(
|
||||
new_data,
|
||||
self.schema,
|
||||
metadata=self.schema.metadata,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
ds = self.to_lance()
|
||||
builder = ds.merge_insert(merge._on)
|
||||
if merge._when_matched_update_all:
|
||||
builder.when_matched_update_all(merge._when_matched_update_all_condition)
|
||||
if merge._when_not_matched_insert_all:
|
||||
builder.when_not_matched_insert_all()
|
||||
if merge._when_not_matched_by_source_delete:
|
||||
cond = merge._when_not_matched_by_source_condition
|
||||
builder.when_not_matched_by_source_delete(cond)
|
||||
builder.execute(new_data)
|
||||
|
||||
def cleanup_old_versions(
|
||||
self,
|
||||
older_than: Optional[timedelta] = None,
|
||||
@@ -1233,8 +1530,9 @@ class LanceTable(Table):
|
||||
This can be run after making several small appends to optimize the table
|
||||
for faster reads.
|
||||
|
||||
Arguments are passed onto :meth:`lance.dataset.DatasetOptimizer.compact_files`.
|
||||
For most cases, the default should be fine.
|
||||
Arguments are passed onto `lance.dataset.DatasetOptimizer.compact_files`.
|
||||
(see Lance documentation for more details) For most cases, the default
|
||||
should be fine.
|
||||
"""
|
||||
return self.to_lance().optimize.compact_files(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -11,15 +11,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import os
|
||||
import pathlib
|
||||
import warnings
|
||||
from datetime import date, datetime
|
||||
from functools import singledispatch
|
||||
from typing import Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.fs as pa_fs
|
||||
|
||||
|
||||
@@ -115,7 +118,7 @@ def join_uri(base: Union[str, pathlib.Path], *parts: str) -> str:
|
||||
return "/".join([p.rstrip("/") for p in [base, *parts]])
|
||||
|
||||
|
||||
def safe_import(module: str, mitigation=None):
|
||||
def attempt_import_or_raise(module: str, mitigation=None):
|
||||
"""
|
||||
Import the specified module. If the module is not installed,
|
||||
raise an ImportError with a helpful message.
|
||||
@@ -134,6 +137,62 @@ def safe_import(module: str, mitigation=None):
|
||||
raise ImportError(f"Please install {mitigation or module}")
|
||||
|
||||
|
||||
def safe_import_pandas():
|
||||
try:
|
||||
import pandas as pd
|
||||
|
||||
return pd
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def safe_import_polars():
|
||||
try:
|
||||
import polars as pl
|
||||
|
||||
return pl
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def inf_vector_column_query(schema: pa.Schema) -> str:
|
||||
"""
|
||||
Get the vector column name
|
||||
|
||||
Parameters
|
||||
----------
|
||||
schema : pa.Schema
|
||||
The schema of the vector column.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str: the vector column name.
|
||||
"""
|
||||
vector_col_name = ""
|
||||
vector_col_count = 0
|
||||
for field_name in schema.names:
|
||||
field = schema.field(field_name)
|
||||
if pa.types.is_fixed_size_list(field.type) and pa.types.is_floating(
|
||||
field.type.value_type
|
||||
):
|
||||
vector_col_count += 1
|
||||
if vector_col_count > 1:
|
||||
raise ValueError(
|
||||
"Schema has more than one vector column. "
|
||||
"Please specify the vector column name "
|
||||
"for vector search"
|
||||
)
|
||||
break
|
||||
elif vector_col_count == 1:
|
||||
vector_col_name = field_name
|
||||
if vector_col_count == 0:
|
||||
raise ValueError(
|
||||
"There is no vector column in the data. "
|
||||
"Please specify the vector column name for vector search"
|
||||
)
|
||||
return vector_col_name
|
||||
|
||||
|
||||
@singledispatch
|
||||
def value_to_sql(value):
|
||||
raise NotImplementedError("SQL conversion is not implemented for this type")
|
||||
@@ -182,3 +241,25 @@ def _(value: list):
|
||||
@value_to_sql.register(np.ndarray)
|
||||
def _(value: np.ndarray):
|
||||
return value_to_sql(value.tolist())
|
||||
|
||||
|
||||
def deprecated(func):
|
||||
"""This is a decorator which can be used to mark functions
|
||||
as deprecated. It will result in a warning being emitted
|
||||
when the function is used."""
|
||||
|
||||
@functools.wraps(func)
|
||||
def new_func(*args, **kwargs):
|
||||
warnings.simplefilter("always", DeprecationWarning) # turn off filter
|
||||
warnings.warn(
|
||||
(
|
||||
f"Function {func.__name__} is deprecated and will be "
|
||||
"removed in a future version"
|
||||
),
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
warnings.simplefilter("default", DeprecationWarning) # reset filter
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return new_func
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
[project]
|
||||
name = "lancedb"
|
||||
version = "0.5.1"
|
||||
version = "0.5.6"
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.9.10",
|
||||
"pylance==0.9.16",
|
||||
"ratelimiter~=1.0",
|
||||
"retry>=0.9.2",
|
||||
"tqdm>=4.27.0",
|
||||
@@ -48,7 +48,7 @@ classifiers = [
|
||||
repository = "https://github.com/lancedb/lancedb"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars"]
|
||||
tests = ["aiohttp", "pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "duckdb", "pytz", "polars>=0.19"]
|
||||
dev = ["ruff", "pre-commit"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
|
||||
@@ -88,6 +88,7 @@ def test_embedding_function(tmp_path):
|
||||
assert np.allclose(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_embedding_function_rate_limit(tmp_path):
|
||||
def _get_schema_from_model(model):
|
||||
class Schema(LanceModel):
|
||||
|
||||
@@ -23,11 +23,6 @@ import lancedb
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
try:
|
||||
if importlib.util.find_spec("mlx.core") is not None:
|
||||
_mlx = True
|
||||
except ImportError:
|
||||
_mlx = None
|
||||
# These are integration tests for embedding functions.
|
||||
# They are slow because they require downloading models
|
||||
# or connection to external api
|
||||
@@ -74,10 +69,14 @@ def test_basic_text_embeddings(alias, tmp_path):
|
||||
)
|
||||
|
||||
query = "greetings"
|
||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||
actual = (
|
||||
table.search(query, vector_column_name="vector").limit(1).to_pydantic(Words)[0]
|
||||
)
|
||||
|
||||
vec = func.compute_query_embeddings(query)[0]
|
||||
expected = table.search(vec).limit(1).to_pydantic(Words)[0]
|
||||
expected = (
|
||||
table.search(vec, vector_column_name="vector").limit(1).to_pydantic(Words)[0]
|
||||
)
|
||||
assert actual.text == expected.text
|
||||
assert actual.text == "hello world"
|
||||
assert not np.allclose(actual.vector, actual.vector2)
|
||||
@@ -121,7 +120,11 @@ def test_openclip(tmp_path):
|
||||
)
|
||||
|
||||
# text search
|
||||
actual = table.search("man's best friend").limit(1).to_pydantic(Images)[0]
|
||||
actual = (
|
||||
table.search("man's best friend", vector_column_name="vector")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == "dog"
|
||||
frombytes = (
|
||||
table.search("man's best friend", vector_column_name="vec_from_bytes")
|
||||
@@ -135,7 +138,11 @@ def test_openclip(tmp_path):
|
||||
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
|
||||
image_bytes = requests.get(query_image_uri).content
|
||||
query_image = Image.open(io.BytesIO(image_bytes))
|
||||
actual = table.search(query_image).limit(1).to_pydantic(Images)[0]
|
||||
actual = (
|
||||
table.search(query_image, vector_column_name="vector")
|
||||
.limit(1)
|
||||
.to_pydantic(Images)[0]
|
||||
)
|
||||
assert actual.label == "dog"
|
||||
other = (
|
||||
table.search(query_image, vector_column_name="vec_from_bytes")
|
||||
@@ -210,6 +217,13 @@ def test_gemini_embedding(tmp_path):
|
||||
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||
|
||||
|
||||
try:
|
||||
if importlib.util.find_spec("mlx.core") is not None:
|
||||
_mlx = True
|
||||
except ImportError:
|
||||
_mlx = None
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
_mlx is None,
|
||||
reason="mlx tests only required for apple users.",
|
||||
@@ -266,3 +280,49 @@ def test_bedrock_embedding(tmp_path):
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
||||
)
|
||||
def test_openai_embedding(tmp_path):
|
||||
def _get_table(model):
|
||||
class TextModel(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
return tbl
|
||||
|
||||
model = get_registry().get("openai").create(max_retries=0)
|
||||
tbl = _get_table(model)
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||
|
||||
model = (
|
||||
get_registry()
|
||||
.get("openai")
|
||||
.create(max_retries=0, name="text-embedding-3-large")
|
||||
)
|
||||
tbl = _get_table(model)
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||
|
||||
model = (
|
||||
get_registry()
|
||||
.get("openai")
|
||||
.create(max_retries=0, name="text-embedding-3-large", dim=1024)
|
||||
)
|
||||
tbl = _get_table(model)
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||
|
||||
@@ -29,10 +29,14 @@ class FakeLanceDBClient:
|
||||
def post(self, path: str):
|
||||
pass
|
||||
|
||||
def mount_retry_adapter_for_table(self, table_name: str):
|
||||
pass
|
||||
|
||||
|
||||
def test_remote_db():
|
||||
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
|
||||
setattr(conn, "_client", FakeLanceDBClient())
|
||||
|
||||
table = conn["test"]
|
||||
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
|
||||
table.search([1.0, 2.0]).to_pandas()
|
||||
|
||||
@@ -7,7 +7,12 @@ import lancedb
|
||||
from lancedb.conftest import MockTextEmbeddingFunction # noqa
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.rerankers import CohereReranker, CrossEncoderReranker
|
||||
from lancedb.rerankers import (
|
||||
CohereReranker,
|
||||
ColbertReranker,
|
||||
CrossEncoderReranker,
|
||||
OpenaiReranker,
|
||||
)
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
|
||||
@@ -75,7 +80,6 @@ def get_test_table(tmp_path):
|
||||
return table, MyTable
|
||||
|
||||
|
||||
## These tests are pretty loose, we should also check for correctness
|
||||
def test_linear_combination(tmp_path):
|
||||
table, schema = get_test_table(tmp_path)
|
||||
# The default reranker
|
||||
@@ -95,14 +99,19 @@ def test_linear_combination(tmp_path):
|
||||
|
||||
assert result1 == result3 # 2 & 3 should be the same as they use score as score
|
||||
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.limit(50)
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(normalize="score")
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _score column of the results returned by the reranker "
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
@@ -122,19 +131,24 @@ def test_cohere_reranker(tmp_path):
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="rank", reranker=CohereReranker())
|
||||
.rerank(reranker=CohereReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.limit(50)
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(reranker=CohereReranker())
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _score column of the results returned by the reranker "
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
@@ -150,19 +164,96 @@ def test_cross_encoder_reranker(tmp_path):
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="rank", reranker=CrossEncoderReranker())
|
||||
.rerank(reranker=CrossEncoderReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
# test explicit hybrid query
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.limit(50)
|
||||
table.search((query_vector, query), query_type="hybrid")
|
||||
.limit(30)
|
||||
.rerank(reranker=CrossEncoderReranker())
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _score column of the results returned by the reranker "
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
|
||||
|
||||
def test_colbert_reranker(tmp_path):
|
||||
pytest.importorskip("transformers")
|
||||
table, schema = get_test_table(tmp_path)
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=ColbertReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(reranker=ColbertReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
# test explicit hybrid query
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(reranker=ColbertReranker())
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
||||
)
|
||||
def test_openai_reranker(tmp_path):
|
||||
pytest.importorskip("openai")
|
||||
table, schema = get_test_table(tmp_path)
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=OpenaiReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(reranker=OpenaiReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
# test explicit hybrid query
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(reranker=OpenaiReranker())
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from copy import copy
|
||||
from datetime import date, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
from typing import List
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
@@ -25,6 +27,7 @@ import pyarrow as pa
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
import lancedb
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
@@ -35,6 +38,7 @@ from lancedb.table import LanceTable
|
||||
class MockDB:
|
||||
def __init__(self, uri: Path):
|
||||
self.uri = uri
|
||||
self.read_consistency_interval = None
|
||||
|
||||
@functools.cached_property
|
||||
def is_managed_remote(self) -> bool:
|
||||
@@ -267,9 +271,8 @@ def test_versioning(db):
|
||||
|
||||
|
||||
def test_create_index_method():
|
||||
with patch.object(LanceTable, "_reset_dataset", return_value=None):
|
||||
with patch.object(
|
||||
LanceTable, "_dataset", new_callable=PropertyMock
|
||||
LanceTable, "_dataset_mut", new_callable=PropertyMock
|
||||
) as mock_dataset:
|
||||
# Setup mock responses
|
||||
mock_dataset.return_value.create_index.return_value = None
|
||||
@@ -493,6 +496,69 @@ def test_update_types(db):
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_merge_insert(db):
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
data=pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]}),
|
||||
)
|
||||
assert len(table) == 3
|
||||
version = table.version
|
||||
|
||||
new_data = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]})
|
||||
|
||||
# upsert
|
||||
table.merge_insert(
|
||||
"a"
|
||||
).when_matched_update_all().when_not_matched_insert_all().execute(new_data)
|
||||
|
||||
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "x", "y", "z"]})
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
|
||||
table.restore(version)
|
||||
|
||||
# conditional update
|
||||
table.merge_insert("a").when_matched_update_all(where="target.b = 'b'").execute(
|
||||
new_data
|
||||
)
|
||||
expected = pa.table({"a": [1, 2, 3], "b": ["a", "x", "c"]})
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
|
||||
table.restore(version)
|
||||
|
||||
# insert-if-not-exists
|
||||
table.merge_insert("a").when_not_matched_insert_all().execute(new_data)
|
||||
|
||||
expected = pa.table({"a": [1, 2, 3, 4], "b": ["a", "b", "c", "z"]})
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
|
||||
table.restore(version)
|
||||
|
||||
new_data = pa.table({"a": [2, 4], "b": ["x", "z"]})
|
||||
|
||||
# replace-range
|
||||
table.merge_insert(
|
||||
"a"
|
||||
).when_matched_update_all().when_not_matched_insert_all().when_not_matched_by_source_delete(
|
||||
"a > 2"
|
||||
).execute(new_data)
|
||||
|
||||
expected = pa.table({"a": [1, 2, 4], "b": ["a", "x", "z"]})
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
|
||||
table.restore(version)
|
||||
|
||||
# replace-range no condition
|
||||
table.merge_insert(
|
||||
"a"
|
||||
).when_matched_update_all().when_not_matched_insert_all().when_not_matched_by_source_delete().execute(
|
||||
new_data
|
||||
)
|
||||
|
||||
expected = pa.table({"a": [2, 4], "b": ["x", "z"]})
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
|
||||
|
||||
def test_create_with_embedding_function(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
@@ -644,6 +710,59 @@ def test_empty_query(db):
|
||||
assert len(df) == 100
|
||||
|
||||
|
||||
def test_search_with_schema_inf_single_vector(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
vector_col: Vector(10)
|
||||
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
)
|
||||
|
||||
v1 = np.random.randn(10)
|
||||
v2 = np.random.randn(10)
|
||||
data = [
|
||||
{"vector_col": v1, "text": "foo"},
|
||||
{"vector_col": v2, "text": "bar"},
|
||||
]
|
||||
df = pd.DataFrame(data)
|
||||
table.add(df)
|
||||
|
||||
q = np.random.randn(10)
|
||||
result1 = table.search(q, vector_column_name="vector_col").limit(1).to_pandas()
|
||||
result2 = table.search(q).limit(1).to_pandas()
|
||||
|
||||
assert result1["text"].iloc[0] == result2["text"].iloc[0]
|
||||
|
||||
|
||||
def test_search_with_schema_inf_multiple_vector(db):
|
||||
class MyTable(LanceModel):
|
||||
text: str
|
||||
vector1: Vector(10)
|
||||
vector2: Vector(10)
|
||||
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"my_table",
|
||||
schema=MyTable,
|
||||
)
|
||||
|
||||
v1 = np.random.randn(10)
|
||||
v2 = np.random.randn(10)
|
||||
data = [
|
||||
{"vector1": v1, "vector2": v2, "text": "foo"},
|
||||
{"vector1": v2, "vector2": v1, "text": "bar"},
|
||||
]
|
||||
df = pd.DataFrame(data)
|
||||
table.add(df)
|
||||
|
||||
q = np.random.randn(10)
|
||||
with pytest.raises(ValueError):
|
||||
table.search(q).limit(1).to_pandas()
|
||||
|
||||
|
||||
def test_compact_cleanup(db):
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
@@ -684,10 +803,8 @@ def test_count_rows(db):
|
||||
assert table.count_rows(filter="text='bar'") == 1
|
||||
|
||||
|
||||
def test_hybrid_search(db):
|
||||
# hardcoding temporarily.. this test is failing with tmp_path mockdb.
|
||||
# Probably not being parsed right by the fts
|
||||
db = MockDB("~/lancedb_")
|
||||
def test_hybrid_search(db, tmp_path):
|
||||
db = MockDB(str(tmp_path))
|
||||
# Create a LanceDB table schema with a vector and a text column
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
|
||||
@@ -736,3 +853,48 @@ def test_hybrid_search(db):
|
||||
"Our father who art in heaven", query_type="hybrid"
|
||||
).to_pydantic(MyTable)
|
||||
assert result1 == result3
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"consistency_interval", [None, timedelta(seconds=0), timedelta(seconds=0.1)]
|
||||
)
|
||||
def test_consistency(tmp_path, consistency_interval):
|
||||
db = lancedb.connect(tmp_path)
|
||||
table = LanceTable.create(db, "my_table", data=[{"id": 0}])
|
||||
|
||||
db2 = lancedb.connect(tmp_path, read_consistency_interval=consistency_interval)
|
||||
table2 = db2.open_table("my_table")
|
||||
assert table2.version == table.version
|
||||
|
||||
table.add([{"id": 1}])
|
||||
|
||||
if consistency_interval is None:
|
||||
assert table2.version == table.version - 1
|
||||
table2.checkout_latest()
|
||||
assert table2.version == table.version
|
||||
elif consistency_interval == timedelta(seconds=0):
|
||||
assert table2.version == table.version
|
||||
else:
|
||||
# (consistency_interval == timedelta(seconds=0.1)
|
||||
assert table2.version == table.version - 1
|
||||
sleep(0.1)
|
||||
assert table2.version == table.version
|
||||
|
||||
|
||||
def test_restore_consistency(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
table = LanceTable.create(db, "my_table", data=[{"id": 0}])
|
||||
|
||||
db2 = lancedb.connect(tmp_path, read_consistency_interval=timedelta(seconds=0))
|
||||
table2 = db2.open_table("my_table")
|
||||
assert table2.version == table.version
|
||||
|
||||
# If we call checkout, it should lose consistency
|
||||
table_fixed = copy(table2)
|
||||
table_fixed.checkout(table.version)
|
||||
# But if we call checkout_latest, it should be consistent again
|
||||
table_ref_latest = copy(table_fixed)
|
||||
table_ref_latest.checkout_latest()
|
||||
table.add([{"id": 2}])
|
||||
assert table_fixed.version == table.version - 1
|
||||
assert table_ref_latest.version == table.version
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
[package]
|
||||
name = "vectordb-node"
|
||||
version = "0.4.7"
|
||||
version = "0.4.10"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
license = "Apache-2.0"
|
||||
edition = "2018"
|
||||
license.workspace = true
|
||||
edition.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
exclude = ["index.node"]
|
||||
|
||||
[lib]
|
||||
@@ -28,3 +31,6 @@ object_store = { workspace = true, features = ["aws"] }
|
||||
snafu = { workspace = true }
|
||||
async-trait = "0"
|
||||
env_logger = "0"
|
||||
|
||||
# Prevent dynamic linking of lzma, which comes from datafusion
|
||||
lzma-sys = { version = "*", features = ["static"] }
|
||||
|
||||
@@ -22,7 +22,7 @@ use arrow_schema::SchemaRef;
|
||||
|
||||
use crate::error::Result;
|
||||
|
||||
pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBatch>, SchemaRef)> {
|
||||
pub fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBatch>, SchemaRef)> {
|
||||
let mut batches: Vec<RecordBatch> = Vec::new();
|
||||
let file_reader = FileReader::try_new(Cursor::new(slice), None)?;
|
||||
let schema = file_reader.schema();
|
||||
@@ -33,7 +33,7 @@ pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result<(Vec<RecordBa
|
||||
Ok((batches, schema))
|
||||
}
|
||||
|
||||
pub(crate) fn record_batch_to_buffer(batches: Vec<RecordBatch>) -> Result<Vec<u8>> {
|
||||
pub fn record_batch_to_buffer(batches: Vec<RecordBatch>) -> Result<Vec<u8>> {
|
||||
if batches.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
@@ -17,10 +17,7 @@ use neon::types::buffer::TypedArray;
|
||||
|
||||
use crate::error::ResultExt;
|
||||
|
||||
pub(crate) fn vec_str_to_array<'a, C: Context<'a>>(
|
||||
vec: &Vec<String>,
|
||||
cx: &mut C,
|
||||
) -> JsResult<'a, JsArray> {
|
||||
pub fn vec_str_to_array<'a, C: Context<'a>>(vec: &[String], cx: &mut C) -> JsResult<'a, JsArray> {
|
||||
let a = JsArray::new(cx, vec.len() as u32);
|
||||
for (i, s) in vec.iter().enumerate() {
|
||||
let v = cx.string(s);
|
||||
@@ -29,7 +26,7 @@ pub(crate) fn vec_str_to_array<'a, C: Context<'a>>(
|
||||
Ok(a)
|
||||
}
|
||||
|
||||
pub(crate) fn js_array_to_vec(array: &JsArray, cx: &mut FunctionContext) -> Vec<f32> {
|
||||
pub fn js_array_to_vec(array: &JsArray, cx: &mut FunctionContext) -> Vec<f32> {
|
||||
let mut query_vec: Vec<f32> = Vec::new();
|
||||
for i in 0..array.len(cx) {
|
||||
let entry: Handle<JsNumber> = array.get(cx, i).unwrap();
|
||||
@@ -39,7 +36,7 @@ pub(crate) fn js_array_to_vec(array: &JsArray, cx: &mut FunctionContext) -> Vec<
|
||||
}
|
||||
|
||||
// Creates a new JsBuffer from a rust buffer with a special logic for electron
|
||||
pub(crate) fn new_js_buffer<'a>(
|
||||
pub fn new_js_buffer<'a>(
|
||||
buffer: Vec<u8>,
|
||||
cx: &mut TaskContext<'a>,
|
||||
is_electron: bool,
|
||||
|
||||
@@ -18,7 +18,6 @@ use neon::prelude::NeonResult;
|
||||
use snafu::Snafu;
|
||||
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(visibility(pub(crate)))]
|
||||
pub enum Error {
|
||||
#[snafu(display("column '{name}' is missing"))]
|
||||
MissingColumn { name: String },
|
||||
|
||||
@@ -21,7 +21,7 @@ use neon::{
|
||||
use crate::{error::ResultExt, runtime, table::JsTable};
|
||||
use vectordb::Table;
|
||||
|
||||
pub(crate) fn table_create_scalar_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
pub fn table_create_scalar_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let column = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
let replace = cx.argument::<JsBoolean>(1)?.value(&mut cx);
|
||||
|
||||
@@ -24,7 +24,7 @@ use crate::neon_ext::js_object_ext::JsObjectExt;
|
||||
use crate::runtime;
|
||||
use crate::table::JsTable;
|
||||
|
||||
pub(crate) fn table_create_vector_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
pub fn table_create_vector_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let index_params = cx.argument::<JsObject>(0)?;
|
||||
|
||||
|
||||
@@ -260,6 +260,7 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> {
|
||||
cx.export_function("tableCountRows", JsTable::js_count_rows)?;
|
||||
cx.export_function("tableDelete", JsTable::js_delete)?;
|
||||
cx.export_function("tableUpdate", JsTable::js_update)?;
|
||||
cx.export_function("tableMergeInsert", JsTable::js_merge_insert)?;
|
||||
cx.export_function("tableCleanupOldVersions", JsTable::js_cleanup)?;
|
||||
cx.export_function("tableCompactFiles", JsTable::js_compact)?;
|
||||
cx.export_function("tableListIndices", JsTable::js_list_indices)?;
|
||||
|
||||
@@ -13,7 +13,7 @@ use crate::neon_ext::js_object_ext::JsObjectExt;
|
||||
use crate::table::JsTable;
|
||||
use crate::{convert, runtime};
|
||||
|
||||
pub(crate) struct JsQuery {}
|
||||
pub struct JsQuery {}
|
||||
|
||||
impl JsQuery {
|
||||
pub(crate) fn js_search(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::ops::Deref;
|
||||
|
||||
use arrow_array::{RecordBatch, RecordBatchIterator};
|
||||
use lance::dataset::optimize::CompactionOptions;
|
||||
use lance::dataset::{WriteMode, WriteParams};
|
||||
@@ -26,7 +28,7 @@ use vectordb::TableRef;
|
||||
use crate::error::ResultExt;
|
||||
use crate::{convert, get_aws_credential_provider, get_aws_region, runtime, JsDatabase};
|
||||
|
||||
pub(crate) struct JsTable {
|
||||
pub struct JsTable {
|
||||
pub table: TableRef,
|
||||
}
|
||||
|
||||
@@ -34,7 +36,7 @@ impl Finalize for JsTable {}
|
||||
|
||||
impl From<TableRef> for JsTable {
|
||||
fn from(table: TableRef) -> Self {
|
||||
JsTable { table }
|
||||
Self { table }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,14 +85,14 @@ impl JsTable {
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let table = table_rst.or_throw(&mut cx)?;
|
||||
Ok(cx.boxed(JsTable::from(table)))
|
||||
Ok(cx.boxed(Self::from(table)))
|
||||
});
|
||||
});
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
pub(crate) fn js_add(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||
let buffer = cx.argument::<JsBuffer>(0)?;
|
||||
let write_mode = cx.argument::<JsString>(1)?.value(&mut cx);
|
||||
let (batches, schema) =
|
||||
@@ -123,21 +125,34 @@ impl JsTable {
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
add_result.or_throw(&mut cx)?;
|
||||
Ok(cx.boxed(JsTable::from(table)))
|
||||
Ok(cx.boxed(Self::from(table)))
|
||||
});
|
||||
});
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
pub(crate) fn js_count_rows(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||
let filter = cx
|
||||
.argument_opt(0)
|
||||
.and_then(|filt| {
|
||||
if filt.is_a::<JsUndefined, _>(&mut cx) || filt.is_a::<JsNull, _>(&mut cx) {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
filt.downcast_or_throw::<JsString, _>(&mut cx)
|
||||
.map(|js_filt| js_filt.deref().value(&mut cx)),
|
||||
)
|
||||
}
|
||||
})
|
||||
.transpose()?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
let channel = cx.channel();
|
||||
let table = js_table.table.clone();
|
||||
|
||||
rt.spawn(async move {
|
||||
let num_rows_result = table.count_rows().await;
|
||||
let num_rows_result = table.count_rows(filter).await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let num_rows = num_rows_result.or_throw(&mut cx)?;
|
||||
@@ -148,7 +163,7 @@ impl JsTable {
|
||||
}
|
||||
|
||||
pub(crate) fn js_delete(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
let predicate = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
@@ -160,14 +175,67 @@ impl JsTable {
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
delete_result.or_throw(&mut cx)?;
|
||||
Ok(cx.boxed(JsTable::from(table)))
|
||||
Ok(cx.boxed(Self::from(table)))
|
||||
})
|
||||
});
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
pub(crate) fn js_merge_insert(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
let channel = cx.channel();
|
||||
let table = js_table.table.clone();
|
||||
|
||||
let key = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
let mut builder = table.merge_insert(&[&key]);
|
||||
if cx.argument::<JsBoolean>(1)?.value(&mut cx) {
|
||||
let filter = cx.argument_opt(2).unwrap();
|
||||
if filter.is_a::<JsNull, _>(&mut cx) {
|
||||
builder.when_matched_update_all(None);
|
||||
} else {
|
||||
let filter = filter
|
||||
.downcast_or_throw::<JsString, _>(&mut cx)?
|
||||
.deref()
|
||||
.value(&mut cx);
|
||||
builder.when_matched_update_all(Some(filter));
|
||||
}
|
||||
}
|
||||
if cx.argument::<JsBoolean>(3)?.value(&mut cx) {
|
||||
builder.when_not_matched_insert_all();
|
||||
}
|
||||
if cx.argument::<JsBoolean>(4)?.value(&mut cx) {
|
||||
let filter = cx.argument_opt(5).unwrap();
|
||||
if filter.is_a::<JsNull, _>(&mut cx) {
|
||||
builder.when_not_matched_by_source_delete(None);
|
||||
} else {
|
||||
let filter = filter
|
||||
.downcast_or_throw::<JsString, _>(&mut cx)?
|
||||
.deref()
|
||||
.value(&mut cx);
|
||||
builder.when_not_matched_by_source_delete(Some(filter));
|
||||
}
|
||||
}
|
||||
|
||||
let buffer = cx.argument::<JsBuffer>(6)?;
|
||||
let (batches, schema) =
|
||||
arrow_buffer_to_record_batch(buffer.as_slice(&cx)).or_throw(&mut cx)?;
|
||||
|
||||
rt.spawn(async move {
|
||||
let new_data = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
|
||||
let merge_insert_result = builder.execute(Box::new(new_data)).await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
merge_insert_result.or_throw(&mut cx)?;
|
||||
Ok(cx.boxed(Self::from(table)))
|
||||
})
|
||||
});
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
pub(crate) fn js_update(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||
let table = js_table.table.clone();
|
||||
|
||||
let rt = runtime(&mut cx)?;
|
||||
@@ -226,7 +294,7 @@ impl JsTable {
|
||||
.await;
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
update_result.or_throw(&mut cx)?;
|
||||
Ok(cx.boxed(JsTable::from(table)))
|
||||
Ok(cx.boxed(Self::from(table)))
|
||||
})
|
||||
});
|
||||
|
||||
@@ -234,7 +302,7 @@ impl JsTable {
|
||||
}
|
||||
|
||||
pub(crate) fn js_cleanup(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
let table = js_table.table.clone();
|
||||
@@ -272,7 +340,7 @@ impl JsTable {
|
||||
let old_versions = cx.number(prune_stats.old_versions as f64);
|
||||
output_metrics.set(&mut cx, "oldVersions", old_versions)?;
|
||||
|
||||
let output_table = cx.boxed(JsTable::from(table));
|
||||
let output_table = cx.boxed(Self::from(table));
|
||||
|
||||
let output = JsObject::new(&mut cx);
|
||||
output.set(&mut cx, "metrics", output_metrics)?;
|
||||
@@ -285,7 +353,7 @@ impl JsTable {
|
||||
}
|
||||
|
||||
pub(crate) fn js_compact(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
let table = js_table.table.clone();
|
||||
@@ -344,7 +412,7 @@ impl JsTable {
|
||||
let files_added = cx.number(stats.files_added as f64);
|
||||
output_metrics.set(&mut cx, "filesAdded", files_added)?;
|
||||
|
||||
let output_table = cx.boxed(JsTable::from(table));
|
||||
let output_table = cx.boxed(Self::from(table));
|
||||
|
||||
let output = JsObject::new(&mut cx);
|
||||
output.set(&mut cx, "metrics", output_metrics)?;
|
||||
@@ -357,7 +425,7 @@ impl JsTable {
|
||||
}
|
||||
|
||||
pub(crate) fn js_list_indices(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
// let predicate = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
@@ -396,7 +464,7 @@ impl JsTable {
|
||||
}
|
||||
|
||||
pub(crate) fn js_index_stats(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
let index_uuid = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
@@ -444,7 +512,7 @@ impl JsTable {
|
||||
}
|
||||
|
||||
pub(crate) fn js_schema(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<Self>, _>(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
let channel = cx.channel();
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
[package]
|
||||
name = "vectordb"
|
||||
version = "0.4.7"
|
||||
edition = "2021"
|
||||
version = "0.4.10"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license = "Apache-2.0"
|
||||
repository = "https://github.com/lancedb/lancedb"
|
||||
keywords = ["lancedb", "lance", "database", "search"]
|
||||
categories = ["database-implementations"]
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
[dependencies]
|
||||
|
||||
@@ -188,12 +188,12 @@ impl Database {
|
||||
/// # Returns
|
||||
///
|
||||
/// * A [Database] object.
|
||||
pub async fn connect(uri: &str) -> Result<Database> {
|
||||
pub async fn connect(uri: &str) -> Result<Self> {
|
||||
let options = ConnectOptions::new(uri);
|
||||
Self::connect_with_options(&options).await
|
||||
}
|
||||
|
||||
pub async fn connect_with_options(options: &ConnectOptions) -> Result<Database> {
|
||||
pub async fn connect_with_options(options: &ConnectOptions) -> Result<Self> {
|
||||
let uri = &options.uri;
|
||||
let parse_res = url::Url::parse(uri);
|
||||
|
||||
@@ -276,7 +276,7 @@ impl Database {
|
||||
None => None,
|
||||
};
|
||||
|
||||
Ok(Database {
|
||||
Ok(Self {
|
||||
uri: table_base_uri,
|
||||
query_string,
|
||||
base_path,
|
||||
@@ -288,7 +288,7 @@ impl Database {
|
||||
}
|
||||
}
|
||||
|
||||
async fn open_path(path: &str) -> Result<Database> {
|
||||
async fn open_path(path: &str) -> Result<Self> {
|
||||
let (object_store, base_path) = ObjectStore::from_uri(path).await?;
|
||||
if object_store.is_local() {
|
||||
Self::try_create_dir(path).context(CreateDirSnafu { path })?;
|
||||
@@ -422,13 +422,11 @@ mod tests {
|
||||
let tmp_dir = tempdir().unwrap();
|
||||
let uri = std::fs::canonicalize(tmp_dir.path().to_str().unwrap()).unwrap();
|
||||
|
||||
let mut relative_anacestors = vec![];
|
||||
let current_dir = std::env::current_dir().unwrap();
|
||||
let mut ancestors = current_dir.ancestors();
|
||||
while let Some(_) = ancestors.next() {
|
||||
relative_anacestors.push("..");
|
||||
}
|
||||
let relative_root = std::path::PathBuf::from(relative_anacestors.join("/"));
|
||||
let ancestors = current_dir.ancestors();
|
||||
let relative_ancestors = vec![".."; ancestors.count()];
|
||||
|
||||
let relative_root = std::path::PathBuf::from(relative_ancestors.join("/"));
|
||||
let relative_uri = relative_root.join(&uri);
|
||||
|
||||
let db = Database::connect(relative_uri.to_str().unwrap())
|
||||
|
||||
@@ -69,7 +69,7 @@ pub struct IndexBuilder {
|
||||
|
||||
impl IndexBuilder {
|
||||
pub(crate) fn new(table: Arc<dyn Table>, columns: &[&str]) -> Self {
|
||||
IndexBuilder {
|
||||
Self {
|
||||
table,
|
||||
columns: columns.iter().map(|c| c.to_string()).collect(),
|
||||
name: None,
|
||||
@@ -197,7 +197,7 @@ impl IndexBuilder {
|
||||
let num_partitions = if let Some(n) = self.num_partitions {
|
||||
n
|
||||
} else {
|
||||
suggested_num_partitions(self.table.count_rows().await?)
|
||||
suggested_num_partitions(self.table.count_rows(None).await?)
|
||||
};
|
||||
let num_sub_vectors: u32 = if let Some(n) = self.num_sub_vectors {
|
||||
n
|
||||
|
||||
@@ -23,13 +23,13 @@ pub struct VectorIndex {
|
||||
}
|
||||
|
||||
impl VectorIndex {
|
||||
pub fn new_from_format(manifest: &Manifest, index: &Index) -> VectorIndex {
|
||||
pub fn new_from_format(manifest: &Manifest, index: &Index) -> Self {
|
||||
let fields = index
|
||||
.fields
|
||||
.iter()
|
||||
.map(|i| manifest.schema.fields[*i as usize].name.clone())
|
||||
.collect();
|
||||
VectorIndex {
|
||||
Self {
|
||||
columns: fields,
|
||||
index_name: index.name.clone(),
|
||||
index_uuid: index.uuid.to_string(),
|
||||
|
||||
@@ -357,12 +357,14 @@ mod test {
|
||||
let db = Database::connect(dir1.to_str().unwrap()).await.unwrap();
|
||||
|
||||
let mut param = WriteParams::default();
|
||||
let mut store_params = ObjectStoreParams::default();
|
||||
store_params.object_store_wrapper = Some(object_store_wrapper);
|
||||
let store_params = ObjectStoreParams {
|
||||
object_store_wrapper: Some(object_store_wrapper),
|
||||
..Default::default()
|
||||
};
|
||||
param.store_params = Some(store_params);
|
||||
|
||||
let mut datagen = BatchGenerator::new();
|
||||
datagen = datagen.col(Box::new(IncrementingInt32::default()));
|
||||
datagen = datagen.col(Box::<IncrementingInt32>::default());
|
||||
datagen = datagen.col(Box::new(RandomVector::default().named("vector".into())));
|
||||
|
||||
let res = db
|
||||
@@ -372,7 +374,7 @@ mod test {
|
||||
// leave this here for easy debugging
|
||||
let t = res.unwrap();
|
||||
|
||||
assert_eq!(t.count_rows().await.unwrap(), 100);
|
||||
assert_eq!(t.count_rows(None).await.unwrap(), 100);
|
||||
|
||||
let q = t
|
||||
.search(&[0.1, 0.1, 0.1, 0.1])
|
||||
|
||||
@@ -62,7 +62,7 @@ impl Query {
|
||||
/// * `dataset` - Lance dataset.
|
||||
///
|
||||
pub(crate) fn new(dataset: Arc<Dataset>) -> Self {
|
||||
Query {
|
||||
Self {
|
||||
dataset,
|
||||
query_vector: None,
|
||||
column: None,
|
||||
@@ -257,7 +257,7 @@ mod tests {
|
||||
assert_eq!(query.query_vector.unwrap(), new_vector);
|
||||
assert_eq!(query.limit.unwrap(), 100);
|
||||
assert_eq!(query.nprobes, 1000);
|
||||
assert_eq!(query.use_index, true);
|
||||
assert!(query.use_index);
|
||||
assert_eq!(query.metric_type, Some(MetricType::Cosine));
|
||||
assert_eq!(query.refine_factor, Some(999));
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user