mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 13:29:57 +00:00
Compare commits
43 Commits
v0.3.2
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ef20b2a138 | ||
|
|
2e0f251bfd | ||
|
|
2cb91e818d | ||
|
|
2835c76336 | ||
|
|
8068a2bbc3 | ||
|
|
24111d543a | ||
|
|
7eec2b8f9a | ||
|
|
b2b70ea399 | ||
|
|
e50a3c1783 | ||
|
|
b517134309 | ||
|
|
6fb539b5bf | ||
|
|
f37fe120fd | ||
|
|
2e115acb9a | ||
|
|
27a638362d | ||
|
|
22a6695d7a | ||
|
|
57eff82ee7 | ||
|
|
7732f7d41c | ||
|
|
5ca98c326f | ||
|
|
b55db397eb | ||
|
|
c04d72ac8a | ||
|
|
28b02fb72a | ||
|
|
f3cf986777 | ||
|
|
c73fcc8898 | ||
|
|
cd9debc3b7 | ||
|
|
26a97ba997 | ||
|
|
ce19fedb08 | ||
|
|
14e8e48de2 | ||
|
|
c30faf6083 | ||
|
|
64a4f025bb | ||
|
|
6dc968e7d3 | ||
|
|
06b5b69f1e | ||
|
|
6bd3a838fc | ||
|
|
f36fea8f20 | ||
|
|
0a30591729 | ||
|
|
0ed39b6146 | ||
|
|
a8c7f80073 | ||
|
|
0293bbe142 | ||
|
|
7372656369 | ||
|
|
d46bc5dd6e | ||
|
|
86efb11572 | ||
|
|
bb01ad5290 | ||
|
|
1b8cda0941 | ||
|
|
bc85a749a3 |
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.3.2
|
||||
current_version = 0.3.5
|
||||
commit = True
|
||||
message = Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
4
.github/workflows/node.yml
vendored
4
.github/workflows/node.yml
vendored
@@ -11,6 +11,10 @@ on:
|
||||
- .github/workflows/node.yml
|
||||
- docker-compose.yml
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
# Disable full debug symbol generation to speed up CI build and keep memory down
|
||||
# "1" means line tables only, which is useful for panic tracebacks.
|
||||
|
||||
2
.github/workflows/npm-publish.yml
vendored
2
.github/workflows/npm-publish.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
||||
node/vectordb-*.tgz
|
||||
|
||||
node-macos:
|
||||
runs-on: macos-12
|
||||
runs-on: macos-13
|
||||
# Only runs on tags that matches the make-release action
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
strategy:
|
||||
|
||||
7
.github/workflows/python.yml
vendored
7
.github/workflows/python.yml
vendored
@@ -8,6 +8,11 @@ on:
|
||||
paths:
|
||||
- python/**
|
||||
- .github/workflows/python.yml
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
linux:
|
||||
timeout-minutes: 30
|
||||
@@ -43,7 +48,7 @@ jobs:
|
||||
run: pytest --doctest-modules lancedb
|
||||
mac:
|
||||
timeout-minutes: 30
|
||||
runs-on: "macos-12"
|
||||
runs-on: "macos-13"
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
6
.github/workflows/rust.yml
vendored
6
.github/workflows/rust.yml
vendored
@@ -10,6 +10,10 @@ on:
|
||||
- rust/**
|
||||
- .github/workflows/rust.yml
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
# This env var is used by Swatinem/rust-cache@v2 for the cache
|
||||
# key, so we set it to make sure it is always consistent.
|
||||
@@ -44,7 +48,7 @@ jobs:
|
||||
- name: Run tests
|
||||
run: cargo test --all-features
|
||||
macos:
|
||||
runs-on: macos-12
|
||||
runs-on: macos-13
|
||||
timeout-minutes: 30
|
||||
defaults:
|
||||
run:
|
||||
|
||||
10
Cargo.toml
10
Cargo.toml
@@ -5,9 +5,9 @@ exclude = ["python"]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.8.5", "features" = ["dynamodb"] }
|
||||
lance-linalg = { "version" = "=0.8.5" }
|
||||
lance-testing = { "version" = "=0.8.5" }
|
||||
lance = { "version" = "=0.8.10", "features" = ["dynamodb"] }
|
||||
lance-linalg = { "version" = "=0.8.10" }
|
||||
lance-testing = { "version" = "=0.8.10" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "47.0.0", optional = false }
|
||||
arrow-array = "47.0"
|
||||
@@ -18,8 +18,8 @@ arrow-schema = "47.0"
|
||||
arrow-arith = "47.0"
|
||||
arrow-cast = "47.0"
|
||||
chrono = "0.4.23"
|
||||
half = { "version" = "=2.2.1", default-features = false, features = [
|
||||
"num-traits"
|
||||
half = { "version" = "=2.3.1", default-features = false, features = [
|
||||
"num-traits",
|
||||
] }
|
||||
log = "0.4"
|
||||
object_store = "0.7.1"
|
||||
|
||||
26
docs/README.md
Normal file
26
docs/README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# LanceDB Documentation
|
||||
|
||||
LanceDB docs are deployed to https://lancedb.github.io/lancedb/.
|
||||
|
||||
Docs is built and deployed automatically by [Github Actions](.github/workflows/docs.yml)
|
||||
whenever a commit is pushed to the `main` branch. So it is possible for the docs to show
|
||||
unreleased features.
|
||||
|
||||
## Building the docs
|
||||
|
||||
### Setup
|
||||
1. Install LanceDB. From LanceDB repo root: `pip install -e python`
|
||||
2. Install dependencies. From LanceDB repo root: `pip install -r docs/requirements.txt`
|
||||
3. Make sure you have node and npm setup
|
||||
4. Make sure protobuf and libssl are installed
|
||||
|
||||
### Building node module and create markdown files
|
||||
|
||||
See [Javascript docs README](docs/src/javascript/README.md)
|
||||
|
||||
### Build docs
|
||||
From LanceDB repo root:
|
||||
|
||||
Run: `PYTHONPATH=. mkdocs build -f docs/mkdocs.yml`
|
||||
|
||||
If successful, you should see a `docs/site` directory that you can verify locally.
|
||||
@@ -73,12 +73,14 @@ nav:
|
||||
- Vector Search: search.md
|
||||
- SQL filters: sql.md
|
||||
- Indexing: ann_indexes.md
|
||||
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||
- 🧬 Embeddings:
|
||||
- embeddings/index.md
|
||||
- Ingest Embedding Functions: embeddings/embedding_functions.md
|
||||
- Available Functions: embeddings/default_embedding_functions.md
|
||||
- Create Custom Embedding Functions: embeddings/api.md
|
||||
- Example- MultiModal CLIP Embeddings: notebooks/DisappearingEmbeddingFunction.ipynb
|
||||
- Example - Multi-lingual semantic search: notebooks/multi_lingual_example.ipynb
|
||||
- Example - MultiModal CLIP Embeddings: notebooks/DisappearingEmbeddingFunction.ipynb
|
||||
- 🔍 Python full-text search: fts.md
|
||||
- 🔌 Integrations:
|
||||
- integrations/index.md
|
||||
@@ -110,12 +112,14 @@ nav:
|
||||
- Vector Search: search.md
|
||||
- SQL filters: sql.md
|
||||
- Indexing: ann_indexes.md
|
||||
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||
- Embeddings:
|
||||
- embeddings/index.md
|
||||
- Ingest Embedding Functions: embeddings/embedding_functions.md
|
||||
- Available Functions: embeddings/default_embedding_functions.md
|
||||
- Create Custom Embedding Functions: embeddings/api.md
|
||||
- Example- MultiModal CLIP Embeddings: notebooks/DisappearingEmbeddingFunction.ipynb
|
||||
- Example - Multi-lingual semantic search: notebooks/multi_lingual_example.ipynb
|
||||
- Example - MultiModal CLIP Embeddings: notebooks/DisappearingEmbeddingFunction.ipynb
|
||||
- Python full-text search: fts.md
|
||||
- Integrations:
|
||||
- integrations/index.md
|
||||
|
||||
@@ -71,9 +71,41 @@ a single PQ code.
|
||||
### Use GPU to build vector index
|
||||
|
||||
Lance Python SDK has experimental GPU support for creating IVF index.
|
||||
Using GPU for index creation requires [PyTorch>2.0](https://pytorch.org/) being installed.
|
||||
|
||||
You can specify the GPU device to train IVF partitions via
|
||||
|
||||
- **accelerator**: Specify to `"cuda"`` to enable GPU training.
|
||||
- **accelerator**: Specify to ``cuda`` or ``mps`` (on Apple Silicon) to enable GPU training.
|
||||
|
||||
=== "Linux"
|
||||
|
||||
<!-- skip-test -->
|
||||
``` { .python .copy }
|
||||
# Create index using CUDA on Nvidia GPUs.
|
||||
tbl.create_index(
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
accelerator="cuda"
|
||||
)
|
||||
```
|
||||
|
||||
=== "Macos"
|
||||
|
||||
<!-- skip-test -->
|
||||
```python
|
||||
# Create index using MPS on Apple Silicon.
|
||||
tbl.create_index(
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
accelerator="mps"
|
||||
)
|
||||
```
|
||||
|
||||
Trouble shootings:
|
||||
|
||||
If you see ``AssertionError: Torch not compiled with CUDA enabled``, you need to [install
|
||||
PyTorch with CUDA support](https://pytorch.org/get-started/locally/).
|
||||
|
||||
|
||||
## Querying an ANN Index
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ LanceDB's core is written in Rust 🦀 and is built using <a href="https://githu
|
||||
|
||||
## Documentation Quick Links
|
||||
* [`Basic Operations`](basic.md) - basic functionality of LanceDB.
|
||||
* [`Embedding Functions`](embedding.md) - functions for working with embeddings.
|
||||
* [`Embedding Functions`](embeddings/index.md) - functions for working with embeddings.
|
||||
* [`Indexing`](ann_indexes.md) - create vector indexes to speed up queries.
|
||||
* [`Full text search`](fts.md) - [EXPERIMENTAL] full-text search API
|
||||
* [`Ecosystem Integrations`](python/integration.md) - integrating LanceDB with python data tooling ecosystem.
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "88c1af18",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Example - MultiModal CLIP Embeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c6b5d346-2c2a-4341-a132-00e53543f8d1",
|
||||
|
||||
604
docs/src/notebooks/multi_lingual_example.ipynb
Normal file
604
docs/src/notebooks/multi_lingual_example.ipynb
Normal file
File diff suppressed because one or more lines are too long
1189
docs/src/notebooks/reproducibility.ipynb
Normal file
1189
docs/src/notebooks/reproducibility.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
@@ -22,21 +22,19 @@ pip install lancedb
|
||||
|
||||
::: lancedb.query.LanceQueryBuilder
|
||||
|
||||
::: lancedb.query.LanceFtsQueryBuilder
|
||||
|
||||
## Embeddings
|
||||
|
||||
::: lancedb.embeddings.functions.EmbeddingFunctionRegistry
|
||||
::: lancedb.embeddings.registry.EmbeddingFunctionRegistry
|
||||
|
||||
::: lancedb.embeddings.functions.EmbeddingFunction
|
||||
::: lancedb.embeddings.base.EmbeddingFunction
|
||||
|
||||
::: lancedb.embeddings.functions.TextEmbeddingFunction
|
||||
::: lancedb.embeddings.base.TextEmbeddingFunction
|
||||
|
||||
::: lancedb.embeddings.functions.SentenceTransformerEmbeddings
|
||||
::: lancedb.embeddings.sentence_transformers.SentenceTransformerEmbeddings
|
||||
|
||||
::: lancedb.embeddings.functions.OpenAIEmbeddings
|
||||
::: lancedb.embeddings.openai.OpenAIEmbeddings
|
||||
|
||||
::: lancedb.embeddings.functions.OpenClipEmbeddings
|
||||
::: lancedb.embeddings.open_clip.OpenClipEmbeddings
|
||||
|
||||
::: lancedb.embeddings.with_embeddings
|
||||
|
||||
@@ -56,7 +54,7 @@ pip install lancedb
|
||||
|
||||
## Utilities
|
||||
|
||||
::: lancedb.vector
|
||||
::: lancedb.schema.vector
|
||||
|
||||
## Integrations
|
||||
|
||||
|
||||
4
docs/src/scripts/posthog.js
Normal file
4
docs/src/scripts/posthog.js
Normal file
@@ -0,0 +1,4 @@
|
||||
window.addEventListener("DOMContentLoaded", (event) => {
|
||||
!function(t,e){var o,n,p,r;e.__SV||(window.posthog=e,e._i=[],e.init=function(i,s,a){function g(t,e){var o=e.split(".");2==o.length&&(t=t[o[0]],e=o[1]),t[e]=function(){t.push([e].concat(Array.prototype.slice.call(arguments,0)))}}(p=t.createElement("script")).type="text/javascript",p.async=!0,p.src=s.api_host+"/static/array.js",(r=t.getElementsByTagName("script")[0]).parentNode.insertBefore(p,r);var u=e;for(void 0!==a?u=e[a]=[]:a="posthog",u.people=u.people||[],u.toString=function(t){var e="posthog";return"posthog"!==a&&(e+="."+a),t||(e+=" (stub)"),e},u.people.toString=function(){return u.toString(1)+".people (stub)"},o="capture identify alias people.set people.set_once set_config register register_once unregister opt_out_capturing has_opted_out_capturing opt_in_capturing reset isFeatureEnabled onFeatureFlags getFeatureFlag getFeatureFlagPayload reloadFeatureFlags group updateEarlyAccessFeatureEnrollment getEarlyAccessFeatures getActiveMatchingSurveys getSurveys".split(" "),n=0;n<o.length;n++)g(u,o[n]);e._i.push([i,s,a])},e.__SV=1)}(document,window.posthog||[]);
|
||||
posthog.init('phc_oENDjGgHtmIDrV6puUiFem2RB4JA8gGWulfdulmMdZP',{api_host:'https://app.posthog.com'})
|
||||
});
|
||||
@@ -4,7 +4,7 @@
|
||||
In a recommendation system or search engine, you can find similar products from
|
||||
the one you searched.
|
||||
In LLM and other AI applications,
|
||||
each data point can be [presented by the embeddings generated from some models](embedding.md),
|
||||
each data point can be [presented by the embeddings generated from some models](embeddings/index.md),
|
||||
it returns the most relevant features.
|
||||
|
||||
A search in high-dimensional vector space, is to find `K-Nearest-Neighbors (KNN)` of the query vector.
|
||||
|
||||
@@ -18,29 +18,45 @@ python_file = ".py"
|
||||
python_folder = "python"
|
||||
|
||||
files = glob.glob(glob_string, recursive=True)
|
||||
excluded_files = [f for excluded_glob in excluded_globs for f in glob.glob(excluded_glob, recursive=True)]
|
||||
excluded_files = [
|
||||
f
|
||||
for excluded_glob in excluded_globs
|
||||
for f in glob.glob(excluded_glob, recursive=True)
|
||||
]
|
||||
|
||||
|
||||
def yield_lines(lines: Iterator[str], prefix: str, suffix: str):
|
||||
in_code_block = False
|
||||
# Python code has strict indentation
|
||||
strip_length = 0
|
||||
skip_test = False
|
||||
for line in lines:
|
||||
if "skip-test" in line:
|
||||
skip_test = True
|
||||
if line.strip().startswith(prefix + python_prefix):
|
||||
in_code_block = True
|
||||
strip_length = len(line) - len(line.lstrip())
|
||||
elif in_code_block and line.strip().startswith(suffix):
|
||||
in_code_block = False
|
||||
yield "\n"
|
||||
if not skip_test:
|
||||
yield "\n"
|
||||
skip_test = False
|
||||
elif in_code_block:
|
||||
yield line[strip_length:]
|
||||
if not skip_test:
|
||||
yield line[strip_length:]
|
||||
|
||||
for file in filter(lambda file: file not in excluded_files, files):
|
||||
with open(file, "r") as f:
|
||||
lines = list(yield_lines(iter(f), "```", "```"))
|
||||
|
||||
if len(lines) > 0:
|
||||
out_path = Path(python_folder) / Path(file).name.strip(".md") / (Path(file).name.strip(".md") + python_file)
|
||||
print(lines)
|
||||
out_path = (
|
||||
Path(python_folder)
|
||||
/ Path(file).name.strip(".md")
|
||||
/ (Path(file).name.strip(".md") + python_file)
|
||||
)
|
||||
print(out_path)
|
||||
out_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(out_path, "w") as out:
|
||||
out.writelines(lines)
|
||||
out.writelines(lines)
|
||||
|
||||
@@ -10,7 +10,7 @@ npm install vectordb
|
||||
|
||||
This will download the appropriate native library for your platform. We currently
|
||||
support x86_64 Linux, aarch64 Linux, Intel MacOS, and ARM (M1/M2) MacOS. We do not
|
||||
yet support Windows or musl-based Linux (such as Alpine Linux).
|
||||
yet support musl-based Linux (such as Alpine Linux).
|
||||
|
||||
## Usage
|
||||
|
||||
|
||||
74
node/package-lock.json
generated
74
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.3.1",
|
||||
"version": "0.3.5",
|
||||
"lockfileVersion": 2,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "vectordb",
|
||||
"version": "0.3.1",
|
||||
"version": "0.3.5",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
@@ -53,11 +53,11 @@
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.3.1",
|
||||
"@lancedb/vectordb-darwin-x64": "0.3.1",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.3.1",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.3.1",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.3.1"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.3.5",
|
||||
"@lancedb/vectordb-darwin-x64": "0.3.5",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.3.5",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.3.5",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.3.5"
|
||||
}
|
||||
},
|
||||
"node_modules/@apache-arrow/ts": {
|
||||
@@ -317,9 +317,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
||||
"version": "0.3.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.3.1.tgz",
|
||||
"integrity": "sha512-h3yUP249xaO3rrRuVC4oRxEm5/9T66CGKiI8OwYCJUOEFrfz/jj+6PK8geMn7IqbPnOY9YRPSEi/Cc3EdFd6Sg==",
|
||||
"version": "0.3.5",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.3.5.tgz",
|
||||
"integrity": "sha512-Nnso+WXMSTIUouddDgPDNt40K6d2fF7W5OsfgAMDXAhUrdSMOZbVP0bWklRz9J7JluseBL9/MfLSEYZDTvrACg==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -329,9 +329,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-darwin-x64": {
|
||||
"version": "0.3.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.1.tgz",
|
||||
"integrity": "sha512-SQ32iMMVfvjXgvFGSGdsXcSnVDypR6eE06d7VIXsuKAg6P9e1XUhB4YcsHGeAEEv3gEoUSgsljo92ZvXJcWouQ==",
|
||||
"version": "0.3.5",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.5.tgz",
|
||||
"integrity": "sha512-gvg/iq13zAamLL7jueiIw7Q67dygm/NmILkFQ3WrAOUjr0IMxLBCv+XMxt62xajTrA+ObyfmU1uiuhrJL81PWw==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -341,9 +341,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
||||
"version": "0.3.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.1.tgz",
|
||||
"integrity": "sha512-+jk2nJnaIWTqcOAyix2y+ClLNM5ECIdwyHZp5KjDqOlP6Z7eb5V2Xsah0AFp8nX3BiRRvqj3zR3zi26D7OBnYw==",
|
||||
"version": "0.3.5",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.5.tgz",
|
||||
"integrity": "sha512-6PvCBIXI9zPqF478TibZxxiAehFZ530g0FOFDT49xtp540HvhE9+XQk/yO0w96mvyoCfzB2lK4haDmdhCoehNw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -353,9 +353,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
||||
"version": "0.3.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.1.tgz",
|
||||
"integrity": "sha512-I42Zf2lH8SUZLLYDDG4kzZ8iPq2wf1cXMh9iKNiLwgl5BnRsZVQ5A5k0uCX7IV7FcnHL/febKOxixXQyoKNAzw==",
|
||||
"version": "0.3.5",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.5.tgz",
|
||||
"integrity": "sha512-e3nqurUeCow4QONeNf/QP50Z90mgrh9xoUfjRSHcCPQcP6WgmFEafbt0jeSVgZ7tbt7+03/MK0YexhHM/5sBjA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -365,9 +365,9 @@
|
||||
]
|
||||
},
|
||||
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
||||
"version": "0.3.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.1.tgz",
|
||||
"integrity": "sha512-3OBS+fc4kcwhkqIy5b2Nump/iYoAgQd6gmYIJux3LJbMCc4yDcPJdFGVQkWu43JfBh7YOWPfOng2NSCUDBGmoA==",
|
||||
"version": "0.3.5",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.5.tgz",
|
||||
"integrity": "sha512-RC1FfgEr6Z9sADuvspT2PG1B2mpKRdckgeiHqTHkIXdq3Qp5V5TeQJAbVvMr2xd1q99W6zreub52QXf+AilLVQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -4869,33 +4869,33 @@
|
||||
}
|
||||
},
|
||||
"@lancedb/vectordb-darwin-arm64": {
|
||||
"version": "0.3.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.3.1.tgz",
|
||||
"integrity": "sha512-h3yUP249xaO3rrRuVC4oRxEm5/9T66CGKiI8OwYCJUOEFrfz/jj+6PK8geMn7IqbPnOY9YRPSEi/Cc3EdFd6Sg==",
|
||||
"version": "0.3.5",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.3.5.tgz",
|
||||
"integrity": "sha512-Nnso+WXMSTIUouddDgPDNt40K6d2fF7W5OsfgAMDXAhUrdSMOZbVP0bWklRz9J7JluseBL9/MfLSEYZDTvrACg==",
|
||||
"optional": true
|
||||
},
|
||||
"@lancedb/vectordb-darwin-x64": {
|
||||
"version": "0.3.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.1.tgz",
|
||||
"integrity": "sha512-SQ32iMMVfvjXgvFGSGdsXcSnVDypR6eE06d7VIXsuKAg6P9e1XUhB4YcsHGeAEEv3gEoUSgsljo92ZvXJcWouQ==",
|
||||
"version": "0.3.5",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.3.5.tgz",
|
||||
"integrity": "sha512-gvg/iq13zAamLL7jueiIw7Q67dygm/NmILkFQ3WrAOUjr0IMxLBCv+XMxt62xajTrA+ObyfmU1uiuhrJL81PWw==",
|
||||
"optional": true
|
||||
},
|
||||
"@lancedb/vectordb-linux-arm64-gnu": {
|
||||
"version": "0.3.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.1.tgz",
|
||||
"integrity": "sha512-+jk2nJnaIWTqcOAyix2y+ClLNM5ECIdwyHZp5KjDqOlP6Z7eb5V2Xsah0AFp8nX3BiRRvqj3zR3zi26D7OBnYw==",
|
||||
"version": "0.3.5",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.3.5.tgz",
|
||||
"integrity": "sha512-6PvCBIXI9zPqF478TibZxxiAehFZ530g0FOFDT49xtp540HvhE9+XQk/yO0w96mvyoCfzB2lK4haDmdhCoehNw==",
|
||||
"optional": true
|
||||
},
|
||||
"@lancedb/vectordb-linux-x64-gnu": {
|
||||
"version": "0.3.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.1.tgz",
|
||||
"integrity": "sha512-I42Zf2lH8SUZLLYDDG4kzZ8iPq2wf1cXMh9iKNiLwgl5BnRsZVQ5A5k0uCX7IV7FcnHL/febKOxixXQyoKNAzw==",
|
||||
"version": "0.3.5",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.3.5.tgz",
|
||||
"integrity": "sha512-e3nqurUeCow4QONeNf/QP50Z90mgrh9xoUfjRSHcCPQcP6WgmFEafbt0jeSVgZ7tbt7+03/MK0YexhHM/5sBjA==",
|
||||
"optional": true
|
||||
},
|
||||
"@lancedb/vectordb-win32-x64-msvc": {
|
||||
"version": "0.3.1",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.1.tgz",
|
||||
"integrity": "sha512-3OBS+fc4kcwhkqIy5b2Nump/iYoAgQd6gmYIJux3LJbMCc4yDcPJdFGVQkWu43JfBh7YOWPfOng2NSCUDBGmoA==",
|
||||
"version": "0.3.5",
|
||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.3.5.tgz",
|
||||
"integrity": "sha512-RC1FfgEr6Z9sADuvspT2PG1B2mpKRdckgeiHqTHkIXdq3Qp5V5TeQJAbVvMr2xd1q99W6zreub52QXf+AilLVQ==",
|
||||
"optional": true
|
||||
},
|
||||
"@neon-rs/cli": {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.3.2",
|
||||
"version": "0.3.5",
|
||||
"description": " Serverless, low-latency vector database for AI applications",
|
||||
"main": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
@@ -81,10 +81,10 @@
|
||||
}
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@lancedb/vectordb-darwin-arm64": "0.3.2",
|
||||
"@lancedb/vectordb-darwin-x64": "0.3.2",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.3.2",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.3.2",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.3.2"
|
||||
"@lancedb/vectordb-darwin-arm64": "0.3.5",
|
||||
"@lancedb/vectordb-darwin-x64": "0.3.5",
|
||||
"@lancedb/vectordb-linux-arm64-gnu": "0.3.5",
|
||||
"@lancedb/vectordb-linux-x64-gnu": "0.3.5",
|
||||
"@lancedb/vectordb-win32-x64-msvc": "0.3.5"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ import { Query } from './query'
|
||||
import { isEmbeddingFunction } from './embedding/embedding_function'
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-var-requires
|
||||
const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, tableCreate, tableAdd, tableCreateVectorIndex, tableCountRows, tableDelete, tableCleanupOldVersions, tableCompactFiles } = require('../native.js')
|
||||
const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, tableCreate, tableAdd, tableCreateVectorIndex, tableCountRows, tableDelete, tableCleanupOldVersions, tableCompactFiles, tableListIndices, tableIndexStats } = require('../native.js')
|
||||
|
||||
export { Query }
|
||||
export type { EmbeddingFunction }
|
||||
@@ -260,6 +260,27 @@ export interface Table<T = number[]> {
|
||||
* ```
|
||||
*/
|
||||
delete: (filter: string) => Promise<void>
|
||||
|
||||
/**
|
||||
* List the indicies on this table.
|
||||
*/
|
||||
listIndices: () => Promise<VectorIndex[]>
|
||||
|
||||
/**
|
||||
* Get statistics about an index.
|
||||
*/
|
||||
indexStats: (indexUuid: string) => Promise<IndexStats>
|
||||
}
|
||||
|
||||
export interface VectorIndex {
|
||||
columns: string[]
|
||||
name: string
|
||||
uuid: string
|
||||
}
|
||||
|
||||
export interface IndexStats {
|
||||
numIndexedRows: number | null
|
||||
numUnindexedRows: number | null
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -502,6 +523,14 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
return res.metrics
|
||||
})
|
||||
}
|
||||
|
||||
async listIndices (): Promise<VectorIndex[]> {
|
||||
return tableListIndices.call(this._tbl)
|
||||
}
|
||||
|
||||
async indexStats (indexUuid: string): Promise<IndexStats> {
|
||||
return tableIndexStats.call(this._tbl, indexUuid)
|
||||
}
|
||||
}
|
||||
|
||||
export interface CleanupStats {
|
||||
|
||||
@@ -65,8 +65,8 @@ describe('LanceDB Mirrored Store Integration test', function () {
|
||||
const mirroredPath = path.join(dir, `${tableName}.lance`)
|
||||
fs.readdir(mirroredPath, { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
// there should be two dirs
|
||||
assert.equal(files.length, 2)
|
||||
// there should be three dirs
|
||||
assert.equal(files.length, 3)
|
||||
assert.isTrue(files[0].isDirectory())
|
||||
assert.isTrue(files[1].isDirectory())
|
||||
|
||||
@@ -76,6 +76,12 @@ describe('LanceDB Mirrored Store Integration test', function () {
|
||||
assert.isTrue(files[0].name.endsWith('.txn'))
|
||||
})
|
||||
|
||||
fs.readdir(path.join(mirroredPath, '_versions'), { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
assert.equal(files.length, 1)
|
||||
assert.isTrue(files[0].name.endsWith('.manifest'))
|
||||
})
|
||||
|
||||
fs.readdir(path.join(mirroredPath, 'data'), { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
assert.equal(files.length, 1)
|
||||
@@ -88,8 +94,8 @@ describe('LanceDB Mirrored Store Integration test', function () {
|
||||
|
||||
fs.readdir(mirroredPath, { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
// there should be two dirs
|
||||
assert.equal(files.length, 3)
|
||||
// there should be four dirs
|
||||
assert.equal(files.length, 4)
|
||||
assert.isTrue(files[0].isDirectory())
|
||||
assert.isTrue(files[1].isDirectory())
|
||||
assert.isTrue(files[2].isDirectory())
|
||||
@@ -128,12 +134,13 @@ describe('LanceDB Mirrored Store Integration test', function () {
|
||||
|
||||
fs.readdir(mirroredPath, { withFileTypes: true }, (err, files) => {
|
||||
if (err != null) throw err
|
||||
// there should be two dirs
|
||||
assert.equal(files.length, 4)
|
||||
// there should be five dirs
|
||||
assert.equal(files.length, 5)
|
||||
assert.isTrue(files[0].isDirectory())
|
||||
assert.isTrue(files[1].isDirectory())
|
||||
assert.isTrue(files[2].isDirectory())
|
||||
assert.isTrue(files[3].isDirectory())
|
||||
assert.isTrue(files[4].isDirectory())
|
||||
|
||||
// Three TXs now
|
||||
fs.readdir(path.join(mirroredPath, '_transactions'), { withFileTypes: true }, (err, files) => {
|
||||
|
||||
@@ -14,7 +14,9 @@
|
||||
|
||||
import {
|
||||
type EmbeddingFunction, type Table, type VectorIndexParams, type Connection,
|
||||
type ConnectionOptions, type CreateTableOptions, type WriteOptions
|
||||
type ConnectionOptions, type CreateTableOptions, type VectorIndex,
|
||||
type WriteOptions,
|
||||
type IndexStats
|
||||
} from '../index'
|
||||
import { Query } from '../query'
|
||||
|
||||
@@ -241,4 +243,21 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
async delete (filter: string): Promise<void> {
|
||||
await this._client.post(`/v1/table/${this._name}/delete/`, { predicate: filter })
|
||||
}
|
||||
|
||||
async listIndices (): Promise<VectorIndex[]> {
|
||||
const results = await this._client.post(`/v1/table/${this._name}/index/list/`)
|
||||
return results.data.indexes?.map((index: any) => ({
|
||||
columns: index.columns,
|
||||
name: index.index_name,
|
||||
uuid: index.index_uuid
|
||||
}))
|
||||
}
|
||||
|
||||
async indexStats (indexUuid: string): Promise<IndexStats> {
|
||||
const results = await this._client.post(`/v1/table/${this._name}/index/${indexUuid}/stats/`)
|
||||
return {
|
||||
numIndexedRows: results.data.num_indexed_rows,
|
||||
numUnindexedRows: results.data.num_unindexed_rows
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -328,6 +328,24 @@ describe('LanceDB client', function () {
|
||||
const createIndex = table.createIndex({ type: 'ivf_pq', column: 'name', num_partitions: -1, max_iters: 2, num_sub_vectors: 2 })
|
||||
await expect(createIndex).to.be.rejectedWith('num_partitions: must be > 0')
|
||||
})
|
||||
|
||||
it('should be able to list index and stats', async function () {
|
||||
const uri = await createTestDB(32, 300)
|
||||
const con = await lancedb.connect(uri)
|
||||
const table = await con.openTable('vectors')
|
||||
await table.createIndex({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 })
|
||||
|
||||
const indices = await table.listIndices()
|
||||
expect(indices).to.have.lengthOf(1)
|
||||
expect(indices[0].name).to.equal('vector_idx')
|
||||
expect(indices[0].uuid).to.not.be.equal(undefined)
|
||||
expect(indices[0].columns).to.have.lengthOf(1)
|
||||
expect(indices[0].columns[0]).to.equal('vector')
|
||||
|
||||
const stats = await table.indexStats(indices[0].uuid)
|
||||
expect(stats.numIndexedRows).to.equal(300)
|
||||
expect(stats.numUnindexedRows).to.equal(0)
|
||||
}).timeout(50_000)
|
||||
})
|
||||
|
||||
describe('when using a custom embedding function', function () {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.3.1
|
||||
current_version = 0.3.3
|
||||
commit = True
|
||||
message = [python] Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
@@ -84,7 +84,9 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
|
||||
context windows that don't cross document boundaries. In this case, we can
|
||||
pass ``document_id`` as the group by.
|
||||
|
||||
>>> contextualize(data).window(4).stride(2).text_col('token').groupby('document_id').to_pandas()
|
||||
>>> (contextualize(data)
|
||||
... .window(4).stride(2).text_col('token').groupby('document_id')
|
||||
... .to_pandas())
|
||||
token document_id
|
||||
0 The quick brown fox 1
|
||||
2 brown fox jumped over 1
|
||||
@@ -92,18 +94,24 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
|
||||
6 the lazy dog 1
|
||||
9 I love sandwiches 2
|
||||
|
||||
``min_window_size`` determines the minimum size of the context windows that are generated
|
||||
This can be used to trim the last few context windows which have size less than
|
||||
``min_window_size``. By default context windows of size 1 are skipped.
|
||||
``min_window_size`` determines the minimum size of the context windows
|
||||
that are generated.This can be used to trim the last few context windows
|
||||
which have size less than ``min_window_size``.
|
||||
By default context windows of size 1 are skipped.
|
||||
|
||||
>>> contextualize(data).window(6).stride(3).text_col('token').groupby('document_id').to_pandas()
|
||||
>>> (contextualize(data)
|
||||
... .window(6).stride(3).text_col('token').groupby('document_id')
|
||||
... .to_pandas())
|
||||
token document_id
|
||||
0 The quick brown fox jumped over 1
|
||||
3 fox jumped over the lazy dog 1
|
||||
6 the lazy dog 1
|
||||
9 I love sandwiches 2
|
||||
|
||||
>>> contextualize(data).window(6).stride(3).min_window_size(4).text_col('token').groupby('document_id').to_pandas()
|
||||
>>> (contextualize(data)
|
||||
... .window(6).stride(3).min_window_size(4).text_col('token')
|
||||
... .groupby('document_id')
|
||||
... .to_pandas())
|
||||
token document_id
|
||||
0 The quick brown fox jumped over 1
|
||||
3 fox jumped over the lazy dog 1
|
||||
@@ -113,7 +121,9 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
|
||||
|
||||
|
||||
class Contextualizer:
|
||||
"""Create context windows from a DataFrame. See [lancedb.context.contextualize][]."""
|
||||
"""Create context windows from a DataFrame.
|
||||
See [lancedb.context.contextualize][].
|
||||
"""
|
||||
|
||||
def __init__(self, raw_df):
|
||||
self._text_col = None
|
||||
@@ -183,7 +193,7 @@ class Contextualizer:
|
||||
deprecated_in="0.3.1",
|
||||
removed_in="0.4.0",
|
||||
current_version=__version__,
|
||||
details="Use the bar function instead",
|
||||
details="Use to_pandas() instead",
|
||||
)
|
||||
def to_df(self) -> "pd.DataFrame":
|
||||
return self.to_pandas()
|
||||
|
||||
@@ -52,12 +52,24 @@ class DBConnection(ABC):
|
||||
----------
|
||||
name: str
|
||||
The name of the table.
|
||||
data: list, tuple, dict, pd.DataFrame; optional
|
||||
The data to initialize the table. User must provide at least one of `data` or `schema`.
|
||||
schema: pyarrow.Schema or LanceModel; optional
|
||||
The schema of the table.
|
||||
data: The data to initialize the table, *optional*
|
||||
User must provide at least one of `data` or `schema`.
|
||||
Acceptable types are:
|
||||
|
||||
- dict or list-of-dict
|
||||
|
||||
- pandas.DataFrame
|
||||
|
||||
- pyarrow.Table or pyarrow.RecordBatch
|
||||
schema: The schema of the table, *optional*
|
||||
Acceptable types are:
|
||||
|
||||
- pyarrow.Schema
|
||||
|
||||
- [LanceModel][lancedb.pydantic.LanceModel]
|
||||
mode: str; default "create"
|
||||
The mode to use when creating the table. Can be either "create" or "overwrite".
|
||||
The mode to use when creating the table.
|
||||
Can be either "create" or "overwrite".
|
||||
By default, if the table already exists, an exception is raised.
|
||||
If you want to overwrite the table, use mode="overwrite".
|
||||
on_bad_vectors: str, default "error"
|
||||
@@ -150,7 +162,8 @@ class DBConnection(ABC):
|
||||
... for i in range(5):
|
||||
... yield pa.RecordBatch.from_arrays(
|
||||
... [
|
||||
... pa.array([[3.1, 4.1], [5.9, 26.5]], pa.list_(pa.float32(), 2)),
|
||||
... pa.array([[3.1, 4.1], [5.9, 26.5]],
|
||||
... pa.list_(pa.float32(), 2)),
|
||||
... pa.array(["foo", "bar"]),
|
||||
... pa.array([10.0, 20.0]),
|
||||
... ],
|
||||
@@ -250,7 +263,7 @@ class LanceDBConnection(DBConnection):
|
||||
return self._uri
|
||||
|
||||
def table_names(self) -> list[str]:
|
||||
"""Get the names of all tables in the database.
|
||||
"""Get the names of all tables in the database. The names are sorted.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -274,6 +287,7 @@ class LanceDBConnection(DBConnection):
|
||||
for file_info in paths
|
||||
if file_info.extension == "lance"
|
||||
]
|
||||
tables.sort()
|
||||
return tables
|
||||
|
||||
def __len__(self) -> int:
|
||||
|
||||
@@ -11,16 +11,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction
|
||||
from .cohere import CohereEmbeddingFunction
|
||||
from .functions import (
|
||||
EmbeddingFunction,
|
||||
EmbeddingFunctionConfig,
|
||||
EmbeddingFunctionRegistry,
|
||||
OpenAIEmbeddings,
|
||||
OpenClipEmbeddings,
|
||||
SentenceTransformerEmbeddings,
|
||||
TextEmbeddingFunction,
|
||||
register,
|
||||
)
|
||||
from .open_clip import OpenClipEmbeddings
|
||||
from .openai import OpenAIEmbeddings
|
||||
from .registry import EmbeddingFunctionRegistry, get_registry
|
||||
from .sentence_transformers import SentenceTransformerEmbeddings
|
||||
from .utils import with_embeddings
|
||||
|
||||
138
python/lancedb/embeddings/base.py
Normal file
138
python/lancedb/embeddings/base.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import importlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from .utils import TEXT
|
||||
|
||||
|
||||
class EmbeddingFunction(BaseModel, ABC):
|
||||
"""
|
||||
An ABC for embedding functions.
|
||||
|
||||
All concrete embedding functions must implement the following:
|
||||
1. compute_query_embeddings() which takes a query and returns a list of embeddings
|
||||
2. get_source_embeddings() which returns a list of embeddings for the source column
|
||||
For text data, the two will be the same. For multi-modal data, the source column
|
||||
might be images and the vector column might be text.
|
||||
3. ndims method which returns the number of dimensions of the vector column
|
||||
"""
|
||||
|
||||
_ndims: int = PrivateAttr()
|
||||
|
||||
@classmethod
|
||||
def create(cls, **kwargs):
|
||||
"""
|
||||
Create an instance of the embedding function
|
||||
"""
|
||||
return cls(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for the source column in the database
|
||||
"""
|
||||
pass
|
||||
|
||||
def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
elif isinstance(texts, pa.Array):
|
||||
texts = texts.to_pylist()
|
||||
elif isinstance(texts, pa.ChunkedArray):
|
||||
texts = texts.combine_chunks().to_pylist()
|
||||
return texts
|
||||
|
||||
@classmethod
|
||||
def safe_import(cls, module: str, mitigation=None):
|
||||
"""
|
||||
Import the specified module. If the module is not installed,
|
||||
raise an ImportError with a helpful message.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
module : str
|
||||
The name of the module to import
|
||||
mitigation : Optional[str]
|
||||
The package(s) to install to mitigate the error.
|
||||
If not provided then the module name will be used.
|
||||
"""
|
||||
try:
|
||||
return importlib.import_module(module)
|
||||
except ImportError:
|
||||
raise ImportError(f"Please install {mitigation or module}")
|
||||
|
||||
def safe_model_dump(self):
|
||||
from ..pydantic import PYDANTIC_VERSION
|
||||
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
return dict(self)
|
||||
return self.model_dump()
|
||||
|
||||
@abstractmethod
|
||||
def ndims(self):
|
||||
"""
|
||||
Return the dimensions of the vector column
|
||||
"""
|
||||
pass
|
||||
|
||||
def SourceField(self, **kwargs):
|
||||
"""
|
||||
Creates a pydantic Field that can automatically annotate
|
||||
the source column for this embedding function
|
||||
"""
|
||||
return Field(json_schema_extra={"source_column_for": self}, **kwargs)
|
||||
|
||||
def VectorField(self, **kwargs):
|
||||
"""
|
||||
Creates a pydantic Field that can automatically annotate
|
||||
the target vector column for this embedding function
|
||||
"""
|
||||
return Field(json_schema_extra={"vector_column_for": self}, **kwargs)
|
||||
|
||||
|
||||
class EmbeddingFunctionConfig(BaseModel):
|
||||
"""
|
||||
This model encapsulates the configuration for a embedding function
|
||||
in a lancedb table. It holds the embedding function, the source column,
|
||||
and the vector column
|
||||
"""
|
||||
|
||||
vector_column: str
|
||||
source_column: str
|
||||
function: EmbeddingFunction
|
||||
|
||||
|
||||
class TextEmbeddingFunction(EmbeddingFunction):
|
||||
"""
|
||||
A callable ABC for embedding functions that take text as input
|
||||
"""
|
||||
|
||||
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
return self.compute_source_embeddings(query, *args, **kwargs)
|
||||
|
||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
texts = self.sanitize_input(texts)
|
||||
return self.generate_embeddings(texts)
|
||||
|
||||
@abstractmethod
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Generate the embeddings for the given texts
|
||||
"""
|
||||
pass
|
||||
@@ -16,7 +16,8 @@ from typing import ClassVar, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .functions import TextEmbeddingFunction, register
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import api_key_not_found_help
|
||||
|
||||
|
||||
|
||||
@@ -1,578 +0,0 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import concurrent.futures
|
||||
import importlib
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import urllib.error
|
||||
import urllib.parse as urlparse
|
||||
import urllib.request
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from cachetools import cached
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class EmbeddingFunctionRegistry:
|
||||
"""
|
||||
This is a singleton class used to register embedding functions
|
||||
and fetch them by name. It also handles serializing and deserializing.
|
||||
You can implement your own embedding function by subclassing EmbeddingFunction
|
||||
or TextEmbeddingFunction and registering it with the registry.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> registry = EmbeddingFunctionRegistry.get_instance()
|
||||
>>> @registry.register("my-embedding-function")
|
||||
... class MyEmbeddingFunction(EmbeddingFunction):
|
||||
... def ndims(self) -> int:
|
||||
... return 128
|
||||
...
|
||||
... def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
... return self.compute_source_embeddings(query, *args, **kwargs)
|
||||
...
|
||||
... def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
... return [np.random.rand(self.ndims()) for _ in range(len(texts))]
|
||||
...
|
||||
>>> registry.get("my-embedding-function")
|
||||
<class 'lancedb.embeddings.functions.MyEmbeddingFunction'>
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
return __REGISTRY__
|
||||
|
||||
def __init__(self):
|
||||
self._functions = {}
|
||||
|
||||
def register(self, alias: str = None):
|
||||
"""
|
||||
This creates a decorator that can be used to register
|
||||
an EmbeddingFunction.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
alias : Optional[str]
|
||||
a human friendly name for the embedding function. If not
|
||||
provided, the class name will be used.
|
||||
"""
|
||||
|
||||
# This is a decorator for a class that inherits from BaseModel
|
||||
# It adds the class to the registry
|
||||
def decorator(cls):
|
||||
if not issubclass(cls, EmbeddingFunction):
|
||||
raise TypeError("Must be a subclass of EmbeddingFunction")
|
||||
if cls.__name__ in self._functions:
|
||||
raise KeyError(f"{cls.__name__} was already registered")
|
||||
key = alias or cls.__name__
|
||||
self._functions[key] = cls
|
||||
cls.__embedding_function_registry_alias__ = alias
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the registry to its initial state
|
||||
"""
|
||||
self._functions = {}
|
||||
|
||||
def get(self, name: str):
|
||||
"""
|
||||
Fetch an embedding function class by name
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The name of the embedding function to fetch
|
||||
Either the alias or the class name if no alias was provided
|
||||
during registration
|
||||
"""
|
||||
return self._functions[name]
|
||||
|
||||
def parse_functions(
|
||||
self, metadata: Optional[Dict[bytes, bytes]]
|
||||
) -> Dict[str, "EmbeddingFunctionConfig"]:
|
||||
"""
|
||||
Parse the metadata from an arrow table and
|
||||
return a mapping of the vector column to the
|
||||
embedding function and source column
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metadata : Optional[Dict[bytes, bytes]]
|
||||
The metadata from an arrow table. Note that
|
||||
the keys and values are bytes (pyarrow api)
|
||||
|
||||
Returns
|
||||
-------
|
||||
functions : dict
|
||||
A mapping of vector column name to embedding function.
|
||||
An empty dict is returned if input is None or does not
|
||||
contain b"embedding_functions".
|
||||
"""
|
||||
if metadata is None or b"embedding_functions" not in metadata:
|
||||
return {}
|
||||
serialized = metadata[b"embedding_functions"]
|
||||
raw_list = json.loads(serialized.decode("utf-8"))
|
||||
return {
|
||||
obj["vector_column"]: EmbeddingFunctionConfig(
|
||||
vector_column=obj["vector_column"],
|
||||
source_column=obj["source_column"],
|
||||
function=self.get(obj["name"])(**obj["model"]),
|
||||
)
|
||||
for obj in raw_list
|
||||
}
|
||||
|
||||
def function_to_metadata(self, conf: "EmbeddingFunctionConfig"):
|
||||
"""
|
||||
Convert the given embedding function and source / vector column configs
|
||||
into a config dictionary that can be serialized into arrow metadata
|
||||
"""
|
||||
func = conf.function
|
||||
name = getattr(
|
||||
func, "__embedding_function_registry_alias__", func.__class__.__name__
|
||||
)
|
||||
json_data = func.safe_model_dump()
|
||||
return {
|
||||
"name": name,
|
||||
"model": json_data,
|
||||
"source_column": conf.source_column,
|
||||
"vector_column": conf.vector_column,
|
||||
}
|
||||
|
||||
def get_table_metadata(self, func_list):
|
||||
"""
|
||||
Convert a list of embedding functions and source / vector configs
|
||||
into a config dictionary that can be serialized into arrow metadata
|
||||
"""
|
||||
if func_list is None or len(func_list) == 0:
|
||||
return None
|
||||
json_data = [self.function_to_metadata(func) for func in func_list]
|
||||
# Note that metadata dictionary values must be bytes
|
||||
# so we need to json dump then utf8 encode
|
||||
metadata = json.dumps(json_data, indent=2).encode("utf-8")
|
||||
return {"embedding_functions": metadata}
|
||||
|
||||
|
||||
# Global instance
|
||||
__REGISTRY__ = EmbeddingFunctionRegistry()
|
||||
|
||||
|
||||
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
IMAGES = Union[
|
||||
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
|
||||
]
|
||||
|
||||
|
||||
class EmbeddingFunction(BaseModel, ABC):
|
||||
"""
|
||||
An ABC for embedding functions.
|
||||
|
||||
All concrete embedding functions must implement the following:
|
||||
1. compute_query_embeddings() which takes a query and returns a list of embeddings
|
||||
2. get_source_embeddings() which returns a list of embeddings for the source column
|
||||
For text data, the two will be the same. For multi-modal data, the source column
|
||||
might be images and the vector column might be text.
|
||||
3. ndims method which returns the number of dimensions of the vector column
|
||||
"""
|
||||
|
||||
_ndims: int = PrivateAttr()
|
||||
|
||||
@classmethod
|
||||
def create(cls, **kwargs):
|
||||
"""
|
||||
Create an instance of the embedding function
|
||||
"""
|
||||
return cls(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for the source column in the database
|
||||
"""
|
||||
pass
|
||||
|
||||
def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
elif isinstance(texts, pa.Array):
|
||||
texts = texts.to_pylist()
|
||||
elif isinstance(texts, pa.ChunkedArray):
|
||||
texts = texts.combine_chunks().to_pylist()
|
||||
return texts
|
||||
|
||||
@classmethod
|
||||
def safe_import(cls, module: str, mitigation=None):
|
||||
"""
|
||||
Import the specified module. If the module is not installed,
|
||||
raise an ImportError with a helpful message.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
module : str
|
||||
The name of the module to import
|
||||
mitigation : Optional[str]
|
||||
The package(s) to install to mitigate the error.
|
||||
If not provided then the module name will be used.
|
||||
"""
|
||||
try:
|
||||
return importlib.import_module(module)
|
||||
except ImportError:
|
||||
raise ImportError(f"Please install {mitigation or module}")
|
||||
|
||||
def safe_model_dump(self):
|
||||
from ..pydantic import PYDANTIC_VERSION
|
||||
|
||||
if PYDANTIC_VERSION.major < 2:
|
||||
return dict(self)
|
||||
return self.model_dump()
|
||||
|
||||
@abstractmethod
|
||||
def ndims(self):
|
||||
"""
|
||||
Return the dimensions of the vector column
|
||||
"""
|
||||
pass
|
||||
|
||||
def SourceField(self, **kwargs):
|
||||
"""
|
||||
Creates a pydantic Field that can automatically annotate
|
||||
the source column for this embedding function
|
||||
"""
|
||||
return Field(json_schema_extra={"source_column_for": self}, **kwargs)
|
||||
|
||||
def VectorField(self, **kwargs):
|
||||
"""
|
||||
Creates a pydantic Field that can automatically annotate
|
||||
the target vector column for this embedding function
|
||||
"""
|
||||
return Field(json_schema_extra={"vector_column_for": self}, **kwargs)
|
||||
|
||||
|
||||
class EmbeddingFunctionConfig(BaseModel):
|
||||
"""
|
||||
This model encapsulates the configuration for a embedding function
|
||||
in a lancedb table. It holds the embedding function, the source column,
|
||||
and the vector column
|
||||
"""
|
||||
|
||||
vector_column: str
|
||||
source_column: str
|
||||
function: EmbeddingFunction
|
||||
|
||||
|
||||
class TextEmbeddingFunction(EmbeddingFunction):
|
||||
"""
|
||||
A callable ABC for embedding functions that take text as input
|
||||
"""
|
||||
|
||||
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
return self.compute_source_embeddings(query, *args, **kwargs)
|
||||
|
||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
texts = self.sanitize_input(texts)
|
||||
return self.generate_embeddings(texts)
|
||||
|
||||
@abstractmethod
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Generate the embeddings for the given texts
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8
|
||||
register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name)
|
||||
|
||||
|
||||
@register("sentence-transformers")
|
||||
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the sentence-transformers library
|
||||
|
||||
https://huggingface.co/sentence-transformers
|
||||
"""
|
||||
|
||||
name: str = "all-MiniLM-L6-v2"
|
||||
device: str = "cpu"
|
||||
normalize: bool = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._ndims = None
|
||||
|
||||
@property
|
||||
def embedding_model(self):
|
||||
"""
|
||||
Get the sentence-transformers embedding model specified by the
|
||||
name and device. This is cached so that the model is only loaded
|
||||
once per process.
|
||||
"""
|
||||
return self.__class__.get_embedding_model(self.name, self.device)
|
||||
|
||||
def ndims(self):
|
||||
if self._ndims is None:
|
||||
self._ndims = len(self.generate_embeddings("foo")[0])
|
||||
return self._ndims
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
return self.embedding_model.encode(
|
||||
list(texts),
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=self.normalize,
|
||||
).tolist()
|
||||
|
||||
@classmethod
|
||||
@cached(cache={})
|
||||
def get_embedding_model(cls, name, device):
|
||||
"""
|
||||
Get the sentence-transformers embedding model specified by the
|
||||
name and device. This is cached so that the model is only loaded
|
||||
once per process.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The name of the model to load
|
||||
device : str
|
||||
The device to load the model on
|
||||
|
||||
TODO: use lru_cache instead with a reasonable/configurable maxsize
|
||||
"""
|
||||
sentence_transformers = cls.safe_import(
|
||||
"sentence_transformers", "sentence-transformers"
|
||||
)
|
||||
return sentence_transformers.SentenceTransformer(name, device=device)
|
||||
|
||||
|
||||
@register("openai")
|
||||
class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the OpenAI API
|
||||
|
||||
https://platform.openai.com/docs/guides/embeddings
|
||||
"""
|
||||
|
||||
name: str = "text-embedding-ada-002"
|
||||
|
||||
def ndims(self):
|
||||
# TODO don't hardcode this
|
||||
return 1536
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
# TODO retry, rate limit, token limit
|
||||
openai = self.safe_import("openai")
|
||||
rs = openai.Embedding.create(input=texts, model=self.name)["data"]
|
||||
return [v["embedding"] for v in rs]
|
||||
|
||||
|
||||
@register("open-clip")
|
||||
class OpenClipEmbeddings(EmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the OpenClip API
|
||||
For multi-modal text-to-image search
|
||||
|
||||
https://github.com/mlfoundations/open_clip
|
||||
"""
|
||||
|
||||
name: str = "ViT-B-32"
|
||||
pretrained: str = "laion2b_s34b_b79k"
|
||||
device: str = "cpu"
|
||||
batch_size: int = 64
|
||||
normalize: bool = True
|
||||
_model = PrivateAttr()
|
||||
_preprocess = PrivateAttr()
|
||||
_tokenizer = PrivateAttr()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
open_clip = self.safe_import("open_clip", "open-clip")
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||
self.name, pretrained=self.pretrained
|
||||
)
|
||||
model.to(self.device)
|
||||
self._model, self._preprocess = model, preprocess
|
||||
self._tokenizer = open_clip.get_tokenizer(self.name)
|
||||
self._ndims = None
|
||||
|
||||
def ndims(self):
|
||||
if self._ndims is None:
|
||||
self._ndims = self.generate_text_embeddings("foo").shape[0]
|
||||
return self._ndims
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : Union[str, PIL.Image.Image]
|
||||
The query to embed. A query can be either text or an image.
|
||||
"""
|
||||
if isinstance(query, str):
|
||||
return [self.generate_text_embeddings(query)]
|
||||
else:
|
||||
PIL = self.safe_import("PIL", "pillow")
|
||||
if isinstance(query, PIL.Image.Image):
|
||||
return [self.generate_image_embedding(query)]
|
||||
else:
|
||||
raise TypeError("OpenClip supports str or PIL Image as query")
|
||||
|
||||
def generate_text_embeddings(self, text: str) -> np.ndarray:
|
||||
torch = self.safe_import("torch")
|
||||
text = self.sanitize_input(text)
|
||||
text = self._tokenizer(text)
|
||||
text.to(self.device)
|
||||
with torch.no_grad():
|
||||
text_features = self._model.encode_text(text.to(self.device))
|
||||
if self.normalize:
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
return text_features.cpu().numpy().squeeze()
|
||||
|
||||
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(images, (str, bytes)):
|
||||
images = [images]
|
||||
elif isinstance(images, pa.Array):
|
||||
images = images.to_pylist()
|
||||
elif isinstance(images, pa.ChunkedArray):
|
||||
images = images.combine_chunks().to_pylist()
|
||||
return images
|
||||
|
||||
def compute_source_embeddings(
|
||||
self, images: IMAGES, *args, **kwargs
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given images
|
||||
"""
|
||||
images = self.sanitize_input(images)
|
||||
embeddings = []
|
||||
for i in range(0, len(images), self.batch_size):
|
||||
j = min(i + self.batch_size, len(images))
|
||||
batch = images[i:j]
|
||||
embeddings.extend(self._parallel_get(batch))
|
||||
return embeddings
|
||||
|
||||
def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
|
||||
"""
|
||||
Issue concurrent requests to retrieve the image data
|
||||
"""
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(self.generate_image_embedding, image)
|
||||
for image in images
|
||||
]
|
||||
return [future.result() for future in tqdm(futures)]
|
||||
|
||||
def generate_image_embedding(
|
||||
self, image: Union[str, bytes, "PIL.Image.Image"]
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate the embedding for a single image
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image : Union[str, bytes, PIL.Image.Image]
|
||||
The image to embed. If the image is a str, it is treated as a uri.
|
||||
If the image is bytes, it is treated as the raw image bytes.
|
||||
"""
|
||||
torch = self.safe_import("torch")
|
||||
# TODO handle retry and errors for https
|
||||
image = self._to_pil(image)
|
||||
image = self._preprocess(image).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
return self._encode_and_normalize_image(image)
|
||||
|
||||
def _to_pil(self, image: Union[str, bytes]):
|
||||
PIL = self.safe_import("PIL", "pillow")
|
||||
if isinstance(image, bytes):
|
||||
return PIL.Image.open(io.BytesIO(image))
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
return image
|
||||
elif isinstance(image, str):
|
||||
parsed = urlparse.urlparse(image)
|
||||
# TODO handle drive letter on windows.
|
||||
if parsed.scheme == "file":
|
||||
return PIL.Image.open(parsed.path)
|
||||
elif parsed.scheme == "":
|
||||
return PIL.Image.open(image if os.name == "nt" else parsed.path)
|
||||
elif parsed.scheme.startswith("http"):
|
||||
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
|
||||
else:
|
||||
raise NotImplementedError("Only local and http(s) urls are supported")
|
||||
|
||||
def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"):
|
||||
"""
|
||||
encode a single image tensor and optionally normalize the output
|
||||
"""
|
||||
image_features = self._model.encode_image(image_tensor.to(self.device))
|
||||
if self.normalize:
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
return image_features.cpu().numpy().squeeze()
|
||||
|
||||
|
||||
def url_retrieve(url: str):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
url: str
|
||||
URL to download from
|
||||
"""
|
||||
try:
|
||||
with urllib.request.urlopen(url) as conn:
|
||||
return conn.read()
|
||||
except (socket.gaierror, urllib.error.URLError) as err:
|
||||
raise ConnectionError("could not download {} due to {}".format(url, err))
|
||||
163
python/lancedb/embeddings/open_clip.py
Normal file
163
python/lancedb/embeddings/open_clip.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import concurrent.futures
|
||||
import io
|
||||
import os
|
||||
import urllib.parse as urlparse
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from pydantic import PrivateAttr
|
||||
from tqdm import tqdm
|
||||
|
||||
from .base import EmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import IMAGES, url_retrieve
|
||||
|
||||
|
||||
@register("open-clip")
|
||||
class OpenClipEmbeddings(EmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the OpenClip API
|
||||
For multi-modal text-to-image search
|
||||
|
||||
https://github.com/mlfoundations/open_clip
|
||||
"""
|
||||
|
||||
name: str = "ViT-B-32"
|
||||
pretrained: str = "laion2b_s34b_b79k"
|
||||
device: str = "cpu"
|
||||
batch_size: int = 64
|
||||
normalize: bool = True
|
||||
_model = PrivateAttr()
|
||||
_preprocess = PrivateAttr()
|
||||
_tokenizer = PrivateAttr()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
open_clip = self.safe_import("open_clip", "open-clip")
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||
self.name, pretrained=self.pretrained
|
||||
)
|
||||
model.to(self.device)
|
||||
self._model, self._preprocess = model, preprocess
|
||||
self._tokenizer = open_clip.get_tokenizer(self.name)
|
||||
self._ndims = None
|
||||
|
||||
def ndims(self):
|
||||
if self._ndims is None:
|
||||
self._ndims = self.generate_text_embeddings("foo").shape[0]
|
||||
return self._ndims
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : Union[str, PIL.Image.Image]
|
||||
The query to embed. A query can be either text or an image.
|
||||
"""
|
||||
if isinstance(query, str):
|
||||
return [self.generate_text_embeddings(query)]
|
||||
else:
|
||||
PIL = self.safe_import("PIL", "pillow")
|
||||
if isinstance(query, PIL.Image.Image):
|
||||
return [self.generate_image_embedding(query)]
|
||||
else:
|
||||
raise TypeError("OpenClip supports str or PIL Image as query")
|
||||
|
||||
def generate_text_embeddings(self, text: str) -> np.ndarray:
|
||||
torch = self.safe_import("torch")
|
||||
text = self.sanitize_input(text)
|
||||
text = self._tokenizer(text)
|
||||
text.to(self.device)
|
||||
with torch.no_grad():
|
||||
text_features = self._model.encode_text(text.to(self.device))
|
||||
if self.normalize:
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
return text_features.cpu().numpy().squeeze()
|
||||
|
||||
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(images, (str, bytes)):
|
||||
images = [images]
|
||||
elif isinstance(images, pa.Array):
|
||||
images = images.to_pylist()
|
||||
elif isinstance(images, pa.ChunkedArray):
|
||||
images = images.combine_chunks().to_pylist()
|
||||
return images
|
||||
|
||||
def compute_source_embeddings(
|
||||
self, images: IMAGES, *args, **kwargs
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given images
|
||||
"""
|
||||
images = self.sanitize_input(images)
|
||||
embeddings = []
|
||||
for i in range(0, len(images), self.batch_size):
|
||||
j = min(i + self.batch_size, len(images))
|
||||
batch = images[i:j]
|
||||
embeddings.extend(self._parallel_get(batch))
|
||||
return embeddings
|
||||
|
||||
def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
|
||||
"""
|
||||
Issue concurrent requests to retrieve the image data
|
||||
"""
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(self.generate_image_embedding, image)
|
||||
for image in images
|
||||
]
|
||||
return [future.result() for future in tqdm(futures)]
|
||||
|
||||
def generate_image_embedding(
|
||||
self, image: Union[str, bytes, "PIL.Image.Image"]
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate the embedding for a single image
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image : Union[str, bytes, PIL.Image.Image]
|
||||
The image to embed. If the image is a str, it is treated as a uri.
|
||||
If the image is bytes, it is treated as the raw image bytes.
|
||||
"""
|
||||
torch = self.safe_import("torch")
|
||||
# TODO handle retry and errors for https
|
||||
image = self._to_pil(image)
|
||||
image = self._preprocess(image).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
return self._encode_and_normalize_image(image)
|
||||
|
||||
def _to_pil(self, image: Union[str, bytes]):
|
||||
PIL = self.safe_import("PIL", "pillow")
|
||||
if isinstance(image, bytes):
|
||||
return PIL.Image.open(io.BytesIO(image))
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
return image
|
||||
elif isinstance(image, str):
|
||||
parsed = urlparse.urlparse(image)
|
||||
# TODO handle drive letter on windows.
|
||||
if parsed.scheme == "file":
|
||||
return PIL.Image.open(parsed.path)
|
||||
elif parsed.scheme == "":
|
||||
return PIL.Image.open(image if os.name == "nt" else parsed.path)
|
||||
elif parsed.scheme.startswith("http"):
|
||||
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
|
||||
else:
|
||||
raise NotImplementedError("Only local and http(s) urls are supported")
|
||||
|
||||
def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"):
|
||||
"""
|
||||
encode a single image tensor and optionally normalize the output
|
||||
"""
|
||||
image_features = self._model.encode_image(image_tensor.to(self.device))
|
||||
if self.normalize:
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
return image_features.cpu().numpy().squeeze()
|
||||
37
python/lancedb/embeddings/openai.py
Normal file
37
python/lancedb/embeddings/openai.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
|
||||
|
||||
@register("openai")
|
||||
class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the OpenAI API
|
||||
|
||||
https://platform.openai.com/docs/guides/embeddings
|
||||
"""
|
||||
|
||||
name: str = "text-embedding-ada-002"
|
||||
|
||||
def ndims(self):
|
||||
# TODO don't hardcode this
|
||||
return 1536
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
# TODO retry, rate limit, token limit
|
||||
openai = self.safe_import("openai")
|
||||
rs = openai.Embedding.create(input=texts, model=self.name)["data"]
|
||||
return [v["embedding"] for v in rs]
|
||||
186
python/lancedb/embeddings/registry.py
Normal file
186
python/lancedb/embeddings/registry.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
from typing import Dict, Optional
|
||||
|
||||
from .base import EmbeddingFunction, EmbeddingFunctionConfig
|
||||
|
||||
|
||||
class EmbeddingFunctionRegistry:
|
||||
"""
|
||||
This is a singleton class used to register embedding functions
|
||||
and fetch them by name. It also handles serializing and deserializing.
|
||||
You can implement your own embedding function by subclassing EmbeddingFunction
|
||||
or TextEmbeddingFunction and registering it with the registry.
|
||||
|
||||
NOTE: Here TEXT is a type alias for Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
Examples
|
||||
--------
|
||||
>>> registry = EmbeddingFunctionRegistry.get_instance()
|
||||
>>> @registry.register("my-embedding-function")
|
||||
... class MyEmbeddingFunction(EmbeddingFunction):
|
||||
... def ndims(self) -> int:
|
||||
... return 128
|
||||
...
|
||||
... def compute_query_embeddings(self, query: str, *args, **kwargs):
|
||||
... return self.compute_source_embeddings(query, *args, **kwargs)
|
||||
...
|
||||
... def compute_source_embeddings(self, texts, *args, **kwargs):
|
||||
... return [np.random.rand(self.ndims()) for _ in range(len(texts))]
|
||||
...
|
||||
>>> registry.get("my-embedding-function")
|
||||
<class 'lancedb.embeddings.registry.MyEmbeddingFunction'>
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
return __REGISTRY__
|
||||
|
||||
def __init__(self):
|
||||
self._functions = {}
|
||||
|
||||
def register(self, alias: str = None):
|
||||
"""
|
||||
This creates a decorator that can be used to register
|
||||
an EmbeddingFunction.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
alias : Optional[str]
|
||||
a human friendly name for the embedding function. If not
|
||||
provided, the class name will be used.
|
||||
"""
|
||||
|
||||
# This is a decorator for a class that inherits from BaseModel
|
||||
# It adds the class to the registry
|
||||
def decorator(cls):
|
||||
if not issubclass(cls, EmbeddingFunction):
|
||||
raise TypeError("Must be a subclass of EmbeddingFunction")
|
||||
if cls.__name__ in self._functions:
|
||||
raise KeyError(f"{cls.__name__} was already registered")
|
||||
key = alias or cls.__name__
|
||||
self._functions[key] = cls
|
||||
cls.__embedding_function_registry_alias__ = alias
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the registry to its initial state
|
||||
"""
|
||||
self._functions = {}
|
||||
|
||||
def get(self, name: str):
|
||||
"""
|
||||
Fetch an embedding function class by name
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The name of the embedding function to fetch
|
||||
Either the alias or the class name if no alias was provided
|
||||
during registration
|
||||
"""
|
||||
return self._functions[name]
|
||||
|
||||
def parse_functions(
|
||||
self, metadata: Optional[Dict[bytes, bytes]]
|
||||
) -> Dict[str, "EmbeddingFunctionConfig"]:
|
||||
"""
|
||||
Parse the metadata from an arrow table and
|
||||
return a mapping of the vector column to the
|
||||
embedding function and source column
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metadata : Optional[Dict[bytes, bytes]]
|
||||
The metadata from an arrow table. Note that
|
||||
the keys and values are bytes (pyarrow api)
|
||||
|
||||
Returns
|
||||
-------
|
||||
functions : dict
|
||||
A mapping of vector column name to embedding function.
|
||||
An empty dict is returned if input is None or does not
|
||||
contain b"embedding_functions".
|
||||
"""
|
||||
if metadata is None or b"embedding_functions" not in metadata:
|
||||
return {}
|
||||
serialized = metadata[b"embedding_functions"]
|
||||
raw_list = json.loads(serialized.decode("utf-8"))
|
||||
return {
|
||||
obj["vector_column"]: EmbeddingFunctionConfig(
|
||||
vector_column=obj["vector_column"],
|
||||
source_column=obj["source_column"],
|
||||
function=self.get(obj["name"])(**obj["model"]),
|
||||
)
|
||||
for obj in raw_list
|
||||
}
|
||||
|
||||
def function_to_metadata(self, conf: "EmbeddingFunctionConfig"):
|
||||
"""
|
||||
Convert the given embedding function and source / vector column configs
|
||||
into a config dictionary that can be serialized into arrow metadata
|
||||
"""
|
||||
func = conf.function
|
||||
name = getattr(
|
||||
func, "__embedding_function_registry_alias__", func.__class__.__name__
|
||||
)
|
||||
json_data = func.safe_model_dump()
|
||||
return {
|
||||
"name": name,
|
||||
"model": json_data,
|
||||
"source_column": conf.source_column,
|
||||
"vector_column": conf.vector_column,
|
||||
}
|
||||
|
||||
def get_table_metadata(self, func_list):
|
||||
"""
|
||||
Convert a list of embedding functions and source / vector configs
|
||||
into a config dictionary that can be serialized into arrow metadata
|
||||
"""
|
||||
if func_list is None or len(func_list) == 0:
|
||||
return None
|
||||
json_data = [self.function_to_metadata(func) for func in func_list]
|
||||
# Note that metadata dictionary values must be bytes
|
||||
# so we need to json dump then utf8 encode
|
||||
metadata = json.dumps(json_data, indent=2).encode("utf-8")
|
||||
return {"embedding_functions": metadata}
|
||||
|
||||
|
||||
# Global instance
|
||||
__REGISTRY__ = EmbeddingFunctionRegistry()
|
||||
|
||||
|
||||
# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8
|
||||
register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name)
|
||||
|
||||
|
||||
def get_registry():
|
||||
"""
|
||||
Utility function to get the global instance of the registry
|
||||
|
||||
Returns
|
||||
-------
|
||||
EmbeddingFunctionRegistry
|
||||
The global registry instance
|
||||
|
||||
Examples
|
||||
--------
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
openai = registry.get("openai").create()
|
||||
"""
|
||||
return __REGISTRY__.get_instance()
|
||||
77
python/lancedb/embeddings/sentence_transformers.py
Normal file
77
python/lancedb/embeddings/sentence_transformers.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
from cachetools import cached
|
||||
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
|
||||
|
||||
@register("sentence-transformers")
|
||||
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the sentence-transformers library
|
||||
|
||||
https://huggingface.co/sentence-transformers
|
||||
"""
|
||||
|
||||
name: str = "all-MiniLM-L6-v2"
|
||||
device: str = "cpu"
|
||||
normalize: bool = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._ndims = None
|
||||
|
||||
@property
|
||||
def embedding_model(self):
|
||||
"""
|
||||
Get the sentence-transformers embedding model specified by the
|
||||
name and device. This is cached so that the model is only loaded
|
||||
once per process.
|
||||
"""
|
||||
return self.__class__.get_embedding_model(self.name, self.device)
|
||||
|
||||
def ndims(self):
|
||||
if self._ndims is None:
|
||||
self._ndims = len(self.generate_embeddings("foo")[0])
|
||||
return self._ndims
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
return self.embedding_model.encode(
|
||||
list(texts),
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=self.normalize,
|
||||
).tolist()
|
||||
|
||||
@classmethod
|
||||
@cached(cache={})
|
||||
def get_embedding_model(cls, name, device):
|
||||
"""
|
||||
Get the sentence-transformers embedding model specified by the
|
||||
name and device. This is cached so that the model is only loaded
|
||||
once per process.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The name of the model to load
|
||||
device : str
|
||||
The device to load the model on
|
||||
|
||||
TODO: use lru_cache instead with a reasonable/configurable maxsize
|
||||
"""
|
||||
sentence_transformers = cls.safe_import(
|
||||
"sentence_transformers", "sentence-transformers"
|
||||
)
|
||||
return sentence_transformers.SentenceTransformer(name, device=device)
|
||||
@@ -12,8 +12,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import socket
|
||||
import sys
|
||||
from typing import Callable, Union
|
||||
import urllib.error
|
||||
from typing import Callable, List, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
@@ -24,7 +26,12 @@ from ..util import safe_import_pandas
|
||||
from ..utils.general import LOGGER
|
||||
|
||||
pd = safe_import_pandas()
|
||||
|
||||
DATA = Union[pa.Table, "pd.DataFrame"]
|
||||
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||
IMAGES = Union[
|
||||
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
|
||||
]
|
||||
|
||||
|
||||
def with_embeddings(
|
||||
@@ -155,6 +162,20 @@ class FunctionWrapper:
|
||||
yield from _chunker(arr)
|
||||
|
||||
|
||||
def url_retrieve(url: str):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
url: str
|
||||
URL to download from
|
||||
"""
|
||||
try:
|
||||
with urllib.request.urlopen(url) as conn:
|
||||
return conn.read()
|
||||
except (socket.gaierror, urllib.error.URLError) as err:
|
||||
raise ConnectionError("could not download {} due to {}".format(url, err))
|
||||
|
||||
|
||||
def api_key_not_found_help(provider):
|
||||
LOGGER.error(f"Could not find API key for {provider}.")
|
||||
raise ValueError(f"Please set the {provider.upper()}_API_KEY environment variable.")
|
||||
|
||||
@@ -19,6 +19,7 @@ import inspect
|
||||
import sys
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import date, datetime
|
||||
from typing import Any, Callable, Dict, Generator, List, Type, Union, _GenericAlias
|
||||
|
||||
import numpy as np
|
||||
@@ -159,6 +160,10 @@ def _py_type_to_arrow_type(py_type: Type[Any]) -> pa.DataType:
|
||||
return pa.bool_()
|
||||
elif py_type == bytes:
|
||||
return pa.binary()
|
||||
elif py_type == date:
|
||||
return pa.date32()
|
||||
elif py_type == datetime:
|
||||
return pa.timestamp("us")
|
||||
raise TypeError(
|
||||
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}"
|
||||
)
|
||||
@@ -322,7 +327,12 @@ class LanceModel(pydantic.BaseModel):
|
||||
for vec, func in vec_and_function:
|
||||
for source, field_info in cls.safe_get_fields().items():
|
||||
src_func = get_extras(field_info, "source_column_for")
|
||||
if src_func == func:
|
||||
if src_func is func:
|
||||
# note we can't use == here since the function is a pydantic
|
||||
# model so two instances of the same function are ==, so if you
|
||||
# have multiple vector columns from multiple sources, both will
|
||||
# be mapped to the same source column
|
||||
# GH594
|
||||
configs.append(
|
||||
EmbeddingFunctionConfig(
|
||||
source_column=source, vector_column=vec, function=func
|
||||
|
||||
@@ -30,7 +30,40 @@ pd = safe_import_pandas()
|
||||
|
||||
|
||||
class Query(pydantic.BaseModel):
|
||||
"""A Query"""
|
||||
"""The LanceDB Query
|
||||
|
||||
Attributes
|
||||
----------
|
||||
vector : List[float]
|
||||
the vector to search for
|
||||
filter : Optional[str]
|
||||
sql filter to refine the query with, optional
|
||||
prefilter : bool
|
||||
if True then apply the filter before vector search
|
||||
k : int
|
||||
top k results to return
|
||||
metric : str
|
||||
the distance metric between a pair of vectors,
|
||||
|
||||
can support L2 (default), Cosine and Dot.
|
||||
[metric definitions][search]
|
||||
columns : Optional[List[str]]
|
||||
which columns to return in the results
|
||||
nprobes : int
|
||||
The number of probes used - optional
|
||||
|
||||
- A higher number makes search more accurate but also slower.
|
||||
|
||||
- See discussion in [Querying an ANN Index][querying-an-ann-index] for
|
||||
tuning advice.
|
||||
refine_factor : Optional[int]
|
||||
Refine the results by reading extra elements and re-ranking them in memory - optional
|
||||
|
||||
- A higher number makes search more accurate but also slower.
|
||||
|
||||
- See discussion in [Querying an ANN Index][querying-an-ann-index] for
|
||||
tuning advice.
|
||||
"""
|
||||
|
||||
vector_column: str = VECTOR_COLUMN_NAME
|
||||
|
||||
@@ -61,6 +94,10 @@ class Query(pydantic.BaseModel):
|
||||
|
||||
|
||||
class LanceQueryBuilder(ABC):
|
||||
"""Build LanceDB query based on specific query type:
|
||||
vector or full text search.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
@@ -133,11 +170,11 @@ class LanceQueryBuilder(ABC):
|
||||
deprecated_in="0.3.1",
|
||||
removed_in="0.4.0",
|
||||
current_version=__version__,
|
||||
details="Use the bar function instead",
|
||||
details="Use to_pandas() instead",
|
||||
)
|
||||
def to_df(self) -> "pd.DataFrame":
|
||||
"""
|
||||
Deprecated alias for `to_pandas()`. Please use `to_pandas()` instead.
|
||||
*Deprecated alias for `to_pandas()`. Please use `to_pandas()` instead.*
|
||||
|
||||
Execute the query and return the results as a pandas DataFrame.
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
@@ -226,13 +263,20 @@ class LanceQueryBuilder(ABC):
|
||||
self._columns = columns
|
||||
return self
|
||||
|
||||
def where(self, where) -> LanceQueryBuilder:
|
||||
def where(self, where: str, prefilter: bool = False) -> LanceQueryBuilder:
|
||||
"""Set the where clause.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
where: str
|
||||
The where clause.
|
||||
The where clause which is a valid SQL where clause. See
|
||||
`Lance filter pushdown <https://lancedb.github.io/lance/read_and_write.html#filter-push-down>`_
|
||||
for valid SQL expressions.
|
||||
prefilter: bool, default False
|
||||
If True, apply the filter before vector search, otherwise the
|
||||
filter is applied on the result of vector search.
|
||||
This feature is **EXPERIMENTAL** and may be removed and modified
|
||||
without warning in the future.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -240,13 +284,12 @@ class LanceQueryBuilder(ABC):
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._where = where
|
||||
self._prefilter = prefilter
|
||||
return self
|
||||
|
||||
|
||||
class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
"""
|
||||
A builder for nearest neighbor queries for LanceDB.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
@@ -302,7 +345,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
Higher values will yield better recall (more likely to find vectors if
|
||||
they exist) at the expense of latency.
|
||||
|
||||
See discussion in [Querying an ANN Index][../querying-an-ann-index] for
|
||||
See discussion in [Querying an ANN Index][querying-an-ann-index] for
|
||||
tuning advice.
|
||||
|
||||
Parameters
|
||||
@@ -369,14 +412,14 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
Parameters
|
||||
----------
|
||||
where: str
|
||||
The where clause.
|
||||
The where clause which is a valid SQL where clause. See
|
||||
`Lance filter pushdown <https://lancedb.github.io/lance/read_and_write.html#filter-push-down>`_
|
||||
for valid SQL expressions.
|
||||
prefilter: bool, default False
|
||||
If True, apply the filter before vector search, otherwise the
|
||||
filter is applied on the result of vector search.
|
||||
This feature is **EXPERIMENTAL** and may be removed and modified
|
||||
without warning in the future. Currently this is only supported
|
||||
in OSS and can only be used with a table that does not have an ANN
|
||||
index.
|
||||
without warning in the future.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -389,6 +432,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
|
||||
class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
"""A builder for full text search for LanceDB."""
|
||||
|
||||
def __init__(self, table: "lancedb.table.Table", query: str):
|
||||
super().__init__(table)
|
||||
self._query = query
|
||||
|
||||
@@ -151,10 +151,15 @@ class RestfulLanceDBClient:
|
||||
return await deserialize(resp)
|
||||
|
||||
@_check_not_closed
|
||||
async def list_tables(self):
|
||||
async def list_tables(self, limit: int, page_token: str):
|
||||
"""List all tables in the database."""
|
||||
json = await self.get("/v1/table/", {})
|
||||
return json["tables"]
|
||||
try:
|
||||
json = await self.get(
|
||||
"/v1/table/", {"limit": limit, "page_token": page_token}
|
||||
)
|
||||
return json["tables"]
|
||||
except StopAsyncIteration:
|
||||
return []
|
||||
|
||||
@_check_not_closed
|
||||
async def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
from typing import Iterator, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pyarrow as pa
|
||||
@@ -52,10 +52,27 @@ class RemoteDBConnection(DBConnection):
|
||||
def __repr__(self) -> str:
|
||||
return f"RemoveConnect(name={self.db_name})"
|
||||
|
||||
def table_names(self) -> List[str]:
|
||||
"""List the names of all tables in the database."""
|
||||
result = self._loop.run_until_complete(self._client.list_tables())
|
||||
return result
|
||||
def table_names(self, last_token: str, limit=10) -> Iterator[str]:
|
||||
"""List the names of all tables in the database.
|
||||
Parameters
|
||||
----------
|
||||
last_token: str
|
||||
The last token to start the new page.
|
||||
|
||||
Returns
|
||||
-------
|
||||
An iterator of table names.
|
||||
"""
|
||||
while True:
|
||||
result = self._loop.run_until_complete(
|
||||
self._client.list_tables(limit, last_token)
|
||||
)
|
||||
if len(result) > 0:
|
||||
last_token = result[len(result) - 1]
|
||||
else:
|
||||
break
|
||||
for item in result:
|
||||
yield result
|
||||
|
||||
def open_table(self, name: str) -> Table:
|
||||
"""Open a Lance Table in the database.
|
||||
@@ -87,7 +104,11 @@ class RemoteDBConnection(DBConnection):
|
||||
raise ValueError("Either data or schema must be provided.")
|
||||
if data is not None:
|
||||
data = _sanitize_data(
|
||||
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
data,
|
||||
schema,
|
||||
metadata=None,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
else:
|
||||
if schema is None:
|
||||
@@ -122,3 +143,8 @@ class RemoteDBConnection(DBConnection):
|
||||
f"/v1/table/{name}/drop/",
|
||||
)
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection to the database."""
|
||||
self._loop.close()
|
||||
await self._client.close()
|
||||
|
||||
@@ -29,8 +29,7 @@ from lance.dataset import CleanupStats, ReaderLike
|
||||
from lance.vector import vec_to_table
|
||||
|
||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from .embeddings import EmbeddingFunctionRegistry
|
||||
from .embeddings.functions import EmbeddingFunctionConfig
|
||||
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from .pydantic import LanceModel
|
||||
from .query import LanceQueryBuilder, Query
|
||||
from .util import fs_from_uri, safe_import_pandas
|
||||
@@ -150,13 +149,13 @@ class Table(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def schema(self) -> pa.Schema:
|
||||
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#) of
|
||||
this Table
|
||||
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
|
||||
of this Table
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_pandas(self):
|
||||
def to_pandas(self) -> "pd.DataFrame":
|
||||
"""Return the table as a pandas DataFrame.
|
||||
|
||||
Returns
|
||||
@@ -192,17 +191,18 @@ class Table(ABC):
|
||||
The distance metric to use when creating the index.
|
||||
Valid values are "L2", "cosine", or "dot".
|
||||
L2 is euclidean distance.
|
||||
num_partitions: int
|
||||
num_partitions: int, default 256
|
||||
The number of IVF partitions to use when creating the index.
|
||||
Default is 256.
|
||||
num_sub_vectors: int
|
||||
num_sub_vectors: int, default 96
|
||||
The number of PQ sub-vectors to use when creating the index.
|
||||
Default is 96.
|
||||
vector_column_name: str, default "vector"
|
||||
The vector column name to create the index.
|
||||
replace: bool, default True
|
||||
If True, replace the existing index if it exists.
|
||||
If False, raise an error if duplicate index exists.
|
||||
- If True, replace the existing index if it exists.
|
||||
|
||||
- If False, raise an error if duplicate index exists.
|
||||
accelerator: str, default None
|
||||
If set, use the given accelerator to create the index.
|
||||
Only support "cuda" for now.
|
||||
@@ -221,8 +221,14 @@ class Table(ABC):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: list-of-dict, dict, pd.DataFrame
|
||||
The data to insert into the table.
|
||||
data: DATA
|
||||
The data to insert into the table. Acceptable types are:
|
||||
|
||||
- dict or list-of-dict
|
||||
|
||||
- pandas.DataFrame
|
||||
|
||||
- pyarrow.Table or pyarrow.RecordBatch
|
||||
mode: str
|
||||
The mode to use when writing the data. Valid values are
|
||||
"append" and "overwrite".
|
||||
@@ -243,31 +249,70 @@ class Table(ABC):
|
||||
query_type: str = "auto",
|
||||
) -> LanceQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector.
|
||||
of the given query vector. We currently support [vector search][search]
|
||||
and [full-text search][experimental-full-text-search].
|
||||
|
||||
All query options are defined in [Query][lancedb.query.Query].
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> data = [
|
||||
... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]},
|
||||
... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]},
|
||||
... {"original_width": 3000, "caption": "test", "vector": [0.3, 6.2, 2.6]}
|
||||
... ]
|
||||
>>> table = db.create_table("my_table", data)
|
||||
>>> query = [0.4, 1.4, 2.4]
|
||||
>>> (table.search(query, vector_column_name="vector")
|
||||
... .where("original_width > 1000", prefilter=True)
|
||||
... .select(["caption", "original_width"])
|
||||
... .limit(2)
|
||||
... .to_pandas())
|
||||
caption original_width vector _distance
|
||||
0 foo 2000 [0.5, 3.4, 1.3] 5.220000
|
||||
1 test 3000 [0.3, 6.2, 2.6] 23.089996
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: str, list, np.ndarray, PIL.Image.Image, default None
|
||||
The query to search for. If None then
|
||||
the select/where/limit clauses are applied to filter
|
||||
query: list/np.ndarray/str/PIL.Image.Image, default None
|
||||
The targetted vector to search for.
|
||||
|
||||
- *default None*.
|
||||
Acceptable types are: list, np.ndarray, PIL.Image.Image
|
||||
|
||||
- If None then the select/where/limit clauses are applied to filter
|
||||
the table
|
||||
vector_column_name: str, default "vector"
|
||||
vector_column_name: str
|
||||
The name of the vector column to search.
|
||||
query_type: str, default "auto"
|
||||
"vector", "fts", or "auto"
|
||||
If "auto" then the query type is inferred from the query;
|
||||
If `query` is a list/np.ndarray then the query type is "vector";
|
||||
If `query` is a PIL.Image.Image then either do vector search
|
||||
or raise an error if no corresponding embedding function is found.
|
||||
If `query` is a string, then the query type is "vector" if the
|
||||
*default "vector"*
|
||||
query_type: str
|
||||
*default "auto"*.
|
||||
Acceptable types are: "vector", "fts", or "auto"
|
||||
|
||||
- If "auto" then the query type is inferred from the query;
|
||||
|
||||
- If `query` is a list/np.ndarray then the query type is
|
||||
"vector";
|
||||
|
||||
- If `query` is a PIL.Image.Image then either do vector search,
|
||||
or raise an error if no corresponding embedding function is found.
|
||||
|
||||
- If `query` is a string, then the query type is "vector" if the
|
||||
table has embedding functions else the query type is "fts"
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
A query builder object representing the query.
|
||||
Once executed, the query returns selected columns, the vector,
|
||||
and also the "_distance" column which is the distance between the query
|
||||
Once executed, the query returns
|
||||
|
||||
- selected columns
|
||||
|
||||
- the vector
|
||||
|
||||
- and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -286,14 +331,19 @@ class Table(ABC):
|
||||
Parameters
|
||||
----------
|
||||
where: str
|
||||
The SQL where clause to use when deleting rows. For example, 'x = 2'
|
||||
or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
|
||||
The SQL where clause to use when deleting rows.
|
||||
|
||||
- For example, 'x = 2' or 'x IN (1, 2, 3)'.
|
||||
|
||||
The filter must not be empty, or it will error.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> data = [
|
||||
... {"x": 1, "vector": [1, 2]}, {"x": 2, "vector": [3, 4]}, {"x": 3, "vector": [5, 6]}
|
||||
... {"x": 1, "vector": [1, 2]},
|
||||
... {"x": 2, "vector": [3, 4]},
|
||||
... {"x": 3, "vector": [5, 6]}
|
||||
... ]
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data)
|
||||
@@ -378,7 +428,8 @@ class LanceTable(Table):
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}])
|
||||
>>> table = db.create_table("my_table",
|
||||
... [{"vector": [1.1, 0.9], "type": "vector"}])
|
||||
>>> table.version
|
||||
2
|
||||
>>> table.to_pandas()
|
||||
@@ -425,7 +476,8 @@ class LanceTable(Table):
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}])
|
||||
>>> table = db.create_table("my_table", [
|
||||
... {"vector": [1.1, 0.9], "type": "vector"}])
|
||||
>>> table.version
|
||||
2
|
||||
>>> table.to_pandas()
|
||||
@@ -670,14 +722,39 @@ class LanceTable(Table):
|
||||
query_type: str = "auto",
|
||||
) -> LanceQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector.
|
||||
of the given query vector. We currently support [vector search][search]
|
||||
and [full-text search][search].
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> data = [
|
||||
... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]},
|
||||
... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]},
|
||||
... {"original_width": 3000, "caption": "test", "vector": [0.3, 6.2, 2.6]}
|
||||
... ]
|
||||
>>> table = db.create_table("my_table", data)
|
||||
>>> query = [0.4, 1.4, 2.4]
|
||||
>>> (table.search(query, vector_column_name="vector")
|
||||
... .where("original_width > 1000", prefilter=True)
|
||||
... .select(["caption", "original_width"])
|
||||
... .limit(2)
|
||||
... .to_pandas())
|
||||
caption original_width vector _distance
|
||||
0 foo 2000 [0.5, 3.4, 1.3] 5.220000
|
||||
1 test 3000 [0.3, 6.2, 2.6] 23.089996
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: str, list, np.ndarray, a PIL Image or None
|
||||
The query to search for. If None then
|
||||
the select/where/limit clauses are applied to filter
|
||||
the table
|
||||
query: list/np.ndarray/str/PIL.Image.Image, default None
|
||||
The targetted vector to search for.
|
||||
|
||||
- *default None*.
|
||||
Acceptable types are: list, np.ndarray, PIL.Image.Image
|
||||
|
||||
- If None then the select/[where][sql]/limit clauses are applied
|
||||
to filter the table
|
||||
vector_column_name: str, default "vector"
|
||||
The name of the vector column to search.
|
||||
query_type: str, default "auto"
|
||||
@@ -686,7 +763,7 @@ class LanceTable(Table):
|
||||
If `query` is a list/np.ndarray then the query type is "vector";
|
||||
If `query` is a PIL.Image.Image then either do vector search
|
||||
or raise an error if no corresponding embedding function is found.
|
||||
If the query is a string, then the query type is "vector" if the
|
||||
If the `query` is a string, then the query type is "vector" if the
|
||||
table has embedding functions, else the query type is "fts"
|
||||
|
||||
Returns
|
||||
@@ -721,7 +798,9 @@ class LanceTable(Table):
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> data = [
|
||||
... {"x": 1, "vector": [1, 2]}, {"x": 2, "vector": [3, 4]}, {"x": 3, "vector": [5, 6]}
|
||||
... {"x": 1, "vector": [1, 2]},
|
||||
... {"x": 2, "vector": [3, 4]},
|
||||
... {"x": 3, "vector": [5, 6]}
|
||||
... ]
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data)
|
||||
@@ -741,7 +820,8 @@ class LanceTable(Table):
|
||||
The data to insert into the table.
|
||||
At least one of `data` or `schema` must be provided.
|
||||
schema: pa.Schema or LanceModel, optional
|
||||
The schema of the table. If not provided, the schema is inferred from the data.
|
||||
The schema of the table. If not provided,
|
||||
the schema is inferred from the data.
|
||||
At least one of `data` or `schema` must be provided.
|
||||
mode: str, default "create"
|
||||
The mode to use when writing the data. Valid values are
|
||||
@@ -812,7 +892,8 @@ class LanceTable(Table):
|
||||
file_info = fs.get_file_info(path)
|
||||
if file_info.type != pa.fs.FileType.Directory:
|
||||
raise FileNotFoundError(
|
||||
f"Table {name} does not exist. Please first call db.create_table({name}, data)"
|
||||
f"Table {name} does not exist."
|
||||
f"Please first call db.create_table({name}, data)"
|
||||
)
|
||||
return tbl
|
||||
|
||||
@@ -839,7 +920,9 @@ class LanceTable(Table):
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> data = [
|
||||
... {"x": 1, "vector": [1, 2]}, {"x": 2, "vector": [3, 4]}, {"x": 3, "vector": [5, 6]}
|
||||
... {"x": 1, "vector": [1, 2]},
|
||||
... {"x": 2, "vector": [3, 4]},
|
||||
... {"x": 3, "vector": [5, 6]}
|
||||
... ]
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data)
|
||||
@@ -873,12 +956,6 @@ class LanceTable(Table):
|
||||
|
||||
def _execute_query(self, query: Query) -> pa.Table:
|
||||
ds = self.to_lance()
|
||||
if query.prefilter:
|
||||
for idx in ds.list_indices():
|
||||
if query.vector_column in idx["fields"]:
|
||||
raise NotImplementedError(
|
||||
"Prefiltering for indexed vector column is coming soon."
|
||||
)
|
||||
return ds.to_table(
|
||||
columns=query.columns,
|
||||
filter=query.filter,
|
||||
@@ -1020,7 +1097,8 @@ def _sanitize_vector_column(
|
||||
# ChunkedArray is annoying to work with, so we combine chunks here
|
||||
vec_arr = data[vector_column_name].combine_chunks()
|
||||
if pa.types.is_list(data[vector_column_name].type):
|
||||
# if it's a variable size list array we make sure the dimensions are all the same
|
||||
# if it's a variable size list array,
|
||||
# we make sure the dimensions are all the same
|
||||
has_jagged_ndims = len(vec_arr.values) % len(data) != 0
|
||||
if has_jagged_ndims:
|
||||
data = _sanitize_jagged(
|
||||
|
||||
@@ -63,7 +63,8 @@ def set_sentry():
|
||||
"""
|
||||
if "exc_info" in hint:
|
||||
exc_type, exc_value, tb = hint["exc_info"]
|
||||
if "out of memory" in str(exc_value).lower():
|
||||
ignored_errors = ["out of memory", "no space left on device", "testing"]
|
||||
if any(error in str(exc_value).lower() for error in ignored_errors):
|
||||
return None
|
||||
|
||||
if is_git_dir():
|
||||
@@ -97,7 +98,7 @@ def set_sentry():
|
||||
dsn="https://c63ef8c64e05d1aa1a96513361f3ca2f@o4505950840946688.ingest.sentry.io/4505950933614592",
|
||||
debug=False,
|
||||
include_local_variables=False,
|
||||
traces_sample_rate=1.0,
|
||||
traces_sample_rate=0.5,
|
||||
environment="production", # 'dev' or 'production'
|
||||
before_send=before_send,
|
||||
ignore_errors=[KeyboardInterrupt, FileNotFoundError, bdb.BdbQuit],
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
[project]
|
||||
name = "lancedb"
|
||||
version = "0.3.1"
|
||||
version = "0.3.3"
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.8.5",
|
||||
"pylance==0.8.10",
|
||||
"ratelimiter~=1.0",
|
||||
"retry>=0.9.2",
|
||||
"tqdm>=4.1.0",
|
||||
@@ -52,7 +52,7 @@ tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests"]
|
||||
dev = ["ruff", "pre-commit", "black"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||
clip = ["torch", "pillow", "open-clip"]
|
||||
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip", "cohere"]
|
||||
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere"]
|
||||
|
||||
[project.scripts]
|
||||
lancedb = "lancedb.cli.cli:cli"
|
||||
|
||||
@@ -150,6 +150,21 @@ def test_ingest_iterator(tmp_path):
|
||||
run_tests(PydanticSchema)
|
||||
|
||||
|
||||
def test_table_names(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pd.DataFrame(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
db.create_table("test2", data=data)
|
||||
db.create_table("test1", data=data)
|
||||
db.create_table("test3", data=data)
|
||||
assert db.table_names() == ["test1", "test2", "test3"]
|
||||
|
||||
|
||||
def test_create_mode(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pd.DataFrame(
|
||||
@@ -287,3 +302,27 @@ def test_replace_index(tmp_path):
|
||||
num_sub_vectors=4,
|
||||
replace=True,
|
||||
)
|
||||
|
||||
|
||||
def test_prefilter_with_index(tmp_path):
|
||||
db = lancedb.connect(uri=tmp_path)
|
||||
data = [
|
||||
{"vector": np.random.rand(128), "item": "foo", "price": float(i)}
|
||||
for i in range(1000)
|
||||
]
|
||||
sample_key = data[100]["vector"]
|
||||
table = db.create_table(
|
||||
"test",
|
||||
data,
|
||||
)
|
||||
table.create_index(
|
||||
num_partitions=2,
|
||||
num_sub_vectors=4,
|
||||
)
|
||||
table = (
|
||||
table.search(sample_key)
|
||||
.where("price == 500", prefilter=True)
|
||||
.limit(5)
|
||||
.to_arrow()
|
||||
)
|
||||
assert table.num_rows == 1
|
||||
|
||||
@@ -19,7 +19,7 @@ import pytest
|
||||
import requests
|
||||
|
||||
import lancedb
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
# These are integration tests for embedding functions.
|
||||
@@ -31,12 +31,15 @@ from lancedb.pydantic import LanceModel, Vector
|
||||
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
|
||||
def test_sentence_transformer(alias, tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
registry = get_registry()
|
||||
func = registry.get(alias).create()
|
||||
func2 = registry.get(alias).create()
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
text2: str = func2.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
vector2: Vector(func2.ndims()) = func2.VectorField()
|
||||
|
||||
table = db.create_table("words", schema=Words)
|
||||
table.add(
|
||||
@@ -50,7 +53,16 @@ def test_sentence_transformer(alias, tmp_path):
|
||||
"foo",
|
||||
"bar",
|
||||
"baz",
|
||||
]
|
||||
],
|
||||
"text2": [
|
||||
"to be or not to be",
|
||||
"that is the question",
|
||||
"for whether tis nobler",
|
||||
"in the mind to suffer",
|
||||
"the slings and arrows",
|
||||
"of outrageous fortune",
|
||||
"or to take arms",
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
@@ -62,6 +74,13 @@ def test_sentence_transformer(alias, tmp_path):
|
||||
expected = table.search(vec).limit(1).to_pydantic(Words)[0]
|
||||
assert actual.text == expected.text
|
||||
assert actual.text == "hello world"
|
||||
assert not np.allclose(actual.vector, actual.vector2)
|
||||
|
||||
actual = (
|
||||
table.search(query, vector_column_name="vector2").limit(1).to_pydantic(Words)[0]
|
||||
)
|
||||
assert actual.text != "hello world"
|
||||
assert not np.allclose(actual.vector, actual.vector2)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@@ -69,7 +88,7 @@ def test_openclip(tmp_path):
|
||||
from PIL import Image
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
registry = get_registry()
|
||||
func = registry.get("open-clip").create()
|
||||
|
||||
class Images(LanceModel):
|
||||
@@ -131,11 +150,7 @@ def test_openclip(tmp_path):
|
||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||
) # also skip if cohere not installed
|
||||
def test_cohere_embedding_function():
|
||||
cohere = (
|
||||
EmbeddingFunctionRegistry.get_instance()
|
||||
.get("cohere")
|
||||
.create(name="embed-multilingual-v2.0")
|
||||
)
|
||||
cohere = get_registry().get("cohere").create(name="embed-multilingual-v2.0")
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = cohere.SourceField()
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import json
|
||||
import sys
|
||||
from datetime import date, datetime
|
||||
from typing import List, Optional
|
||||
|
||||
import pyarrow as pa
|
||||
@@ -40,10 +41,18 @@ def test_pydantic_to_arrow():
|
||||
li: List[int]
|
||||
opt: Optional[str] = None
|
||||
st: StructModel
|
||||
dt: date
|
||||
dtt: datetime
|
||||
# d: dict
|
||||
|
||||
m = TestModel(
|
||||
id=1, s="hello", vec=[1.0, 2.0, 3.0], li=[2, 3, 4], st=StructModel(a="a", b=1.0)
|
||||
id=1,
|
||||
s="hello",
|
||||
vec=[1.0, 2.0, 3.0],
|
||||
li=[2, 3, 4],
|
||||
st=StructModel(a="a", b=1.0),
|
||||
dt=date.today(),
|
||||
dtt=datetime.now(),
|
||||
)
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
@@ -62,6 +71,8 @@ def test_pydantic_to_arrow():
|
||||
),
|
||||
False,
|
||||
),
|
||||
pa.field("dt", pa.date32(), False),
|
||||
pa.field("dtt", pa.timestamp("us"), False),
|
||||
]
|
||||
)
|
||||
assert schema == expect_schema
|
||||
@@ -79,10 +90,18 @@ def test_pydantic_to_arrow_py38():
|
||||
li: List[int]
|
||||
opt: Optional[str] = None
|
||||
st: StructModel
|
||||
dt: date
|
||||
dtt: datetime
|
||||
# d: dict
|
||||
|
||||
m = TestModel(
|
||||
id=1, s="hello", vec=[1.0, 2.0, 3.0], li=[2, 3, 4], st=StructModel(a="a", b=1.0)
|
||||
id=1,
|
||||
s="hello",
|
||||
vec=[1.0, 2.0, 3.0],
|
||||
li=[2, 3, 4],
|
||||
st=StructModel(a="a", b=1.0),
|
||||
dt=date.today(),
|
||||
dtt=datetime.now(),
|
||||
)
|
||||
|
||||
schema = pydantic_to_schema(TestModel)
|
||||
@@ -101,6 +120,8 @@ def test_pydantic_to_arrow_py38():
|
||||
),
|
||||
False,
|
||||
),
|
||||
pa.field("dt", pa.date32(), False),
|
||||
pa.field("dtt", pa.timestamp("us"), False),
|
||||
]
|
||||
)
|
||||
assert schema == expect_schema
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "vectordb-node"
|
||||
version = "0.3.2"
|
||||
version = "0.3.5"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
license = "Apache-2.0"
|
||||
edition = "2018"
|
||||
|
||||
@@ -70,7 +70,6 @@ fn get_index_params_builder(
|
||||
.map(|mt| {
|
||||
let metric_type = mt.unwrap();
|
||||
index_builder.metric_type(metric_type);
|
||||
pq_params.metric_type = metric_type;
|
||||
});
|
||||
|
||||
let num_partitions = obj.get_opt_usize(cx, "num_partitions")?;
|
||||
|
||||
@@ -239,6 +239,8 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> {
|
||||
cx.export_function("tableDelete", JsTable::js_delete)?;
|
||||
cx.export_function("tableCleanupOldVersions", JsTable::js_cleanup)?;
|
||||
cx.export_function("tableCompactFiles", JsTable::js_compact)?;
|
||||
cx.export_function("tableListIndices", JsTable::js_list_indices)?;
|
||||
cx.export_function("tableIndexStats", JsTable::js_index_stats)?;
|
||||
cx.export_function(
|
||||
"tableCreateVectorIndex",
|
||||
index::vector::table_create_vector_index,
|
||||
|
||||
@@ -247,7 +247,7 @@ impl JsTable {
|
||||
}
|
||||
|
||||
rt.spawn(async move {
|
||||
let stats = table.compact_files(options).await;
|
||||
let stats = table.compact_files(options, None).await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let stats = stats.or_throw(&mut cx)?;
|
||||
@@ -276,4 +276,91 @@ impl JsTable {
|
||||
});
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
pub(crate) fn js_list_indices(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
// let predicate = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
let channel = cx.channel();
|
||||
let table = js_table.table.clone();
|
||||
|
||||
rt.spawn(async move {
|
||||
let indices = table.load_indices().await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let indices = indices.or_throw(&mut cx)?;
|
||||
|
||||
let output = JsArray::new(&mut cx, indices.len() as u32);
|
||||
for (i, index) in indices.iter().enumerate() {
|
||||
let js_index = JsObject::new(&mut cx);
|
||||
let index_name = cx.string(index.index_name.clone());
|
||||
js_index.set(&mut cx, "name", index_name)?;
|
||||
|
||||
let index_uuid = cx.string(index.index_uuid.clone());
|
||||
js_index.set(&mut cx, "uuid", index_uuid)?;
|
||||
|
||||
let js_index_columns = JsArray::new(&mut cx, index.columns.len() as u32);
|
||||
for (j, column) in index.columns.iter().enumerate() {
|
||||
let js_column = cx.string(column.clone());
|
||||
js_index_columns.set(&mut cx, j as u32, js_column)?;
|
||||
}
|
||||
js_index.set(&mut cx, "columns", js_index_columns)?;
|
||||
|
||||
output.set(&mut cx, i as u32, js_index)?;
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
})
|
||||
});
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
pub(crate) fn js_index_stats(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let rt = runtime(&mut cx)?;
|
||||
let (deferred, promise) = cx.promise();
|
||||
let index_uuid = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
let channel = cx.channel();
|
||||
let table = js_table.table.clone();
|
||||
|
||||
rt.spawn(async move {
|
||||
let load_stats = futures::try_join!(
|
||||
table.count_indexed_rows(&index_uuid),
|
||||
table.count_unindexed_rows(&index_uuid)
|
||||
);
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let (indexed_rows, unindexed_rows) = load_stats.or_throw(&mut cx)?;
|
||||
|
||||
let output = JsObject::new(&mut cx);
|
||||
|
||||
match indexed_rows {
|
||||
Some(x) => {
|
||||
let i = cx.number(x as f64);
|
||||
output.set(&mut cx, "numIndexedRows", i)?;
|
||||
}
|
||||
None => {
|
||||
let null = cx.null();
|
||||
output.set(&mut cx, "numIndexedRows", null)?;
|
||||
}
|
||||
};
|
||||
|
||||
match unindexed_rows {
|
||||
Some(x) => {
|
||||
let i = cx.number(x as f64);
|
||||
output.set(&mut cx, "numUnindexedRows", i)?;
|
||||
}
|
||||
None => {
|
||||
let null = cx.null();
|
||||
output.set(&mut cx, "numUnindexedRows", null)?;
|
||||
}
|
||||
};
|
||||
|
||||
Ok(output)
|
||||
})
|
||||
});
|
||||
|
||||
Ok(promise)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "vectordb"
|
||||
version = "0.3.2"
|
||||
version = "0.3.5"
|
||||
edition = "2021"
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license = "Apache-2.0"
|
||||
|
||||
@@ -161,7 +161,7 @@ impl Database {
|
||||
///
|
||||
/// * A [Vec<String>] with all table names.
|
||||
pub async fn table_names(&self) -> Result<Vec<String>> {
|
||||
let f = self
|
||||
let mut f = self
|
||||
.object_store
|
||||
.read_dir(self.base_path.clone())
|
||||
.await?
|
||||
@@ -175,7 +175,8 @@ impl Database {
|
||||
is_lance.unwrap_or(false)
|
||||
})
|
||||
.filter_map(|p| p.file_stem().and_then(|s| s.to_str().map(String::from)))
|
||||
.collect();
|
||||
.collect::<Vec<String>>();
|
||||
f.sort();
|
||||
Ok(f)
|
||||
}
|
||||
|
||||
@@ -312,8 +313,8 @@ mod tests {
|
||||
let db = Database::connect(uri).await.unwrap();
|
||||
let tables = db.table_names().await.unwrap();
|
||||
assert_eq!(tables.len(), 2);
|
||||
assert!(tables.contains(&String::from("table1")));
|
||||
assert!(tables.contains(&String::from("table2")));
|
||||
assert!(tables[0].eq(&String::from("table1")));
|
||||
assert!(tables[1].eq(&String::from("table2")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use lance::format::{Index, Manifest};
|
||||
use lance::index::vector::ivf::IvfBuildParams;
|
||||
use lance::index::vector::pq::PQBuildParams;
|
||||
use lance::index::vector::VectorIndexParams;
|
||||
@@ -98,7 +99,11 @@ impl VectorIndexBuilder for IvfPQIndexBuilder {
|
||||
let ivf_params = self.ivf_params.clone().unwrap_or_default();
|
||||
let pq_params = self.pq_params.clone().unwrap_or_default();
|
||||
|
||||
VectorIndexParams::with_ivf_pq_params(pq_params.metric_type, ivf_params, pq_params)
|
||||
VectorIndexParams::with_ivf_pq_params(
|
||||
self.metric_type.unwrap_or(MetricType::L2),
|
||||
ivf_params,
|
||||
pq_params,
|
||||
)
|
||||
}
|
||||
|
||||
fn get_replace(&self) -> bool {
|
||||
@@ -106,6 +111,27 @@ impl VectorIndexBuilder for IvfPQIndexBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VectorIndex {
|
||||
pub columns: Vec<String>,
|
||||
pub index_name: String,
|
||||
pub index_uuid: String,
|
||||
}
|
||||
|
||||
impl VectorIndex {
|
||||
pub fn new_from_format(manifest: &Manifest, index: &Index) -> VectorIndex {
|
||||
let fields = index
|
||||
.fields
|
||||
.iter()
|
||||
.map(|i| manifest.schema.fields[*i as usize].name.clone())
|
||||
.collect();
|
||||
VectorIndex {
|
||||
columns: fields,
|
||||
index_name: index.name.clone(),
|
||||
index_uuid: index.uuid.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -158,7 +184,6 @@ mod tests {
|
||||
pq_params.max_iters = 1;
|
||||
pq_params.num_bits = 8;
|
||||
pq_params.num_sub_vectors = 50;
|
||||
pq_params.metric_type = MetricType::Cosine;
|
||||
pq_params.max_opq_iters = 2;
|
||||
index_builder.ivf_params(ivf_params);
|
||||
index_builder.pq_params(pq_params);
|
||||
@@ -176,7 +201,6 @@ mod tests {
|
||||
assert_eq!(pq_params.max_iters, 1);
|
||||
assert_eq!(pq_params.num_bits, 8);
|
||||
assert_eq!(pq_params.num_sub_vectors, 50);
|
||||
assert_eq!(pq_params.metric_type, MetricType::Cosine);
|
||||
assert_eq!(pq_params.max_opq_iters, 2);
|
||||
} else {
|
||||
assert!(false, "Expected second stage to be pq")
|
||||
|
||||
@@ -57,7 +57,7 @@ trait PrimaryOnly {
|
||||
|
||||
impl PrimaryOnly for Path {
|
||||
fn primary_only(&self) -> bool {
|
||||
self.to_string().contains("manifest")
|
||||
self.filename().unwrap_or("") == "_latest.manifest"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,8 +118,10 @@ impl ObjectStore for MirroringObjectStore {
|
||||
self.primary.head(location).await
|
||||
}
|
||||
|
||||
// garbage collection on secondary will happen async from other means
|
||||
async fn delete(&self, location: &Path) -> Result<()> {
|
||||
if !location.primary_only() {
|
||||
self.secondary.delete(location).await?;
|
||||
}
|
||||
self.primary.delete(location).await
|
||||
}
|
||||
|
||||
@@ -132,7 +134,7 @@ impl ObjectStore for MirroringObjectStore {
|
||||
}
|
||||
|
||||
async fn copy(&self, from: &Path, to: &Path) -> Result<()> {
|
||||
if from.primary_only() {
|
||||
if to.primary_only() {
|
||||
self.primary.copy(from, to).await
|
||||
} else {
|
||||
self.secondary.copy(from, to).await?;
|
||||
@@ -142,6 +144,9 @@ impl ObjectStore for MirroringObjectStore {
|
||||
}
|
||||
|
||||
async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> {
|
||||
if !to.primary_only() {
|
||||
self.secondary.copy(from, to).await?;
|
||||
}
|
||||
self.primary.copy_if_not_exists(from, to).await
|
||||
}
|
||||
}
|
||||
@@ -379,7 +384,7 @@ mod test {
|
||||
let primary_f = primary_elem.unwrap().unwrap();
|
||||
// hit manifest, skip, _versions contains all the manifest and should not exist on secondary
|
||||
let primary_raw_path = primary_f.file_name().to_str().unwrap();
|
||||
if primary_raw_path.contains("manifest") || primary_raw_path.contains("_versions") {
|
||||
if primary_raw_path.contains("_latest.manifest") {
|
||||
primary_elem = primary_iter.next();
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -18,14 +18,16 @@ use std::sync::Arc;
|
||||
use arrow_array::{Float32Array, RecordBatchReader};
|
||||
use arrow_schema::SchemaRef;
|
||||
use lance::dataset::cleanup::RemovalStats;
|
||||
use lance::dataset::optimize::{compact_files, CompactionMetrics, CompactionOptions};
|
||||
use lance::dataset::optimize::{
|
||||
compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions,
|
||||
};
|
||||
use lance::dataset::{Dataset, WriteParams};
|
||||
use lance::index::IndexType;
|
||||
use lance::index::{DatasetIndexExt, IndexType};
|
||||
use lance::io::object_store::WrappingObjectStore;
|
||||
use std::path::Path;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::index::vector::VectorIndexBuilder;
|
||||
use crate::index::vector::{VectorIndex, VectorIndexBuilder};
|
||||
use crate::query::Query;
|
||||
use crate::utils::{PatchReadParam, PatchWriteParam};
|
||||
use crate::WriteMode;
|
||||
@@ -153,6 +155,22 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn checkout_latest(&self) -> Result<Self> {
|
||||
let latest_version_id = self.dataset.latest_version_id().await?;
|
||||
let dataset = if latest_version_id == self.dataset.version().version {
|
||||
self.dataset.clone()
|
||||
} else {
|
||||
Arc::new(self.dataset.checkout_version(latest_version_id).await?)
|
||||
};
|
||||
|
||||
Ok(Table {
|
||||
name: self.name.clone(),
|
||||
uri: self.uri.clone(),
|
||||
dataset,
|
||||
store_wrapper: self.store_wrapper.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn get_table_name(uri: &str) -> Result<String> {
|
||||
let path = Path::new(uri);
|
||||
let name = path
|
||||
@@ -222,8 +240,6 @@ impl Table {
|
||||
|
||||
/// Create index on the table.
|
||||
pub async fn create_index(&mut self, index_builder: &impl VectorIndexBuilder) -> Result<()> {
|
||||
use lance::index::DatasetIndexExt;
|
||||
|
||||
let mut dataset = self.dataset.as_ref().clone();
|
||||
dataset
|
||||
.create_index(
|
||||
@@ -241,6 +257,14 @@ impl Table {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn optimize_indices(&mut self) -> Result<()> {
|
||||
let mut dataset = self.dataset.as_ref().clone();
|
||||
|
||||
dataset.optimize_indices().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Insert records into this Table
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -337,12 +361,45 @@ impl Table {
|
||||
/// for faster reads.
|
||||
///
|
||||
/// This calls into [lance::dataset::optimize::compact_files].
|
||||
pub async fn compact_files(&mut self, options: CompactionOptions) -> Result<CompactionMetrics> {
|
||||
pub async fn compact_files(
|
||||
&mut self,
|
||||
options: CompactionOptions,
|
||||
remap_options: Option<Arc<dyn IndexRemapperOptions>>,
|
||||
) -> Result<CompactionMetrics> {
|
||||
let mut dataset = self.dataset.as_ref().clone();
|
||||
let metrics = compact_files(&mut dataset, options, None).await?;
|
||||
let metrics = compact_files(&mut dataset, options, remap_options).await?;
|
||||
self.dataset = Arc::new(dataset);
|
||||
Ok(metrics)
|
||||
}
|
||||
|
||||
pub fn count_fragments(&self) -> usize {
|
||||
self.dataset.count_fragments()
|
||||
}
|
||||
|
||||
pub fn count_deleted_rows(&self) -> usize {
|
||||
self.dataset.count_deleted_rows()
|
||||
}
|
||||
|
||||
pub fn num_small_files(&self, max_rows_per_group: usize) -> usize {
|
||||
self.dataset.num_small_files(max_rows_per_group)
|
||||
}
|
||||
|
||||
pub async fn count_indexed_rows(&self, index_uuid: &str) -> Result<Option<usize>> {
|
||||
Ok(self.dataset.count_indexed_rows(index_uuid).await?)
|
||||
}
|
||||
|
||||
pub async fn count_unindexed_rows(&self, index_uuid: &str) -> Result<Option<usize>> {
|
||||
Ok(self.dataset.count_unindexed_rows(index_uuid).await?)
|
||||
}
|
||||
|
||||
pub async fn load_indices(&self) -> Result<Vec<VectorIndex>> {
|
||||
let (indices, mf) =
|
||||
futures::try_join!(self.dataset.load_indices(), self.dataset.latest_manifest())?;
|
||||
Ok(indices
|
||||
.iter()
|
||||
.map(|i| VectorIndex::new_from_format(&mf, i))
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
Reference in New Issue
Block a user