mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 13:29:57 +00:00
Compare commits
93 Commits
python-v0.
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
38321fa226 | ||
|
|
22749c3fa2 | ||
|
|
123a49df77 | ||
|
|
a57aa4b142 | ||
|
|
d8e3e54226 | ||
|
|
ccfdf4853a | ||
|
|
87e5d86e90 | ||
|
|
1cf8a3e4e0 | ||
|
|
5372843281 | ||
|
|
54677b8f0b | ||
|
|
ebcf9bf6ae | ||
|
|
797514bcbf | ||
|
|
1c872ce501 | ||
|
|
479f471c14 | ||
|
|
ae0d2f2599 | ||
|
|
1e8678f11a | ||
|
|
662968559d | ||
|
|
9d895801f2 | ||
|
|
80613a40fd | ||
|
|
d43ef7f11e | ||
|
|
554e068917 | ||
|
|
567734dd6e | ||
|
|
1589499f89 | ||
|
|
682e95fa83 | ||
|
|
1ad5e7f2f0 | ||
|
|
ddb3ef4ce5 | ||
|
|
ef20b2a138 | ||
|
|
2e0f251bfd | ||
|
|
2cb91e818d | ||
|
|
2835c76336 | ||
|
|
8068a2bbc3 | ||
|
|
24111d543a | ||
|
|
7eec2b8f9a | ||
|
|
b2b70ea399 | ||
|
|
e50a3c1783 | ||
|
|
b517134309 | ||
|
|
6fb539b5bf | ||
|
|
f37fe120fd | ||
|
|
2e115acb9a | ||
|
|
27a638362d | ||
|
|
22a6695d7a | ||
|
|
57eff82ee7 | ||
|
|
7732f7d41c | ||
|
|
5ca98c326f | ||
|
|
b55db397eb | ||
|
|
c04d72ac8a | ||
|
|
28b02fb72a | ||
|
|
f3cf986777 | ||
|
|
c73fcc8898 | ||
|
|
cd9debc3b7 | ||
|
|
26a97ba997 | ||
|
|
ce19fedb08 | ||
|
|
14e8e48de2 | ||
|
|
c30faf6083 | ||
|
|
64a4f025bb | ||
|
|
6dc968e7d3 | ||
|
|
06b5b69f1e | ||
|
|
6bd3a838fc | ||
|
|
f36fea8f20 | ||
|
|
0a30591729 | ||
|
|
0ed39b6146 | ||
|
|
a8c7f80073 | ||
|
|
0293bbe142 | ||
|
|
7372656369 | ||
|
|
d46bc5dd6e | ||
|
|
86efb11572 | ||
|
|
bb01ad5290 | ||
|
|
1b8cda0941 | ||
|
|
bc85a749a3 | ||
|
|
02c35d3457 | ||
|
|
345c136cfb | ||
|
|
043e388254 | ||
|
|
fe64fc4671 | ||
|
|
6d66404506 | ||
|
|
eff94ecea8 | ||
|
|
7dfb555fea | ||
|
|
f762a669e7 | ||
|
|
0bdc7140dd | ||
|
|
8f6e955b24 | ||
|
|
1096da09da | ||
|
|
683824f1e9 | ||
|
|
db7bdefe77 | ||
|
|
e41894b071 | ||
|
|
e1ae2bcbd8 | ||
|
|
ababc3f8ec | ||
|
|
a1377afcaa | ||
|
|
a26c8f3316 | ||
|
|
88d8d7249e | ||
|
|
0eb7c9ea0c | ||
|
|
1db66c6980 | ||
|
|
c58da8fc8a | ||
|
|
448c4a835d | ||
|
|
850f80de99 |
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.2.6
|
current_version = 0.3.8
|
||||||
commit = True
|
commit = True
|
||||||
message = Bump version: {current_version} → {new_version}
|
message = Bump version: {current_version} → {new_version}
|
||||||
tag = True
|
tag = True
|
||||||
|
|||||||
4
.github/workflows/node.yml
vendored
4
.github/workflows/node.yml
vendored
@@ -11,6 +11,10 @@ on:
|
|||||||
- .github/workflows/node.yml
|
- .github/workflows/node.yml
|
||||||
- docker-compose.yml
|
- docker-compose.yml
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
env:
|
env:
|
||||||
# Disable full debug symbol generation to speed up CI build and keep memory down
|
# Disable full debug symbol generation to speed up CI build and keep memory down
|
||||||
# "1" means line tables only, which is useful for panic tracebacks.
|
# "1" means line tables only, which is useful for panic tracebacks.
|
||||||
|
|||||||
2
.github/workflows/npm-publish.yml
vendored
2
.github/workflows/npm-publish.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
|||||||
node/vectordb-*.tgz
|
node/vectordb-*.tgz
|
||||||
|
|
||||||
node-macos:
|
node-macos:
|
||||||
runs-on: macos-12
|
runs-on: macos-13
|
||||||
# Only runs on tags that matches the make-release action
|
# Only runs on tags that matches the make-release action
|
||||||
if: startsWith(github.ref, 'refs/tags/v')
|
if: startsWith(github.ref, 'refs/tags/v')
|
||||||
strategy:
|
strategy:
|
||||||
|
|||||||
20
.github/workflows/python.yml
vendored
20
.github/workflows/python.yml
vendored
@@ -8,6 +8,11 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- python/**
|
- python/**
|
||||||
- .github/workflows/python.yml
|
- .github/workflows/python.yml
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
linux:
|
linux:
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
@@ -32,18 +37,19 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip install -e .[tests]
|
pip install -e .[tests]
|
||||||
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
|
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
|
||||||
pip install pytest pytest-mock black isort
|
pip install pytest pytest-mock ruff
|
||||||
- name: Black
|
- name: Lint
|
||||||
run: black --check --diff --no-color --quiet .
|
run: ruff format --check .
|
||||||
- name: isort
|
|
||||||
run: isort --check --diff --quiet .
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: pytest -m "not slow" -x -v --durations=30 tests
|
run: pytest -m "not slow" -x -v --durations=30 tests
|
||||||
- name: doctest
|
- name: doctest
|
||||||
run: pytest --doctest-modules lancedb
|
run: pytest --doctest-modules lancedb
|
||||||
mac:
|
mac:
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
runs-on: "macos-12"
|
strategy:
|
||||||
|
matrix:
|
||||||
|
mac-runner: [ "macos-13", "macos-13-xlarge" ]
|
||||||
|
runs-on: "${{ matrix.mac-runner }}"
|
||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
shell: bash
|
shell: bash
|
||||||
@@ -62,8 +68,6 @@ jobs:
|
|||||||
pip install -e .[tests]
|
pip install -e .[tests]
|
||||||
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
|
pip install tantivy@git+https://github.com/quickwit-oss/tantivy-py#164adc87e1a033117001cf70e38c82a53014d985
|
||||||
pip install pytest pytest-mock black
|
pip install pytest pytest-mock black
|
||||||
- name: Black
|
|
||||||
run: black --check --diff --no-color --quiet .
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: pytest -m "not slow" -x -v --durations=30 tests
|
run: pytest -m "not slow" -x -v --durations=30 tests
|
||||||
pydantic1x:
|
pydantic1x:
|
||||||
|
|||||||
9
.github/workflows/rust.yml
vendored
9
.github/workflows/rust.yml
vendored
@@ -10,6 +10,10 @@ on:
|
|||||||
- rust/**
|
- rust/**
|
||||||
- .github/workflows/rust.yml
|
- .github/workflows/rust.yml
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
env:
|
env:
|
||||||
# This env var is used by Swatinem/rust-cache@v2 for the cache
|
# This env var is used by Swatinem/rust-cache@v2 for the cache
|
||||||
# key, so we set it to make sure it is always consistent.
|
# key, so we set it to make sure it is always consistent.
|
||||||
@@ -44,8 +48,11 @@ jobs:
|
|||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: cargo test --all-features
|
run: cargo test --all-features
|
||||||
macos:
|
macos:
|
||||||
runs-on: macos-12
|
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
mac-runner: [ "macos-13", "macos-13-xlarge" ]
|
||||||
|
runs-on: "${{ matrix.mac-runner }}"
|
||||||
defaults:
|
defaults:
|
||||||
run:
|
run:
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|||||||
29
Cargo.toml
29
Cargo.toml
@@ -5,21 +5,24 @@ exclude = ["python"]
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
lance = { "version" = "=0.8.1", "features" = ["dynamodb"] }
|
lance = { "version" = "=0.8.17", "features" = ["dynamodb"] }
|
||||||
lance-linalg = { "version" = "=0.8.1" }
|
lance-index = { "version" = "=0.8.17" }
|
||||||
|
lance-linalg = { "version" = "=0.8.17" }
|
||||||
|
lance-testing = { "version" = "=0.8.17" }
|
||||||
# Note that this one does not include pyarrow
|
# Note that this one does not include pyarrow
|
||||||
arrow = { version = "43.0.0", optional = false }
|
arrow = { version = "47.0.0", optional = false }
|
||||||
arrow-array = "43.0"
|
arrow-array = "47.0"
|
||||||
arrow-data = "43.0"
|
arrow-data = "47.0"
|
||||||
arrow-ipc = "43.0"
|
arrow-ipc = "47.0"
|
||||||
arrow-ord = "43.0"
|
arrow-ord = "47.0"
|
||||||
arrow-schema = "43.0"
|
arrow-schema = "47.0"
|
||||||
arrow-arith = "43.0"
|
arrow-arith = "47.0"
|
||||||
arrow-cast = "43.0"
|
arrow-cast = "47.0"
|
||||||
half = { "version" = "=2.2.1", default-features = false, features = [
|
chrono = "0.4.23"
|
||||||
"num-traits"
|
half = { "version" = "=2.3.1", default-features = false, features = [
|
||||||
|
"num-traits",
|
||||||
] }
|
] }
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
object_store = "0.6.1"
|
object_store = "0.7.1"
|
||||||
snafu = "0.7.4"
|
snafu = "0.7.4"
|
||||||
url = "2"
|
url = "2"
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ The key features of LanceDB include:
|
|||||||
|
|
||||||
* Zero-copy, automatic versioning, manage versions of your data without needing extra infrastructure.
|
* Zero-copy, automatic versioning, manage versions of your data without needing extra infrastructure.
|
||||||
|
|
||||||
|
* GPU support in building vector index(*).
|
||||||
|
|
||||||
* Ecosystem integrations with [LangChain 🦜️🔗](https://python.langchain.com/en/latest/modules/indexes/vectorstores/examples/lanecdb.html), [LlamaIndex 🦙](https://gpt-index.readthedocs.io/en/latest/examples/vector_stores/LanceDBIndexDemo.html), Apache-Arrow, Pandas, Polars, DuckDB and more on the way.
|
* Ecosystem integrations with [LangChain 🦜️🔗](https://python.langchain.com/en/latest/modules/indexes/vectorstores/examples/lanecdb.html), [LlamaIndex 🦙](https://gpt-index.readthedocs.io/en/latest/examples/vector_stores/LanceDBIndexDemo.html), Apache-Arrow, Pandas, Polars, DuckDB and more on the way.
|
||||||
|
|
||||||
LanceDB's core is written in Rust 🦀 and is built using <a href="https://github.com/lancedb/lance">Lance</a>, an open-source columnar format designed for performant ML workloads.
|
LanceDB's core is written in Rust 🦀 and is built using <a href="https://github.com/lancedb/lance">Lance</a>, an open-source columnar format designed for performant ML workloads.
|
||||||
@@ -52,8 +54,7 @@ const table = await db.createTable('vectors',
|
|||||||
[{ id: 1, vector: [0.1, 0.2], item: "foo", price: 10 },
|
[{ id: 1, vector: [0.1, 0.2], item: "foo", price: 10 },
|
||||||
{ id: 2, vector: [1.1, 1.2], item: "bar", price: 50 }])
|
{ id: 2, vector: [1.1, 1.2], item: "bar", price: 50 }])
|
||||||
|
|
||||||
const query = table.search([0.1, 0.3]);
|
const query = table.search([0.1, 0.3]).limit(2);
|
||||||
query.limit = 20;
|
|
||||||
const results = await query.execute();
|
const results = await query.execute();
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -70,7 +71,7 @@ db = lancedb.connect(uri)
|
|||||||
table = db.create_table("my_table",
|
table = db.create_table("my_table",
|
||||||
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}])
|
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}])
|
||||||
result = table.search([100, 100]).limit(2).to_df()
|
result = table.search([100, 100]).limit(2).to_pandas()
|
||||||
```
|
```
|
||||||
|
|
||||||
## Blogs, Tutorials & Videos
|
## Blogs, Tutorials & Videos
|
||||||
|
|||||||
26
docs/README.md
Normal file
26
docs/README.md
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# LanceDB Documentation
|
||||||
|
|
||||||
|
LanceDB docs are deployed to https://lancedb.github.io/lancedb/.
|
||||||
|
|
||||||
|
Docs is built and deployed automatically by [Github Actions](.github/workflows/docs.yml)
|
||||||
|
whenever a commit is pushed to the `main` branch. So it is possible for the docs to show
|
||||||
|
unreleased features.
|
||||||
|
|
||||||
|
## Building the docs
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
1. Install LanceDB. From LanceDB repo root: `pip install -e python`
|
||||||
|
2. Install dependencies. From LanceDB repo root: `pip install -r docs/requirements.txt`
|
||||||
|
3. Make sure you have node and npm setup
|
||||||
|
4. Make sure protobuf and libssl are installed
|
||||||
|
|
||||||
|
### Building node module and create markdown files
|
||||||
|
|
||||||
|
See [Javascript docs README](docs/src/javascript/README.md)
|
||||||
|
|
||||||
|
### Build docs
|
||||||
|
From LanceDB repo root:
|
||||||
|
|
||||||
|
Run: `PYTHONPATH=. mkdocs build -f docs/mkdocs.yml`
|
||||||
|
|
||||||
|
If successful, you should see a `docs/site` directory that you can verify locally.
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
site_name: LanceDB Docs
|
site_name: LanceDB Docs
|
||||||
|
site_url: https://lancedb.github.io/lancedb/
|
||||||
repo_url: https://github.com/lancedb/lancedb
|
repo_url: https://github.com/lancedb/lancedb
|
||||||
edit_uri: https://github.com/lancedb/lancedb/tree/main/docs/src
|
edit_uri: https://github.com/lancedb/lancedb/tree/main/docs/src
|
||||||
repo_name: lancedb/lancedb
|
repo_name: lancedb/lancedb
|
||||||
@@ -21,6 +22,7 @@ theme:
|
|||||||
- navigation.tracking
|
- navigation.tracking
|
||||||
- navigation.instant
|
- navigation.instant
|
||||||
- navigation.indexes
|
- navigation.indexes
|
||||||
|
- navigation.expand
|
||||||
icon:
|
icon:
|
||||||
repo: fontawesome/brands/github
|
repo: fontawesome/brands/github
|
||||||
custom_dir: overrides
|
custom_dir: overrides
|
||||||
@@ -36,7 +38,7 @@ plugins:
|
|||||||
docstring_style: numpy
|
docstring_style: numpy
|
||||||
rendering:
|
rendering:
|
||||||
heading_level: 4
|
heading_level: 4
|
||||||
show_source: false
|
show_source: true
|
||||||
show_symbol_type_in_heading: true
|
show_symbol_type_in_heading: true
|
||||||
show_signature_annotations: true
|
show_signature_annotations: true
|
||||||
show_root_heading: true
|
show_root_heading: true
|
||||||
@@ -68,11 +70,18 @@ nav:
|
|||||||
- 🏢 Home: index.md
|
- 🏢 Home: index.md
|
||||||
- 💡 Basics: basic.md
|
- 💡 Basics: basic.md
|
||||||
- 📚 Guides:
|
- 📚 Guides:
|
||||||
- Tables: guides/tables.md
|
- Create Ingest Update Delete: guides/tables.md
|
||||||
- Vector Search: search.md
|
- Vector Search: search.md
|
||||||
- SQL filters: sql.md
|
- SQL filters: sql.md
|
||||||
- Indexing: ann_indexes.md
|
- Indexing: ann_indexes.md
|
||||||
- 🧬 Embeddings: embedding.md
|
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||||
|
- 🧬 Embeddings:
|
||||||
|
- embeddings/index.md
|
||||||
|
- Ingest Embedding Functions: embeddings/embedding_functions.md
|
||||||
|
- Available Functions: embeddings/default_embedding_functions.md
|
||||||
|
- Create Custom Embedding Functions: embeddings/api.md
|
||||||
|
- Example - Multi-lingual semantic search: notebooks/multi_lingual_example.ipynb
|
||||||
|
- Example - MultiModal CLIP Embeddings: notebooks/DisappearingEmbeddingFunction.ipynb
|
||||||
- 🔍 Python full-text search: fts.md
|
- 🔍 Python full-text search: fts.md
|
||||||
- 🔌 Integrations:
|
- 🔌 Integrations:
|
||||||
- integrations/index.md
|
- integrations/index.md
|
||||||
@@ -96,13 +105,22 @@ nav:
|
|||||||
- Serverless Website Chatbot: examples/serverless_website_chatbot.md
|
- Serverless Website Chatbot: examples/serverless_website_chatbot.md
|
||||||
- YouTube Transcript Search: examples/youtube_transcript_bot_with_nodejs.md
|
- YouTube Transcript Search: examples/youtube_transcript_bot_with_nodejs.md
|
||||||
- TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md
|
- TransformersJS Embedding Search: examples/transformerjs_embedding_search_nodejs.md
|
||||||
|
- ⚙️ CLI & Config: cli_config.md
|
||||||
|
|
||||||
- Basics: basic.md
|
- Basics: basic.md
|
||||||
- Guides:
|
- Guides:
|
||||||
- Tables: guides/tables.md
|
- Create Ingest Update Delete: guides/tables.md
|
||||||
- Vector Search: search.md
|
- Vector Search: search.md
|
||||||
- SQL filters: sql.md
|
- SQL filters: sql.md
|
||||||
- Indexing: ann_indexes.md
|
- Indexing: ann_indexes.md
|
||||||
- Embeddings: embedding.md
|
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||||
|
- Embeddings:
|
||||||
|
- embeddings/index.md
|
||||||
|
- Ingest Embedding Functions: embeddings/embedding_functions.md
|
||||||
|
- Available Functions: embeddings/default_embedding_functions.md
|
||||||
|
- Create Custom Embedding Functions: embeddings/api.md
|
||||||
|
- Example - Multi-lingual semantic search: notebooks/multi_lingual_example.ipynb
|
||||||
|
- Example - MultiModal CLIP Embeddings: notebooks/DisappearingEmbeddingFunction.ipynb
|
||||||
- Python full-text search: fts.md
|
- Python full-text search: fts.md
|
||||||
- Integrations:
|
- Integrations:
|
||||||
- integrations/index.md
|
- integrations/index.md
|
||||||
|
|||||||
@@ -68,6 +68,44 @@ a single PQ code.
|
|||||||
<figcaption>IVF_PQ index with <code>num_partitions=2, num_sub_vectors=4</code></figcaption>
|
<figcaption>IVF_PQ index with <code>num_partitions=2, num_sub_vectors=4</code></figcaption>
|
||||||
</figure>
|
</figure>
|
||||||
|
|
||||||
|
### Use GPU to build vector index
|
||||||
|
|
||||||
|
Lance Python SDK has experimental GPU support for creating IVF index.
|
||||||
|
Using GPU for index creation requires [PyTorch>2.0](https://pytorch.org/) being installed.
|
||||||
|
|
||||||
|
You can specify the GPU device to train IVF partitions via
|
||||||
|
|
||||||
|
- **accelerator**: Specify to ``cuda`` or ``mps`` (on Apple Silicon) to enable GPU training.
|
||||||
|
|
||||||
|
=== "Linux"
|
||||||
|
|
||||||
|
<!-- skip-test -->
|
||||||
|
``` { .python .copy }
|
||||||
|
# Create index using CUDA on Nvidia GPUs.
|
||||||
|
tbl.create_index(
|
||||||
|
num_partitions=256,
|
||||||
|
num_sub_vectors=96,
|
||||||
|
accelerator="cuda"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "Macos"
|
||||||
|
|
||||||
|
<!-- skip-test -->
|
||||||
|
```python
|
||||||
|
# Create index using MPS on Apple Silicon.
|
||||||
|
tbl.create_index(
|
||||||
|
num_partitions=256,
|
||||||
|
num_sub_vectors=96,
|
||||||
|
accelerator="mps"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Trouble shootings:
|
||||||
|
|
||||||
|
If you see ``AssertionError: Torch not compiled with CUDA enabled``, you need to [install
|
||||||
|
PyTorch with CUDA support](https://pytorch.org/get-started/locally/).
|
||||||
|
|
||||||
|
|
||||||
## Querying an ANN Index
|
## Querying an ANN Index
|
||||||
|
|
||||||
@@ -91,7 +129,7 @@ There are a couple of parameters that can be used to fine-tune the search:
|
|||||||
.limit(2) \
|
.limit(2) \
|
||||||
.nprobes(20) \
|
.nprobes(20) \
|
||||||
.refine_factor(10) \
|
.refine_factor(10) \
|
||||||
.to_df()
|
.to_pandas()
|
||||||
```
|
```
|
||||||
```
|
```
|
||||||
vector item _distance
|
vector item _distance
|
||||||
@@ -118,7 +156,7 @@ You can further filter the elements returned by a search using a where clause.
|
|||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
```python
|
```python
|
||||||
tbl.search(np.random.random((1536))).where("item != 'item 1141'").to_df()
|
tbl.search(np.random.random((1536))).where("item != 'item 1141'").to_pandas()
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "Javascript"
|
=== "Javascript"
|
||||||
@@ -135,7 +173,7 @@ You can select the columns returned by the query using a select clause.
|
|||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
```python
|
```python
|
||||||
tbl.search(np.random.random((1536))).select(["vector"]).to_df()
|
tbl.search(np.random.random((1536))).select(["vector"]).to_pandas()
|
||||||
```
|
```
|
||||||
```
|
```
|
||||||
vector _distance
|
vector _distance
|
||||||
|
|||||||
BIN
docs/src/assets/dog_clip_output.png
Normal file
BIN
docs/src/assets/dog_clip_output.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 342 KiB |
BIN
docs/src/assets/embedding_intro.png
Normal file
BIN
docs/src/assets/embedding_intro.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 245 KiB |
BIN
docs/src/assets/embeddings_api.png
Normal file
BIN
docs/src/assets/embeddings_api.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 83 KiB |
@@ -146,7 +146,7 @@ Once you've embedded the query, you can find its nearest neighbors using the fol
|
|||||||
|
|
||||||
=== "Python"
|
=== "Python"
|
||||||
```python
|
```python
|
||||||
tbl.search([100, 100]).limit(2).to_df()
|
tbl.search([100, 100]).limit(2).to_pandas()
|
||||||
```
|
```
|
||||||
|
|
||||||
This returns a pandas DataFrame with the results.
|
This returns a pandas DataFrame with the results.
|
||||||
|
|||||||
37
docs/src/cli_config.md
Normal file
37
docs/src/cli_config.md
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
|
||||||
|
## LanceDB CLI
|
||||||
|
Once lanceDB is installed, you can access the CLI using `lancedb` command on the console
|
||||||
|
```
|
||||||
|
lancedb
|
||||||
|
```
|
||||||
|
This lists out all the various command-line options available. You can get the usage or help for a particular command
|
||||||
|
```
|
||||||
|
lancedb {command} --help
|
||||||
|
```
|
||||||
|
|
||||||
|
## LanceDB config
|
||||||
|
LanceDB uses a global config file to store certain settings. These settings are configurable using the lanceDB cli.
|
||||||
|
To view your config settings, you can use:
|
||||||
|
```
|
||||||
|
lancedb config
|
||||||
|
```
|
||||||
|
These config parameters can be tuned using the cli.
|
||||||
|
```
|
||||||
|
lancedb {config_name} --{argument}
|
||||||
|
```
|
||||||
|
|
||||||
|
## LanceDB Opt-in Diagnostics
|
||||||
|
When enabled, LanceDB will send anonymous events to help us improve LanceDB. These diagnostics are used only for error reporting and no data is collected. Error & stats allow us to automate certain aspects of bug reporting, prioritization of fixes and feature requests.
|
||||||
|
These diagnostics are opt-in and can be enabled or disabled using the `lancedb diagnostics` command. These are enabled by default.
|
||||||
|
Get usage help.
|
||||||
|
```
|
||||||
|
lancedb diagnostics --help
|
||||||
|
```
|
||||||
|
Disable diagnostics
|
||||||
|
```
|
||||||
|
lancedb diagnostics --disabled
|
||||||
|
```
|
||||||
|
Enable diagnostics
|
||||||
|
```
|
||||||
|
lancedb diagnostics --enabled
|
||||||
|
```
|
||||||
213
docs/src/embeddings/api.md
Normal file
213
docs/src/embeddings/api.md
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
To use your own custom embedding function, you need to follow these 2 simple steps.
|
||||||
|
1. Create your embedding function by implementing the `EmbeddingFunction` interface
|
||||||
|
2. Register your embedding function in the global `EmbeddingFunctionRegistry`.
|
||||||
|
|
||||||
|
Let us see how this looks like in action.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
`EmbeddingFunction` & `EmbeddingFunctionRegistry` handle low-level details for serializing schema and model information as metadata. To build a custom embdding function, you don't need to worry about those details and simply focus on setting up the model.
|
||||||
|
|
||||||
|
## `TextEmbeddingFunction` Interface
|
||||||
|
|
||||||
|
There is another optional layer of abstraction provided in form of `TextEmbeddingFunction`. You can use this if your model isn't multi-modal in nature and only operates on text. In such case both source and vector fields will have the same pathway for vectorization, so you simply just need to setup the model and rest is handled by `TextEmbeddingFunction`. You can read more about the class and its attributes in the class reference.
|
||||||
|
|
||||||
|
|
||||||
|
Let's implement `SentenceTransformerEmbeddings` class. All you need to do is implement the `generate_embeddings()` and `ndims` function to handle the input types you expect and register the class in the global `EmbeddingFunctionRegistry`
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lancedb.embeddings import register
|
||||||
|
|
||||||
|
@register("sentence-transformers")
|
||||||
|
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||||
|
name: str = "all-MiniLM-L6-v2"
|
||||||
|
# set more default instance vars like device, etc.
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._ndims = None
|
||||||
|
|
||||||
|
def generate_embeddings(self, texts):
|
||||||
|
return self._embedding_model().encode(list(texts), ...).tolist()
|
||||||
|
|
||||||
|
def ndims(self):
|
||||||
|
if self._ndims is None:
|
||||||
|
self._ndims = len(self.generate_embeddings("foo")[0])
|
||||||
|
return self._ndims
|
||||||
|
|
||||||
|
@cached(cache={})
|
||||||
|
def _embedding_model(self):
|
||||||
|
return sentence_transformers.SentenceTransformer(name)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
This is a stripped down version of our implementation of `SentenceTransformerEmbeddings` that removes certain optimizations and defaul settings.
|
||||||
|
|
||||||
|
Now you can use this embedding function to create your table schema and that's it! you can then ingest data and run queries without manually vectorizing the inputs.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
|
||||||
|
registry = EmbeddingFunctionRegistry.get_instance()
|
||||||
|
stransformer = registry.get("sentence-transformers").create()
|
||||||
|
|
||||||
|
class TextModelSchema(LanceModel):
|
||||||
|
vector: Vector(stransformer.ndims) = stransformer.VectorField()
|
||||||
|
text: str = stransformer.SourceField()
|
||||||
|
|
||||||
|
tbl = db.create_table("table", schema=TextModelSchema)
|
||||||
|
|
||||||
|
tbl.add(pd.DataFrame({"text": ["halo", "world"]}))
|
||||||
|
result = tbl.search("world").limit(5)
|
||||||
|
```
|
||||||
|
|
||||||
|
NOTE:
|
||||||
|
|
||||||
|
You can always implement the `EmbeddingFunction` interface directly if you want or need to, `TextEmbeddingFunction` just makes it much simpler and faster for you to do so, by setting up the boiler plat for text-specific use case
|
||||||
|
|
||||||
|
## Multi-modal embedding function example
|
||||||
|
You can also use the `EmbeddingFunction` interface to implement more complex workflows such as multi-modal embedding function support. LanceDB implements `OpenClipEmeddingFunction` class that suppports multi-modal seach. Here's the implementation that you can use as a reference to build your own multi-modal embedding functions.
|
||||||
|
|
||||||
|
```python
|
||||||
|
@register("open-clip")
|
||||||
|
class OpenClipEmbeddings(EmbeddingFunction):
|
||||||
|
name: str = "ViT-B-32"
|
||||||
|
pretrained: str = "laion2b_s34b_b79k"
|
||||||
|
device: str = "cpu"
|
||||||
|
batch_size: int = 64
|
||||||
|
normalize: bool = True
|
||||||
|
_model = PrivateAttr()
|
||||||
|
_preprocess = PrivateAttr()
|
||||||
|
_tokenizer = PrivateAttr()
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
open_clip = self.safe_import("open_clip", "open-clip") # EmbeddingFunction util to import external libs and raise if not found
|
||||||
|
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||||
|
self.name, pretrained=self.pretrained
|
||||||
|
)
|
||||||
|
model.to(self.device)
|
||||||
|
self._model, self._preprocess = model, preprocess
|
||||||
|
self._tokenizer = open_clip.get_tokenizer(self.name)
|
||||||
|
self._ndims = None
|
||||||
|
|
||||||
|
def ndims(self):
|
||||||
|
if self._ndims is None:
|
||||||
|
self._ndims = self.generate_text_embeddings("foo").shape[0]
|
||||||
|
return self._ndims
|
||||||
|
|
||||||
|
def compute_query_embeddings(
|
||||||
|
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
|
||||||
|
) -> List[np.ndarray]:
|
||||||
|
"""
|
||||||
|
Compute the embeddings for a given user query
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
query : Union[str, PIL.Image.Image]
|
||||||
|
The query to embed. A query can be either text or an image.
|
||||||
|
"""
|
||||||
|
if isinstance(query, str):
|
||||||
|
return [self.generate_text_embeddings(query)]
|
||||||
|
else:
|
||||||
|
PIL = self.safe_import("PIL", "pillow")
|
||||||
|
if isinstance(query, PIL.Image.Image):
|
||||||
|
return [self.generate_image_embedding(query)]
|
||||||
|
else:
|
||||||
|
raise TypeError("OpenClip supports str or PIL Image as query")
|
||||||
|
|
||||||
|
def generate_text_embeddings(self, text: str) -> np.ndarray:
|
||||||
|
torch = self.safe_import("torch")
|
||||||
|
text = self.sanitize_input(text)
|
||||||
|
text = self._tokenizer(text)
|
||||||
|
text.to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
text_features = self._model.encode_text(text.to(self.device))
|
||||||
|
if self.normalize:
|
||||||
|
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||||
|
return text_features.cpu().numpy().squeeze()
|
||||||
|
|
||||||
|
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
|
||||||
|
"""
|
||||||
|
Sanitize the input to the embedding function.
|
||||||
|
"""
|
||||||
|
if isinstance(images, (str, bytes)):
|
||||||
|
images = [images]
|
||||||
|
elif isinstance(images, pa.Array):
|
||||||
|
images = images.to_pylist()
|
||||||
|
elif isinstance(images, pa.ChunkedArray):
|
||||||
|
images = images.combine_chunks().to_pylist()
|
||||||
|
return images
|
||||||
|
|
||||||
|
def compute_source_embeddings(
|
||||||
|
self, images: IMAGES, *args, **kwargs
|
||||||
|
) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Get the embeddings for the given images
|
||||||
|
"""
|
||||||
|
images = self.sanitize_input(images)
|
||||||
|
embeddings = []
|
||||||
|
for i in range(0, len(images), self.batch_size):
|
||||||
|
j = min(i + self.batch_size, len(images))
|
||||||
|
batch = images[i:j]
|
||||||
|
embeddings.extend(self._parallel_get(batch))
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
|
||||||
|
"""
|
||||||
|
Issue concurrent requests to retrieve the image data
|
||||||
|
"""
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
futures = [
|
||||||
|
executor.submit(self.generate_image_embedding, image)
|
||||||
|
for image in images
|
||||||
|
]
|
||||||
|
return [future.result() for future in futures]
|
||||||
|
|
||||||
|
def generate_image_embedding(
|
||||||
|
self, image: Union[str, bytes, "PIL.Image.Image"]
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Generate the embedding for a single image
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
image : Union[str, bytes, PIL.Image.Image]
|
||||||
|
The image to embed. If the image is a str, it is treated as a uri.
|
||||||
|
If the image is bytes, it is treated as the raw image bytes.
|
||||||
|
"""
|
||||||
|
torch = self.safe_import("torch")
|
||||||
|
# TODO handle retry and errors for https
|
||||||
|
image = self._to_pil(image)
|
||||||
|
image = self._preprocess(image).unsqueeze(0)
|
||||||
|
with torch.no_grad():
|
||||||
|
return self._encode_and_normalize_image(image)
|
||||||
|
|
||||||
|
def _to_pil(self, image: Union[str, bytes]):
|
||||||
|
PIL = self.safe_import("PIL", "pillow")
|
||||||
|
if isinstance(image, bytes):
|
||||||
|
return PIL.Image.open(io.BytesIO(image))
|
||||||
|
if isinstance(image, PIL.Image.Image):
|
||||||
|
return image
|
||||||
|
elif isinstance(image, str):
|
||||||
|
parsed = urlparse.urlparse(image)
|
||||||
|
# TODO handle drive letter on windows.
|
||||||
|
if parsed.scheme == "file":
|
||||||
|
return PIL.Image.open(parsed.path)
|
||||||
|
elif parsed.scheme == "":
|
||||||
|
return PIL.Image.open(image if os.name == "nt" else parsed.path)
|
||||||
|
elif parsed.scheme.startswith("http"):
|
||||||
|
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Only local and http(s) urls are supported")
|
||||||
|
|
||||||
|
def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"):
|
||||||
|
"""
|
||||||
|
encode a single image tensor and optionally normalize the output
|
||||||
|
"""
|
||||||
|
image_features = self._model.encode_image(image_tensor)
|
||||||
|
if self.normalize:
|
||||||
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||||
|
return image_features.cpu().numpy().squeeze()
|
||||||
|
```
|
||||||
208
docs/src/embeddings/default_embedding_functions.md
Normal file
208
docs/src/embeddings/default_embedding_functions.md
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
There are various Embedding functions available out of the box with lancedb. We're working on supporting other popular embedding APIs.
|
||||||
|
|
||||||
|
## Text Embedding Functions
|
||||||
|
Here are the text embedding functions registered by default.
|
||||||
|
Embedding functions have inbuilt rate limit handler wrapper for source and query embedding function calls that retry with exponential standoff.
|
||||||
|
Each `EmbeddingFunction` implementation automatically takes `max_retries` as an argument which has the deafult value of 7.
|
||||||
|
|
||||||
|
### Sentence Transformers
|
||||||
|
Here are the parameters that you can set when registering a `sentence-transformers` object, and their default values:
|
||||||
|
|
||||||
|
| Parameter | Type | Default Value | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `name` | `str` | `"all-MiniLM-L6-v2"` | The name of the model. |
|
||||||
|
| `device` | `str` | `"cpu"` | The device to run the model on. Can be `"cpu"` or `"gpu"`. |
|
||||||
|
| `normalize` | `bool` | `True` | Whether to normalize the input text before feeding it to the model. |
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
db = lancedb.connect("/tmp/db")
|
||||||
|
registry = EmbeddingFunctionRegistry.get_instance()
|
||||||
|
func = registry.get("sentence-transformers").create(device="cpu")
|
||||||
|
|
||||||
|
class Words(LanceModel):
|
||||||
|
text: str = func.SourceField()
|
||||||
|
vector: Vector(func.ndims()) = func.VectorField()
|
||||||
|
|
||||||
|
table = db.create_table("words", schema=Words)
|
||||||
|
table.add(
|
||||||
|
[
|
||||||
|
{"text": "hello world"}
|
||||||
|
{"text": "goodbye world"}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
query = "greetings"
|
||||||
|
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||||
|
print(actual.text)
|
||||||
|
```
|
||||||
|
|
||||||
|
### OpenAIEmbeddings
|
||||||
|
LanceDB has OpenAI embeddings function in the registry by default. It is registered as `openai` and here are the parameters that you can customize when creating the instances
|
||||||
|
|
||||||
|
| Parameter | Type | Default Value | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `name` | `str` | `"text-embedding-ada-002"` | The name of the model. |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
db = lancedb.connect("/tmp/db")
|
||||||
|
registry = EmbeddingFunctionRegistry.get_instance()
|
||||||
|
func = registry.get("openai").create()
|
||||||
|
|
||||||
|
class Words(LanceModel):
|
||||||
|
text: str = func.SourceField()
|
||||||
|
vector: Vector(func.ndims()) = func.VectorField()
|
||||||
|
|
||||||
|
table = db.create_table("words", schema=Words)
|
||||||
|
table.add(
|
||||||
|
[
|
||||||
|
{"text": "hello world"}
|
||||||
|
{"text": "goodbye world"}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
query = "greetings"
|
||||||
|
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||||
|
print(actual.text)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Instructor Embeddings
|
||||||
|
Instructor is an instruction-finetuned text embedding model that can generate text embeddings tailored to any task (e.g., classification, retrieval, clustering, text evaluation, etc.) and domains (e.g., science, finance, etc.) by simply providing the task instruction, without any finetuning
|
||||||
|
|
||||||
|
If you want to calculate customized embeddings for specific sentences, you may follow the unified template to write instructions:
|
||||||
|
|
||||||
|
Represent the `domain` `text_type` for `task_objective`:
|
||||||
|
|
||||||
|
* `domain` is optional, and it specifies the domain of the text, e.g., science, finance, medicine, etc.
|
||||||
|
* `text_type` is required, and it specifies the encoding unit, e.g., sentence, document, paragraph, etc.
|
||||||
|
* `task_objective` is optional, and it specifies the objective of embedding, e.g., retrieve a document, classify the sentence, etc.
|
||||||
|
|
||||||
|
More information about the model can be found here - https://github.com/xlang-ai/instructor-embedding
|
||||||
|
|
||||||
|
| Argument | Type | Default | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `name` | `str` | "hkunlp/instructor-base" | The name of the model to use |
|
||||||
|
| `batch_size` | `int` | `32` | The batch size to use when generating embeddings |
|
||||||
|
| `device` | `str` | `"cpu"` | The device to use when generating embeddings |
|
||||||
|
| `show_progress_bar` | `bool` | `True` | Whether to show a progress bar when generating embeddings |
|
||||||
|
| `normalize_embeddings` | `bool` | `True` | Whether to normalize the embeddings |
|
||||||
|
| `quantize` | `bool` | `False` | Whether to quantize the model |
|
||||||
|
| `source_instruction` | `str` | `"represent the docuement for retreival"` | The instruction for the source column |
|
||||||
|
| `query_instruction` | `str` | `"represent the document for retreiving the most similar documents"` | The instruction for the query |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
import lancedb
|
||||||
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
from lancedb.embeddings import get_registry, InstuctorEmbeddingFunction
|
||||||
|
|
||||||
|
instructor = get_registry().get("instructor").create(
|
||||||
|
source_instruction="represent the docuement for retreival",
|
||||||
|
query_instruction="represent the document for retreiving the most similar documents"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Schema(LanceModel):
|
||||||
|
vector: Vector(instructor.ndims()) = instructor.VectorField()
|
||||||
|
text: str = instructor.SourceField()
|
||||||
|
|
||||||
|
db = lancedb.connect("~/.lancedb")
|
||||||
|
tbl = db.create_table("test", schema=Schema, mode="overwrite")
|
||||||
|
|
||||||
|
texts = [{"text": "Capitalism has been dominant in the Western world since the end of feudalism, but most feel[who?] that..."},
|
||||||
|
{"text": "The disparate impact theory is especially controversial under the Fair Housing Act because the Act..."},
|
||||||
|
{"text": "Disparate impact in United States labor law refers to practices in employment, housing, and other areas that.."}]
|
||||||
|
|
||||||
|
tbl.add(texts)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Multi-modal embedding functions
|
||||||
|
Multi-modal embedding functions allow you query your table using both images and text.
|
||||||
|
|
||||||
|
### OpenClipEmbeddings
|
||||||
|
We support CLIP model embeddings using the open souce alternbative, open-clip which support various customizations. It is registered as `open-clip` and supports following customizations.
|
||||||
|
|
||||||
|
|
||||||
|
| Parameter | Type | Default Value | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `name` | `str` | `"ViT-B-32"` | The name of the model. |
|
||||||
|
| `pretrained` | `str` | `"laion2b_s34b_b79k"` | The name of the pretrained model to load. |
|
||||||
|
| `device` | `str` | `"cpu"` | The device to run the model on. Can be `"cpu"` or `"gpu"`. |
|
||||||
|
| `batch_size` | `int` | `64` | The number of images to process in a batch. |
|
||||||
|
| `normalize` | `bool` | `True` | Whether to normalize the input images before feeding them to the model. |
|
||||||
|
|
||||||
|
|
||||||
|
This embedding function supports ingesting images as both bytes and urls. You can query them using both test and other images.
|
||||||
|
|
||||||
|
NOTE:
|
||||||
|
LanceDB supports ingesting images directly from accessible links.
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
registry = EmbeddingFunctionRegistry.get_instance()
|
||||||
|
func = registry.get("open-clip").create()
|
||||||
|
|
||||||
|
class Images(LanceModel):
|
||||||
|
label: str
|
||||||
|
image_uri: str = func.SourceField() # image uri as the source
|
||||||
|
image_bytes: bytes = func.SourceField() # image bytes as the source
|
||||||
|
vector: Vector(func.ndims()) = func.VectorField() # vector column
|
||||||
|
vec_from_bytes: Vector(func.ndims()) = func.VectorField() # Another vector column
|
||||||
|
|
||||||
|
table = db.create_table("images", schema=Images)
|
||||||
|
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
|
||||||
|
uris = [
|
||||||
|
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
|
||||||
|
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
|
||||||
|
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
|
||||||
|
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
|
||||||
|
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
|
||||||
|
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
|
||||||
|
]
|
||||||
|
# get each uri as bytes
|
||||||
|
image_bytes = [requests.get(uri).content for uri in uris]
|
||||||
|
table.add(
|
||||||
|
[{"label": labels, "image_uri": uris, "image_bytes": image_bytes}]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
Now we can search using text from both the default vector column and the custom vector column
|
||||||
|
```python
|
||||||
|
|
||||||
|
# text search
|
||||||
|
actual = table.search("man's best friend").limit(1).to_pydantic(Images)[0]
|
||||||
|
print(actual.label) # prints "dog"
|
||||||
|
|
||||||
|
frombytes = (
|
||||||
|
table.search("man's best friend", vector_column_name="vec_from_bytes")
|
||||||
|
.limit(1)
|
||||||
|
.to_pydantic(Images)[0]
|
||||||
|
)
|
||||||
|
print(frombytes.label)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
Because we're using a multi-modal embedding function, we can also search using images
|
||||||
|
|
||||||
|
```python
|
||||||
|
# image search
|
||||||
|
query_image_uri = "http://farm1.staticflickr.com/200/467715466_ed4a31801f_z.jpg"
|
||||||
|
image_bytes = requests.get(query_image_uri).content
|
||||||
|
query_image = Image.open(io.BytesIO(image_bytes))
|
||||||
|
actual = table.search(query_image).limit(1).to_pydantic(Images)[0]
|
||||||
|
print(actual.label == "dog")
|
||||||
|
|
||||||
|
# image search using a custom vector column
|
||||||
|
other = (
|
||||||
|
table.search(query_image, vector_column_name="vec_from_bytes")
|
||||||
|
.limit(1)
|
||||||
|
.to_pydantic(Images)[0]
|
||||||
|
)
|
||||||
|
print(actual.label)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
If you have any questions about the embeddings API, supported models, or see a relevant model missing, please raise an issue.
|
||||||
95
docs/src/embeddings/embedding_functions.md
Normal file
95
docs/src/embeddings/embedding_functions.md
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
Representing multi-modal data as vector embeddings is becoming a standard practice. Embedding functions themselves be thought of as a part of the processing pipeline that each request(input) has to be passed through. After initial setup these components are not expected to change for a particular project.
|
||||||
|
|
||||||
|
This is main motivation behind our new embedding functions API, that allow you simply set it up once and the table remembers it, effectively making the **embedding functions disappear in the background** so you don't have to worry about modelling and simply focus on the DB aspects of VectorDB.
|
||||||
|
|
||||||
|
|
||||||
|
You can simply follow these steps and forget about the details of your embedding functions as long as you don't intend to change it.
|
||||||
|
|
||||||
|
### Step 1 - Define the embedding function
|
||||||
|
We have some pre-defined embedding functions in the global registry with more coming soon. Here's let's an implementation of CLIP as example.
|
||||||
|
```
|
||||||
|
registry = EmbeddingFunctionRegistry.get_instance()
|
||||||
|
clip = registry.get("open-clip").create()
|
||||||
|
|
||||||
|
```
|
||||||
|
You can also define your own embedding function by implementing the `EmbeddingFunction` abstract base interface. It subclasses PyDantic Model which can be utilized to write complex schemas simply as we'll see next!
|
||||||
|
|
||||||
|
### Step 2 - Define the Data Model or Schema
|
||||||
|
Our embedding function from the previous section abstracts away all the details about the models and dimensions required to define the schema. You can simply set a feild as **source** or **vector** column. Here's how
|
||||||
|
|
||||||
|
```python
|
||||||
|
class Pets(LanceModel):
|
||||||
|
vector: Vector(clip.ndims) = clip.VectorField()
|
||||||
|
image_uri: str = clip.SourceField()
|
||||||
|
|
||||||
|
```
|
||||||
|
`VectorField` tells LanceDB to use the clip embedding function to generate query embeddings for `vector` column & `SourceField` tells that when adding data, automatically use the embedding function to encode `image_uri`.
|
||||||
|
|
||||||
|
|
||||||
|
### Step 3 - Create LanceDB Table
|
||||||
|
Now that we have chosen/defined our embedding function and the schema, we can create the table
|
||||||
|
|
||||||
|
```python
|
||||||
|
db = lancedb.connect("~/lancedb")
|
||||||
|
table = db.create_table("pets", schema=Pets)
|
||||||
|
|
||||||
|
```
|
||||||
|
That's it! We have ingested all the information needed to embed source and query inputs. We can now forget about the model and dimension details and start to build or VectorDB
|
||||||
|
|
||||||
|
### Step 4 - Ingest lots of data and run vector search!
|
||||||
|
Now you can just add the data and it'll be vectorized automatically
|
||||||
|
|
||||||
|
```python
|
||||||
|
table.add([{"image_uri": u} for u in uris])
|
||||||
|
```
|
||||||
|
|
||||||
|
Our OpenCLIP query embedding function support querying via both text and images.
|
||||||
|
|
||||||
|
```python
|
||||||
|
result = table.search("dog")
|
||||||
|
```
|
||||||
|
|
||||||
|
Let's query an image
|
||||||
|
|
||||||
|
```python
|
||||||
|
p = Path("path/to/images/samoyed_100.jpg")
|
||||||
|
query_image = Image.open(p)
|
||||||
|
table.search(query_image)
|
||||||
|
|
||||||
|
```
|
||||||
|
### Rate limit Handling
|
||||||
|
`EmbeddingFunction` class wraps the calls for source and query embedding generation inside a rate limit handler that retries the requests with exponential backoff after successive failures. By default the maximum retires is set to 7. You can tune it by setting it to a different number or disable it by setting it to 0.
|
||||||
|
Example
|
||||||
|
----
|
||||||
|
|
||||||
|
```python
|
||||||
|
clip = registry.get("open-clip").create() # Defaults to 7 max retries
|
||||||
|
clip = registry.get("open-clip").create(max_retries=10) # Increase max retries to 10
|
||||||
|
clip = registry.get("open-clip").create(max_retries=0) # Retries disabled
|
||||||
|
````
|
||||||
|
|
||||||
|
NOTE:
|
||||||
|
Embedding functions can also fail due to other errors that have nothing to do with rate limits. This is why the error is also logged.
|
||||||
|
|
||||||
|
### A little fun with PyDantic
|
||||||
|
LanceDB is integrated with PyDantic. Infact we've used the integration in the above example to define the schema. It is also being used behing the scene by the embdding function API to ingest useful information as table metadata.
|
||||||
|
You can also use it for adding utility operations in the schema. For example, in our multi-modal example, you can search images using text or another image. Let us define a utility function to plot the image.
|
||||||
|
```python
|
||||||
|
class Pets(LanceModel):
|
||||||
|
vector: Vector(clip.ndims) = clip.VectorField()
|
||||||
|
image_uri: str = clip.SourceField()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image(self):
|
||||||
|
return Image.open(self.image_uri)
|
||||||
|
```
|
||||||
|
Now, you can covert your search results to pydantic model and use this property.
|
||||||
|
|
||||||
|
```python
|
||||||
|
rs = table.search(query_image).limit(3).to_pydantic(Pets)
|
||||||
|
rs[2].image
|
||||||
|
```
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
Now that you've the basic idea about LanceDB embedding function, let us now dive deeper into the API that you can use to implement your own embedding functions!
|
||||||
@@ -1,13 +1,20 @@
|
|||||||
# Embedding Functions
|
# Embedding
|
||||||
|
|
||||||
Embeddings are high dimensional floating-point vector representations of your data or query.
|
Embeddings are high dimensional floating-point vector representations of your data or query. Anything can be embedded using some embedding model or function. Position of embedding in a high dimensional vector space has semantic significance to a degree that depends on the type of modal and training. These embeddings when projected in a 2-D space generally group similar entities close-by forming groups.
|
||||||
Anything can be embedded using some embedding model or function.
|
|
||||||
For a given embedding function, the output will always have the same number of dimensions.
|
|
||||||
|
|
||||||
## Creating an embedding function
|

|
||||||
|
|
||||||
Any function that takes as input a batch (list) of data and outputs a batch (list) of embeddings
|
# Creating an embedding function
|
||||||
can be used by LanceDB as an embedding function. The input and output batch sizes should be the same.
|
|
||||||
|
LanceDB supports 2 major ways of vectorizing your data, explicit and implicit.
|
||||||
|
|
||||||
|
1. By manually embedding the data before ingesting in the table
|
||||||
|
2. By automatically embedding the data and query as they come, by ingesting embedding function information in the table itself! Covered in [Next Section](embedding_functions.md)
|
||||||
|
|
||||||
|
Whatever workflow you prefer, we have the tools to support you.
|
||||||
|
## Explicit Vectorization
|
||||||
|
|
||||||
|
In this workflow, you can create your embedding function and vectorize your data using lancedb's `with_embedding` function. Let's look at some examples.
|
||||||
|
|
||||||
### HuggingFace example
|
### HuggingFace example
|
||||||
|
|
||||||
@@ -118,7 +125,7 @@ belong in the same latent space and your results will be nonsensical.
|
|||||||
```python
|
```python
|
||||||
query = "What's the best pizza topping?"
|
query = "What's the best pizza topping?"
|
||||||
query_vector = embed_func([query])[0]
|
query_vector = embed_func([query])[0]
|
||||||
tbl.search(query_vector).limit(10).to_df()
|
tbl.search(query_vector).limit(10).to_pandas()
|
||||||
```
|
```
|
||||||
|
|
||||||
The above snippet returns a pandas DataFrame with the 10 closest vectors to the query.
|
The above snippet returns a pandas DataFrame with the 10 closest vectors to the query.
|
||||||
@@ -134,9 +141,9 @@ belong in the same latent space and your results will be nonsensical.
|
|||||||
The above snippet returns an array of records with the 10 closest vectors to the query.
|
The above snippet returns an array of records with the 10 closest vectors to the query.
|
||||||
|
|
||||||
|
|
||||||
## Roadmap
|
## Implicit vectorization / Ingesting embedding functions
|
||||||
|
Representing multi-modal data as vector embeddings is becoming a standard practice. Embedding functions themselves be thought of as a part of the processing pipeline that each request(input) has to be passed through. After initial setup these components are not expected to change for a particular project.
|
||||||
|
|
||||||
In the near future, we'll be integrating the embedding functions deeper into LanceDB<br/>.
|
This is main motivation behind our new embedding functions API, that allow you simply set it up once and the table remembers it, effectively making the **embedding functions disappear in the background** so you don't have to worry about modelling and simply focus on the DB aspects of VectorDB.
|
||||||
The goal is that you just have to configure the function once when you create the table,
|
|
||||||
and then you'll never have to deal with embeddings / vectors after that unless you want to.
|
Learn more in the Next Section
|
||||||
We'll also integrate more popular models and APIs.
|
|
||||||
@@ -80,14 +80,14 @@ def handler(event, context):
|
|||||||
# Shape of SIFT is (128,1M), d=float32
|
# Shape of SIFT is (128,1M), d=float32
|
||||||
query_vector = np.array(event['query_vector'], dtype=np.float32)
|
query_vector = np.array(event['query_vector'], dtype=np.float32)
|
||||||
|
|
||||||
rs = table.search(query_vector).limit(2).to_df()
|
rs = table.search(query_vector).limit(2).to_list()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"statusCode": status_code,
|
"statusCode": status_code,
|
||||||
"headers": {
|
"headers": {
|
||||||
"Content-Type": "application/json"
|
"Content-Type": "application/json"
|
||||||
},
|
},
|
||||||
"body": rs.to_json()
|
"body": json.dumps(rs)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -43,7 +43,13 @@ table.create_fts_index("text")
|
|||||||
To search:
|
To search:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
df = table.search("puppy").limit(10).select(["text"]).to_df()
|
table.search("puppy").limit(10).select(["text"]).to_list()
|
||||||
|
```
|
||||||
|
|
||||||
|
Which returns a list of dictionaries:
|
||||||
|
|
||||||
|
```python
|
||||||
|
[{'text': 'Frodo was a happy puppy', 'score': 0.6931471824645996}]
|
||||||
```
|
```
|
||||||
|
|
||||||
LanceDB automatically looks for an FTS index if the input is str.
|
LanceDB automatically looks for an FTS index if the input is str.
|
||||||
|
|||||||
@@ -251,8 +251,9 @@ After a table has been created, you can always add more data to it using
|
|||||||
### Adding Pandas DataFrame
|
### Adding Pandas DataFrame
|
||||||
|
|
||||||
```python
|
```python
|
||||||
df = pd.DataFrame([{"vector": [1.3, 1.4], "item": "fizz", "price": 100.0},
|
df = pd.DataFrame({
|
||||||
{"vector": [9.5, 56.2], "item": "buzz", "price": 200.0}])
|
"vector": [[1.3, 1.4], [9.5, 56.2]], "item": ["fizz", "buzz"], "price": [100.0, 200.0]
|
||||||
|
})
|
||||||
tbl.add(df)
|
tbl.add(df)
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -261,17 +262,12 @@ After a table has been created, you can always add more data to it using
|
|||||||
### Adding to table using Iterator
|
### Adding to table using Iterator
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
def make_batches():
|
def make_batches():
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
yield pd.DataFrame(
|
yield [
|
||||||
{
|
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||||
"vector": [[3.1, 4.1], [1, 1]],
|
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}
|
||||||
"item": ["foo", "bar"],
|
]
|
||||||
"price": [10.0, 20.0],
|
|
||||||
})
|
|
||||||
|
|
||||||
tbl.add(make_batches())
|
tbl.add(make_batches())
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -306,9 +302,10 @@ Use the `delete()` method on tables to delete rows from a table. To choose which
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
import lancedb
|
import lancedb
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
|
data = [{"x": 1, "vector": [1, 2]},
|
||||||
|
{"x": 2, "vector": [3, 4]},
|
||||||
|
{"x": 3, "vector": [5, 6]}]
|
||||||
db = lancedb.connect("./.lancedb")
|
db = lancedb.connect("./.lancedb")
|
||||||
table = db.create_table("my_table", data)
|
table = db.create_table("my_table", data)
|
||||||
table.to_pandas()
|
table.to_pandas()
|
||||||
@@ -364,6 +361,48 @@ Use the `delete()` method on tables to delete rows from a table. To choose which
|
|||||||
await tbl.countRows() // Returns 1
|
await tbl.countRows() // Returns 1
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Updating a Table [Experimental]
|
||||||
|
EXPERIMENTAL: Update rows in the table (not threadsafe).
|
||||||
|
|
||||||
|
This can be used to update zero to all rows depending on how many rows match the where clause.
|
||||||
|
|
||||||
|
| Parameter | Type | Description |
|
||||||
|
|---|---|---|
|
||||||
|
| `where` | `str` | The SQL where clause to use when updating rows. For example, `'x = 2'` or `'x IN (1, 2, 3)'`. The filter must not be empty, or it will error. |
|
||||||
|
| `values` | `dict` | The values to update. The keys are the column names and the values are the values to set. |
|
||||||
|
|
||||||
|
|
||||||
|
=== "Python"
|
||||||
|
|
||||||
|
```python
|
||||||
|
import lancedb
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
# Create a lancedb connection
|
||||||
|
db = lancedb.connect("./.lancedb")
|
||||||
|
|
||||||
|
# Create a table from a pandas DataFrame
|
||||||
|
data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
|
||||||
|
table = db.create_table("my_table", data)
|
||||||
|
|
||||||
|
# Update the table where x = 2
|
||||||
|
table.update(where="x = 2", values={"vector": [10, 10]})
|
||||||
|
|
||||||
|
# Get the updated table as a pandas DataFrame
|
||||||
|
df = table.to_pandas()
|
||||||
|
|
||||||
|
# Print the DataFrame
|
||||||
|
print(df)
|
||||||
|
```
|
||||||
|
|
||||||
|
Output
|
||||||
|
```shell
|
||||||
|
x vector
|
||||||
|
0 1 [1.0, 2.0]
|
||||||
|
1 3 [5.0, 6.0]
|
||||||
|
2 2 [10.0, 10.0]
|
||||||
|
```
|
||||||
|
|
||||||
## What's Next?
|
## What's Next?
|
||||||
|
|
||||||
Learn how to Query your tables and create indices
|
Learn how to Query your tables and create indices
|
||||||
@@ -36,7 +36,7 @@ LanceDB's core is written in Rust 🦀 and is built using <a href="https://githu
|
|||||||
table = db.create_table("my_table",
|
table = db.create_table("my_table",
|
||||||
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}])
|
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}])
|
||||||
result = table.search([100, 100]).limit(2).to_df()
|
result = table.search([100, 100]).limit(2).to_list()
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "Javascript"
|
=== "Javascript"
|
||||||
@@ -67,7 +67,7 @@ LanceDB's core is written in Rust 🦀 and is built using <a href="https://githu
|
|||||||
|
|
||||||
## Documentation Quick Links
|
## Documentation Quick Links
|
||||||
* [`Basic Operations`](basic.md) - basic functionality of LanceDB.
|
* [`Basic Operations`](basic.md) - basic functionality of LanceDB.
|
||||||
* [`Embedding Functions`](embedding.md) - functions for working with embeddings.
|
* [`Embedding Functions`](embeddings/index.md) - functions for working with embeddings.
|
||||||
* [`Indexing`](ann_indexes.md) - create vector indexes to speed up queries.
|
* [`Indexing`](ann_indexes.md) - create vector indexes to speed up queries.
|
||||||
* [`Full text search`](fts.md) - [EXPERIMENTAL] full-text search API
|
* [`Full text search`](fts.md) - [EXPERIMENTAL] full-text search API
|
||||||
* [`Ecosystem Integrations`](python/integration.md) - integrating LanceDB with python data tooling ecosystem.
|
* [`Ecosystem Integrations`](python/integration.md) - integrating LanceDB with python data tooling ecosystem.
|
||||||
|
|||||||
764
docs/src/notebooks/DisappearingEmbeddingFunction.ipynb
Normal file
764
docs/src/notebooks/DisappearingEmbeddingFunction.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -144,7 +144,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"# Pre-processing and loading the documentation\n",
|
"# Pre-processing and loading the documentation\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Next, let's pre-process and load the documentation. To make sure we don't need to do this repeatedly if we were updating code, we're caching it using pickle so we can retrieve it again (this could take a few minutes to run the first time yyou do it). We'll also add some more metadata to the docs here such as the title and version of the code:"
|
"Next, let's pre-process and load the documentation. To make sure we don't need to do this repeatedly if we were updating code, we're caching it using pickle so we can retrieve it again (this could take a few minutes to run the first time you do it). We'll also add some more metadata to the docs here such as the title and version of the code:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -255,7 +255,7 @@
|
|||||||
"id": "28d93b85",
|
"id": "28d93b85",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"And thats it! We're all setup. The next step is to run some queries, let's try a few:"
|
"And that's it! We're all set up. The next step is to run some queries, let's try a few:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
604
docs/src/notebooks/multi_lingual_example.ipynb
Normal file
604
docs/src/notebooks/multi_lingual_example.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -19,11 +19,11 @@
|
|||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1.2\u001b[0m\n",
|
"\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m A new release of pip available: \u001B[0m\u001B[31;49m22.3.1\u001B[0m\u001B[39;49m -> \u001B[0m\u001B[32;49m23.1.2\u001B[0m\n",
|
||||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
|
"\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m To update, run: \u001B[0m\u001B[32;49mpip install --upgrade pip\u001B[0m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1.2\u001b[0m\n",
|
"\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m A new release of pip available: \u001B[0m\u001B[31;49m22.3.1\u001B[0m\u001B[39;49m -> \u001B[0m\u001B[32;49m23.1.2\u001B[0m\n",
|
||||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
|
"\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m To update, run: \u001B[0m\u001B[32;49mpip install --upgrade pip\u001B[0m\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -39,6 +39,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import io\n",
|
"import io\n",
|
||||||
|
"\n",
|
||||||
"import PIL\n",
|
"import PIL\n",
|
||||||
"import duckdb\n",
|
"import duckdb\n",
|
||||||
"import lancedb"
|
"import lancedb"
|
||||||
@@ -158,18 +159,18 @@
|
|||||||
" \"db = lancedb.connect('~/datasets/demo')\\n\"\n",
|
" \"db = lancedb.connect('~/datasets/demo')\\n\"\n",
|
||||||
" \"tbl = db.open_table('diffusiondb')\\n\\n\"\n",
|
" \"tbl = db.open_table('diffusiondb')\\n\\n\"\n",
|
||||||
" f\"embedding = embed_func('{query}')\\n\"\n",
|
" f\"embedding = embed_func('{query}')\\n\"\n",
|
||||||
" \"tbl.search(embedding).limit(9).to_df()\"\n",
|
" \"tbl.search(embedding).limit(9).to_pandas()\"\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" return (_extract(tbl.search(emb).limit(9).to_df()), code)\n",
|
" return (_extract(tbl.search(emb).limit(9).to_pandas()), code)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def find_image_keywords(query):\n",
|
"def find_image_keywords(query):\n",
|
||||||
" code = (\n",
|
" code = (\n",
|
||||||
" \"import lancedb\\n\"\n",
|
" \"import lancedb\\n\"\n",
|
||||||
" \"db = lancedb.connect('~/datasets/demo')\\n\"\n",
|
" \"db = lancedb.connect('~/datasets/demo')\\n\"\n",
|
||||||
" \"tbl = db.open_table('diffusiondb')\\n\\n\"\n",
|
" \"tbl = db.open_table('diffusiondb')\\n\\n\"\n",
|
||||||
" f\"tbl.search('{query}').limit(9).to_df()\"\n",
|
" f\"tbl.search('{query}').limit(9).to_pandas()\"\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" return (_extract(tbl.search(query).limit(9).to_df()), code)\n",
|
" return (_extract(tbl.search(query).limit(9).to_pandas()), code)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def find_image_sql(query):\n",
|
"def find_image_sql(query):\n",
|
||||||
" code = (\n",
|
" code = (\n",
|
||||||
|
|||||||
1189
docs/src/notebooks/reproducibility.ipynb
Normal file
1189
docs/src/notebooks/reproducibility.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
@@ -114,13 +114,10 @@
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import pandas as pd\n",
|
"data = [\n",
|
||||||
"\n",
|
" {\"vector\": [1.1, 1.2], \"lat\": 45.5, \"long\": -122.7},\n",
|
||||||
"data = pd.DataFrame({\n",
|
" {\"vector\": [0.2, 1.8], \"lat\": 40.1, \"long\": -74.1},\n",
|
||||||
" \"vector\": [[1.1, 1.2], [0.2, 1.8]],\n",
|
"]\n",
|
||||||
" \"lat\": [45.5, 40.1],\n",
|
|
||||||
" \"long\": [-122.7, -74.1]\n",
|
|
||||||
"})\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"db.create_table(\"table2\", data)\n",
|
"db.create_table(\"table2\", data)\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -572,9 +569,11 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"df = pd.DataFrame([{\"vector\": [1.3, 1.4], \"item\": \"fizz\", \"price\": 100.0},\n",
|
"data = [\n",
|
||||||
" {\"vector\": [9.5, 56.2], \"item\": \"buzz\", \"price\": 200.0}])\n",
|
" {\"vector\": [1.3, 1.4], \"item\": \"fizz\", \"price\": 100.0},\n",
|
||||||
"tbl.add(df)"
|
" {\"vector\": [9.5, 56.2], \"item\": \"buzz\", \"price\": 200.0}\n",
|
||||||
|
"]\n",
|
||||||
|
"tbl.add(data)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -596,17 +595,12 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"\n",
|
|
||||||
"import pandas as pd\n",
|
|
||||||
"\n",
|
|
||||||
"def make_batches():\n",
|
"def make_batches():\n",
|
||||||
" for i in range(5):\n",
|
" for i in range(5):\n",
|
||||||
" yield pd.DataFrame(\n",
|
" yield [\n",
|
||||||
" {\n",
|
" {\"vector\": [3.1, 4.1], \"item\": \"foo\", \"price\": 10.0},\n",
|
||||||
" \"vector\": [[3.1, 4.1], [1, 1]],\n",
|
" {\"vector\": [1, 1], \"item\": \"bar\", \"price\": 20.0},\n",
|
||||||
" \"item\": [\"foo\", \"bar\"],\n",
|
" ]\n",
|
||||||
" \"price\": [10.0, 20.0],\n",
|
|
||||||
" })\n",
|
|
||||||
"tbl.add(make_batches())"
|
"tbl.add(make_batches())"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -27,11 +27,11 @@
|
|||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1.1\u001b[0m\n",
|
"\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m A new release of pip is available: \u001B[0m\u001B[31;49m23.0\u001B[0m\u001B[39;49m -> \u001B[0m\u001B[32;49m23.1.1\u001B[0m\n",
|
||||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
|
"\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m To update, run: \u001B[0m\u001B[32;49mpip install --upgrade pip\u001B[0m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1.1\u001b[0m\n",
|
"\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m A new release of pip is available: \u001B[0m\u001B[31;49m23.0\u001B[0m\u001B[39;49m -> \u001B[0m\u001B[32;49m23.1.1\u001B[0m\n",
|
||||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
|
"\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m To update, run: \u001B[0m\u001B[32;49mpip install --upgrade pip\u001B[0m\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -184,7 +184,7 @@
|
|||||||
"df = (contextualize(data.to_pandas())\n",
|
"df = (contextualize(data.to_pandas())\n",
|
||||||
" .groupby(\"title\").text_col(\"text\")\n",
|
" .groupby(\"title\").text_col(\"text\")\n",
|
||||||
" .window(20).stride(4)\n",
|
" .window(20).stride(4)\n",
|
||||||
" .to_df())\n",
|
" .to_pandas())\n",
|
||||||
"df.head(1)"
|
"df.head(1)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -603,7 +603,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Use LanceDB to get top 3 most relevant context\n",
|
"# Use LanceDB to get top 3 most relevant context\n",
|
||||||
"context = tbl.search(emb).limit(3).to_df()"
|
"context = tbl.search(emb).limit(3).to_pandas()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -39,7 +39,6 @@ to lazily generate data:
|
|||||||
|
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import lancedb
|
|
||||||
|
|
||||||
def make_batches() -> Iterable[pa.RecordBatch]:
|
def make_batches() -> Iterable[pa.RecordBatch]:
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
@@ -74,7 +73,7 @@ table = db.open_table("pd_table")
|
|||||||
|
|
||||||
query_vector = [100, 100]
|
query_vector = [100, 100]
|
||||||
# Pandas DataFrame
|
# Pandas DataFrame
|
||||||
df = table.search(query_vector).limit(1).to_df()
|
df = table.search(query_vector).limit(1).to_pandas()
|
||||||
print(df)
|
print(df)
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -89,12 +88,12 @@ If you have more complex criteria, you can always apply the filter to the result
|
|||||||
```python
|
```python
|
||||||
|
|
||||||
# Apply the filter via LanceDB
|
# Apply the filter via LanceDB
|
||||||
results = table.search([100, 100]).where("price < 15").to_df()
|
results = table.search([100, 100]).where("price < 15").to_pandas()
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
assert results["item"].iloc[0] == "foo"
|
assert results["item"].iloc[0] == "foo"
|
||||||
|
|
||||||
# Apply the filter via Pandas
|
# Apply the filter via Pandas
|
||||||
df = results = table.search([100, 100]).to_df()
|
df = results = table.search([100, 100]).to_pandas()
|
||||||
results = df[df.price < 15]
|
results = df[df.price < 15]
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
assert results["item"].iloc[0] == "foo"
|
assert results["item"].iloc[0] == "foo"
|
||||||
|
|||||||
@@ -11,15 +11,13 @@ pip install duckdb lancedb
|
|||||||
We will re-use [the dataset created previously](./arrow.md):
|
We will re-use [the dataset created previously](./arrow.md):
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import pandas as pd
|
|
||||||
import lancedb
|
import lancedb
|
||||||
|
|
||||||
db = lancedb.connect("data/sample-lancedb")
|
db = lancedb.connect("data/sample-lancedb")
|
||||||
data = pd.DataFrame({
|
data = [
|
||||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||||
"item": ["foo", "bar"],
|
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0}
|
||||||
"price": [10.0, 20.0]
|
]
|
||||||
})
|
|
||||||
table = db.create_table("pd_table", data=data)
|
table = db.create_table("pd_table", data=data)
|
||||||
arrow_table = table.to_arrow()
|
arrow_table = table.to_arrow()
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -22,21 +22,19 @@ pip install lancedb
|
|||||||
|
|
||||||
::: lancedb.query.LanceQueryBuilder
|
::: lancedb.query.LanceQueryBuilder
|
||||||
|
|
||||||
::: lancedb.query.LanceFtsQueryBuilder
|
|
||||||
|
|
||||||
## Embeddings
|
## Embeddings
|
||||||
|
|
||||||
::: lancedb.embeddings.functions.EmbeddingFunctionRegistry
|
::: lancedb.embeddings.registry.EmbeddingFunctionRegistry
|
||||||
|
|
||||||
::: lancedb.embeddings.functions.EmbeddingFunction
|
::: lancedb.embeddings.base.EmbeddingFunction
|
||||||
|
|
||||||
::: lancedb.embeddings.functions.TextEmbeddingFunction
|
::: lancedb.embeddings.base.TextEmbeddingFunction
|
||||||
|
|
||||||
::: lancedb.embeddings.functions.SentenceTransformerEmbeddings
|
::: lancedb.embeddings.sentence_transformers.SentenceTransformerEmbeddings
|
||||||
|
|
||||||
::: lancedb.embeddings.functions.OpenAIEmbeddings
|
::: lancedb.embeddings.openai.OpenAIEmbeddings
|
||||||
|
|
||||||
::: lancedb.embeddings.functions.OpenClipEmbeddings
|
::: lancedb.embeddings.open_clip.OpenClipEmbeddings
|
||||||
|
|
||||||
::: lancedb.embeddings.with_embeddings
|
::: lancedb.embeddings.with_embeddings
|
||||||
|
|
||||||
@@ -56,7 +54,7 @@ pip install lancedb
|
|||||||
|
|
||||||
## Utilities
|
## Utilities
|
||||||
|
|
||||||
::: lancedb.vector
|
::: lancedb.schema.vector
|
||||||
|
|
||||||
## Integrations
|
## Integrations
|
||||||
|
|
||||||
|
|||||||
1
docs/src/robots.txt
Normal file
1
docs/src/robots.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
User-agent: *
|
||||||
4
docs/src/scripts/posthog.js
Normal file
4
docs/src/scripts/posthog.js
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
window.addEventListener("DOMContentLoaded", (event) => {
|
||||||
|
!function(t,e){var o,n,p,r;e.__SV||(window.posthog=e,e._i=[],e.init=function(i,s,a){function g(t,e){var o=e.split(".");2==o.length&&(t=t[o[0]],e=o[1]),t[e]=function(){t.push([e].concat(Array.prototype.slice.call(arguments,0)))}}(p=t.createElement("script")).type="text/javascript",p.async=!0,p.src=s.api_host+"/static/array.js",(r=t.getElementsByTagName("script")[0]).parentNode.insertBefore(p,r);var u=e;for(void 0!==a?u=e[a]=[]:a="posthog",u.people=u.people||[],u.toString=function(t){var e="posthog";return"posthog"!==a&&(e+="."+a),t||(e+=" (stub)"),e},u.people.toString=function(){return u.toString(1)+".people (stub)"},o="capture identify alias people.set people.set_once set_config register register_once unregister opt_out_capturing has_opted_out_capturing opt_in_capturing reset isFeatureEnabled onFeatureFlags getFeatureFlag getFeatureFlagPayload reloadFeatureFlags group updateEarlyAccessFeatureEnrollment getEarlyAccessFeatures getActiveMatchingSurveys getSurveys".split(" "),n=0;n<o.length;n++)g(u,o[n]);e._i.push([i,s,a])},e.__SV=1)}(document,window.posthog||[]);
|
||||||
|
posthog.init('phc_oENDjGgHtmIDrV6puUiFem2RB4JA8gGWulfdulmMdZP',{api_host:'https://app.posthog.com'})
|
||||||
|
});
|
||||||
@@ -4,7 +4,7 @@
|
|||||||
In a recommendation system or search engine, you can find similar products from
|
In a recommendation system or search engine, you can find similar products from
|
||||||
the one you searched.
|
the one you searched.
|
||||||
In LLM and other AI applications,
|
In LLM and other AI applications,
|
||||||
each data point can be [presented by the embeddings generated from some models](embedding.md),
|
each data point can be [presented by the embeddings generated from some models](embeddings/index.md),
|
||||||
it returns the most relevant features.
|
it returns the most relevant features.
|
||||||
|
|
||||||
A search in high-dimensional vector space, is to find `K-Nearest-Neighbors (KNN)` of the query vector.
|
A search in high-dimensional vector space, is to find `K-Nearest-Neighbors (KNN)` of the query vector.
|
||||||
@@ -67,7 +67,7 @@ await db_setup.createTable('my_vectors', data)
|
|||||||
|
|
||||||
df = tbl.search(np.random.random((1536))) \
|
df = tbl.search(np.random.random((1536))) \
|
||||||
.limit(10) \
|
.limit(10) \
|
||||||
.to_df()
|
.to_list()
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "JavaScript"
|
=== "JavaScript"
|
||||||
@@ -92,7 +92,7 @@ as well.
|
|||||||
df = tbl.search(np.random.random((1536))) \
|
df = tbl.search(np.random.random((1536))) \
|
||||||
.metric("cosine") \
|
.metric("cosine") \
|
||||||
.limit(10) \
|
.limit(10) \
|
||||||
.to_df()
|
.to_list()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ const excludedGlobs = [
|
|||||||
"../src/embedding.md",
|
"../src/embedding.md",
|
||||||
"../src/examples/*.md",
|
"../src/examples/*.md",
|
||||||
"../src/guides/tables.md",
|
"../src/guides/tables.md",
|
||||||
|
"../src/embeddings/*.md",
|
||||||
];
|
];
|
||||||
|
|
||||||
const nodePrefix = "javascript";
|
const nodePrefix = "javascript";
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ excluded_globs = [
|
|||||||
"../src/integrations/voxel51.md",
|
"../src/integrations/voxel51.md",
|
||||||
"../src/guides/tables.md",
|
"../src/guides/tables.md",
|
||||||
"../src/python/duckdb.md",
|
"../src/python/duckdb.md",
|
||||||
|
"../src/embeddings/*.md",
|
||||||
]
|
]
|
||||||
|
|
||||||
python_prefix = "py"
|
python_prefix = "py"
|
||||||
@@ -17,20 +18,31 @@ python_file = ".py"
|
|||||||
python_folder = "python"
|
python_folder = "python"
|
||||||
|
|
||||||
files = glob.glob(glob_string, recursive=True)
|
files = glob.glob(glob_string, recursive=True)
|
||||||
excluded_files = [f for excluded_glob in excluded_globs for f in glob.glob(excluded_glob, recursive=True)]
|
excluded_files = [
|
||||||
|
f
|
||||||
|
for excluded_glob in excluded_globs
|
||||||
|
for f in glob.glob(excluded_glob, recursive=True)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def yield_lines(lines: Iterator[str], prefix: str, suffix: str):
|
def yield_lines(lines: Iterator[str], prefix: str, suffix: str):
|
||||||
in_code_block = False
|
in_code_block = False
|
||||||
# Python code has strict indentation
|
# Python code has strict indentation
|
||||||
strip_length = 0
|
strip_length = 0
|
||||||
|
skip_test = False
|
||||||
for line in lines:
|
for line in lines:
|
||||||
|
if "skip-test" in line:
|
||||||
|
skip_test = True
|
||||||
if line.strip().startswith(prefix + python_prefix):
|
if line.strip().startswith(prefix + python_prefix):
|
||||||
in_code_block = True
|
in_code_block = True
|
||||||
strip_length = len(line) - len(line.lstrip())
|
strip_length = len(line) - len(line.lstrip())
|
||||||
elif in_code_block and line.strip().startswith(suffix):
|
elif in_code_block and line.strip().startswith(suffix):
|
||||||
in_code_block = False
|
in_code_block = False
|
||||||
|
if not skip_test:
|
||||||
yield "\n"
|
yield "\n"
|
||||||
|
skip_test = False
|
||||||
elif in_code_block:
|
elif in_code_block:
|
||||||
|
if not skip_test:
|
||||||
yield line[strip_length:]
|
yield line[strip_length:]
|
||||||
|
|
||||||
for file in filter(lambda file: file not in excluded_files, files):
|
for file in filter(lambda file: file not in excluded_files, files):
|
||||||
@@ -38,7 +50,12 @@ for file in filter(lambda file: file not in excluded_files, files):
|
|||||||
lines = list(yield_lines(iter(f), "```", "```"))
|
lines = list(yield_lines(iter(f), "```", "```"))
|
||||||
|
|
||||||
if len(lines) > 0:
|
if len(lines) > 0:
|
||||||
out_path = Path(python_folder) / Path(file).name.strip(".md") / (Path(file).name.strip(".md") + python_file)
|
print(lines)
|
||||||
|
out_path = (
|
||||||
|
Path(python_folder)
|
||||||
|
/ Path(file).name.strip(".md")
|
||||||
|
/ (Path(file).name.strip(".md") + python_file)
|
||||||
|
)
|
||||||
print(out_path)
|
print(out_path)
|
||||||
out_path.parent.mkdir(exist_ok=True, parents=True)
|
out_path.parent.mkdir(exist_ok=True, parents=True)
|
||||||
with open(out_path, "w") as out:
|
with open(out_path, "w") as out:
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
lancedb @ git+https://github.com/lancedb/lancedb.git#egg=subdir&subdirectory=python
|
-e ../../python
|
||||||
numpy
|
numpy
|
||||||
pandas
|
pandas
|
||||||
pylance
|
pylance
|
||||||
duckdb
|
duckdb
|
||||||
|
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
|
torch
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ npm install vectordb
|
|||||||
|
|
||||||
This will download the appropriate native library for your platform. We currently
|
This will download the appropriate native library for your platform. We currently
|
||||||
support x86_64 Linux, aarch64 Linux, Intel MacOS, and ARM (M1/M2) MacOS. We do not
|
support x86_64 Linux, aarch64 Linux, Intel MacOS, and ARM (M1/M2) MacOS. We do not
|
||||||
yet support Windows or musl-based Linux (such as Alpine Linux).
|
yet support musl-based Linux (such as Alpine Linux).
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
|
|||||||
104
node/package-lock.json
generated
104
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.2.6",
|
"version": "0.3.8",
|
||||||
"lockfileVersion": 2,
|
"lockfileVersion": 2,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.2.6",
|
"version": "0.3.8",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64",
|
"x64",
|
||||||
"arm64"
|
"arm64"
|
||||||
@@ -53,11 +53,11 @@
|
|||||||
"uuid": "^9.0.0"
|
"uuid": "^9.0.0"
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
"optionalDependencies": {
|
||||||
"@lancedb/vectordb-darwin-arm64": "0.2.6",
|
"@lancedb/vectordb-darwin-arm64": "0.3.8",
|
||||||
"@lancedb/vectordb-darwin-x64": "0.2.6",
|
"@lancedb/vectordb-darwin-x64": "0.3.8",
|
||||||
"@lancedb/vectordb-linux-arm64-gnu": "0.2.6",
|
"@lancedb/vectordb-linux-arm64-gnu": "0.3.8",
|
||||||
"@lancedb/vectordb-linux-x64-gnu": "0.2.6",
|
"@lancedb/vectordb-linux-x64-gnu": "0.3.8",
|
||||||
"@lancedb/vectordb-win32-x64-msvc": "0.2.6"
|
"@lancedb/vectordb-win32-x64-msvc": "0.3.8"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/@apache-arrow/ts": {
|
"node_modules/@apache-arrow/ts": {
|
||||||
@@ -316,66 +316,6 @@
|
|||||||
"@jridgewell/sourcemap-codec": "^1.4.10"
|
"@jridgewell/sourcemap-codec": "^1.4.10"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/@lancedb/vectordb-darwin-arm64": {
|
|
||||||
"version": "0.2.6",
|
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.2.6.tgz",
|
|
||||||
"integrity": "sha512-9KCUvDmhVMuGIhleib/Gq43QhrRXjy2QJz21S85HDwL3DTH4J9n00A0V6eyLTBUyctnvMTcp3XZijosYUy1A8Q==",
|
|
||||||
"cpu": [
|
|
||||||
"arm64"
|
|
||||||
],
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"darwin"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"node_modules/@lancedb/vectordb-darwin-x64": {
|
|
||||||
"version": "0.2.6",
|
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.2.6.tgz",
|
|
||||||
"integrity": "sha512-WCYRFV9w13STgVYn4WSYne39mp+g8ET6TgMLvSSQBYJKp3xEggpSCtACetaDfmNpkml9DK/b5R95Jc7PBbmYgA==",
|
|
||||||
"cpu": [
|
|
||||||
"x64"
|
|
||||||
],
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"darwin"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"node_modules/@lancedb/vectordb-linux-arm64-gnu": {
|
|
||||||
"version": "0.2.6",
|
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.2.6.tgz",
|
|
||||||
"integrity": "sha512-SE9OUgsOT6dG1q9v3nFr9ew+kwPTA4ktvNiHiyQstNz9BniuLNldF/Wtxzk/Z7DhbkPci4MfkR6RdsPTHBatHg==",
|
|
||||||
"cpu": [
|
|
||||||
"arm64"
|
|
||||||
],
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"linux"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"node_modules/@lancedb/vectordb-linux-x64-gnu": {
|
|
||||||
"version": "0.2.6",
|
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.2.6.tgz",
|
|
||||||
"integrity": "sha512-hvUsRQbaJiQnSjjKHIRhJM/eObJOqDJUXcpzz1fWw/MMSoy/CFaQwf9Uen2IWTgcngGkJAkeEKG7N5GxQxVbBQ==",
|
|
||||||
"cpu": [
|
|
||||||
"x64"
|
|
||||||
],
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"linux"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"node_modules/@lancedb/vectordb-win32-x64-msvc": {
|
|
||||||
"version": "0.2.6",
|
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.2.6.tgz",
|
|
||||||
"integrity": "sha512-XPIzbBPt28nsAa7INuyvYMZyJ78bgLfxjSyazlydzO10orIBHvR+sjcrdnCK4l48YmvPXcSYnKxlKMa1oUeIWQ==",
|
|
||||||
"cpu": [
|
|
||||||
"x64"
|
|
||||||
],
|
|
||||||
"optional": true,
|
|
||||||
"os": [
|
|
||||||
"win32"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"node_modules/@neon-rs/cli": {
|
"node_modules/@neon-rs/cli": {
|
||||||
"version": "0.0.160",
|
"version": "0.0.160",
|
||||||
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",
|
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",
|
||||||
@@ -4868,36 +4808,6 @@
|
|||||||
"@jridgewell/sourcemap-codec": "^1.4.10"
|
"@jridgewell/sourcemap-codec": "^1.4.10"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"@lancedb/vectordb-darwin-arm64": {
|
|
||||||
"version": "0.2.6",
|
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-arm64/-/vectordb-darwin-arm64-0.2.6.tgz",
|
|
||||||
"integrity": "sha512-9KCUvDmhVMuGIhleib/Gq43QhrRXjy2QJz21S85HDwL3DTH4J9n00A0V6eyLTBUyctnvMTcp3XZijosYUy1A8Q==",
|
|
||||||
"optional": true
|
|
||||||
},
|
|
||||||
"@lancedb/vectordb-darwin-x64": {
|
|
||||||
"version": "0.2.6",
|
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-darwin-x64/-/vectordb-darwin-x64-0.2.6.tgz",
|
|
||||||
"integrity": "sha512-WCYRFV9w13STgVYn4WSYne39mp+g8ET6TgMLvSSQBYJKp3xEggpSCtACetaDfmNpkml9DK/b5R95Jc7PBbmYgA==",
|
|
||||||
"optional": true
|
|
||||||
},
|
|
||||||
"@lancedb/vectordb-linux-arm64-gnu": {
|
|
||||||
"version": "0.2.6",
|
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-arm64-gnu/-/vectordb-linux-arm64-gnu-0.2.6.tgz",
|
|
||||||
"integrity": "sha512-SE9OUgsOT6dG1q9v3nFr9ew+kwPTA4ktvNiHiyQstNz9BniuLNldF/Wtxzk/Z7DhbkPci4MfkR6RdsPTHBatHg==",
|
|
||||||
"optional": true
|
|
||||||
},
|
|
||||||
"@lancedb/vectordb-linux-x64-gnu": {
|
|
||||||
"version": "0.2.6",
|
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-linux-x64-gnu/-/vectordb-linux-x64-gnu-0.2.6.tgz",
|
|
||||||
"integrity": "sha512-hvUsRQbaJiQnSjjKHIRhJM/eObJOqDJUXcpzz1fWw/MMSoy/CFaQwf9Uen2IWTgcngGkJAkeEKG7N5GxQxVbBQ==",
|
|
||||||
"optional": true
|
|
||||||
},
|
|
||||||
"@lancedb/vectordb-win32-x64-msvc": {
|
|
||||||
"version": "0.2.6",
|
|
||||||
"resolved": "https://registry.npmjs.org/@lancedb/vectordb-win32-x64-msvc/-/vectordb-win32-x64-msvc-0.2.6.tgz",
|
|
||||||
"integrity": "sha512-XPIzbBPt28nsAa7INuyvYMZyJ78bgLfxjSyazlydzO10orIBHvR+sjcrdnCK4l48YmvPXcSYnKxlKMa1oUeIWQ==",
|
|
||||||
"optional": true
|
|
||||||
},
|
|
||||||
"@neon-rs/cli": {
|
"@neon-rs/cli": {
|
||||||
"version": "0.0.160",
|
"version": "0.0.160",
|
||||||
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",
|
"resolved": "https://registry.npmjs.org/@neon-rs/cli/-/cli-0.0.160.tgz",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.2.6",
|
"version": "0.3.8",
|
||||||
"description": " Serverless, low-latency vector database for AI applications",
|
"description": " Serverless, low-latency vector database for AI applications",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"types": "dist/index.d.ts",
|
"types": "dist/index.d.ts",
|
||||||
@@ -81,10 +81,10 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"optionalDependencies": {
|
"optionalDependencies": {
|
||||||
"@lancedb/vectordb-darwin-arm64": "0.2.6",
|
"@lancedb/vectordb-darwin-arm64": "0.3.8",
|
||||||
"@lancedb/vectordb-darwin-x64": "0.2.6",
|
"@lancedb/vectordb-darwin-x64": "0.3.8",
|
||||||
"@lancedb/vectordb-linux-arm64-gnu": "0.2.6",
|
"@lancedb/vectordb-linux-arm64-gnu": "0.3.8",
|
||||||
"@lancedb/vectordb-linux-x64-gnu": "0.2.6",
|
"@lancedb/vectordb-linux-x64-gnu": "0.3.8",
|
||||||
"@lancedb/vectordb-win32-x64-msvc": "0.2.6"
|
"@lancedb/vectordb-win32-x64-msvc": "0.3.8"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import {
|
|||||||
Utf8,
|
Utf8,
|
||||||
type Vector,
|
type Vector,
|
||||||
FixedSizeList,
|
FixedSizeList,
|
||||||
vectorFromArray, type Schema, Table as ArrowTable
|
vectorFromArray, type Schema, Table as ArrowTable, RecordBatchStreamWriter
|
||||||
} from 'apache-arrow'
|
} from 'apache-arrow'
|
||||||
import { type EmbeddingFunction } from './index'
|
import { type EmbeddingFunction } from './index'
|
||||||
|
|
||||||
@@ -77,7 +77,9 @@ function newVectorBuilder (dim: number): FixedSizeListBuilder<Float32> {
|
|||||||
|
|
||||||
// Creates the Arrow Type for a Vector column with dimension `dim`
|
// Creates the Arrow Type for a Vector column with dimension `dim`
|
||||||
function newVectorType (dim: number): FixedSizeList<Float32> {
|
function newVectorType (dim: number): FixedSizeList<Float32> {
|
||||||
const children = new Field<Float32>('item', new Float32())
|
// Somewhere we always default to have the elements nullable, so we need to set it to true
|
||||||
|
// otherwise we often get schema mismatches because the stored data always has schema with nullable elements
|
||||||
|
const children = new Field<Float32>('item', new Float32(), true)
|
||||||
return new FixedSizeList(dim, children)
|
return new FixedSizeList(dim, children)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,6 +90,13 @@ export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown
|
|||||||
return Buffer.from(await writer.toUint8Array())
|
return Buffer.from(await writer.toUint8Array())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Converts an Array of records into Arrow IPC stream format
|
||||||
|
export async function fromRecordsToStreamBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
|
||||||
|
const table = await convertToTable(data, embeddings)
|
||||||
|
const writer = RecordBatchStreamWriter.writeAll(table)
|
||||||
|
return Buffer.from(await writer.toUint8Array())
|
||||||
|
}
|
||||||
|
|
||||||
// Converts an Arrow Table into Arrow IPC format
|
// Converts an Arrow Table into Arrow IPC format
|
||||||
export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
|
export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
|
||||||
if (embeddings !== undefined) {
|
if (embeddings !== undefined) {
|
||||||
@@ -105,6 +114,23 @@ export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: Embe
|
|||||||
return Buffer.from(await writer.toUint8Array())
|
return Buffer.from(await writer.toUint8Array())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Converts an Arrow Table into Arrow IPC stream format
|
||||||
|
export async function fromTableToStreamBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
|
||||||
|
if (embeddings !== undefined) {
|
||||||
|
const source = table.getChild(embeddings.sourceColumn)
|
||||||
|
|
||||||
|
if (source === null) {
|
||||||
|
throw new Error(`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`)
|
||||||
|
}
|
||||||
|
|
||||||
|
const vectors = await embeddings.embed(source.toArray() as T[])
|
||||||
|
const column = vectorFromArray(vectors, newVectorType(vectors[0].length))
|
||||||
|
table = table.assign(new ArrowTable({ vector: column }))
|
||||||
|
}
|
||||||
|
const writer = RecordBatchStreamWriter.writeAll(table)
|
||||||
|
return Buffer.from(await writer.toUint8Array())
|
||||||
|
}
|
||||||
|
|
||||||
// Creates an empty Arrow Table
|
// Creates an empty Arrow Table
|
||||||
export function createEmptyTable (schema: Schema): ArrowTable {
|
export function createEmptyTable (schema: Schema): ArrowTable {
|
||||||
return new ArrowTable(schema)
|
return new ArrowTable(schema)
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import { Query } from './query'
|
|||||||
import { isEmbeddingFunction } from './embedding/embedding_function'
|
import { isEmbeddingFunction } from './embedding/embedding_function'
|
||||||
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-var-requires
|
// eslint-disable-next-line @typescript-eslint/no-var-requires
|
||||||
const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, tableCreate, tableAdd, tableCreateVectorIndex, tableCountRows, tableDelete } = require('../native.js')
|
const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, tableCreate, tableAdd, tableCreateVectorIndex, tableCountRows, tableDelete, tableCleanupOldVersions, tableCompactFiles, tableListIndices, tableIndexStats } = require('../native.js')
|
||||||
|
|
||||||
export { Query }
|
export { Query }
|
||||||
export type { EmbeddingFunction }
|
export type { EmbeddingFunction }
|
||||||
@@ -260,6 +260,27 @@ export interface Table<T = number[]> {
|
|||||||
* ```
|
* ```
|
||||||
*/
|
*/
|
||||||
delete: (filter: string) => Promise<void>
|
delete: (filter: string) => Promise<void>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* List the indicies on this table.
|
||||||
|
*/
|
||||||
|
listIndices: () => Promise<VectorIndex[]>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get statistics about an index.
|
||||||
|
*/
|
||||||
|
indexStats: (indexUuid: string) => Promise<IndexStats>
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface VectorIndex {
|
||||||
|
columns: string[]
|
||||||
|
name: string
|
||||||
|
uuid: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface IndexStats {
|
||||||
|
numIndexedRows: number | null
|
||||||
|
numUnindexedRows: number | null
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -459,6 +480,119 @@ export class LocalTable<T = number[]> implements Table<T> {
|
|||||||
async delete (filter: string): Promise<void> {
|
async delete (filter: string): Promise<void> {
|
||||||
return tableDelete.call(this._tbl, filter).then((newTable: any) => { this._tbl = newTable })
|
return tableDelete.call(this._tbl, filter).then((newTable: any) => { this._tbl = newTable })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clean up old versions of the table, freeing disk space.
|
||||||
|
*
|
||||||
|
* @param olderThan The minimum age in minutes of the versions to delete. If not
|
||||||
|
* provided, defaults to two weeks.
|
||||||
|
* @param deleteUnverified Because they may be part of an in-progress
|
||||||
|
* transaction, uncommitted files newer than 7 days old are
|
||||||
|
* not deleted by default. This means that failed transactions
|
||||||
|
* can leave around data that takes up disk space for up to
|
||||||
|
* 7 days. You can override this safety mechanism by setting
|
||||||
|
* this option to `true`, only if you promise there are no
|
||||||
|
* in progress writes while you run this operation. Failure to
|
||||||
|
* uphold this promise can lead to corrupted tables.
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
|
async cleanupOldVersions (olderThan?: number, deleteUnverified?: boolean): Promise<CleanupStats> {
|
||||||
|
return tableCleanupOldVersions.call(this._tbl, olderThan, deleteUnverified)
|
||||||
|
.then((res: { newTable: any, metrics: CleanupStats }) => {
|
||||||
|
this._tbl = res.newTable
|
||||||
|
return res.metrics
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Run the compaction process on the table.
|
||||||
|
*
|
||||||
|
* This can be run after making several small appends to optimize the table
|
||||||
|
* for faster reads.
|
||||||
|
*
|
||||||
|
* @param options Advanced options configuring compaction. In most cases, you
|
||||||
|
* can omit this arguments, as the default options are sensible
|
||||||
|
* for most tables.
|
||||||
|
* @returns Metrics about the compaction operation.
|
||||||
|
*/
|
||||||
|
async compactFiles (options?: CompactionOptions): Promise<CompactionMetrics> {
|
||||||
|
const optionsArg = options ?? {}
|
||||||
|
return tableCompactFiles.call(this._tbl, optionsArg)
|
||||||
|
.then((res: { newTable: any, metrics: CompactionMetrics }) => {
|
||||||
|
this._tbl = res.newTable
|
||||||
|
return res.metrics
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async listIndices (): Promise<VectorIndex[]> {
|
||||||
|
return tableListIndices.call(this._tbl)
|
||||||
|
}
|
||||||
|
|
||||||
|
async indexStats (indexUuid: string): Promise<IndexStats> {
|
||||||
|
return tableIndexStats.call(this._tbl, indexUuid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface CleanupStats {
|
||||||
|
/**
|
||||||
|
* The number of bytes removed from disk.
|
||||||
|
*/
|
||||||
|
bytesRemoved: number
|
||||||
|
/**
|
||||||
|
* The number of old table versions removed.
|
||||||
|
*/
|
||||||
|
oldVersions: number
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface CompactionOptions {
|
||||||
|
/**
|
||||||
|
* The number of rows per fragment to target. Fragments that have fewer rows
|
||||||
|
* will be compacted into adjacent fragments to produce larger fragments.
|
||||||
|
* Defaults to 1024 * 1024.
|
||||||
|
*/
|
||||||
|
targetRowsPerFragment?: number
|
||||||
|
/**
|
||||||
|
* The maximum number of rows per group. Defaults to 1024.
|
||||||
|
*/
|
||||||
|
maxRowsPerGroup?: number
|
||||||
|
/**
|
||||||
|
* If true, fragments that have rows that are deleted may be compacted to
|
||||||
|
* remove the deleted rows. This can improve the performance of queries.
|
||||||
|
* Default is true.
|
||||||
|
*/
|
||||||
|
materializeDeletions?: boolean
|
||||||
|
/**
|
||||||
|
* A number between 0 and 1, representing the proportion of rows that must be
|
||||||
|
* marked deleted before a fragment is a candidate for compaction to remove
|
||||||
|
* the deleted rows. Default is 10%.
|
||||||
|
*/
|
||||||
|
materializeDeletionsThreshold?: number
|
||||||
|
/**
|
||||||
|
* The number of threads to use for compaction. If not provided, defaults to
|
||||||
|
* the number of cores on the machine.
|
||||||
|
*/
|
||||||
|
numThreads?: number
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface CompactionMetrics {
|
||||||
|
/**
|
||||||
|
* The number of fragments that were removed.
|
||||||
|
*/
|
||||||
|
fragmentsRemoved: number
|
||||||
|
/**
|
||||||
|
* The number of new fragments that were created.
|
||||||
|
*/
|
||||||
|
fragmentsAdded: number
|
||||||
|
/**
|
||||||
|
* The number of files that were removed. Each fragment may have more than one
|
||||||
|
* file.
|
||||||
|
*/
|
||||||
|
filesRemoved: number
|
||||||
|
/**
|
||||||
|
* The number of files added. This is typically equal to the number of
|
||||||
|
* fragments added.
|
||||||
|
*/
|
||||||
|
filesAdded: number
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Config to build IVF_PQ index.
|
/// Config to build IVF_PQ index.
|
||||||
|
|||||||
@@ -18,6 +18,9 @@ import * as chaiAsPromised from 'chai-as-promised'
|
|||||||
import { v4 as uuidv4 } from 'uuid'
|
import { v4 as uuidv4 } from 'uuid'
|
||||||
|
|
||||||
import * as lancedb from '../index'
|
import * as lancedb from '../index'
|
||||||
|
import { tmpdir } from 'os'
|
||||||
|
import * as fs from 'fs'
|
||||||
|
import * as path from 'path'
|
||||||
|
|
||||||
const assert = chai.assert
|
const assert = chai.assert
|
||||||
chai.use(chaiAsPromised)
|
chai.use(chaiAsPromised)
|
||||||
@@ -41,3 +44,137 @@ describe('LanceDB AWS Integration test', function () {
|
|||||||
assert.equal(await table.countRows(), 6)
|
assert.equal(await table.countRows(), 6)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
describe('LanceDB Mirrored Store Integration test', function () {
|
||||||
|
it('s3://...?mirroredStore=... param is processed correctly', async function () {
|
||||||
|
this.timeout(600000)
|
||||||
|
|
||||||
|
const dir = tmpdir()
|
||||||
|
console.log(dir)
|
||||||
|
const conn = await lancedb.connect(`s3://lancedb-integtest?mirroredStore=${dir}`)
|
||||||
|
const data = Array(200).fill({ vector: Array(128).fill(1.0), id: 0 })
|
||||||
|
data.push(...Array(200).fill({ vector: Array(128).fill(1.0), id: 1 }))
|
||||||
|
data.push(...Array(200).fill({ vector: Array(128).fill(1.0), id: 2 }))
|
||||||
|
data.push(...Array(200).fill({ vector: Array(128).fill(1.0), id: 3 }))
|
||||||
|
|
||||||
|
const tableName = uuidv4()
|
||||||
|
|
||||||
|
// try create table and check if it's mirrored
|
||||||
|
const t = await conn.createTable(tableName, data, { writeMode: lancedb.WriteMode.Overwrite })
|
||||||
|
|
||||||
|
const mirroredPath = path.join(dir, `${tableName}.lance`)
|
||||||
|
fs.readdir(mirroredPath, { withFileTypes: true }, (err, files) => {
|
||||||
|
if (err != null) throw err
|
||||||
|
// there should be three dirs
|
||||||
|
assert.equal(files.length, 3)
|
||||||
|
assert.isTrue(files[0].isDirectory())
|
||||||
|
assert.isTrue(files[1].isDirectory())
|
||||||
|
|
||||||
|
fs.readdir(path.join(mirroredPath, '_transactions'), { withFileTypes: true }, (err, files) => {
|
||||||
|
if (err != null) throw err
|
||||||
|
assert.equal(files.length, 1)
|
||||||
|
assert.isTrue(files[0].name.endsWith('.txn'))
|
||||||
|
})
|
||||||
|
|
||||||
|
fs.readdir(path.join(mirroredPath, '_versions'), { withFileTypes: true }, (err, files) => {
|
||||||
|
if (err != null) throw err
|
||||||
|
assert.equal(files.length, 1)
|
||||||
|
assert.isTrue(files[0].name.endsWith('.manifest'))
|
||||||
|
})
|
||||||
|
|
||||||
|
fs.readdir(path.join(mirroredPath, 'data'), { withFileTypes: true }, (err, files) => {
|
||||||
|
if (err != null) throw err
|
||||||
|
assert.equal(files.length, 1)
|
||||||
|
assert.isTrue(files[0].name.endsWith('.lance'))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// try create index and check if it's mirrored
|
||||||
|
await t.createIndex({ column: 'vector', type: 'ivf_pq' })
|
||||||
|
|
||||||
|
fs.readdir(mirroredPath, { withFileTypes: true }, (err, files) => {
|
||||||
|
if (err != null) throw err
|
||||||
|
// there should be four dirs
|
||||||
|
assert.equal(files.length, 4)
|
||||||
|
assert.isTrue(files[0].isDirectory())
|
||||||
|
assert.isTrue(files[1].isDirectory())
|
||||||
|
assert.isTrue(files[2].isDirectory())
|
||||||
|
|
||||||
|
// Two TXs now
|
||||||
|
fs.readdir(path.join(mirroredPath, '_transactions'), { withFileTypes: true }, (err, files) => {
|
||||||
|
if (err != null) throw err
|
||||||
|
assert.equal(files.length, 2)
|
||||||
|
assert.isTrue(files[0].name.endsWith('.txn'))
|
||||||
|
assert.isTrue(files[1].name.endsWith('.txn'))
|
||||||
|
})
|
||||||
|
|
||||||
|
fs.readdir(path.join(mirroredPath, 'data'), { withFileTypes: true }, (err, files) => {
|
||||||
|
if (err != null) throw err
|
||||||
|
assert.equal(files.length, 1)
|
||||||
|
assert.isTrue(files[0].name.endsWith('.lance'))
|
||||||
|
})
|
||||||
|
|
||||||
|
fs.readdir(path.join(mirroredPath, '_indices'), { withFileTypes: true }, (err, files) => {
|
||||||
|
if (err != null) throw err
|
||||||
|
assert.equal(files.length, 1)
|
||||||
|
assert.isTrue(files[0].isDirectory())
|
||||||
|
|
||||||
|
fs.readdir(path.join(mirroredPath, '_indices', files[0].name), { withFileTypes: true }, (err, files) => {
|
||||||
|
if (err != null) throw err
|
||||||
|
|
||||||
|
assert.equal(files.length, 1)
|
||||||
|
assert.isTrue(files[0].isFile())
|
||||||
|
assert.isTrue(files[0].name.endsWith('.idx'))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// try delete and check if it's mirrored
|
||||||
|
await t.delete('id = 0')
|
||||||
|
|
||||||
|
fs.readdir(mirroredPath, { withFileTypes: true }, (err, files) => {
|
||||||
|
if (err != null) throw err
|
||||||
|
// there should be five dirs
|
||||||
|
assert.equal(files.length, 5)
|
||||||
|
assert.isTrue(files[0].isDirectory())
|
||||||
|
assert.isTrue(files[1].isDirectory())
|
||||||
|
assert.isTrue(files[2].isDirectory())
|
||||||
|
assert.isTrue(files[3].isDirectory())
|
||||||
|
assert.isTrue(files[4].isDirectory())
|
||||||
|
|
||||||
|
// Three TXs now
|
||||||
|
fs.readdir(path.join(mirroredPath, '_transactions'), { withFileTypes: true }, (err, files) => {
|
||||||
|
if (err != null) throw err
|
||||||
|
assert.equal(files.length, 3)
|
||||||
|
assert.isTrue(files[0].name.endsWith('.txn'))
|
||||||
|
assert.isTrue(files[1].name.endsWith('.txn'))
|
||||||
|
})
|
||||||
|
|
||||||
|
fs.readdir(path.join(mirroredPath, 'data'), { withFileTypes: true }, (err, files) => {
|
||||||
|
if (err != null) throw err
|
||||||
|
assert.equal(files.length, 1)
|
||||||
|
assert.isTrue(files[0].name.endsWith('.lance'))
|
||||||
|
})
|
||||||
|
|
||||||
|
fs.readdir(path.join(mirroredPath, '_indices'), { withFileTypes: true }, (err, files) => {
|
||||||
|
if (err != null) throw err
|
||||||
|
assert.equal(files.length, 1)
|
||||||
|
assert.isTrue(files[0].isDirectory())
|
||||||
|
|
||||||
|
fs.readdir(path.join(mirroredPath, '_indices', files[0].name), { withFileTypes: true }, (err, files) => {
|
||||||
|
if (err != null) throw err
|
||||||
|
|
||||||
|
assert.equal(files.length, 1)
|
||||||
|
assert.isTrue(files[0].isFile())
|
||||||
|
assert.isTrue(files[0].name.endsWith('.idx'))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
fs.readdir(path.join(mirroredPath, '_deletions'), { withFileTypes: true }, (err, files) => {
|
||||||
|
if (err != null) throw err
|
||||||
|
assert.equal(files.length, 1)
|
||||||
|
assert.isTrue(files[0].name.endsWith('.arrow'))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|||||||
@@ -63,6 +63,9 @@ export class HttpLancedbClient {
|
|||||||
}
|
}
|
||||||
).catch((err) => {
|
).catch((err) => {
|
||||||
console.error('error: ', err)
|
console.error('error: ', err)
|
||||||
|
if (err.response === undefined) {
|
||||||
|
throw new Error(`Network Error: ${err.message as string}`)
|
||||||
|
}
|
||||||
return err.response
|
return err.response
|
||||||
})
|
})
|
||||||
if (response.status !== 200) {
|
if (response.status !== 200) {
|
||||||
@@ -86,13 +89,17 @@ export class HttpLancedbClient {
|
|||||||
{
|
{
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
'x-api-key': this._apiKey()
|
'x-api-key': this._apiKey(),
|
||||||
|
...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {})
|
||||||
},
|
},
|
||||||
params,
|
params,
|
||||||
timeout: 10000
|
timeout: 10000
|
||||||
}
|
}
|
||||||
).catch((err) => {
|
).catch((err) => {
|
||||||
console.error('error: ', err)
|
console.error('error: ', err)
|
||||||
|
if (err.response === undefined) {
|
||||||
|
throw new Error(`Network Error: ${err.message as string}`)
|
||||||
|
}
|
||||||
return err.response
|
return err.response
|
||||||
})
|
})
|
||||||
if (response.status !== 200) {
|
if (response.status !== 200) {
|
||||||
@@ -108,13 +115,18 @@ export class HttpLancedbClient {
|
|||||||
/**
|
/**
|
||||||
* Sent POST request.
|
* Sent POST request.
|
||||||
*/
|
*/
|
||||||
public async post (path: string, data?: any, params?: Record<string, string | number>): Promise<AxiosResponse> {
|
public async post (
|
||||||
|
path: string,
|
||||||
|
data?: any,
|
||||||
|
params?: Record<string, string | number>,
|
||||||
|
content?: string | undefined
|
||||||
|
): Promise<AxiosResponse> {
|
||||||
const response = await axios.post(
|
const response = await axios.post(
|
||||||
`${this._url}${path}`,
|
`${this._url}${path}`,
|
||||||
data,
|
data,
|
||||||
{
|
{
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': content ?? 'application/json',
|
||||||
'x-api-key': this._apiKey(),
|
'x-api-key': this._apiKey(),
|
||||||
...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {})
|
...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {})
|
||||||
},
|
},
|
||||||
@@ -123,6 +135,9 @@ export class HttpLancedbClient {
|
|||||||
}
|
}
|
||||||
).catch((err) => {
|
).catch((err) => {
|
||||||
console.error('error: ', err)
|
console.error('error: ', err)
|
||||||
|
if (err.response === undefined) {
|
||||||
|
throw new Error(`Network Error: ${err.message as string}`)
|
||||||
|
}
|
||||||
return err.response
|
return err.response
|
||||||
})
|
})
|
||||||
if (response.status !== 200) {
|
if (response.status !== 200) {
|
||||||
|
|||||||
@@ -14,12 +14,16 @@
|
|||||||
|
|
||||||
import {
|
import {
|
||||||
type EmbeddingFunction, type Table, type VectorIndexParams, type Connection,
|
type EmbeddingFunction, type Table, type VectorIndexParams, type Connection,
|
||||||
type ConnectionOptions, type CreateTableOptions, type WriteOptions
|
type ConnectionOptions, type CreateTableOptions, type VectorIndex,
|
||||||
|
type WriteOptions,
|
||||||
|
type IndexStats
|
||||||
} from '../index'
|
} from '../index'
|
||||||
import { Query } from '../query'
|
import { Query } from '../query'
|
||||||
|
|
||||||
import { Vector } from 'apache-arrow'
|
import { Vector, Table as ArrowTable } from 'apache-arrow'
|
||||||
import { HttpLancedbClient } from './client'
|
import { HttpLancedbClient } from './client'
|
||||||
|
import { isEmbeddingFunction } from '../embedding/embedding_function'
|
||||||
|
import { createEmptyTable, fromRecordsToStreamBuffer, fromTableToStreamBuffer } from '../arrow'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Remote connection.
|
* Remote connection.
|
||||||
@@ -66,8 +70,60 @@ export class RemoteConnection implements Connection {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async createTable<T> (name: string | CreateTableOptions<T>, data?: Array<Record<string, unknown>>, optsOrEmbedding?: WriteOptions | EmbeddingFunction<T>, opt?: WriteOptions): Promise<Table<T>> {
|
async createTable<T> (nameOrOpts: string | CreateTableOptions<T>, data?: Array<Record<string, unknown>>, optsOrEmbedding?: WriteOptions | EmbeddingFunction<T>, opt?: WriteOptions): Promise<Table<T>> {
|
||||||
throw new Error('Not implemented')
|
// Logic copied from LocatlConnection, refactor these to a base class + connectionImpl pattern
|
||||||
|
let schema
|
||||||
|
let embeddings: undefined | EmbeddingFunction<T>
|
||||||
|
let tableName: string
|
||||||
|
if (typeof nameOrOpts === 'string') {
|
||||||
|
if (optsOrEmbedding !== undefined && isEmbeddingFunction(optsOrEmbedding)) {
|
||||||
|
embeddings = optsOrEmbedding
|
||||||
|
}
|
||||||
|
tableName = nameOrOpts
|
||||||
|
} else {
|
||||||
|
schema = nameOrOpts.schema
|
||||||
|
embeddings = nameOrOpts.embeddingFunction
|
||||||
|
tableName = nameOrOpts.name
|
||||||
|
}
|
||||||
|
|
||||||
|
let buffer: Buffer
|
||||||
|
|
||||||
|
function isEmpty (data: Array<Record<string, unknown>> | ArrowTable<any>): boolean {
|
||||||
|
if (data instanceof ArrowTable) {
|
||||||
|
return data.data.length === 0
|
||||||
|
}
|
||||||
|
return data.length === 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((data === undefined) || isEmpty(data)) {
|
||||||
|
if (schema === undefined) {
|
||||||
|
throw new Error('Either data or schema needs to defined')
|
||||||
|
}
|
||||||
|
buffer = await fromTableToStreamBuffer(createEmptyTable(schema))
|
||||||
|
} else if (data instanceof ArrowTable) {
|
||||||
|
buffer = await fromTableToStreamBuffer(data, embeddings)
|
||||||
|
} else {
|
||||||
|
// data is Array<Record<...>>
|
||||||
|
buffer = await fromRecordsToStreamBuffer(data, embeddings)
|
||||||
|
}
|
||||||
|
|
||||||
|
const res = await this._client.post(
|
||||||
|
`/v1/table/${tableName}/create/`,
|
||||||
|
buffer,
|
||||||
|
undefined,
|
||||||
|
'application/vnd.apache.arrow.stream'
|
||||||
|
)
|
||||||
|
if (res.status !== 200) {
|
||||||
|
throw new Error(`Server Error, status: ${res.status}, ` +
|
||||||
|
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||||
|
`message: ${res.statusText}: ${res.data}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (embeddings === undefined) {
|
||||||
|
return new RemoteTable(this._client, tableName)
|
||||||
|
} else {
|
||||||
|
return new RemoteTable(this._client, tableName, embeddings)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async dropTable (name: string): Promise<void> {
|
async dropTable (name: string): Promise<void> {
|
||||||
@@ -141,11 +197,39 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async add (data: Array<Record<string, unknown>>): Promise<number> {
|
async add (data: Array<Record<string, unknown>>): Promise<number> {
|
||||||
throw new Error('Not implemented')
|
const buffer = await fromRecordsToStreamBuffer(data, this._embeddings)
|
||||||
|
const res = await this._client.post(
|
||||||
|
`/v1/table/${this._name}/insert/`,
|
||||||
|
buffer,
|
||||||
|
{
|
||||||
|
mode: 'append'
|
||||||
|
},
|
||||||
|
'application/vnd.apache.arrow.stream'
|
||||||
|
)
|
||||||
|
if (res.status !== 200) {
|
||||||
|
throw new Error(`Server Error, status: ${res.status}, ` +
|
||||||
|
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||||
|
`message: ${res.statusText}: ${res.data}`)
|
||||||
|
}
|
||||||
|
return data.length
|
||||||
}
|
}
|
||||||
|
|
||||||
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
|
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
|
||||||
throw new Error('Not implemented')
|
const buffer = await fromRecordsToStreamBuffer(data, this._embeddings)
|
||||||
|
const res = await this._client.post(
|
||||||
|
`/v1/table/${this._name}/insert/`,
|
||||||
|
buffer,
|
||||||
|
{
|
||||||
|
mode: 'overwrite'
|
||||||
|
},
|
||||||
|
'application/vnd.apache.arrow.stream'
|
||||||
|
)
|
||||||
|
if (res.status !== 200) {
|
||||||
|
throw new Error(`Server Error, status: ${res.status}, ` +
|
||||||
|
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||||
|
`message: ${res.statusText}: ${res.data}`)
|
||||||
|
}
|
||||||
|
return data.length
|
||||||
}
|
}
|
||||||
|
|
||||||
async createIndex (indexParams: VectorIndexParams): Promise<any> {
|
async createIndex (indexParams: VectorIndexParams): Promise<any> {
|
||||||
@@ -153,10 +237,28 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async countRows (): Promise<number> {
|
async countRows (): Promise<number> {
|
||||||
throw new Error('Not implemented')
|
const result = await this._client.post(`/v1/table/${this._name}/describe/`)
|
||||||
|
return result.data?.stats?.num_rows
|
||||||
}
|
}
|
||||||
|
|
||||||
async delete (filter: string): Promise<void> {
|
async delete (filter: string): Promise<void> {
|
||||||
throw new Error('Not implemented')
|
await this._client.post(`/v1/table/${this._name}/delete/`, { predicate: filter })
|
||||||
|
}
|
||||||
|
|
||||||
|
async listIndices (): Promise<VectorIndex[]> {
|
||||||
|
const results = await this._client.post(`/v1/table/${this._name}/index/list/`)
|
||||||
|
return results.data.indexes?.map((index: any) => ({
|
||||||
|
columns: index.columns,
|
||||||
|
name: index.index_name,
|
||||||
|
uuid: index.index_uuid
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
async indexStats (indexUuid: string): Promise<IndexStats> {
|
||||||
|
const results = await this._client.post(`/v1/table/${this._name}/index/${indexUuid}/stats/`)
|
||||||
|
return {
|
||||||
|
numIndexedRows: results.data.num_indexed_rows,
|
||||||
|
numUnindexedRows: results.data.num_unindexed_rows
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import * as chai from 'chai'
|
|||||||
import * as chaiAsPromised from 'chai-as-promised'
|
import * as chaiAsPromised from 'chai-as-promised'
|
||||||
|
|
||||||
import * as lancedb from '../index'
|
import * as lancedb from '../index'
|
||||||
import { type AwsCredentials, type EmbeddingFunction, MetricType, Query, WriteMode, DefaultWriteOptions, isWriteOptions } from '../index'
|
import { type AwsCredentials, type EmbeddingFunction, MetricType, Query, WriteMode, DefaultWriteOptions, isWriteOptions, type LocalTable } from '../index'
|
||||||
import { FixedSizeList, Field, Int32, makeVector, Schema, Utf8, Table as ArrowTable, vectorFromArray, Float32 } from 'apache-arrow'
|
import { FixedSizeList, Field, Int32, makeVector, Schema, Utf8, Table as ArrowTable, vectorFromArray, Float32 } from 'apache-arrow'
|
||||||
|
|
||||||
const expect = chai.expect
|
const expect = chai.expect
|
||||||
@@ -282,7 +282,8 @@ describe('LanceDB client', function () {
|
|||||||
)
|
)
|
||||||
const table = await con.createTable({ name: 'vectors', schema })
|
const table = await con.createTable({ name: 'vectors', schema })
|
||||||
await table.add([{ vector: Array(128).fill(0.1) }])
|
await table.add([{ vector: Array(128).fill(0.1) }])
|
||||||
await table.delete('vector IS NOT NULL')
|
// https://github.com/lancedb/lance/issues/1635
|
||||||
|
await table.delete('true')
|
||||||
const result = await table.search(Array(128).fill(0.1)).execute()
|
const result = await table.search(Array(128).fill(0.1)).execute()
|
||||||
assert.isEmpty(result)
|
assert.isEmpty(result)
|
||||||
})
|
})
|
||||||
@@ -328,6 +329,24 @@ describe('LanceDB client', function () {
|
|||||||
const createIndex = table.createIndex({ type: 'ivf_pq', column: 'name', num_partitions: -1, max_iters: 2, num_sub_vectors: 2 })
|
const createIndex = table.createIndex({ type: 'ivf_pq', column: 'name', num_partitions: -1, max_iters: 2, num_sub_vectors: 2 })
|
||||||
await expect(createIndex).to.be.rejectedWith('num_partitions: must be > 0')
|
await expect(createIndex).to.be.rejectedWith('num_partitions: must be > 0')
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('should be able to list index and stats', async function () {
|
||||||
|
const uri = await createTestDB(32, 300)
|
||||||
|
const con = await lancedb.connect(uri)
|
||||||
|
const table = await con.openTable('vectors')
|
||||||
|
await table.createIndex({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 })
|
||||||
|
|
||||||
|
const indices = await table.listIndices()
|
||||||
|
expect(indices).to.have.lengthOf(1)
|
||||||
|
expect(indices[0].name).to.equal('vector_idx')
|
||||||
|
expect(indices[0].uuid).to.not.be.equal(undefined)
|
||||||
|
expect(indices[0].columns).to.have.lengthOf(1)
|
||||||
|
expect(indices[0].columns[0]).to.equal('vector')
|
||||||
|
|
||||||
|
const stats = await table.indexStats(indices[0].uuid)
|
||||||
|
expect(stats.numIndexedRows).to.equal(300)
|
||||||
|
expect(stats.numUnindexedRows).to.equal(0)
|
||||||
|
}).timeout(50_000)
|
||||||
})
|
})
|
||||||
|
|
||||||
describe('when using a custom embedding function', function () {
|
describe('when using a custom embedding function', function () {
|
||||||
@@ -378,6 +397,40 @@ describe('LanceDB client', function () {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
describe('Remote LanceDB client', function () {
|
||||||
|
describe('when the server is not reachable', function () {
|
||||||
|
it('produces a network error', async function () {
|
||||||
|
const con = await lancedb.connect({
|
||||||
|
uri: 'db://test-1234',
|
||||||
|
region: 'asdfasfasfdf',
|
||||||
|
apiKey: 'some-api-key'
|
||||||
|
})
|
||||||
|
|
||||||
|
// GET
|
||||||
|
try {
|
||||||
|
await con.tableNames()
|
||||||
|
} catch (err) {
|
||||||
|
expect(err).to.have.property('message', 'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com')
|
||||||
|
}
|
||||||
|
|
||||||
|
// POST
|
||||||
|
try {
|
||||||
|
await con.createTable({ name: 'vectors', schema: new Schema([]) })
|
||||||
|
} catch (err) {
|
||||||
|
expect(err).to.have.property('message', 'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com')
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search
|
||||||
|
const table = await con.openTable('vectors')
|
||||||
|
try {
|
||||||
|
await table.search([0.1, 0.3]).execute()
|
||||||
|
} catch (err) {
|
||||||
|
expect(err).to.have.property('message', 'Network Error: getaddrinfo ENOTFOUND test-1234.asdfasfasfdf.api.lancedb.com')
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
describe('Query object', function () {
|
describe('Query object', function () {
|
||||||
it('sets custom parameters', async function () {
|
it('sets custom parameters', async function () {
|
||||||
const query = new Query([0.1, 0.3])
|
const query = new Query([0.1, 0.3])
|
||||||
@@ -446,3 +499,45 @@ describe('WriteOptions', function () {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
describe('Compact and cleanup', function () {
|
||||||
|
it('can cleanup after compaction', async function () {
|
||||||
|
const dir = await track().mkdir('lancejs')
|
||||||
|
const con = await lancedb.connect(dir)
|
||||||
|
|
||||||
|
const data = [
|
||||||
|
{ price: 10, name: 'foo', vector: [1, 2, 3] },
|
||||||
|
{ price: 50, name: 'bar', vector: [4, 5, 6] }
|
||||||
|
]
|
||||||
|
const table = await con.createTable('t1', data) as LocalTable
|
||||||
|
|
||||||
|
const newData = [
|
||||||
|
{ price: 30, name: 'baz', vector: [7, 8, 9] }
|
||||||
|
]
|
||||||
|
await table.add(newData)
|
||||||
|
|
||||||
|
const compactionMetrics = await table.compactFiles({
|
||||||
|
numThreads: 2
|
||||||
|
})
|
||||||
|
assert.equal(compactionMetrics.fragmentsRemoved, 2)
|
||||||
|
assert.equal(compactionMetrics.fragmentsAdded, 1)
|
||||||
|
assert.equal(await table.countRows(), 3)
|
||||||
|
|
||||||
|
await table.cleanupOldVersions()
|
||||||
|
assert.equal(await table.countRows(), 3)
|
||||||
|
|
||||||
|
// should have no effect, but this validates the arguments are parsed.
|
||||||
|
await table.compactFiles({
|
||||||
|
targetRowsPerFragment: 1024 * 10,
|
||||||
|
maxRowsPerGroup: 1024,
|
||||||
|
materializeDeletions: true,
|
||||||
|
materializeDeletionsThreshold: 0.5,
|
||||||
|
numThreads: 2
|
||||||
|
})
|
||||||
|
|
||||||
|
const cleanupMetrics = await table.cleanupOldVersions(0, true)
|
||||||
|
assert.isAtLeast(cleanupMetrics.bytesRemoved, 1)
|
||||||
|
assert.isAtLeast(cleanupMetrics.oldVersions, 1)
|
||||||
|
assert.equal(await table.countRows(), 3)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.3.0
|
current_version = 0.3.4
|
||||||
commit = True
|
commit = True
|
||||||
message = [python] Bump version: {current_version} → {new_version}
|
message = [python] Bump version: {current_version} → {new_version}
|
||||||
tag = True
|
tag = True
|
||||||
|
|||||||
1
python/LICENSE
Symbolic link
1
python/LICENSE
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
../LICENSE
|
||||||
@@ -16,7 +16,7 @@ pip install lancedb
|
|||||||
import lancedb
|
import lancedb
|
||||||
db = lancedb.connect('<PATH_TO_LANCEDB_DATASET>')
|
db = lancedb.connect('<PATH_TO_LANCEDB_DATASET>')
|
||||||
table = db.open_table('my_table')
|
table = db.open_table('my_table')
|
||||||
results = table.search([0.1, 0.3]).limit(20).to_df()
|
results = table.search([0.1, 0.3]).limit(20).to_list()
|
||||||
print(results)
|
print(results)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -14,12 +14,14 @@
|
|||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .db import URI, DBConnection, LanceDBConnection
|
|
||||||
from .remote.db import RemoteDBConnection
|
|
||||||
from .schema import vector
|
|
||||||
|
|
||||||
__version__ = importlib.metadata.version("lancedb")
|
__version__ = importlib.metadata.version("lancedb")
|
||||||
|
|
||||||
|
from .common import URI
|
||||||
|
from .db import DBConnection, LanceDBConnection
|
||||||
|
from .remote.db import RemoteDBConnection
|
||||||
|
from .schema import vector # noqa: F401
|
||||||
|
from .utils import sentry_log # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
def connect(
|
def connect(
|
||||||
uri: URI,
|
uri: URI,
|
||||||
|
|||||||
12
python/lancedb/cli/__init__.py
Normal file
12
python/lancedb/cli/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
# Copyright 2023 LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
46
python/lancedb/cli/cli.py
Normal file
46
python/lancedb/cli/cli.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
# Copyright 2023 LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
from lancedb.utils import CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
@click.version_option(help="LanceDB command line interface entry point")
|
||||||
|
def cli():
|
||||||
|
"LanceDB command line interface"
|
||||||
|
|
||||||
|
|
||||||
|
diagnostics_help = """
|
||||||
|
Enable or disable LanceDB diagnostics. When enabled, LanceDB will send anonymous events to help us improve LanceDB.
|
||||||
|
These diagnostics are used only for error reporting and no data is collected. You can find more about diagnosis on
|
||||||
|
our docs: https://lancedb.github.io/lancedb/cli_config/
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command(help=diagnostics_help)
|
||||||
|
@click.option("--enabled/--disabled", default=True)
|
||||||
|
def diagnostics(enabled):
|
||||||
|
CONFIG.update({"diagnostics": True if enabled else False})
|
||||||
|
click.echo("LanceDB diagnostics is %s" % ("enabled" if enabled else "disabled"))
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command(help="Show current LanceDB configuration")
|
||||||
|
def config():
|
||||||
|
# TODO: pretty print as table with colors and formatting
|
||||||
|
click.echo("Current LanceDB configuration:")
|
||||||
|
cfg = CONFIG.copy()
|
||||||
|
cfg.pop("uuid") # Don't show uuid as it is not configurable
|
||||||
|
for item, amount in cfg.items():
|
||||||
|
click.echo("{} ({})".format(item, amount))
|
||||||
@@ -1,4 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@@ -38,3 +40,26 @@ class MockTextEmbeddingFunction(TextEmbeddingFunction):
|
|||||||
|
|
||||||
def ndims(self):
|
def ndims(self):
|
||||||
return 10
|
return 10
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitedAPI:
|
||||||
|
rate_limit = 0.1 # 1 request per 0.1 second
|
||||||
|
last_request_time = 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_request():
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
if current_time - RateLimitedAPI.last_request_time < RateLimitedAPI.rate_limit:
|
||||||
|
raise Exception("Rate limit exceeded. Please try again later.")
|
||||||
|
|
||||||
|
# Simulate a successful request
|
||||||
|
RateLimitedAPI.last_request_time = current_time
|
||||||
|
return "Request successful"
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register("test-rate-limited")
|
||||||
|
class MockRateLimitedEmbeddingFunction(MockTextEmbeddingFunction):
|
||||||
|
def generate_embeddings(self, texts):
|
||||||
|
RateLimitedAPI.make_request()
|
||||||
|
return [self._compute_one_embedding(row) for row in texts]
|
||||||
|
|||||||
@@ -12,6 +12,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import deprecation
|
||||||
|
|
||||||
|
from . import __version__
|
||||||
from .exceptions import MissingColumnError, MissingValueError
|
from .exceptions import MissingColumnError, MissingValueError
|
||||||
from .util import safe_import_pandas
|
from .util import safe_import_pandas
|
||||||
|
|
||||||
@@ -43,7 +46,7 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
|
|||||||
this how many tokens, but depending on the input data, it could be sentences,
|
this how many tokens, but depending on the input data, it could be sentences,
|
||||||
paragraphs, messages, etc.
|
paragraphs, messages, etc.
|
||||||
|
|
||||||
>>> contextualize(data).window(3).stride(1).text_col('token').to_df()
|
>>> contextualize(data).window(3).stride(1).text_col('token').to_pandas()
|
||||||
token document_id
|
token document_id
|
||||||
0 The quick brown 1
|
0 The quick brown 1
|
||||||
1 quick brown fox 1
|
1 quick brown fox 1
|
||||||
@@ -56,7 +59,7 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
|
|||||||
8 dog I love 1
|
8 dog I love 1
|
||||||
9 I love sandwiches 2
|
9 I love sandwiches 2
|
||||||
10 love sandwiches 2
|
10 love sandwiches 2
|
||||||
>>> contextualize(data).window(7).stride(1).min_window_size(7).text_col('token').to_df()
|
>>> contextualize(data).window(7).stride(1).min_window_size(7).text_col('token').to_pandas()
|
||||||
token document_id
|
token document_id
|
||||||
0 The quick brown fox jumped over the 1
|
0 The quick brown fox jumped over the 1
|
||||||
1 quick brown fox jumped over the lazy 1
|
1 quick brown fox jumped over the lazy 1
|
||||||
@@ -68,7 +71,7 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
|
|||||||
``stride`` determines how many rows to skip between each window start. This can
|
``stride`` determines how many rows to skip between each window start. This can
|
||||||
be used to reduce the total number of windows generated.
|
be used to reduce the total number of windows generated.
|
||||||
|
|
||||||
>>> contextualize(data).window(4).stride(2).text_col('token').to_df()
|
>>> contextualize(data).window(4).stride(2).text_col('token').to_pandas()
|
||||||
token document_id
|
token document_id
|
||||||
0 The quick brown fox 1
|
0 The quick brown fox 1
|
||||||
2 brown fox jumped over 1
|
2 brown fox jumped over 1
|
||||||
@@ -81,7 +84,9 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
|
|||||||
context windows that don't cross document boundaries. In this case, we can
|
context windows that don't cross document boundaries. In this case, we can
|
||||||
pass ``document_id`` as the group by.
|
pass ``document_id`` as the group by.
|
||||||
|
|
||||||
>>> contextualize(data).window(4).stride(2).text_col('token').groupby('document_id').to_df()
|
>>> (contextualize(data)
|
||||||
|
... .window(4).stride(2).text_col('token').groupby('document_id')
|
||||||
|
... .to_pandas())
|
||||||
token document_id
|
token document_id
|
||||||
0 The quick brown fox 1
|
0 The quick brown fox 1
|
||||||
2 brown fox jumped over 1
|
2 brown fox jumped over 1
|
||||||
@@ -89,18 +94,24 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
|
|||||||
6 the lazy dog 1
|
6 the lazy dog 1
|
||||||
9 I love sandwiches 2
|
9 I love sandwiches 2
|
||||||
|
|
||||||
``min_window_size`` determines the minimum size of the context windows that are generated
|
``min_window_size`` determines the minimum size of the context windows
|
||||||
This can be used to trim the last few context windows which have size less than
|
that are generated.This can be used to trim the last few context windows
|
||||||
``min_window_size``. By default context windows of size 1 are skipped.
|
which have size less than ``min_window_size``.
|
||||||
|
By default context windows of size 1 are skipped.
|
||||||
|
|
||||||
>>> contextualize(data).window(6).stride(3).text_col('token').groupby('document_id').to_df()
|
>>> (contextualize(data)
|
||||||
|
... .window(6).stride(3).text_col('token').groupby('document_id')
|
||||||
|
... .to_pandas())
|
||||||
token document_id
|
token document_id
|
||||||
0 The quick brown fox jumped over 1
|
0 The quick brown fox jumped over 1
|
||||||
3 fox jumped over the lazy dog 1
|
3 fox jumped over the lazy dog 1
|
||||||
6 the lazy dog 1
|
6 the lazy dog 1
|
||||||
9 I love sandwiches 2
|
9 I love sandwiches 2
|
||||||
|
|
||||||
>>> contextualize(data).window(6).stride(3).min_window_size(4).text_col('token').groupby('document_id').to_df()
|
>>> (contextualize(data)
|
||||||
|
... .window(6).stride(3).min_window_size(4).text_col('token')
|
||||||
|
... .groupby('document_id')
|
||||||
|
... .to_pandas())
|
||||||
token document_id
|
token document_id
|
||||||
0 The quick brown fox jumped over 1
|
0 The quick brown fox jumped over 1
|
||||||
3 fox jumped over the lazy dog 1
|
3 fox jumped over the lazy dog 1
|
||||||
@@ -110,7 +121,9 @@ def contextualize(raw_df: "pd.DataFrame") -> Contextualizer:
|
|||||||
|
|
||||||
|
|
||||||
class Contextualizer:
|
class Contextualizer:
|
||||||
"""Create context windows from a DataFrame. See [lancedb.context.contextualize][]."""
|
"""Create context windows from a DataFrame.
|
||||||
|
See [lancedb.context.contextualize][].
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, raw_df):
|
def __init__(self, raw_df):
|
||||||
self._text_col = None
|
self._text_col = None
|
||||||
@@ -176,7 +189,16 @@ class Contextualizer:
|
|||||||
self._min_window_size = min_window_size
|
self._min_window_size = min_window_size
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@deprecation.deprecated(
|
||||||
|
deprecated_in="0.3.1",
|
||||||
|
removed_in="0.4.0",
|
||||||
|
current_version=__version__,
|
||||||
|
details="Use to_pandas() instead",
|
||||||
|
)
|
||||||
def to_df(self) -> "pd.DataFrame":
|
def to_df(self) -> "pd.DataFrame":
|
||||||
|
return self.to_pandas()
|
||||||
|
|
||||||
|
def to_pandas(self) -> "pd.DataFrame":
|
||||||
"""Create the context windows and return a DataFrame."""
|
"""Create the context windows and return a DataFrame."""
|
||||||
if pd is None:
|
if pd is None:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
|
|||||||
@@ -14,26 +14,39 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Union
|
from typing import TYPE_CHECKING, Iterable, List, Optional, Union
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
from overrides import EnforceOverrides, override
|
||||||
from pyarrow import fs
|
from pyarrow import fs
|
||||||
|
|
||||||
from .common import DATA, URI
|
|
||||||
from .embeddings import EmbeddingFunctionConfig
|
|
||||||
from .pydantic import LanceModel
|
|
||||||
from .table import LanceTable, Table
|
from .table import LanceTable, Table
|
||||||
from .util import fs_from_uri, get_uri_location, get_uri_scheme
|
from .util import fs_from_uri, get_uri_location, get_uri_scheme
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .common import DATA, URI
|
||||||
|
from .embeddings import EmbeddingFunctionConfig
|
||||||
|
from .pydantic import LanceModel
|
||||||
|
|
||||||
class DBConnection(ABC):
|
|
||||||
|
class DBConnection(EnforceOverrides):
|
||||||
"""An active LanceDB connection interface."""
|
"""An active LanceDB connection interface."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def table_names(self) -> list[str]:
|
def table_names(
|
||||||
"""List all table names in the database."""
|
self, page_token: Optional[str] = None, limit: int = 10
|
||||||
|
) -> Iterable[str]:
|
||||||
|
"""List all table in this database
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
page_token: str, optional
|
||||||
|
The token to use for pagination. If not present, start from the beginning.
|
||||||
|
limit: int, default 10
|
||||||
|
The size of the page to return.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -45,6 +58,7 @@ class DBConnection(ABC):
|
|||||||
mode: str = "create",
|
mode: str = "create",
|
||||||
on_bad_vectors: str = "error",
|
on_bad_vectors: str = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
|
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||||
) -> Table:
|
) -> Table:
|
||||||
"""Create a [Table][lancedb.table.Table] in the database.
|
"""Create a [Table][lancedb.table.Table] in the database.
|
||||||
|
|
||||||
@@ -52,12 +66,24 @@ class DBConnection(ABC):
|
|||||||
----------
|
----------
|
||||||
name: str
|
name: str
|
||||||
The name of the table.
|
The name of the table.
|
||||||
data: list, tuple, dict, pd.DataFrame; optional
|
data: The data to initialize the table, *optional*
|
||||||
The data to initialize the table. User must provide at least one of `data` or `schema`.
|
User must provide at least one of `data` or `schema`.
|
||||||
schema: pyarrow.Schema or LanceModel; optional
|
Acceptable types are:
|
||||||
The schema of the table.
|
|
||||||
|
- dict or list-of-dict
|
||||||
|
|
||||||
|
- pandas.DataFrame
|
||||||
|
|
||||||
|
- pyarrow.Table or pyarrow.RecordBatch
|
||||||
|
schema: The schema of the table, *optional*
|
||||||
|
Acceptable types are:
|
||||||
|
|
||||||
|
- pyarrow.Schema
|
||||||
|
|
||||||
|
- [LanceModel][lancedb.pydantic.LanceModel]
|
||||||
mode: str; default "create"
|
mode: str; default "create"
|
||||||
The mode to use when creating the table. Can be either "create" or "overwrite".
|
The mode to use when creating the table.
|
||||||
|
Can be either "create" or "overwrite".
|
||||||
By default, if the table already exists, an exception is raised.
|
By default, if the table already exists, an exception is raised.
|
||||||
If you want to overwrite the table, use mode="overwrite".
|
If you want to overwrite the table, use mode="overwrite".
|
||||||
on_bad_vectors: str, default "error"
|
on_bad_vectors: str, default "error"
|
||||||
@@ -150,7 +176,8 @@ class DBConnection(ABC):
|
|||||||
... for i in range(5):
|
... for i in range(5):
|
||||||
... yield pa.RecordBatch.from_arrays(
|
... yield pa.RecordBatch.from_arrays(
|
||||||
... [
|
... [
|
||||||
... pa.array([[3.1, 4.1], [5.9, 26.5]], pa.list_(pa.float32(), 2)),
|
... pa.array([[3.1, 4.1], [5.9, 26.5]],
|
||||||
|
... pa.list_(pa.float32(), 2)),
|
||||||
... pa.array(["foo", "bar"]),
|
... pa.array(["foo", "bar"]),
|
||||||
... pa.array([10.0, 20.0]),
|
... pa.array([10.0, 20.0]),
|
||||||
... ],
|
... ],
|
||||||
@@ -249,12 +276,15 @@ class LanceDBConnection(DBConnection):
|
|||||||
def uri(self) -> str:
|
def uri(self) -> str:
|
||||||
return self._uri
|
return self._uri
|
||||||
|
|
||||||
def table_names(self) -> list[str]:
|
@override
|
||||||
"""Get the names of all tables in the database.
|
def table_names(
|
||||||
|
self, page_token: Optional[str] = None, limit: int = 10
|
||||||
|
) -> Iterable[str]:
|
||||||
|
"""Get the names of all tables in the database. The names are sorted.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
list of str
|
Iterator of str.
|
||||||
A list of table names.
|
A list of table names.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
@@ -274,6 +304,7 @@ class LanceDBConnection(DBConnection):
|
|||||||
for file_info in paths
|
for file_info in paths
|
||||||
if file_info.extension == "lance"
|
if file_info.extension == "lance"
|
||||||
]
|
]
|
||||||
|
tables.sort()
|
||||||
return tables
|
return tables
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
@@ -282,6 +313,7 @@ class LanceDBConnection(DBConnection):
|
|||||||
def __contains__(self, name: str) -> bool:
|
def __contains__(self, name: str) -> bool:
|
||||||
return name in self.table_names()
|
return name in self.table_names()
|
||||||
|
|
||||||
|
@override
|
||||||
def create_table(
|
def create_table(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
@@ -313,6 +345,7 @@ class LanceDBConnection(DBConnection):
|
|||||||
)
|
)
|
||||||
return tbl
|
return tbl
|
||||||
|
|
||||||
|
@override
|
||||||
def open_table(self, name: str) -> LanceTable:
|
def open_table(self, name: str) -> LanceTable:
|
||||||
"""Open a table in the database.
|
"""Open a table in the database.
|
||||||
|
|
||||||
@@ -327,6 +360,7 @@ class LanceDBConnection(DBConnection):
|
|||||||
"""
|
"""
|
||||||
return LanceTable.open(self, name)
|
return LanceTable.open(self, name)
|
||||||
|
|
||||||
|
@override
|
||||||
def drop_table(self, name: str, ignore_missing: bool = False):
|
def drop_table(self, name: str, ignore_missing: bool = False):
|
||||||
"""Drop a table from the database.
|
"""Drop a table from the database.
|
||||||
|
|
||||||
@@ -345,6 +379,7 @@ class LanceDBConnection(DBConnection):
|
|||||||
if not ignore_missing:
|
if not ignore_missing:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@override
|
||||||
def drop_database(self):
|
def drop_database(self):
|
||||||
filesystem, path = fs_from_uri(self.uri)
|
filesystem, path = fs_from_uri(self.uri)
|
||||||
filesystem.delete_dir(path)
|
filesystem.delete_dir(path)
|
||||||
|
|||||||
@@ -11,14 +11,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
# ruff: noqa: F401
|
||||||
from .functions import (
|
from .base import EmbeddingFunction, EmbeddingFunctionConfig, TextEmbeddingFunction
|
||||||
EmbeddingFunction,
|
from .cohere import CohereEmbeddingFunction
|
||||||
EmbeddingFunctionConfig,
|
from .instructor import InstructorEmbeddingFunction
|
||||||
EmbeddingFunctionRegistry,
|
from .open_clip import OpenClipEmbeddings
|
||||||
OpenAIEmbeddings,
|
from .openai import OpenAIEmbeddings
|
||||||
OpenClipEmbeddings,
|
from .registry import EmbeddingFunctionRegistry, get_registry
|
||||||
SentenceTransformerEmbeddings,
|
from .sentence_transformers import SentenceTransformerEmbeddings
|
||||||
TextEmbeddingFunction,
|
|
||||||
)
|
|
||||||
from .utils import with_embeddings
|
from .utils import with_embeddings
|
||||||
|
|||||||
181
python/lancedb/embeddings/base.py
Normal file
181
python/lancedb/embeddings/base.py
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
# Copyright (c) 2023. LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import importlib
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pyarrow as pa
|
||||||
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
|
|
||||||
|
from .utils import TEXT, retry_with_exponential_backoff
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingFunction(BaseModel, ABC):
|
||||||
|
"""
|
||||||
|
An ABC for embedding functions.
|
||||||
|
|
||||||
|
All concrete embedding functions must implement the following:
|
||||||
|
1. compute_query_embeddings() which takes a query and returns a list of embeddings
|
||||||
|
2. get_source_embeddings() which returns a list of embeddings for the source column
|
||||||
|
For text data, the two will be the same. For multi-modal data, the source column
|
||||||
|
might be images and the vector column might be text.
|
||||||
|
3. ndims method which returns the number of dimensions of the vector column
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ("__weakref__",) # pydantic 1.x compatibility
|
||||||
|
max_retries: int = (
|
||||||
|
7 # Setitng 0 disables retires. Maybe this should not be enabled by default,
|
||||||
|
)
|
||||||
|
_ndims: int = PrivateAttr()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, **kwargs):
|
||||||
|
"""
|
||||||
|
Create an instance of the embedding function
|
||||||
|
"""
|
||||||
|
return cls(**kwargs)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Compute the embeddings for a given user query
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Compute the embeddings for the source column in the database
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def compute_query_embeddings_with_retry(self, *args, **kwargs) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Compute the embeddings for a given user query with retries
|
||||||
|
"""
|
||||||
|
return retry_with_exponential_backoff(
|
||||||
|
self.compute_query_embeddings, max_retries=self.max_retries
|
||||||
|
)(
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_source_embeddings_with_retry(self, *args, **kwargs) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Compute the embeddings for the source column in the database with retries
|
||||||
|
"""
|
||||||
|
return retry_with_exponential_backoff(
|
||||||
|
self.compute_source_embeddings, max_retries=self.max_retries
|
||||||
|
)(*args, **kwargs)
|
||||||
|
|
||||||
|
def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
|
||||||
|
"""
|
||||||
|
Sanitize the input to the embedding function.
|
||||||
|
"""
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
elif isinstance(texts, pa.Array):
|
||||||
|
texts = texts.to_pylist()
|
||||||
|
elif isinstance(texts, pa.ChunkedArray):
|
||||||
|
texts = texts.combine_chunks().to_pylist()
|
||||||
|
return texts
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def safe_import(cls, module: str, mitigation=None):
|
||||||
|
"""
|
||||||
|
Import the specified module. If the module is not installed,
|
||||||
|
raise an ImportError with a helpful message.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
module : str
|
||||||
|
The name of the module to import
|
||||||
|
mitigation : Optional[str]
|
||||||
|
The package(s) to install to mitigate the error.
|
||||||
|
If not provided then the module name will be used.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return importlib.import_module(module)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(f"Please install {mitigation or module}")
|
||||||
|
|
||||||
|
def safe_model_dump(self):
|
||||||
|
from ..pydantic import PYDANTIC_VERSION
|
||||||
|
|
||||||
|
if PYDANTIC_VERSION.major < 2:
|
||||||
|
return dict(self)
|
||||||
|
return self.model_dump()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def ndims(self):
|
||||||
|
"""
|
||||||
|
Return the dimensions of the vector column
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def SourceField(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Creates a pydantic Field that can automatically annotate
|
||||||
|
the source column for this embedding function
|
||||||
|
"""
|
||||||
|
return Field(json_schema_extra={"source_column_for": self}, **kwargs)
|
||||||
|
|
||||||
|
def VectorField(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Creates a pydantic Field that can automatically annotate
|
||||||
|
the target vector column for this embedding function
|
||||||
|
"""
|
||||||
|
return Field(json_schema_extra={"vector_column_for": self}, **kwargs)
|
||||||
|
|
||||||
|
def __eq__(self, __value: object) -> bool:
|
||||||
|
if not hasattr(__value, "__dict__"):
|
||||||
|
return False
|
||||||
|
return vars(self) == vars(__value)
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash(frozenset(vars(self).items()))
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingFunctionConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
This model encapsulates the configuration for a embedding function
|
||||||
|
in a lancedb table. It holds the embedding function, the source column,
|
||||||
|
and the vector column
|
||||||
|
"""
|
||||||
|
|
||||||
|
vector_column: str
|
||||||
|
source_column: str
|
||||||
|
function: EmbeddingFunction
|
||||||
|
|
||||||
|
|
||||||
|
class TextEmbeddingFunction(EmbeddingFunction):
|
||||||
|
"""
|
||||||
|
A callable ABC for embedding functions that take text as input
|
||||||
|
"""
|
||||||
|
|
||||||
|
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||||
|
return self.compute_source_embeddings(query, *args, **kwargs)
|
||||||
|
|
||||||
|
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||||
|
texts = self.sanitize_input(texts)
|
||||||
|
return self.generate_embeddings(texts)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate_embeddings(
|
||||||
|
self, texts: Union[List[str], np.ndarray]
|
||||||
|
) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Generate the embeddings for the given texts
|
||||||
|
"""
|
||||||
|
pass
|
||||||
91
python/lancedb/embeddings/cohere.py
Normal file
91
python/lancedb/embeddings/cohere.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
# Copyright (c) 2023. LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import ClassVar, List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .base import TextEmbeddingFunction
|
||||||
|
from .registry import register
|
||||||
|
from .utils import api_key_not_found_help
|
||||||
|
|
||||||
|
|
||||||
|
@register("cohere")
|
||||||
|
class CohereEmbeddingFunction(TextEmbeddingFunction):
|
||||||
|
"""
|
||||||
|
An embedding function that uses the Cohere API
|
||||||
|
|
||||||
|
https://docs.cohere.com/docs/multilingual-language-models
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name: str, default "embed-multilingual-v2.0"
|
||||||
|
The name of the model to use. See the Cohere documentation for
|
||||||
|
a list of available models.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
import lancedb
|
||||||
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||||
|
|
||||||
|
cohere = EmbeddingFunctionRegistry
|
||||||
|
.get_instance()
|
||||||
|
.get("cohere")
|
||||||
|
.create(name="embed-multilingual-v2.0")
|
||||||
|
|
||||||
|
class TextModel(LanceModel):
|
||||||
|
text: str = cohere.SourceField()
|
||||||
|
vector: Vector(cohere.ndims()) = cohere.VectorField()
|
||||||
|
|
||||||
|
data = [ { "text": "hello world" },
|
||||||
|
{ "text": "goodbye world" }]
|
||||||
|
|
||||||
|
db = lancedb.connect("~/.lancedb")
|
||||||
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||||
|
|
||||||
|
tbl.add(data)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "embed-multilingual-v2.0"
|
||||||
|
client: ClassVar = None
|
||||||
|
|
||||||
|
def ndims(self):
|
||||||
|
# TODO: fix hardcoding
|
||||||
|
return 768
|
||||||
|
|
||||||
|
def generate_embeddings(
|
||||||
|
self, texts: Union[List[str], np.ndarray]
|
||||||
|
) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Get the embeddings for the given texts
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
texts: list[str] or np.ndarray (of str)
|
||||||
|
The texts to embed
|
||||||
|
"""
|
||||||
|
# TODO retry, rate limit, token limit
|
||||||
|
self._init_client()
|
||||||
|
rs = CohereEmbeddingFunction.client.embed(texts=texts, model=self.name)
|
||||||
|
|
||||||
|
return [emb for emb in rs.embeddings]
|
||||||
|
|
||||||
|
def _init_client(self):
|
||||||
|
cohere = self.safe_import("cohere")
|
||||||
|
if CohereEmbeddingFunction.client is None:
|
||||||
|
if os.environ.get("COHERE_API_KEY") is None:
|
||||||
|
api_key_not_found_help("cohere")
|
||||||
|
CohereEmbeddingFunction.client = cohere.Client(os.environ["COHERE_API_KEY"])
|
||||||
@@ -1,578 +0,0 @@
|
|||||||
# Copyright (c) 2023. LanceDB Developers
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import concurrent.futures
|
|
||||||
import importlib
|
|
||||||
import io
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import socket
|
|
||||||
import urllib.error
|
|
||||||
import urllib.parse as urlparse
|
|
||||||
import urllib.request
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Dict, List, Optional, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pyarrow as pa
|
|
||||||
from cachetools import cached
|
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingFunctionRegistry:
|
|
||||||
"""
|
|
||||||
This is a singleton class used to register embedding functions
|
|
||||||
and fetch them by name. It also handles serializing and deserializing.
|
|
||||||
You can implement your own embedding function by subclassing EmbeddingFunction
|
|
||||||
or TextEmbeddingFunction and registering it with the registry.
|
|
||||||
|
|
||||||
Examples
|
|
||||||
--------
|
|
||||||
>>> registry = EmbeddingFunctionRegistry.get_instance()
|
|
||||||
>>> @registry.register("my-embedding-function")
|
|
||||||
... class MyEmbeddingFunction(EmbeddingFunction):
|
|
||||||
... def ndims(self) -> int:
|
|
||||||
... return 128
|
|
||||||
...
|
|
||||||
... def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
|
||||||
... return self.compute_source_embeddings(query, *args, **kwargs)
|
|
||||||
...
|
|
||||||
... def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
|
||||||
... return [np.random.rand(self.ndims()) for _ in range(len(texts))]
|
|
||||||
...
|
|
||||||
>>> registry.get("my-embedding-function")
|
|
||||||
<class 'lancedb.embeddings.functions.MyEmbeddingFunction'>
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_instance(cls):
|
|
||||||
return __REGISTRY__
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._functions = {}
|
|
||||||
|
|
||||||
def register(self, alias: str = None):
|
|
||||||
"""
|
|
||||||
This creates a decorator that can be used to register
|
|
||||||
an EmbeddingFunction.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
alias : Optional[str]
|
|
||||||
a human friendly name for the embedding function. If not
|
|
||||||
provided, the class name will be used.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# This is a decorator for a class that inherits from BaseModel
|
|
||||||
# It adds the class to the registry
|
|
||||||
def decorator(cls):
|
|
||||||
if not issubclass(cls, EmbeddingFunction):
|
|
||||||
raise TypeError("Must be a subclass of EmbeddingFunction")
|
|
||||||
if cls.__name__ in self._functions:
|
|
||||||
raise KeyError(f"{cls.__name__} was already registered")
|
|
||||||
key = alias or cls.__name__
|
|
||||||
self._functions[key] = cls
|
|
||||||
cls.__embedding_function_registry_alias__ = alias
|
|
||||||
return cls
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
"""
|
|
||||||
Reset the registry to its initial state
|
|
||||||
"""
|
|
||||||
self._functions = {}
|
|
||||||
|
|
||||||
def get(self, name: str):
|
|
||||||
"""
|
|
||||||
Fetch an embedding function class by name
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
name : str
|
|
||||||
The name of the embedding function to fetch
|
|
||||||
Either the alias or the class name if no alias was provided
|
|
||||||
during registration
|
|
||||||
"""
|
|
||||||
return self._functions[name]
|
|
||||||
|
|
||||||
def parse_functions(
|
|
||||||
self, metadata: Optional[Dict[bytes, bytes]]
|
|
||||||
) -> Dict[str, "EmbeddingFunctionConfig"]:
|
|
||||||
"""
|
|
||||||
Parse the metadata from an arrow table and
|
|
||||||
return a mapping of the vector column to the
|
|
||||||
embedding function and source column
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
metadata : Optional[Dict[bytes, bytes]]
|
|
||||||
The metadata from an arrow table. Note that
|
|
||||||
the keys and values are bytes (pyarrow api)
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
functions : dict
|
|
||||||
A mapping of vector column name to embedding function.
|
|
||||||
An empty dict is returned if input is None or does not
|
|
||||||
contain b"embedding_functions".
|
|
||||||
"""
|
|
||||||
if metadata is None or b"embedding_functions" not in metadata:
|
|
||||||
return {}
|
|
||||||
serialized = metadata[b"embedding_functions"]
|
|
||||||
raw_list = json.loads(serialized.decode("utf-8"))
|
|
||||||
return {
|
|
||||||
obj["vector_column"]: EmbeddingFunctionConfig(
|
|
||||||
vector_column=obj["vector_column"],
|
|
||||||
source_column=obj["source_column"],
|
|
||||||
function=self.get(obj["name"])(**obj["model"]),
|
|
||||||
)
|
|
||||||
for obj in raw_list
|
|
||||||
}
|
|
||||||
|
|
||||||
def function_to_metadata(self, conf: "EmbeddingFunctionConfig"):
|
|
||||||
"""
|
|
||||||
Convert the given embedding function and source / vector column configs
|
|
||||||
into a config dictionary that can be serialized into arrow metadata
|
|
||||||
"""
|
|
||||||
func = conf.function
|
|
||||||
name = getattr(
|
|
||||||
func, "__embedding_function_registry_alias__", func.__class__.__name__
|
|
||||||
)
|
|
||||||
json_data = func.safe_model_dump()
|
|
||||||
return {
|
|
||||||
"name": name,
|
|
||||||
"model": json_data,
|
|
||||||
"source_column": conf.source_column,
|
|
||||||
"vector_column": conf.vector_column,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_table_metadata(self, func_list):
|
|
||||||
"""
|
|
||||||
Convert a list of embedding functions and source / vector configs
|
|
||||||
into a config dictionary that can be serialized into arrow metadata
|
|
||||||
"""
|
|
||||||
if func_list is None or len(func_list) == 0:
|
|
||||||
return None
|
|
||||||
json_data = [self.function_to_metadata(func) for func in func_list]
|
|
||||||
# Note that metadata dictionary values must be bytes
|
|
||||||
# so we need to json dump then utf8 encode
|
|
||||||
metadata = json.dumps(json_data, indent=2).encode("utf-8")
|
|
||||||
return {"embedding_functions": metadata}
|
|
||||||
|
|
||||||
|
|
||||||
# Global instance
|
|
||||||
__REGISTRY__ = EmbeddingFunctionRegistry()
|
|
||||||
|
|
||||||
|
|
||||||
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
|
||||||
IMAGES = Union[
|
|
||||||
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingFunction(BaseModel, ABC):
|
|
||||||
"""
|
|
||||||
An ABC for embedding functions.
|
|
||||||
|
|
||||||
All concrete embedding functions must implement the following:
|
|
||||||
1. compute_query_embeddings() which takes a query and returns a list of embeddings
|
|
||||||
2. get_source_embeddings() which returns a list of embeddings for the source column
|
|
||||||
For text data, the two will be the same. For multi-modal data, the source column
|
|
||||||
might be images and the vector column might be text.
|
|
||||||
3. ndims method which returns the number of dimensions of the vector column
|
|
||||||
"""
|
|
||||||
|
|
||||||
_ndims: int = PrivateAttr()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(cls, **kwargs):
|
|
||||||
"""
|
|
||||||
Create an instance of the embedding function
|
|
||||||
"""
|
|
||||||
return cls(**kwargs)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]:
|
|
||||||
"""
|
|
||||||
Compute the embeddings for a given user query
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]:
|
|
||||||
"""
|
|
||||||
Compute the embeddings for the source column in the database
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def sanitize_input(self, texts: TEXT) -> Union[List[str], np.ndarray]:
|
|
||||||
"""
|
|
||||||
Sanitize the input to the embedding function.
|
|
||||||
"""
|
|
||||||
if isinstance(texts, str):
|
|
||||||
texts = [texts]
|
|
||||||
elif isinstance(texts, pa.Array):
|
|
||||||
texts = texts.to_pylist()
|
|
||||||
elif isinstance(texts, pa.ChunkedArray):
|
|
||||||
texts = texts.combine_chunks().to_pylist()
|
|
||||||
return texts
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def safe_import(cls, module: str, mitigation=None):
|
|
||||||
"""
|
|
||||||
Import the specified module. If the module is not installed,
|
|
||||||
raise an ImportError with a helpful message.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
module : str
|
|
||||||
The name of the module to import
|
|
||||||
mitigation : Optional[str]
|
|
||||||
The package(s) to install to mitigate the error.
|
|
||||||
If not provided then the module name will be used.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return importlib.import_module(module)
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(f"Please install {mitigation or module}")
|
|
||||||
|
|
||||||
def safe_model_dump(self):
|
|
||||||
from ..pydantic import PYDANTIC_VERSION
|
|
||||||
|
|
||||||
if PYDANTIC_VERSION.major < 2:
|
|
||||||
return dict(self)
|
|
||||||
return self.model_dump()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def ndims(self):
|
|
||||||
"""
|
|
||||||
Return the dimensions of the vector column
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def SourceField(self, **kwargs):
|
|
||||||
"""
|
|
||||||
Creates a pydantic Field that can automatically annotate
|
|
||||||
the source column for this embedding function
|
|
||||||
"""
|
|
||||||
return Field(json_schema_extra={"source_column_for": self}, **kwargs)
|
|
||||||
|
|
||||||
def VectorField(self, **kwargs):
|
|
||||||
"""
|
|
||||||
Creates a pydantic Field that can automatically annotate
|
|
||||||
the target vector column for this embedding function
|
|
||||||
"""
|
|
||||||
return Field(json_schema_extra={"vector_column_for": self}, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingFunctionConfig(BaseModel):
|
|
||||||
"""
|
|
||||||
This model encapsulates the configuration for a embedding function
|
|
||||||
in a lancedb table. It holds the embedding function, the source column,
|
|
||||||
and the vector column
|
|
||||||
"""
|
|
||||||
|
|
||||||
vector_column: str
|
|
||||||
source_column: str
|
|
||||||
function: EmbeddingFunction
|
|
||||||
|
|
||||||
|
|
||||||
class TextEmbeddingFunction(EmbeddingFunction):
|
|
||||||
"""
|
|
||||||
A callable ABC for embedding functions that take text as input
|
|
||||||
"""
|
|
||||||
|
|
||||||
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
|
||||||
return self.compute_source_embeddings(query, *args, **kwargs)
|
|
||||||
|
|
||||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
|
||||||
texts = self.sanitize_input(texts)
|
|
||||||
return self.generate_embeddings(texts)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def generate_embeddings(
|
|
||||||
self, texts: Union[List[str], np.ndarray]
|
|
||||||
) -> List[np.array]:
|
|
||||||
"""
|
|
||||||
Generate the embeddings for the given texts
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8
|
|
||||||
register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name)
|
|
||||||
|
|
||||||
|
|
||||||
@register("sentence-transformers")
|
|
||||||
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
|
||||||
"""
|
|
||||||
An embedding function that uses the sentence-transformers library
|
|
||||||
|
|
||||||
https://huggingface.co/sentence-transformers
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str = "all-MiniLM-L6-v2"
|
|
||||||
device: str = "cpu"
|
|
||||||
normalize: bool = True
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._ndims = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def embedding_model(self):
|
|
||||||
"""
|
|
||||||
Get the sentence-transformers embedding model specified by the
|
|
||||||
name and device. This is cached so that the model is only loaded
|
|
||||||
once per process.
|
|
||||||
"""
|
|
||||||
return self.__class__.get_embedding_model(self.name, self.device)
|
|
||||||
|
|
||||||
def ndims(self):
|
|
||||||
if self._ndims is None:
|
|
||||||
self._ndims = len(self.generate_embeddings("foo")[0])
|
|
||||||
return self._ndims
|
|
||||||
|
|
||||||
def generate_embeddings(
|
|
||||||
self, texts: Union[List[str], np.ndarray]
|
|
||||||
) -> List[np.array]:
|
|
||||||
"""
|
|
||||||
Get the embeddings for the given texts
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
texts: list[str] or np.ndarray (of str)
|
|
||||||
The texts to embed
|
|
||||||
"""
|
|
||||||
return self.embedding_model.encode(
|
|
||||||
list(texts),
|
|
||||||
convert_to_numpy=True,
|
|
||||||
normalize_embeddings=self.normalize,
|
|
||||||
).tolist()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@cached(cache={})
|
|
||||||
def get_embedding_model(cls, name, device):
|
|
||||||
"""
|
|
||||||
Get the sentence-transformers embedding model specified by the
|
|
||||||
name and device. This is cached so that the model is only loaded
|
|
||||||
once per process.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
name : str
|
|
||||||
The name of the model to load
|
|
||||||
device : str
|
|
||||||
The device to load the model on
|
|
||||||
|
|
||||||
TODO: use lru_cache instead with a reasonable/configurable maxsize
|
|
||||||
"""
|
|
||||||
sentence_transformers = cls.safe_import(
|
|
||||||
"sentence_transformers", "sentence-transformers"
|
|
||||||
)
|
|
||||||
return sentence_transformers.SentenceTransformer(name, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
@register("openai")
|
|
||||||
class OpenAIEmbeddings(TextEmbeddingFunction):
|
|
||||||
"""
|
|
||||||
An embedding function that uses the OpenAI API
|
|
||||||
|
|
||||||
https://platform.openai.com/docs/guides/embeddings
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str = "text-embedding-ada-002"
|
|
||||||
|
|
||||||
def ndims(self):
|
|
||||||
# TODO don't hardcode this
|
|
||||||
return 1536
|
|
||||||
|
|
||||||
def generate_embeddings(
|
|
||||||
self, texts: Union[List[str], np.ndarray]
|
|
||||||
) -> List[np.array]:
|
|
||||||
"""
|
|
||||||
Get the embeddings for the given texts
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
texts: list[str] or np.ndarray (of str)
|
|
||||||
The texts to embed
|
|
||||||
"""
|
|
||||||
# TODO retry, rate limit, token limit
|
|
||||||
openai = self.safe_import("openai")
|
|
||||||
rs = openai.Embedding.create(input=texts, model=self.name)["data"]
|
|
||||||
return [v["embedding"] for v in rs]
|
|
||||||
|
|
||||||
|
|
||||||
@register("open-clip")
|
|
||||||
class OpenClipEmbeddings(EmbeddingFunction):
|
|
||||||
"""
|
|
||||||
An embedding function that uses the OpenClip API
|
|
||||||
For multi-modal text-to-image search
|
|
||||||
|
|
||||||
https://github.com/mlfoundations/open_clip
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str = "ViT-B-32"
|
|
||||||
pretrained: str = "laion2b_s34b_b79k"
|
|
||||||
device: str = "cpu"
|
|
||||||
batch_size: int = 64
|
|
||||||
normalize: bool = True
|
|
||||||
_model = PrivateAttr()
|
|
||||||
_preprocess = PrivateAttr()
|
|
||||||
_tokenizer = PrivateAttr()
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
open_clip = self.safe_import("open_clip", "open-clip")
|
|
||||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
|
||||||
self.name, pretrained=self.pretrained
|
|
||||||
)
|
|
||||||
model.to(self.device)
|
|
||||||
self._model, self._preprocess = model, preprocess
|
|
||||||
self._tokenizer = open_clip.get_tokenizer(self.name)
|
|
||||||
self._ndims = None
|
|
||||||
|
|
||||||
def ndims(self):
|
|
||||||
if self._ndims is None:
|
|
||||||
self._ndims = self.generate_text_embeddings("foo").shape[0]
|
|
||||||
return self._ndims
|
|
||||||
|
|
||||||
def compute_query_embeddings(
|
|
||||||
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
|
|
||||||
) -> List[np.ndarray]:
|
|
||||||
"""
|
|
||||||
Compute the embeddings for a given user query
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
query : Union[str, PIL.Image.Image]
|
|
||||||
The query to embed. A query can be either text or an image.
|
|
||||||
"""
|
|
||||||
if isinstance(query, str):
|
|
||||||
return [self.generate_text_embeddings(query)]
|
|
||||||
else:
|
|
||||||
PIL = self.safe_import("PIL", "pillow")
|
|
||||||
if isinstance(query, PIL.Image.Image):
|
|
||||||
return [self.generate_image_embedding(query)]
|
|
||||||
else:
|
|
||||||
raise TypeError("OpenClip supports str or PIL Image as query")
|
|
||||||
|
|
||||||
def generate_text_embeddings(self, text: str) -> np.ndarray:
|
|
||||||
torch = self.safe_import("torch")
|
|
||||||
text = self.sanitize_input(text)
|
|
||||||
text = self._tokenizer(text)
|
|
||||||
text.to(self.device)
|
|
||||||
with torch.no_grad():
|
|
||||||
text_features = self._model.encode_text(text.to(self.device))
|
|
||||||
if self.normalize:
|
|
||||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
|
||||||
return text_features.cpu().numpy().squeeze()
|
|
||||||
|
|
||||||
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
|
|
||||||
"""
|
|
||||||
Sanitize the input to the embedding function.
|
|
||||||
"""
|
|
||||||
if isinstance(images, (str, bytes)):
|
|
||||||
images = [images]
|
|
||||||
elif isinstance(images, pa.Array):
|
|
||||||
images = images.to_pylist()
|
|
||||||
elif isinstance(images, pa.ChunkedArray):
|
|
||||||
images = images.combine_chunks().to_pylist()
|
|
||||||
return images
|
|
||||||
|
|
||||||
def compute_source_embeddings(
|
|
||||||
self, images: IMAGES, *args, **kwargs
|
|
||||||
) -> List[np.array]:
|
|
||||||
"""
|
|
||||||
Get the embeddings for the given images
|
|
||||||
"""
|
|
||||||
images = self.sanitize_input(images)
|
|
||||||
embeddings = []
|
|
||||||
for i in range(0, len(images), self.batch_size):
|
|
||||||
j = min(i + self.batch_size, len(images))
|
|
||||||
batch = images[i:j]
|
|
||||||
embeddings.extend(self._parallel_get(batch))
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
|
|
||||||
"""
|
|
||||||
Issue concurrent requests to retrieve the image data
|
|
||||||
"""
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
||||||
futures = [
|
|
||||||
executor.submit(self.generate_image_embedding, image)
|
|
||||||
for image in images
|
|
||||||
]
|
|
||||||
return [future.result() for future in tqdm(futures)]
|
|
||||||
|
|
||||||
def generate_image_embedding(
|
|
||||||
self, image: Union[str, bytes, "PIL.Image.Image"]
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Generate the embedding for a single image
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
image : Union[str, bytes, PIL.Image.Image]
|
|
||||||
The image to embed. If the image is a str, it is treated as a uri.
|
|
||||||
If the image is bytes, it is treated as the raw image bytes.
|
|
||||||
"""
|
|
||||||
torch = self.safe_import("torch")
|
|
||||||
# TODO handle retry and errors for https
|
|
||||||
image = self._to_pil(image)
|
|
||||||
image = self._preprocess(image).unsqueeze(0)
|
|
||||||
with torch.no_grad():
|
|
||||||
return self._encode_and_normalize_image(image)
|
|
||||||
|
|
||||||
def _to_pil(self, image: Union[str, bytes]):
|
|
||||||
PIL = self.safe_import("PIL", "pillow")
|
|
||||||
if isinstance(image, bytes):
|
|
||||||
return PIL.Image.open(io.BytesIO(image))
|
|
||||||
if isinstance(image, PIL.Image.Image):
|
|
||||||
return image
|
|
||||||
elif isinstance(image, str):
|
|
||||||
parsed = urlparse.urlparse(image)
|
|
||||||
# TODO handle drive letter on windows.
|
|
||||||
if parsed.scheme == "file":
|
|
||||||
return PIL.Image.open(parsed.path)
|
|
||||||
elif parsed.scheme == "":
|
|
||||||
return PIL.Image.open(image if os.name == "nt" else parsed.path)
|
|
||||||
elif parsed.scheme.startswith("http"):
|
|
||||||
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Only local and http(s) urls are supported")
|
|
||||||
|
|
||||||
def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"):
|
|
||||||
"""
|
|
||||||
encode a single image tensor and optionally normalize the output
|
|
||||||
"""
|
|
||||||
image_features = self._model.encode_image(image_tensor.to(self.device))
|
|
||||||
if self.normalize:
|
|
||||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
|
||||||
return image_features.cpu().numpy().squeeze()
|
|
||||||
|
|
||||||
|
|
||||||
def url_retrieve(url: str):
|
|
||||||
"""
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
url: str
|
|
||||||
URL to download from
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
with urllib.request.urlopen(url) as conn:
|
|
||||||
return conn.read()
|
|
||||||
except (socket.gaierror, urllib.error.URLError) as err:
|
|
||||||
raise ConnectionError("could not download {} due to {}".format(url, err))
|
|
||||||
137
python/lancedb/embeddings/instructor.py
Normal file
137
python/lancedb/embeddings/instructor.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
# Copyright (c) 2023. LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .base import TextEmbeddingFunction
|
||||||
|
from .registry import register
|
||||||
|
from .utils import TEXT, weak_lru
|
||||||
|
|
||||||
|
|
||||||
|
@register("instructor")
|
||||||
|
class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
||||||
|
"""
|
||||||
|
An embedding function that uses the InstructorEmbedding library. Instructor models support multi-task learning, and can be used for a
|
||||||
|
variety of tasks, including text classification, sentence similarity, and document retrieval.
|
||||||
|
If you want to calculate customized embeddings for specific sentences, you may follow the unified template to write instructions:
|
||||||
|
"Represent the `domain` `text_type` for `task_objective`":
|
||||||
|
|
||||||
|
* domain is optional, and it specifies the domain of the text, e.g., science, finance, medicine, etc.
|
||||||
|
* text_type is required, and it specifies the encoding unit, e.g., sentence, document, paragraph, etc.
|
||||||
|
* task_objective is optional, and it specifies the objective of embedding, e.g., retrieve a document, classify the sentence, etc.
|
||||||
|
|
||||||
|
For example, if you want to calculate embeddings for a document, you may write the instruction as follows:
|
||||||
|
"Represent the document for retreival"
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name: str
|
||||||
|
The name of the model to use. Available models are listed at https://github.com/xlang-ai/instructor-embedding#model-list;
|
||||||
|
The default model is hkunlp/instructor-base
|
||||||
|
batch_size: int, default 32
|
||||||
|
The batch size to use when generating embeddings
|
||||||
|
device: str, default "cpu"
|
||||||
|
The device to use when generating embeddings
|
||||||
|
show_progress_bar: bool, default True
|
||||||
|
Whether to show a progress bar when generating embeddings
|
||||||
|
normalize_embeddings: bool, default True
|
||||||
|
Whether to normalize the embeddings
|
||||||
|
quantize: bool, default False
|
||||||
|
Whether to quantize the model
|
||||||
|
source_instruction: str, default "represent the docuement for retreival"
|
||||||
|
The instruction for the source column
|
||||||
|
query_instruction: str, default "represent the document for retreiving the most similar documents"
|
||||||
|
The instruction for the query
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
import lancedb
|
||||||
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
from lancedb.embeddings import get_registry, InstuctorEmbeddingFunction
|
||||||
|
|
||||||
|
instructor = get_registry().get("instructor").create(
|
||||||
|
source_instruction="represent the docuement for retreival",
|
||||||
|
query_instruction="represent the document for retreiving the most similar documents"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Schema(LanceModel):
|
||||||
|
vector: Vector(instructor.ndims()) = instructor.VectorField()
|
||||||
|
text: str = instructor.SourceField()
|
||||||
|
|
||||||
|
db = lancedb.connect("~/.lancedb")
|
||||||
|
tbl = db.create_table("test", schema=Schema, mode="overwrite")
|
||||||
|
|
||||||
|
texts = [{"text": "Capitalism has been dominant in the Western world since the end of feudalism, but most feel[who?] that..."},
|
||||||
|
{"text": "The disparate impact theory is especially controversial under the Fair Housing Act because the Act..."},
|
||||||
|
{"text": "Disparate impact in United States labor law refers to practices in employment, housing, and other areas that.."}]
|
||||||
|
|
||||||
|
tbl.add(texts)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "hkunlp/instructor-base"
|
||||||
|
batch_size: int = 32
|
||||||
|
device: str = "cpu"
|
||||||
|
show_progress_bar: bool = True
|
||||||
|
normalize_embeddings: bool = True
|
||||||
|
quantize: bool = False
|
||||||
|
# convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly
|
||||||
|
|
||||||
|
source_instruction: str = "represent the document for retrieval"
|
||||||
|
query_instruction: str = (
|
||||||
|
"represent the document for retrieving the most similar documents"
|
||||||
|
)
|
||||||
|
|
||||||
|
@weak_lru(maxsize=1)
|
||||||
|
def ndims(self):
|
||||||
|
model = self.get_model()
|
||||||
|
return model.encode("foo").shape[0]
|
||||||
|
|
||||||
|
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||||
|
return self.generate_embeddings([[self.query_instruction, query]])
|
||||||
|
|
||||||
|
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||||
|
texts = self.sanitize_input(texts)
|
||||||
|
texts_formatted = []
|
||||||
|
for text in texts:
|
||||||
|
texts_formatted.append([self.source_instruction, text])
|
||||||
|
return self.generate_embeddings(texts_formatted)
|
||||||
|
|
||||||
|
def generate_embeddings(self, texts: List) -> List:
|
||||||
|
model = self.get_model()
|
||||||
|
res = model.encode(
|
||||||
|
texts,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
show_progress_bar=self.show_progress_bar,
|
||||||
|
normalize_embeddings=self.normalize_embeddings,
|
||||||
|
).tolist()
|
||||||
|
return res
|
||||||
|
|
||||||
|
@weak_lru(maxsize=1)
|
||||||
|
def get_model(self):
|
||||||
|
instructor_embedding = self.safe_import(
|
||||||
|
"InstructorEmbedding", "InstructorEmbedding"
|
||||||
|
)
|
||||||
|
torch = self.safe_import("torch", "torch")
|
||||||
|
|
||||||
|
model = instructor_embedding.INSTRUCTOR(self.name)
|
||||||
|
if self.quantize:
|
||||||
|
if (
|
||||||
|
"qnnpack" in torch.backends.quantized.supported_engines
|
||||||
|
): # fix for https://github.com/pytorch/pytorch/issues/29327
|
||||||
|
torch.backends.quantized.engine = "qnnpack"
|
||||||
|
model = torch.quantization.quantize_dynamic(
|
||||||
|
model, {torch.nn.Linear}, dtype=torch.qint8
|
||||||
|
)
|
||||||
|
return model
|
||||||
175
python/lancedb/embeddings/open_clip.py
Normal file
175
python/lancedb/embeddings/open_clip.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
# Copyright (c) 2023. LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import concurrent.futures
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import urllib.parse as urlparse
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pyarrow as pa
|
||||||
|
from pydantic import PrivateAttr
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .base import EmbeddingFunction
|
||||||
|
from .registry import register
|
||||||
|
from .utils import IMAGES, url_retrieve
|
||||||
|
|
||||||
|
|
||||||
|
@register("open-clip")
|
||||||
|
class OpenClipEmbeddings(EmbeddingFunction):
|
||||||
|
"""
|
||||||
|
An embedding function that uses the OpenClip API
|
||||||
|
For multi-modal text-to-image search
|
||||||
|
|
||||||
|
https://github.com/mlfoundations/open_clip
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "ViT-B-32"
|
||||||
|
pretrained: str = "laion2b_s34b_b79k"
|
||||||
|
device: str = "cpu"
|
||||||
|
batch_size: int = 64
|
||||||
|
normalize: bool = True
|
||||||
|
_model = PrivateAttr()
|
||||||
|
_preprocess = PrivateAttr()
|
||||||
|
_tokenizer = PrivateAttr()
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
open_clip = self.safe_import("open_clip", "open-clip")
|
||||||
|
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||||
|
self.name, pretrained=self.pretrained
|
||||||
|
)
|
||||||
|
model.to(self.device)
|
||||||
|
self._model, self._preprocess = model, preprocess
|
||||||
|
self._tokenizer = open_clip.get_tokenizer(self.name)
|
||||||
|
self._ndims = None
|
||||||
|
|
||||||
|
def ndims(self):
|
||||||
|
if self._ndims is None:
|
||||||
|
self._ndims = self.generate_text_embeddings("foo").shape[0]
|
||||||
|
return self._ndims
|
||||||
|
|
||||||
|
def compute_query_embeddings(
|
||||||
|
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
|
||||||
|
) -> List[np.ndarray]:
|
||||||
|
"""
|
||||||
|
Compute the embeddings for a given user query
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
query : Union[str, PIL.Image.Image]
|
||||||
|
The query to embed. A query can be either text or an image.
|
||||||
|
"""
|
||||||
|
if isinstance(query, str):
|
||||||
|
return [self.generate_text_embeddings(query)]
|
||||||
|
else:
|
||||||
|
PIL = self.safe_import("PIL", "pillow")
|
||||||
|
if isinstance(query, PIL.Image.Image):
|
||||||
|
return [self.generate_image_embedding(query)]
|
||||||
|
else:
|
||||||
|
raise TypeError("OpenClip supports str or PIL Image as query")
|
||||||
|
|
||||||
|
def generate_text_embeddings(self, text: str) -> np.ndarray:
|
||||||
|
torch = self.safe_import("torch")
|
||||||
|
text = self.sanitize_input(text)
|
||||||
|
text = self._tokenizer(text)
|
||||||
|
text.to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
text_features = self._model.encode_text(text.to(self.device))
|
||||||
|
if self.normalize:
|
||||||
|
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||||
|
return text_features.cpu().numpy().squeeze()
|
||||||
|
|
||||||
|
def sanitize_input(self, images: IMAGES) -> Union[List[bytes], np.ndarray]:
|
||||||
|
"""
|
||||||
|
Sanitize the input to the embedding function.
|
||||||
|
"""
|
||||||
|
if isinstance(images, (str, bytes)):
|
||||||
|
images = [images]
|
||||||
|
elif isinstance(images, pa.Array):
|
||||||
|
images = images.to_pylist()
|
||||||
|
elif isinstance(images, pa.ChunkedArray):
|
||||||
|
images = images.combine_chunks().to_pylist()
|
||||||
|
return images
|
||||||
|
|
||||||
|
def compute_source_embeddings(
|
||||||
|
self, images: IMAGES, *args, **kwargs
|
||||||
|
) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Get the embeddings for the given images
|
||||||
|
"""
|
||||||
|
images = self.sanitize_input(images)
|
||||||
|
embeddings = []
|
||||||
|
for i in range(0, len(images), self.batch_size):
|
||||||
|
j = min(i + self.batch_size, len(images))
|
||||||
|
batch = images[i:j]
|
||||||
|
embeddings.extend(self._parallel_get(batch))
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def _parallel_get(self, images: Union[List[str], List[bytes]]) -> List[np.ndarray]:
|
||||||
|
"""
|
||||||
|
Issue concurrent requests to retrieve the image data
|
||||||
|
"""
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
futures = [
|
||||||
|
executor.submit(self.generate_image_embedding, image)
|
||||||
|
for image in images
|
||||||
|
]
|
||||||
|
return [future.result() for future in tqdm(futures)]
|
||||||
|
|
||||||
|
def generate_image_embedding(
|
||||||
|
self, image: Union[str, bytes, "PIL.Image.Image"]
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Generate the embedding for a single image
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
image : Union[str, bytes, PIL.Image.Image]
|
||||||
|
The image to embed. If the image is a str, it is treated as a uri.
|
||||||
|
If the image is bytes, it is treated as the raw image bytes.
|
||||||
|
"""
|
||||||
|
torch = self.safe_import("torch")
|
||||||
|
# TODO handle retry and errors for https
|
||||||
|
image = self._to_pil(image)
|
||||||
|
image = self._preprocess(image).unsqueeze(0)
|
||||||
|
with torch.no_grad():
|
||||||
|
return self._encode_and_normalize_image(image)
|
||||||
|
|
||||||
|
def _to_pil(self, image: Union[str, bytes]):
|
||||||
|
PIL = self.safe_import("PIL", "pillow")
|
||||||
|
if isinstance(image, bytes):
|
||||||
|
return PIL.Image.open(io.BytesIO(image))
|
||||||
|
if isinstance(image, PIL.Image.Image):
|
||||||
|
return image
|
||||||
|
elif isinstance(image, str):
|
||||||
|
parsed = urlparse.urlparse(image)
|
||||||
|
# TODO handle drive letter on windows.
|
||||||
|
if parsed.scheme == "file":
|
||||||
|
return PIL.Image.open(parsed.path)
|
||||||
|
elif parsed.scheme == "":
|
||||||
|
return PIL.Image.open(image if os.name == "nt" else parsed.path)
|
||||||
|
elif parsed.scheme.startswith("http"):
|
||||||
|
return PIL.Image.open(io.BytesIO(url_retrieve(image)))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Only local and http(s) urls are supported")
|
||||||
|
|
||||||
|
def _encode_and_normalize_image(self, image_tensor: "torch.Tensor"):
|
||||||
|
"""
|
||||||
|
encode a single image tensor and optionally normalize the output
|
||||||
|
"""
|
||||||
|
image_features = self._model.encode_image(image_tensor.to(self.device))
|
||||||
|
if self.normalize:
|
||||||
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||||
|
return image_features.cpu().numpy().squeeze()
|
||||||
49
python/lancedb/embeddings/openai.py
Normal file
49
python/lancedb/embeddings/openai.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
# Copyright (c) 2023. LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .base import TextEmbeddingFunction
|
||||||
|
from .registry import register
|
||||||
|
|
||||||
|
|
||||||
|
@register("openai")
|
||||||
|
class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||||
|
"""
|
||||||
|
An embedding function that uses the OpenAI API
|
||||||
|
|
||||||
|
https://platform.openai.com/docs/guides/embeddings
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "text-embedding-ada-002"
|
||||||
|
|
||||||
|
def ndims(self):
|
||||||
|
# TODO don't hardcode this
|
||||||
|
return 1536
|
||||||
|
|
||||||
|
def generate_embeddings(
|
||||||
|
self, texts: Union[List[str], np.ndarray]
|
||||||
|
) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Get the embeddings for the given texts
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
texts: list[str] or np.ndarray (of str)
|
||||||
|
The texts to embed
|
||||||
|
"""
|
||||||
|
# TODO retry, rate limit, token limit
|
||||||
|
openai = self.safe_import("openai")
|
||||||
|
rs = openai.Embedding.create(input=texts, model=self.name)["data"]
|
||||||
|
return [v["embedding"] for v in rs]
|
||||||
186
python/lancedb/embeddings/registry.py
Normal file
186
python/lancedb/embeddings/registry.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
# Copyright (c) 2023. LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import json
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from .base import EmbeddingFunction, EmbeddingFunctionConfig
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingFunctionRegistry:
|
||||||
|
"""
|
||||||
|
This is a singleton class used to register embedding functions
|
||||||
|
and fetch them by name. It also handles serializing and deserializing.
|
||||||
|
You can implement your own embedding function by subclassing EmbeddingFunction
|
||||||
|
or TextEmbeddingFunction and registering it with the registry.
|
||||||
|
|
||||||
|
NOTE: Here TEXT is a type alias for Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> registry = EmbeddingFunctionRegistry.get_instance()
|
||||||
|
>>> @registry.register("my-embedding-function")
|
||||||
|
... class MyEmbeddingFunction(EmbeddingFunction):
|
||||||
|
... def ndims(self) -> int:
|
||||||
|
... return 128
|
||||||
|
...
|
||||||
|
... def compute_query_embeddings(self, query: str, *args, **kwargs):
|
||||||
|
... return self.compute_source_embeddings(query, *args, **kwargs)
|
||||||
|
...
|
||||||
|
... def compute_source_embeddings(self, texts, *args, **kwargs):
|
||||||
|
... return [np.random.rand(self.ndims()) for _ in range(len(texts))]
|
||||||
|
...
|
||||||
|
>>> registry.get("my-embedding-function")
|
||||||
|
<class 'lancedb.embeddings.registry.MyEmbeddingFunction'>
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls):
|
||||||
|
return __REGISTRY__
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._functions = {}
|
||||||
|
|
||||||
|
def register(self, alias: str = None):
|
||||||
|
"""
|
||||||
|
This creates a decorator that can be used to register
|
||||||
|
an EmbeddingFunction.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
alias : Optional[str]
|
||||||
|
a human friendly name for the embedding function. If not
|
||||||
|
provided, the class name will be used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# This is a decorator for a class that inherits from BaseModel
|
||||||
|
# It adds the class to the registry
|
||||||
|
def decorator(cls):
|
||||||
|
if not issubclass(cls, EmbeddingFunction):
|
||||||
|
raise TypeError("Must be a subclass of EmbeddingFunction")
|
||||||
|
if cls.__name__ in self._functions:
|
||||||
|
raise KeyError(f"{cls.__name__} was already registered")
|
||||||
|
key = alias or cls.__name__
|
||||||
|
self._functions[key] = cls
|
||||||
|
cls.__embedding_function_registry_alias__ = alias
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""
|
||||||
|
Reset the registry to its initial state
|
||||||
|
"""
|
||||||
|
self._functions = {}
|
||||||
|
|
||||||
|
def get(self, name: str):
|
||||||
|
"""
|
||||||
|
Fetch an embedding function class by name
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
The name of the embedding function to fetch
|
||||||
|
Either the alias or the class name if no alias was provided
|
||||||
|
during registration
|
||||||
|
"""
|
||||||
|
return self._functions[name]
|
||||||
|
|
||||||
|
def parse_functions(
|
||||||
|
self, metadata: Optional[Dict[bytes, bytes]]
|
||||||
|
) -> Dict[str, "EmbeddingFunctionConfig"]:
|
||||||
|
"""
|
||||||
|
Parse the metadata from an arrow table and
|
||||||
|
return a mapping of the vector column to the
|
||||||
|
embedding function and source column
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
metadata : Optional[Dict[bytes, bytes]]
|
||||||
|
The metadata from an arrow table. Note that
|
||||||
|
the keys and values are bytes (pyarrow api)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
functions : dict
|
||||||
|
A mapping of vector column name to embedding function.
|
||||||
|
An empty dict is returned if input is None or does not
|
||||||
|
contain b"embedding_functions".
|
||||||
|
"""
|
||||||
|
if metadata is None or b"embedding_functions" not in metadata:
|
||||||
|
return {}
|
||||||
|
serialized = metadata[b"embedding_functions"]
|
||||||
|
raw_list = json.loads(serialized.decode("utf-8"))
|
||||||
|
return {
|
||||||
|
obj["vector_column"]: EmbeddingFunctionConfig(
|
||||||
|
vector_column=obj["vector_column"],
|
||||||
|
source_column=obj["source_column"],
|
||||||
|
function=self.get(obj["name"])(**obj["model"]),
|
||||||
|
)
|
||||||
|
for obj in raw_list
|
||||||
|
}
|
||||||
|
|
||||||
|
def function_to_metadata(self, conf: "EmbeddingFunctionConfig"):
|
||||||
|
"""
|
||||||
|
Convert the given embedding function and source / vector column configs
|
||||||
|
into a config dictionary that can be serialized into arrow metadata
|
||||||
|
"""
|
||||||
|
func = conf.function
|
||||||
|
name = getattr(
|
||||||
|
func, "__embedding_function_registry_alias__", func.__class__.__name__
|
||||||
|
)
|
||||||
|
json_data = func.safe_model_dump()
|
||||||
|
return {
|
||||||
|
"name": name,
|
||||||
|
"model": json_data,
|
||||||
|
"source_column": conf.source_column,
|
||||||
|
"vector_column": conf.vector_column,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_table_metadata(self, func_list):
|
||||||
|
"""
|
||||||
|
Convert a list of embedding functions and source / vector configs
|
||||||
|
into a config dictionary that can be serialized into arrow metadata
|
||||||
|
"""
|
||||||
|
if func_list is None or len(func_list) == 0:
|
||||||
|
return None
|
||||||
|
json_data = [self.function_to_metadata(func) for func in func_list]
|
||||||
|
# Note that metadata dictionary values must be bytes
|
||||||
|
# so we need to json dump then utf8 encode
|
||||||
|
metadata = json.dumps(json_data, indent=2).encode("utf-8")
|
||||||
|
return {"embedding_functions": metadata}
|
||||||
|
|
||||||
|
|
||||||
|
# Global instance
|
||||||
|
__REGISTRY__ = EmbeddingFunctionRegistry()
|
||||||
|
|
||||||
|
|
||||||
|
# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8
|
||||||
|
register = lambda name: EmbeddingFunctionRegistry.get_instance().register(name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_registry():
|
||||||
|
"""
|
||||||
|
Utility function to get the global instance of the registry
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
EmbeddingFunctionRegistry
|
||||||
|
The global registry instance
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
from lancedb.embeddings import get_registry
|
||||||
|
|
||||||
|
registry = get_registry()
|
||||||
|
openai = registry.get("openai").create()
|
||||||
|
"""
|
||||||
|
return __REGISTRY__.get_instance()
|
||||||
89
python/lancedb/embeddings/sentence_transformers.py
Normal file
89
python/lancedb/embeddings/sentence_transformers.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
# Copyright (c) 2023. LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from cachetools import cached
|
||||||
|
|
||||||
|
from .base import TextEmbeddingFunction
|
||||||
|
from .registry import register
|
||||||
|
from .utils import weak_lru
|
||||||
|
|
||||||
|
|
||||||
|
@register("sentence-transformers")
|
||||||
|
class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||||
|
"""
|
||||||
|
An embedding function that uses the sentence-transformers library
|
||||||
|
|
||||||
|
https://huggingface.co/sentence-transformers
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "all-MiniLM-L6-v2"
|
||||||
|
device: str = "cpu"
|
||||||
|
normalize: bool = True
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._ndims = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding_model(self):
|
||||||
|
"""
|
||||||
|
Get the sentence-transformers embedding model specified by the
|
||||||
|
name and device. This is cached so that the model is only loaded
|
||||||
|
once per process.
|
||||||
|
"""
|
||||||
|
return self.get_embedding_model()
|
||||||
|
|
||||||
|
def ndims(self):
|
||||||
|
if self._ndims is None:
|
||||||
|
self._ndims = len(self.generate_embeddings("foo")[0])
|
||||||
|
return self._ndims
|
||||||
|
|
||||||
|
def generate_embeddings(
|
||||||
|
self, texts: Union[List[str], np.ndarray]
|
||||||
|
) -> List[np.array]:
|
||||||
|
"""
|
||||||
|
Get the embeddings for the given texts
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
texts: list[str] or np.ndarray (of str)
|
||||||
|
The texts to embed
|
||||||
|
"""
|
||||||
|
return self.embedding_model.encode(
|
||||||
|
list(texts),
|
||||||
|
convert_to_numpy=True,
|
||||||
|
normalize_embeddings=self.normalize,
|
||||||
|
).tolist()
|
||||||
|
|
||||||
|
@weak_lru(maxsize=1)
|
||||||
|
def get_embedding_model(self):
|
||||||
|
"""
|
||||||
|
Get the sentence-transformers embedding model specified by the
|
||||||
|
name and device. This is cached so that the model is only loaded
|
||||||
|
once per process.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
The name of the model to load
|
||||||
|
device : str
|
||||||
|
The device to load the model on
|
||||||
|
|
||||||
|
TODO: use lru_cache instead with a reasonable/configurable maxsize
|
||||||
|
"""
|
||||||
|
sentence_transformers = self.safe_import(
|
||||||
|
"sentence_transformers", "sentence-transformers"
|
||||||
|
)
|
||||||
|
return sentence_transformers.SentenceTransformer(self.name, device=self.device)
|
||||||
@@ -11,9 +11,15 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import functools
|
||||||
import math
|
import math
|
||||||
|
import random
|
||||||
|
import socket
|
||||||
import sys
|
import sys
|
||||||
from typing import Callable, Union
|
import time
|
||||||
|
import urllib.error
|
||||||
|
import weakref
|
||||||
|
from typing import Callable, List, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
@@ -21,9 +27,15 @@ from lance.vector import vec_to_table
|
|||||||
from retry import retry
|
from retry import retry
|
||||||
|
|
||||||
from ..util import safe_import_pandas
|
from ..util import safe_import_pandas
|
||||||
|
from ..utils.general import LOGGER
|
||||||
|
|
||||||
pd = safe_import_pandas()
|
pd = safe_import_pandas()
|
||||||
|
|
||||||
DATA = Union[pa.Table, "pd.DataFrame"]
|
DATA = Union[pa.Table, "pd.DataFrame"]
|
||||||
|
TEXT = Union[str, List[str], pa.Array, pa.ChunkedArray, np.ndarray]
|
||||||
|
IMAGES = Union[
|
||||||
|
str, bytes, List[str], List[bytes], pa.Array, pa.ChunkedArray, np.ndarray
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def with_embeddings(
|
def with_embeddings(
|
||||||
@@ -152,3 +164,115 @@ class FunctionWrapper:
|
|||||||
yield from tqdm(_chunker(arr), total=math.ceil(length / self._batch_size))
|
yield from tqdm(_chunker(arr), total=math.ceil(length / self._batch_size))
|
||||||
else:
|
else:
|
||||||
yield from _chunker(arr)
|
yield from _chunker(arr)
|
||||||
|
|
||||||
|
|
||||||
|
def weak_lru(maxsize=128):
|
||||||
|
"""
|
||||||
|
LRU cache that keeps weak references to the objects it caches. Only caches the latest instance of the objects to make sure memory usage
|
||||||
|
is bounded.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
maxsize : int, default 128
|
||||||
|
The maximum number of objects to cache.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Callable
|
||||||
|
A decorator that can be applied to a method.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> class Foo:
|
||||||
|
... @weak_lru()
|
||||||
|
... def bar(self, x):
|
||||||
|
... return x
|
||||||
|
>>> foo = Foo()
|
||||||
|
>>> foo.bar(1)
|
||||||
|
1
|
||||||
|
>>> foo.bar(2)
|
||||||
|
2
|
||||||
|
>>> foo.bar(1)
|
||||||
|
1
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrapper(func):
|
||||||
|
@functools.lru_cache(maxsize)
|
||||||
|
def _func(_self, *args, **kwargs):
|
||||||
|
return func(_self(), *args, **kwargs)
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def inner(self, *args, **kwargs):
|
||||||
|
return _func(weakref.ref(self), *args, **kwargs)
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def retry_with_exponential_backoff(
|
||||||
|
func,
|
||||||
|
initial_delay: float = 1,
|
||||||
|
exponential_base: float = 2,
|
||||||
|
jitter: bool = True,
|
||||||
|
max_retries: int = 7,
|
||||||
|
# errors: tuple = (),
|
||||||
|
):
|
||||||
|
"""Retry a function with exponential backoff.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func (function): The function to be retried.
|
||||||
|
initial_delay (float): Initial delay in seconds (default is 1).
|
||||||
|
exponential_base (float): The base for exponential backoff (default is 2).
|
||||||
|
jitter (bool): Whether to add jitter to the delay (default is True).
|
||||||
|
max_retries (int): Maximum number of retries (default is 10).
|
||||||
|
errors (tuple): Tuple of specific exceptions to retry on (default is (openai.error.RateLimitError,)).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
function: The decorated function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
num_retries = 0
|
||||||
|
delay = initial_delay
|
||||||
|
|
||||||
|
# Loop until a successful response or max_retries is hit or an exception is raised
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Currently retrying on all exceptions as there is no way to know the format of the error msgs used by different APIs
|
||||||
|
# We'll log the error and say that it is assumed that if this portion errors out, it's due to rate limit but the user
|
||||||
|
# should check the error message to be sure
|
||||||
|
except Exception as e:
|
||||||
|
num_retries += 1
|
||||||
|
|
||||||
|
if num_retries > max_retries:
|
||||||
|
raise Exception(
|
||||||
|
f"Maximum number of retries ({max_retries}) exceeded."
|
||||||
|
)
|
||||||
|
|
||||||
|
delay *= exponential_base * (1 + jitter * random.random())
|
||||||
|
LOGGER.info(f"Retrying in {delay:.2f} seconds due to {e}")
|
||||||
|
time.sleep(delay)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def url_retrieve(url: str):
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
url: str
|
||||||
|
URL to download from
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(url) as conn:
|
||||||
|
return conn.read()
|
||||||
|
except (socket.gaierror, urllib.error.URLError) as err:
|
||||||
|
raise ConnectionError("could not download {} due to {}".format(url, err))
|
||||||
|
|
||||||
|
|
||||||
|
def api_key_not_found_help(provider):
|
||||||
|
LOGGER.error(f"Could not find API key for {provider}.")
|
||||||
|
raise ValueError(f"Please set the {provider.upper()}_API_KEY environment variable.")
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import inspect
|
|||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import date, datetime
|
||||||
from typing import Any, Callable, Dict, Generator, List, Type, Union, _GenericAlias
|
from typing import Any, Callable, Dict, Generator, List, Type, Union, _GenericAlias
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -159,6 +160,10 @@ def _py_type_to_arrow_type(py_type: Type[Any]) -> pa.DataType:
|
|||||||
return pa.bool_()
|
return pa.bool_()
|
||||||
elif py_type == bytes:
|
elif py_type == bytes:
|
||||||
return pa.binary()
|
return pa.binary()
|
||||||
|
elif py_type == date:
|
||||||
|
return pa.date32()
|
||||||
|
elif py_type == datetime:
|
||||||
|
return pa.timestamp("us")
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}"
|
f"Converting Pydantic type to Arrow Type: unsupported type {py_type}"
|
||||||
)
|
)
|
||||||
@@ -322,7 +327,12 @@ class LanceModel(pydantic.BaseModel):
|
|||||||
for vec, func in vec_and_function:
|
for vec, func in vec_and_function:
|
||||||
for source, field_info in cls.safe_get_fields().items():
|
for source, field_info in cls.safe_get_fields().items():
|
||||||
src_func = get_extras(field_info, "source_column_for")
|
src_func = get_extras(field_info, "source_column_for")
|
||||||
if src_func == func:
|
if src_func is func:
|
||||||
|
# note we can't use == here since the function is a pydantic
|
||||||
|
# model so two instances of the same function are ==, so if you
|
||||||
|
# have multiple vector columns from multiple sources, both will
|
||||||
|
# be mapped to the same source column
|
||||||
|
# GH594
|
||||||
configs.append(
|
configs.append(
|
||||||
EmbeddingFunctionConfig(
|
EmbeddingFunctionConfig(
|
||||||
source_column=source, vector_column=vec, function=func
|
source_column=source, vector_column=vec, function=func
|
||||||
|
|||||||
@@ -14,21 +14,58 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Literal, Optional, Type, Union
|
from typing import TYPE_CHECKING, List, Literal, Optional, Type, Union
|
||||||
|
|
||||||
|
import deprecation
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
|
from . import __version__
|
||||||
from .common import VECTOR_COLUMN_NAME
|
from .common import VECTOR_COLUMN_NAME
|
||||||
from .pydantic import LanceModel
|
|
||||||
from .util import safe_import_pandas
|
from .util import safe_import_pandas
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .pydantic import LanceModel
|
||||||
|
|
||||||
pd = safe_import_pandas()
|
pd = safe_import_pandas()
|
||||||
|
|
||||||
|
|
||||||
class Query(pydantic.BaseModel):
|
class Query(pydantic.BaseModel):
|
||||||
"""A Query"""
|
"""The LanceDB Query
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
vector : List[float]
|
||||||
|
the vector to search for
|
||||||
|
filter : Optional[str]
|
||||||
|
sql filter to refine the query with, optional
|
||||||
|
prefilter : bool
|
||||||
|
if True then apply the filter before vector search
|
||||||
|
k : int
|
||||||
|
top k results to return
|
||||||
|
metric : str
|
||||||
|
the distance metric between a pair of vectors,
|
||||||
|
|
||||||
|
can support L2 (default), Cosine and Dot.
|
||||||
|
[metric definitions][search]
|
||||||
|
columns : Optional[List[str]]
|
||||||
|
which columns to return in the results
|
||||||
|
nprobes : int
|
||||||
|
The number of probes used - optional
|
||||||
|
|
||||||
|
- A higher number makes search more accurate but also slower.
|
||||||
|
|
||||||
|
- See discussion in [Querying an ANN Index][querying-an-ann-index] for
|
||||||
|
tuning advice.
|
||||||
|
refine_factor : Optional[int]
|
||||||
|
Refine the results by reading extra elements and re-ranking them in memory - optional
|
||||||
|
|
||||||
|
- A higher number makes search more accurate but also slower.
|
||||||
|
|
||||||
|
- See discussion in [Querying an ANN Index][querying-an-ann-index] for
|
||||||
|
tuning advice.
|
||||||
|
"""
|
||||||
|
|
||||||
vector_column: str = VECTOR_COLUMN_NAME
|
vector_column: str = VECTOR_COLUMN_NAME
|
||||||
|
|
||||||
@@ -59,6 +96,10 @@ class Query(pydantic.BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class LanceQueryBuilder(ABC):
|
class LanceQueryBuilder(ABC):
|
||||||
|
"""Build LanceDB query based on specific query type:
|
||||||
|
vector or full text search.
|
||||||
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
cls,
|
cls,
|
||||||
@@ -101,7 +142,7 @@ class LanceQueryBuilder(ABC):
|
|||||||
if not isinstance(query, (list, np.ndarray)):
|
if not isinstance(query, (list, np.ndarray)):
|
||||||
conf = table.embedding_functions.get(vector_column_name)
|
conf = table.embedding_functions.get(vector_column_name)
|
||||||
if conf is not None:
|
if conf is not None:
|
||||||
query = conf.function.compute_query_embeddings(query)[0]
|
query = conf.function.compute_query_embeddings_with_retry(query)[0]
|
||||||
else:
|
else:
|
||||||
msg = f"No embedding function for {vector_column_name}"
|
msg = f"No embedding function for {vector_column_name}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
@@ -112,7 +153,7 @@ class LanceQueryBuilder(ABC):
|
|||||||
else:
|
else:
|
||||||
conf = table.embedding_functions.get(vector_column_name)
|
conf = table.embedding_functions.get(vector_column_name)
|
||||||
if conf is not None:
|
if conf is not None:
|
||||||
query = conf.function.compute_query_embeddings(query)[0]
|
query = conf.function.compute_query_embeddings_with_retry(query)[0]
|
||||||
return query, "vector"
|
return query, "vector"
|
||||||
else:
|
else:
|
||||||
return query, "fts"
|
return query, "fts"
|
||||||
@@ -127,7 +168,24 @@ class LanceQueryBuilder(ABC):
|
|||||||
self._columns = None
|
self._columns = None
|
||||||
self._where = None
|
self._where = None
|
||||||
|
|
||||||
|
@deprecation.deprecated(
|
||||||
|
deprecated_in="0.3.1",
|
||||||
|
removed_in="0.4.0",
|
||||||
|
current_version=__version__,
|
||||||
|
details="Use to_pandas() instead",
|
||||||
|
)
|
||||||
def to_df(self) -> "pd.DataFrame":
|
def to_df(self) -> "pd.DataFrame":
|
||||||
|
"""
|
||||||
|
*Deprecated alias for `to_pandas()`. Please use `to_pandas()` instead.*
|
||||||
|
|
||||||
|
Execute the query and return the results as a pandas DataFrame.
|
||||||
|
In addition to the selected columns, LanceDB also returns a vector
|
||||||
|
and also the "_distance" column which is the distance between the query
|
||||||
|
vector and the returned vector.
|
||||||
|
"""
|
||||||
|
return self.to_pandas()
|
||||||
|
|
||||||
|
def to_pandas(self) -> "pd.DataFrame":
|
||||||
"""
|
"""
|
||||||
Execute the query and return the results as a pandas DataFrame.
|
Execute the query and return the results as a pandas DataFrame.
|
||||||
In addition to the selected columns, LanceDB also returns a vector
|
In addition to the selected columns, LanceDB also returns a vector
|
||||||
@@ -148,6 +206,16 @@ class LanceQueryBuilder(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def to_list(self) -> List[dict]:
|
||||||
|
"""
|
||||||
|
Execute the query and return the results as a list of dictionaries.
|
||||||
|
|
||||||
|
Each list entry is a dictionary with the selected column names as keys,
|
||||||
|
or all table columns if `select` is not called. The vector and the "_distance"
|
||||||
|
fields are returned whether or not they're explicitly selected.
|
||||||
|
"""
|
||||||
|
return self.to_arrow().to_pylist()
|
||||||
|
|
||||||
def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
|
def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
|
||||||
"""Return the table as a list of pydantic models.
|
"""Return the table as a list of pydantic models.
|
||||||
|
|
||||||
@@ -197,13 +265,20 @@ class LanceQueryBuilder(ABC):
|
|||||||
self._columns = columns
|
self._columns = columns
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def where(self, where) -> LanceQueryBuilder:
|
def where(self, where: str, prefilter: bool = False) -> LanceQueryBuilder:
|
||||||
"""Set the where clause.
|
"""Set the where clause.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
where: str
|
where: str
|
||||||
The where clause.
|
The where clause which is a valid SQL where clause. See
|
||||||
|
`Lance filter pushdown <https://lancedb.github.io/lance/read_and_write.html#filter-push-down>`_
|
||||||
|
for valid SQL expressions.
|
||||||
|
prefilter: bool, default False
|
||||||
|
If True, apply the filter before vector search, otherwise the
|
||||||
|
filter is applied on the result of vector search.
|
||||||
|
This feature is **EXPERIMENTAL** and may be removed and modified
|
||||||
|
without warning in the future.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@@ -211,13 +286,12 @@ class LanceQueryBuilder(ABC):
|
|||||||
The LanceQueryBuilder object.
|
The LanceQueryBuilder object.
|
||||||
"""
|
"""
|
||||||
self._where = where
|
self._where = where
|
||||||
|
self._prefilter = prefilter
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class LanceVectorQueryBuilder(LanceQueryBuilder):
|
class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||||
"""
|
"""
|
||||||
A builder for nearest neighbor queries for LanceDB.
|
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
>>> import lancedb
|
>>> import lancedb
|
||||||
@@ -232,7 +306,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
... .where("b < 10")
|
... .where("b < 10")
|
||||||
... .select(["b"])
|
... .select(["b"])
|
||||||
... .limit(2)
|
... .limit(2)
|
||||||
... .to_df())
|
... .to_pandas())
|
||||||
b vector _distance
|
b vector _distance
|
||||||
0 6 [0.4, 0.4] 0.0
|
0 6 [0.4, 0.4] 0.0
|
||||||
"""
|
"""
|
||||||
@@ -273,7 +347,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
Higher values will yield better recall (more likely to find vectors if
|
Higher values will yield better recall (more likely to find vectors if
|
||||||
they exist) at the expense of latency.
|
they exist) at the expense of latency.
|
||||||
|
|
||||||
See discussion in [Querying an ANN Index][../querying-an-ann-index] for
|
See discussion in [Querying an ANN Index][querying-an-ann-index] for
|
||||||
tuning advice.
|
tuning advice.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -340,14 +414,14 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
where: str
|
where: str
|
||||||
The where clause.
|
The where clause which is a valid SQL where clause. See
|
||||||
|
`Lance filter pushdown <https://lancedb.github.io/lance/read_and_write.html#filter-push-down>`_
|
||||||
|
for valid SQL expressions.
|
||||||
prefilter: bool, default False
|
prefilter: bool, default False
|
||||||
If True, apply the filter before vector search, otherwise the
|
If True, apply the filter before vector search, otherwise the
|
||||||
filter is applied on the result of vector search.
|
filter is applied on the result of vector search.
|
||||||
This feature is **EXPERIMENTAL** and may be removed and modified
|
This feature is **EXPERIMENTAL** and may be removed and modified
|
||||||
without warning in the future. Currently this is only supported
|
without warning in the future.
|
||||||
in OSS and can only be used with a table that does not have an ANN
|
|
||||||
index.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@@ -360,6 +434,8 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
|
|
||||||
|
|
||||||
class LanceFtsQueryBuilder(LanceQueryBuilder):
|
class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||||
|
"""A builder for full text search for LanceDB."""
|
||||||
|
|
||||||
def __init__(self, table: "lancedb.table.Table", query: str):
|
def __init__(self, table: "lancedb.table.Table", query: str):
|
||||||
super().__init__(table)
|
super().__init__(table)
|
||||||
self._query = query
|
self._query = query
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
from typing import Any, Callable, Dict, Optional, Union
|
from typing import Any, Callable, Dict, Iterable, Optional, Union
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import attrs
|
import attrs
|
||||||
@@ -151,9 +151,13 @@ class RestfulLanceDBClient:
|
|||||||
return await deserialize(resp)
|
return await deserialize(resp)
|
||||||
|
|
||||||
@_check_not_closed
|
@_check_not_closed
|
||||||
async def list_tables(self):
|
async def list_tables(
|
||||||
|
self, limit: int, page_token: Optional[str] = None
|
||||||
|
) -> Iterable[str]:
|
||||||
"""List all tables in the database."""
|
"""List all tables in the database."""
|
||||||
json = await self.get("/v1/table/", {})
|
if page_token is None:
|
||||||
|
page_token = ""
|
||||||
|
json = await self.get("/v1/table/", {"limit": limit, "page_token": page_token})
|
||||||
return json["tables"]
|
return json["tables"]
|
||||||
|
|
||||||
@_check_not_closed
|
@_check_not_closed
|
||||||
|
|||||||
@@ -12,14 +12,19 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Optional
|
from typing import Iterable, List, Optional, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
from overrides import override
|
||||||
|
|
||||||
from ..common import DATA
|
from ..common import DATA
|
||||||
from ..db import DBConnection
|
from ..db import DBConnection
|
||||||
|
from ..embeddings import EmbeddingFunctionConfig
|
||||||
|
from ..pydantic import LanceModel
|
||||||
from ..table import Table, _sanitize_data
|
from ..table import Table, _sanitize_data
|
||||||
from .arrow import to_ipc_binary
|
from .arrow import to_ipc_binary
|
||||||
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
|
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
|
||||||
@@ -52,11 +57,31 @@ class RemoteDBConnection(DBConnection):
|
|||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"RemoveConnect(name={self.db_name})"
|
return f"RemoveConnect(name={self.db_name})"
|
||||||
|
|
||||||
def table_names(self) -> List[str]:
|
@override
|
||||||
"""List the names of all tables in the database."""
|
def table_names(self, page_token: Optional[str] = None, limit=10) -> Iterable[str]:
|
||||||
result = self._loop.run_until_complete(self._client.list_tables())
|
"""List the names of all tables in the database.
|
||||||
return result
|
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
page_token: str
|
||||||
|
The last token to start the new page.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
An iterator of table names.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
result = self._loop.run_until_complete(
|
||||||
|
self._client.list_tables(limit, page_token)
|
||||||
|
)
|
||||||
|
if len(result) > 0:
|
||||||
|
page_token = result[len(result) - 1]
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
for item in result:
|
||||||
|
yield item
|
||||||
|
|
||||||
|
@override
|
||||||
def open_table(self, name: str) -> Table:
|
def open_table(self, name: str) -> Table:
|
||||||
"""Open a Lance Table in the database.
|
"""Open a Lance Table in the database.
|
||||||
|
|
||||||
@@ -71,23 +96,50 @@ class RemoteDBConnection(DBConnection):
|
|||||||
"""
|
"""
|
||||||
from .table import RemoteTable
|
from .table import RemoteTable
|
||||||
|
|
||||||
# TODO: check if table exists
|
# check if table exists
|
||||||
|
try:
|
||||||
|
self._loop.run_until_complete(
|
||||||
|
self._client.post(f"/v1/table/{name}/describe/")
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logging.error(
|
||||||
|
"Table {name} does not exist."
|
||||||
|
"Please first call db.create_table({name}, data)"
|
||||||
|
)
|
||||||
return RemoteTable(self, name)
|
return RemoteTable(self, name)
|
||||||
|
|
||||||
|
@override
|
||||||
def create_table(
|
def create_table(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
data: DATA = None,
|
data: DATA = None,
|
||||||
schema: pa.Schema = None,
|
schema: Optional[Union[pa.Schema, LanceModel]] = None,
|
||||||
on_bad_vectors: str = "error",
|
on_bad_vectors: str = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
|
embedding_functions: Optional[List[EmbeddingFunctionConfig]] = None,
|
||||||
) -> Table:
|
) -> Table:
|
||||||
if data is None and schema is None:
|
if data is None and schema is None:
|
||||||
raise ValueError("Either data or schema must be provided.")
|
raise ValueError("Either data or schema must be provided.")
|
||||||
|
if embedding_functions is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"embedding_functions is not supported for remote databases."
|
||||||
|
"Please vote https://github.com/lancedb/lancedb/issues/626 "
|
||||||
|
"for this feature."
|
||||||
|
)
|
||||||
|
|
||||||
|
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
||||||
|
# convert LanceModel to pyarrow schema
|
||||||
|
# note that it's possible this contains
|
||||||
|
# embedding function metadata already
|
||||||
|
schema = schema.to_arrow_schema()
|
||||||
|
|
||||||
if data is not None:
|
if data is not None:
|
||||||
data = _sanitize_data(
|
data = _sanitize_data(
|
||||||
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
data,
|
||||||
|
schema,
|
||||||
|
metadata=None,
|
||||||
|
on_bad_vectors=on_bad_vectors,
|
||||||
|
fill_value=fill_value,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if schema is None:
|
if schema is None:
|
||||||
@@ -109,6 +161,7 @@ class RemoteDBConnection(DBConnection):
|
|||||||
)
|
)
|
||||||
return RemoteTable(self, name)
|
return RemoteTable(self, name)
|
||||||
|
|
||||||
|
@override
|
||||||
def drop_table(self, name: str):
|
def drop_table(self, name: str):
|
||||||
"""Drop a table from the database.
|
"""Drop a table from the database.
|
||||||
|
|
||||||
@@ -122,3 +175,8 @@ class RemoteDBConnection(DBConnection):
|
|||||||
f"/v1/table/{name}/drop/",
|
f"/v1/table/{name}/drop/",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Close the connection to the database."""
|
||||||
|
self._loop.close()
|
||||||
|
await self._client.close()
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from lance import json_to_schema
|
from lance import json_to_schema
|
||||||
@@ -44,6 +44,14 @@ class RemoteTable(Table):
|
|||||||
schema = json_to_schema(resp["schema"])
|
schema = json_to_schema(resp["schema"])
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
|
@property
|
||||||
|
def version(self) -> int:
|
||||||
|
"""Get the current version of the table"""
|
||||||
|
resp = self._conn._loop.run_until_complete(
|
||||||
|
self._conn._client.post(f"/v1/table/{self._name}/describe/")
|
||||||
|
)
|
||||||
|
return resp["version"]
|
||||||
|
|
||||||
def to_arrow(self) -> pa.Table:
|
def to_arrow(self) -> pa.Table:
|
||||||
"""Return the table as an Arrow table."""
|
"""Return the table as an Arrow table."""
|
||||||
raise NotImplementedError("to_arrow() is not supported on the LanceDB cloud")
|
raise NotImplementedError("to_arrow() is not supported on the LanceDB cloud")
|
||||||
@@ -62,8 +70,63 @@ class RemoteTable(Table):
|
|||||||
num_sub_vectors=96,
|
num_sub_vectors=96,
|
||||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||||
replace: bool = True,
|
replace: bool = True,
|
||||||
|
accelerator: Optional[str] = None,
|
||||||
|
index_cache_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
raise NotImplementedError
|
"""Create an index on the table.
|
||||||
|
Currently, the only parameters that matter are
|
||||||
|
the metric and the vector column name.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
metric : str
|
||||||
|
The metric to use for the index. Default is "L2".
|
||||||
|
num_partitions : int
|
||||||
|
The number of partitions to use for the index. Default is 256.
|
||||||
|
num_sub_vectors : int
|
||||||
|
The number of sub-vectors to use for the index. Default is 96.
|
||||||
|
vector_column_name : str
|
||||||
|
The name of the vector column. Default is "vector".
|
||||||
|
replace : bool
|
||||||
|
Whether to replace the existing index. Default is True.
|
||||||
|
accelerator : str, optional
|
||||||
|
If set, use the given accelerator to create the index.
|
||||||
|
Default is None. Currently not supported.
|
||||||
|
index_cache_size : int, optional
|
||||||
|
The size of the index cache in number of entries. Default value is 256.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
import lancedb
|
||||||
|
import uuid
|
||||||
|
from lancedb.schema import vector
|
||||||
|
conn = lancedb.connect("db://...", api_key="...", region="...")
|
||||||
|
table_name = uuid.uuid4().hex
|
||||||
|
schema = pa.schema(
|
||||||
|
[
|
||||||
|
pa.field("id", pa.uint32(), False),
|
||||||
|
pa.field("vector", vector(128), False),
|
||||||
|
pa.field("s", pa.string(), False),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
table = conn.create_table(
|
||||||
|
table_name,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
table.create_index()
|
||||||
|
"""
|
||||||
|
index_type = "vector"
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"column": vector_column_name,
|
||||||
|
"index_type": index_type,
|
||||||
|
"metric_type": metric,
|
||||||
|
"index_cache_size": index_cache_size,
|
||||||
|
}
|
||||||
|
resp = self._conn._loop.run_until_complete(
|
||||||
|
self._conn._client.post(f"/v1/table/{self._name}/create_index/", data=data)
|
||||||
|
)
|
||||||
|
return resp
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
@@ -98,10 +161,12 @@ class RemoteTable(Table):
|
|||||||
return LanceVectorQueryBuilder(self, query, vector_column_name)
|
return LanceVectorQueryBuilder(self, query, vector_column_name)
|
||||||
|
|
||||||
def _execute_query(self, query: Query) -> pa.Table:
|
def _execute_query(self, query: Query) -> pa.Table:
|
||||||
if query.prefilter:
|
|
||||||
raise NotImplementedError("Cloud support for prefiltering is coming soon")
|
|
||||||
result = self._conn._client.query(self._name, query)
|
result = self._conn._client.query(self._name, query)
|
||||||
return self._conn._loop.run_until_complete(result).to_arrow()
|
return self._conn._loop.run_until_complete(result).to_arrow()
|
||||||
|
|
||||||
def delete(self, predicate: str):
|
def delete(self, predicate: str):
|
||||||
raise NotImplementedError
|
"""Delete rows from the table."""
|
||||||
|
payload = {"predicate": predicate}
|
||||||
|
self._conn._loop.run_until_complete(
|
||||||
|
self._conn._client.post(f"/v1/table/{self._name}/delete/", data=payload)
|
||||||
|
)
|
||||||
|
|||||||
@@ -17,22 +17,27 @@ import inspect
|
|||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any, Iterable, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union
|
||||||
|
|
||||||
import lance
|
import lance
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pyarrow.compute as pc
|
import pyarrow.compute as pc
|
||||||
from lance import LanceDataset
|
from lance import LanceDataset
|
||||||
from lance.dataset import ReaderLike
|
|
||||||
from lance.vector import vec_to_table
|
from lance.vector import vec_to_table
|
||||||
|
|
||||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||||
from .embeddings import EmbeddingFunctionRegistry
|
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||||
from .embeddings.functions import EmbeddingFunctionConfig
|
|
||||||
from .pydantic import LanceModel
|
from .pydantic import LanceModel
|
||||||
from .query import LanceQueryBuilder, Query
|
from .query import LanceQueryBuilder, Query
|
||||||
from .util import fs_from_uri, safe_import_pandas
|
from .util import fs_from_uri, safe_import_pandas
|
||||||
|
from .utils.events import register_event
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
from lance.dataset import CleanupStats, ReaderLike
|
||||||
|
|
||||||
|
|
||||||
pd = safe_import_pandas()
|
pd = safe_import_pandas()
|
||||||
|
|
||||||
@@ -85,7 +90,9 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
|
|||||||
for vector_column, conf in functions.items():
|
for vector_column, conf in functions.items():
|
||||||
func = conf.function
|
func = conf.function
|
||||||
if vector_column not in data.column_names:
|
if vector_column not in data.column_names:
|
||||||
col_data = func.compute_source_embeddings(data[conf.source_column])
|
col_data = func.compute_source_embeddings_with_retry(
|
||||||
|
data[conf.source_column]
|
||||||
|
)
|
||||||
if schema is not None:
|
if schema is not None:
|
||||||
dtype = schema.field(vector_column).type
|
dtype = schema.field(vector_column).type
|
||||||
else:
|
else:
|
||||||
@@ -136,7 +143,7 @@ class Table(ABC):
|
|||||||
|
|
||||||
Can query the table with [Table.search][lancedb.table.Table.search].
|
Can query the table with [Table.search][lancedb.table.Table.search].
|
||||||
|
|
||||||
>>> table.search([0.4, 0.4]).select(["b"]).to_df()
|
>>> table.search([0.4, 0.4]).select(["b"]).to_pandas()
|
||||||
b vector _distance
|
b vector _distance
|
||||||
0 4 [0.5, 1.3] 0.82
|
0 4 [0.5, 1.3] 0.82
|
||||||
1 2 [1.1, 1.2] 1.13
|
1 2 [1.1, 1.2] 1.13
|
||||||
@@ -148,13 +155,13 @@ class Table(ABC):
|
|||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def schema(self) -> pa.Schema:
|
def schema(self) -> pa.Schema:
|
||||||
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#) of
|
"""The [Arrow Schema](https://arrow.apache.org/docs/python/api/datatypes.html#)
|
||||||
this [Table](Table)
|
of this Table
|
||||||
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def to_pandas(self):
|
def to_pandas(self) -> "pd.DataFrame":
|
||||||
"""Return the table as a pandas DataFrame.
|
"""Return the table as a pandas DataFrame.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@@ -180,6 +187,8 @@ class Table(ABC):
|
|||||||
num_sub_vectors=96,
|
num_sub_vectors=96,
|
||||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||||
replace: bool = True,
|
replace: bool = True,
|
||||||
|
accelerator: Optional[str] = None,
|
||||||
|
index_cache_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
"""Create an index on the table.
|
"""Create an index on the table.
|
||||||
|
|
||||||
@@ -189,17 +198,23 @@ class Table(ABC):
|
|||||||
The distance metric to use when creating the index.
|
The distance metric to use when creating the index.
|
||||||
Valid values are "L2", "cosine", or "dot".
|
Valid values are "L2", "cosine", or "dot".
|
||||||
L2 is euclidean distance.
|
L2 is euclidean distance.
|
||||||
num_partitions: int
|
num_partitions: int, default 256
|
||||||
The number of IVF partitions to use when creating the index.
|
The number of IVF partitions to use when creating the index.
|
||||||
Default is 256.
|
Default is 256.
|
||||||
num_sub_vectors: int
|
num_sub_vectors: int, default 96
|
||||||
The number of PQ sub-vectors to use when creating the index.
|
The number of PQ sub-vectors to use when creating the index.
|
||||||
Default is 96.
|
Default is 96.
|
||||||
vector_column_name: str, default "vector"
|
vector_column_name: str, default "vector"
|
||||||
The vector column name to create the index.
|
The vector column name to create the index.
|
||||||
replace: bool, default True
|
replace: bool, default True
|
||||||
If True, replace the existing index if it exists.
|
- If True, replace the existing index if it exists.
|
||||||
If False, raise an error if duplicate index exists.
|
|
||||||
|
- If False, raise an error if duplicate index exists.
|
||||||
|
accelerator: str, default None
|
||||||
|
If set, use the given accelerator to create the index.
|
||||||
|
Only support "cuda" for now.
|
||||||
|
index_cache_size : int, optional
|
||||||
|
The size of the index cache in number of entries. Default value is 256.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -215,8 +230,14 @@ class Table(ABC):
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
data: list-of-dict, dict, pd.DataFrame
|
data: DATA
|
||||||
The data to insert into the table.
|
The data to insert into the table. Acceptable types are:
|
||||||
|
|
||||||
|
- dict or list-of-dict
|
||||||
|
|
||||||
|
- pandas.DataFrame
|
||||||
|
|
||||||
|
- pyarrow.Table or pyarrow.RecordBatch
|
||||||
mode: str
|
mode: str
|
||||||
The mode to use when writing the data. Valid values are
|
The mode to use when writing the data. Valid values are
|
||||||
"append" and "overwrite".
|
"append" and "overwrite".
|
||||||
@@ -237,31 +258,70 @@ class Table(ABC):
|
|||||||
query_type: str = "auto",
|
query_type: str = "auto",
|
||||||
) -> LanceQueryBuilder:
|
) -> LanceQueryBuilder:
|
||||||
"""Create a search query to find the nearest neighbors
|
"""Create a search query to find the nearest neighbors
|
||||||
of the given query vector.
|
of the given query vector. We currently support [vector search][search]
|
||||||
|
and [full-text search][experimental-full-text-search].
|
||||||
|
|
||||||
|
All query options are defined in [Query][lancedb.query.Query].
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> import lancedb
|
||||||
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
|
>>> data = [
|
||||||
|
... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]},
|
||||||
|
... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]},
|
||||||
|
... {"original_width": 3000, "caption": "test", "vector": [0.3, 6.2, 2.6]}
|
||||||
|
... ]
|
||||||
|
>>> table = db.create_table("my_table", data)
|
||||||
|
>>> query = [0.4, 1.4, 2.4]
|
||||||
|
>>> (table.search(query, vector_column_name="vector")
|
||||||
|
... .where("original_width > 1000", prefilter=True)
|
||||||
|
... .select(["caption", "original_width"])
|
||||||
|
... .limit(2)
|
||||||
|
... .to_pandas())
|
||||||
|
caption original_width vector _distance
|
||||||
|
0 foo 2000 [0.5, 3.4, 1.3] 5.220000
|
||||||
|
1 test 3000 [0.3, 6.2, 2.6] 23.089996
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
query: str, list, np.ndarray, PIL.Image.Image, default None
|
query: list/np.ndarray/str/PIL.Image.Image, default None
|
||||||
The query to search for. If None then
|
The targetted vector to search for.
|
||||||
the select/where/limit clauses are applied to filter
|
|
||||||
|
- *default None*.
|
||||||
|
Acceptable types are: list, np.ndarray, PIL.Image.Image
|
||||||
|
|
||||||
|
- If None then the select/where/limit clauses are applied to filter
|
||||||
the table
|
the table
|
||||||
vector_column_name: str, default "vector"
|
vector_column_name: str
|
||||||
The name of the vector column to search.
|
The name of the vector column to search.
|
||||||
query_type: str, default "auto"
|
*default "vector"*
|
||||||
"vector", "fts", or "auto"
|
query_type: str
|
||||||
If "auto" then the query type is inferred from the query;
|
*default "auto"*.
|
||||||
If `query` is a list/np.ndarray then the query type is "vector";
|
Acceptable types are: "vector", "fts", or "auto"
|
||||||
If `query` is a PIL.Image.Image then either do vector search
|
|
||||||
|
- If "auto" then the query type is inferred from the query;
|
||||||
|
|
||||||
|
- If `query` is a list/np.ndarray then the query type is
|
||||||
|
"vector";
|
||||||
|
|
||||||
|
- If `query` is a PIL.Image.Image then either do vector search,
|
||||||
or raise an error if no corresponding embedding function is found.
|
or raise an error if no corresponding embedding function is found.
|
||||||
If `query` is a string, then the query type is "vector" if the
|
|
||||||
|
- If `query` is a string, then the query type is "vector" if the
|
||||||
table has embedding functions else the query type is "fts"
|
table has embedding functions else the query type is "fts"
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
LanceQueryBuilder
|
LanceQueryBuilder
|
||||||
A query builder object representing the query.
|
A query builder object representing the query.
|
||||||
Once executed, the query returns selected columns, the vector,
|
Once executed, the query returns
|
||||||
and also the "_distance" column which is the distance between the query
|
|
||||||
|
- selected columns
|
||||||
|
|
||||||
|
- the vector
|
||||||
|
|
||||||
|
- and also the "_distance" column which is the distance between the query
|
||||||
vector and the returned vector.
|
vector and the returned vector.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -280,14 +340,20 @@ class Table(ABC):
|
|||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
where: str
|
where: str
|
||||||
The SQL where clause to use when deleting rows. For example, 'x = 2'
|
The SQL where clause to use when deleting rows.
|
||||||
or 'x IN (1, 2, 3)'. The filter must not be empty, or it will error.
|
|
||||||
|
- For example, 'x = 2' or 'x IN (1, 2, 3)'.
|
||||||
|
|
||||||
|
The filter must not be empty, or it will error.
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
>>> import lancedb
|
>>> import lancedb
|
||||||
>>> import pandas as pd
|
>>> data = [
|
||||||
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
|
... {"x": 1, "vector": [1, 2]},
|
||||||
|
... {"x": 2, "vector": [3, 4]},
|
||||||
|
... {"x": 3, "vector": [5, 6]}
|
||||||
|
... ]
|
||||||
>>> db = lancedb.connect("./.lancedb")
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
>>> table = db.create_table("my_table", data)
|
>>> table = db.create_table("my_table", data)
|
||||||
>>> table.to_pandas()
|
>>> table.to_pandas()
|
||||||
@@ -371,7 +437,8 @@ class LanceTable(Table):
|
|||||||
--------
|
--------
|
||||||
>>> import lancedb
|
>>> import lancedb
|
||||||
>>> db = lancedb.connect("./.lancedb")
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
>>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}])
|
>>> table = db.create_table("my_table",
|
||||||
|
... [{"vector": [1.1, 0.9], "type": "vector"}])
|
||||||
>>> table.version
|
>>> table.version
|
||||||
2
|
2
|
||||||
>>> table.to_pandas()
|
>>> table.to_pandas()
|
||||||
@@ -390,6 +457,17 @@ class LanceTable(Table):
|
|||||||
raise ValueError(f"Invalid version {version}")
|
raise ValueError(f"Invalid version {version}")
|
||||||
self._reset_dataset(version=version)
|
self._reset_dataset(version=version)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Accessing the property updates the cached value
|
||||||
|
_ = self._dataset
|
||||||
|
except Exception as e:
|
||||||
|
if "not found" in str(e):
|
||||||
|
raise ValueError(
|
||||||
|
f"Version {version} no longer exists. Was it cleaned up?"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
def restore(self, version: int = None):
|
def restore(self, version: int = None):
|
||||||
"""Restore a version of the table. This is an in-place operation.
|
"""Restore a version of the table. This is an in-place operation.
|
||||||
|
|
||||||
@@ -407,7 +485,8 @@ class LanceTable(Table):
|
|||||||
--------
|
--------
|
||||||
>>> import lancedb
|
>>> import lancedb
|
||||||
>>> db = lancedb.connect("./.lancedb")
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
>>> table = db.create_table("my_table", [{"vector": [1.1, 0.9], "type": "vector"}])
|
>>> table = db.create_table("my_table", [
|
||||||
|
... {"vector": [1.1, 0.9], "type": "vector"}])
|
||||||
>>> table.version
|
>>> table.version
|
||||||
2
|
2
|
||||||
>>> table.to_pandas()
|
>>> table.to_pandas()
|
||||||
@@ -479,6 +558,8 @@ class LanceTable(Table):
|
|||||||
num_sub_vectors=96,
|
num_sub_vectors=96,
|
||||||
vector_column_name=VECTOR_COLUMN_NAME,
|
vector_column_name=VECTOR_COLUMN_NAME,
|
||||||
replace: bool = True,
|
replace: bool = True,
|
||||||
|
accelerator: Optional[str] = None,
|
||||||
|
index_cache_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
"""Create an index on the table."""
|
"""Create an index on the table."""
|
||||||
self._dataset.create_index(
|
self._dataset.create_index(
|
||||||
@@ -488,8 +569,11 @@ class LanceTable(Table):
|
|||||||
num_partitions=num_partitions,
|
num_partitions=num_partitions,
|
||||||
num_sub_vectors=num_sub_vectors,
|
num_sub_vectors=num_sub_vectors,
|
||||||
replace=replace,
|
replace=replace,
|
||||||
|
accelerator=accelerator,
|
||||||
|
index_cache_size=index_cache_size,
|
||||||
)
|
)
|
||||||
self._reset_dataset()
|
self._reset_dataset()
|
||||||
|
register_event("create_index")
|
||||||
|
|
||||||
def create_fts_index(self, field_names: Union[str, List[str]]):
|
def create_fts_index(self, field_names: Union[str, List[str]]):
|
||||||
"""Create a full-text search index on the table.
|
"""Create a full-text search index on the table.
|
||||||
@@ -508,6 +592,7 @@ class LanceTable(Table):
|
|||||||
field_names = [field_names]
|
field_names = [field_names]
|
||||||
index = create_index(self._get_fts_index_path(), field_names)
|
index = create_index(self._get_fts_index_path(), field_names)
|
||||||
populate_index(index, self, field_names)
|
populate_index(index, self, field_names)
|
||||||
|
register_event("create_fts_index")
|
||||||
|
|
||||||
def _get_fts_index_path(self):
|
def _get_fts_index_path(self):
|
||||||
return os.path.join(self._dataset_uri, "_indices", "tantivy")
|
return os.path.join(self._dataset_uri, "_indices", "tantivy")
|
||||||
@@ -560,6 +645,7 @@ class LanceTable(Table):
|
|||||||
)
|
)
|
||||||
lance.write_dataset(data, self._dataset_uri, schema=self.schema, mode=mode)
|
lance.write_dataset(data, self._dataset_uri, schema=self.schema, mode=mode)
|
||||||
self._reset_dataset()
|
self._reset_dataset()
|
||||||
|
register_event("add")
|
||||||
|
|
||||||
def merge(
|
def merge(
|
||||||
self,
|
self,
|
||||||
@@ -623,6 +709,7 @@ class LanceTable(Table):
|
|||||||
other_table, left_on=left_on, right_on=right_on, schema=schema
|
other_table, left_on=left_on, right_on=right_on, schema=schema
|
||||||
)
|
)
|
||||||
self._reset_dataset()
|
self._reset_dataset()
|
||||||
|
register_event("merge")
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def embedding_functions(self) -> dict:
|
def embedding_functions(self) -> dict:
|
||||||
@@ -646,14 +733,39 @@ class LanceTable(Table):
|
|||||||
query_type: str = "auto",
|
query_type: str = "auto",
|
||||||
) -> LanceQueryBuilder:
|
) -> LanceQueryBuilder:
|
||||||
"""Create a search query to find the nearest neighbors
|
"""Create a search query to find the nearest neighbors
|
||||||
of the given query vector.
|
of the given query vector. We currently support [vector search][search]
|
||||||
|
and [full-text search][search].
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> import lancedb
|
||||||
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
|
>>> data = [
|
||||||
|
... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]},
|
||||||
|
... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]},
|
||||||
|
... {"original_width": 3000, "caption": "test", "vector": [0.3, 6.2, 2.6]}
|
||||||
|
... ]
|
||||||
|
>>> table = db.create_table("my_table", data)
|
||||||
|
>>> query = [0.4, 1.4, 2.4]
|
||||||
|
>>> (table.search(query, vector_column_name="vector")
|
||||||
|
... .where("original_width > 1000", prefilter=True)
|
||||||
|
... .select(["caption", "original_width"])
|
||||||
|
... .limit(2)
|
||||||
|
... .to_pandas())
|
||||||
|
caption original_width vector _distance
|
||||||
|
0 foo 2000 [0.5, 3.4, 1.3] 5.220000
|
||||||
|
1 test 3000 [0.3, 6.2, 2.6] 23.089996
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
query: str, list, np.ndarray, a PIL Image or None
|
query: list/np.ndarray/str/PIL.Image.Image, default None
|
||||||
The query to search for. If None then
|
The targetted vector to search for.
|
||||||
the select/where/limit clauses are applied to filter
|
|
||||||
the table
|
- *default None*.
|
||||||
|
Acceptable types are: list, np.ndarray, PIL.Image.Image
|
||||||
|
|
||||||
|
- If None then the select/[where][sql]/limit clauses are applied
|
||||||
|
to filter the table
|
||||||
vector_column_name: str, default "vector"
|
vector_column_name: str, default "vector"
|
||||||
The name of the vector column to search.
|
The name of the vector column to search.
|
||||||
query_type: str, default "auto"
|
query_type: str, default "auto"
|
||||||
@@ -662,7 +774,7 @@ class LanceTable(Table):
|
|||||||
If `query` is a list/np.ndarray then the query type is "vector";
|
If `query` is a list/np.ndarray then the query type is "vector";
|
||||||
If `query` is a PIL.Image.Image then either do vector search
|
If `query` is a PIL.Image.Image then either do vector search
|
||||||
or raise an error if no corresponding embedding function is found.
|
or raise an error if no corresponding embedding function is found.
|
||||||
If the query is a string, then the query type is "vector" if the
|
If the `query` is a string, then the query type is "vector" if the
|
||||||
table has embedding functions, else the query type is "fts"
|
table has embedding functions, else the query type is "fts"
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@@ -673,6 +785,7 @@ class LanceTable(Table):
|
|||||||
and also the "_distance" column which is the distance between the query
|
and also the "_distance" column which is the distance between the query
|
||||||
vector and the returned vector.
|
vector and the returned vector.
|
||||||
"""
|
"""
|
||||||
|
register_event("search")
|
||||||
return LanceQueryBuilder.create(
|
return LanceQueryBuilder.create(
|
||||||
self, query, query_type, vector_column_name=vector_column_name
|
self, query, query_type, vector_column_name=vector_column_name
|
||||||
)
|
)
|
||||||
@@ -695,8 +808,11 @@ class LanceTable(Table):
|
|||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
>>> import lancedb
|
>>> import lancedb
|
||||||
>>> import pandas as pd
|
>>> data = [
|
||||||
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
|
... {"x": 1, "vector": [1, 2]},
|
||||||
|
... {"x": 2, "vector": [3, 4]},
|
||||||
|
... {"x": 3, "vector": [5, 6]}
|
||||||
|
... ]
|
||||||
>>> db = lancedb.connect("./.lancedb")
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
>>> table = db.create_table("my_table", data)
|
>>> table = db.create_table("my_table", data)
|
||||||
>>> table.to_pandas()
|
>>> table.to_pandas()
|
||||||
@@ -715,7 +831,8 @@ class LanceTable(Table):
|
|||||||
The data to insert into the table.
|
The data to insert into the table.
|
||||||
At least one of `data` or `schema` must be provided.
|
At least one of `data` or `schema` must be provided.
|
||||||
schema: pa.Schema or LanceModel, optional
|
schema: pa.Schema or LanceModel, optional
|
||||||
The schema of the table. If not provided, the schema is inferred from the data.
|
The schema of the table. If not provided,
|
||||||
|
the schema is inferred from the data.
|
||||||
At least one of `data` or `schema` must be provided.
|
At least one of `data` or `schema` must be provided.
|
||||||
mode: str, default "create"
|
mode: str, default "create"
|
||||||
The mode to use when writing the data. Valid values are
|
The mode to use when writing the data. Valid values are
|
||||||
@@ -776,6 +893,7 @@ class LanceTable(Table):
|
|||||||
if data is not None:
|
if data is not None:
|
||||||
table.add(data)
|
table.add(data)
|
||||||
|
|
||||||
|
register_event("create_table")
|
||||||
return table
|
return table
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -785,7 +903,8 @@ class LanceTable(Table):
|
|||||||
file_info = fs.get_file_info(path)
|
file_info = fs.get_file_info(path)
|
||||||
if file_info.type != pa.fs.FileType.Directory:
|
if file_info.type != pa.fs.FileType.Directory:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"Table {name} does not exist. Please first call db.create_table({name}, data)"
|
f"Table {name} does not exist."
|
||||||
|
f"Please first call db.create_table({name}, data)"
|
||||||
)
|
)
|
||||||
return tbl
|
return tbl
|
||||||
|
|
||||||
@@ -811,8 +930,11 @@ class LanceTable(Table):
|
|||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
>>> import lancedb
|
>>> import lancedb
|
||||||
>>> import pandas as pd
|
>>> data = [
|
||||||
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
|
... {"x": 1, "vector": [1, 2]},
|
||||||
|
... {"x": 2, "vector": [3, 4]},
|
||||||
|
... {"x": 3, "vector": [5, 6]}
|
||||||
|
... ]
|
||||||
>>> db = lancedb.connect("./.lancedb")
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
>>> table = db.create_table("my_table", data)
|
>>> table = db.create_table("my_table", data)
|
||||||
>>> table.to_pandas()
|
>>> table.to_pandas()
|
||||||
@@ -841,15 +963,10 @@ class LanceTable(Table):
|
|||||||
self.delete(where)
|
self.delete(where)
|
||||||
self.add(orig_data, mode="append")
|
self.add(orig_data, mode="append")
|
||||||
self._reset_dataset()
|
self._reset_dataset()
|
||||||
|
register_event("update")
|
||||||
|
|
||||||
def _execute_query(self, query: Query) -> pa.Table:
|
def _execute_query(self, query: Query) -> pa.Table:
|
||||||
ds = self.to_lance()
|
ds = self.to_lance()
|
||||||
if query.prefilter:
|
|
||||||
for idx in ds.list_indices():
|
|
||||||
if query.vector_column in idx["fields"]:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Prefiltering for indexed vector column is coming soon."
|
|
||||||
)
|
|
||||||
return ds.to_table(
|
return ds.to_table(
|
||||||
columns=query.columns,
|
columns=query.columns,
|
||||||
filter=query.filter,
|
filter=query.filter,
|
||||||
@@ -864,6 +981,48 @@ class LanceTable(Table):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def cleanup_old_versions(
|
||||||
|
self,
|
||||||
|
older_than: Optional[timedelta] = None,
|
||||||
|
*,
|
||||||
|
delete_unverified: bool = False,
|
||||||
|
) -> CleanupStats:
|
||||||
|
"""
|
||||||
|
Clean up old versions of the table, freeing disk space.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
older_than: timedelta, default None
|
||||||
|
The minimum age of the version to delete. If None, then this defaults
|
||||||
|
to two weeks.
|
||||||
|
delete_unverified: bool, default False
|
||||||
|
Because they may be part of an in-progress transaction, files newer
|
||||||
|
than 7 days old are not deleted by default. If you are sure that
|
||||||
|
there are no in-progress transactions, then you can set this to True
|
||||||
|
to delete all files older than `older_than`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
CleanupStats
|
||||||
|
The stats of the cleanup operation, including how many bytes were
|
||||||
|
freed.
|
||||||
|
"""
|
||||||
|
return self.to_lance().cleanup_old_versions(
|
||||||
|
older_than, delete_unverified=delete_unverified
|
||||||
|
)
|
||||||
|
|
||||||
|
def compact_files(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Run the compaction process on the table.
|
||||||
|
|
||||||
|
This can be run after making several small appends to optimize the table
|
||||||
|
for faster reads.
|
||||||
|
|
||||||
|
Arguments are passed onto :meth:`lance.dataset.DatasetOptimizer.compact_files`.
|
||||||
|
For most cases, the default should be fine.
|
||||||
|
"""
|
||||||
|
return self.to_lance().optimize.compact_files(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_schema(
|
def _sanitize_schema(
|
||||||
data: pa.Table,
|
data: pa.Table,
|
||||||
@@ -949,7 +1108,8 @@ def _sanitize_vector_column(
|
|||||||
# ChunkedArray is annoying to work with, so we combine chunks here
|
# ChunkedArray is annoying to work with, so we combine chunks here
|
||||||
vec_arr = data[vector_column_name].combine_chunks()
|
vec_arr = data[vector_column_name].combine_chunks()
|
||||||
if pa.types.is_list(data[vector_column_name].type):
|
if pa.types.is_list(data[vector_column_name].type):
|
||||||
# if it's a variable size list array we make sure the dimensions are all the same
|
# if it's a variable size list array,
|
||||||
|
# we make sure the dimensions are all the same
|
||||||
has_jagged_ndims = len(vec_arr.values) % len(data) != 0
|
has_jagged_ndims = len(vec_arr.values) % len(data) != 0
|
||||||
if has_jagged_ndims:
|
if has_jagged_ndims:
|
||||||
data = _sanitize_jagged(
|
data = _sanitize_jagged(
|
||||||
|
|||||||
15
python/lancedb/utils/__init__.py
Normal file
15
python/lancedb/utils/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# Copyright 2023 LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from .config import Config
|
||||||
|
|
||||||
|
CONFIG = Config()
|
||||||
116
python/lancedb/utils/config.py
Normal file
116
python/lancedb/utils/config.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
# Copyright 2023 LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from .general import LOGGER, is_dir_writeable, yaml_load, yaml_save
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_config_dir(sub_dir="lancedb"):
|
||||||
|
"""
|
||||||
|
Get the user config directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sub_dir (str): The name of the subdirectory to create.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Path): The path to the user config directory.
|
||||||
|
"""
|
||||||
|
# Return the appropriate config directory for each operating system
|
||||||
|
if platform.system() == "Windows":
|
||||||
|
path = Path.home() / "AppData" / "Roaming" / sub_dir
|
||||||
|
elif platform.system() == "Darwin":
|
||||||
|
path = Path.home() / "Library" / "Application Support" / sub_dir
|
||||||
|
elif platform.system() == "Linux":
|
||||||
|
path = Path.home() / ".config" / sub_dir
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported operating system: {platform.system()}")
|
||||||
|
|
||||||
|
# GCP and AWS lambda fix, only /tmp is writeable
|
||||||
|
if not is_dir_writeable(path.parent):
|
||||||
|
LOGGER.warning(
|
||||||
|
f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD."
|
||||||
|
"Alternatively you can define a LANCEDB_CONFIG_DIR environment variable for this path."
|
||||||
|
)
|
||||||
|
path = (
|
||||||
|
Path("/tmp") / sub_dir
|
||||||
|
if is_dir_writeable("/tmp")
|
||||||
|
else Path().cwd() / sub_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the subdirectory if it does not exist
|
||||||
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
USER_CONFIG_DIR = Path(os.getenv("LANCEDB_CONFIG_DIR") or get_user_config_dir())
|
||||||
|
CONFIG_FILE = USER_CONFIG_DIR / "config.yaml"
|
||||||
|
|
||||||
|
|
||||||
|
class Config(dict):
|
||||||
|
"""
|
||||||
|
Manages lancedb config stored in a YAML file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (str | Path): Path to the lancedb config YAML file. Default is USER_CONFIG_DIR / 'config.yaml'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, file=CONFIG_FILE):
|
||||||
|
self.file = Path(file)
|
||||||
|
self.defaults = { # Default global config values
|
||||||
|
"diagnostics": True,
|
||||||
|
"uuid": hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(),
|
||||||
|
}
|
||||||
|
|
||||||
|
super().__init__(copy.deepcopy(self.defaults))
|
||||||
|
|
||||||
|
if not self.file.exists():
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
self.load()
|
||||||
|
correct_keys = self.keys() == self.defaults.keys()
|
||||||
|
correct_types = all(
|
||||||
|
type(a) is type(b) for a, b in zip(self.values(), self.defaults.values())
|
||||||
|
)
|
||||||
|
if not (correct_keys and correct_types):
|
||||||
|
LOGGER.warning(
|
||||||
|
"WARNING ⚠️ LanceDB settings reset to default values. This may be due to a possible problem "
|
||||||
|
"with your settings or a recent package update. "
|
||||||
|
f"\nView settings & usage with 'lancedb settings' or at '{self.file}'"
|
||||||
|
)
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
"""Loads settings from the YAML file."""
|
||||||
|
super().update(yaml_load(self.file))
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
"""Saves the current settings to the YAML file."""
|
||||||
|
yaml_save(self.file, dict(self))
|
||||||
|
|
||||||
|
def update(self, *args, **kwargs):
|
||||||
|
"""Updates a setting value in the current settings."""
|
||||||
|
super().update(*args, **kwargs)
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Resets the settings to default and saves them."""
|
||||||
|
self.clear()
|
||||||
|
self.update(self.defaults)
|
||||||
|
self.save()
|
||||||
161
python/lancedb/utils/events.py
Normal file
161
python/lancedb/utils/events.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
# Copyright 2023 LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import importlib.metadata
|
||||||
|
import platform
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
from lancedb.utils import CONFIG
|
||||||
|
from lancedb.utils.general import TryExcept
|
||||||
|
|
||||||
|
from .general import (
|
||||||
|
PLATFORMS,
|
||||||
|
get_git_origin_url,
|
||||||
|
is_git_dir,
|
||||||
|
is_github_actions_ci,
|
||||||
|
is_online,
|
||||||
|
is_pip_package,
|
||||||
|
is_pytest_running,
|
||||||
|
threaded_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _Events:
|
||||||
|
"""
|
||||||
|
A class for collecting anonymous event analytics. Event analytics are enabled when ``diagnostics=True`` in config and
|
||||||
|
disabled when ``diagnostics=False``.
|
||||||
|
|
||||||
|
You can enable or disable diagnostics by running ``lancedb diagnostics --enabled`` or ``lancedb diagnostics --disabled``.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
url : str
|
||||||
|
The URL to send anonymous events.
|
||||||
|
rate_limit : float
|
||||||
|
The rate limit in seconds for sending events.
|
||||||
|
metadata : dict
|
||||||
|
A dictionary containing metadata about the environment.
|
||||||
|
enabled : bool
|
||||||
|
A flag to enable or disable Events based on certain conditions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
|
||||||
|
url = "https://app.posthog.com/capture/"
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
api_key = "phc_oENDjGgHtmIDrV6puUiFem2RB4JA8gGWulfdulmMdZP"
|
||||||
|
# This api-key is write only and is safe to expose in the codebase.
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Initializes the Events object with default values for events, rate_limit, and metadata.
|
||||||
|
"""
|
||||||
|
self.events = [] # events list
|
||||||
|
self.max_events = 25 # max events to store in memory
|
||||||
|
self.rate_limit = 60.0 # rate limit (seconds)
|
||||||
|
self.time = 0.0
|
||||||
|
|
||||||
|
if is_git_dir():
|
||||||
|
install = "git"
|
||||||
|
elif is_pip_package():
|
||||||
|
install = "pip"
|
||||||
|
else:
|
||||||
|
install = "other"
|
||||||
|
self.metadata = {
|
||||||
|
"cli": sys.argv[0],
|
||||||
|
"install": install,
|
||||||
|
"python": ".".join(platform.python_version_tuple()[:2]),
|
||||||
|
"version": importlib.metadata.version("lancedb"),
|
||||||
|
"platforms": PLATFORMS,
|
||||||
|
"session_id": round(random.random() * 1e15),
|
||||||
|
# 'engagement_time_msec': 1000 # TODO: In future we might be interested in this metric
|
||||||
|
}
|
||||||
|
|
||||||
|
TESTS_RUNNING = is_pytest_running() or is_github_actions_ci()
|
||||||
|
ONLINE = is_online()
|
||||||
|
self.enabled = (
|
||||||
|
CONFIG["diagnostics"]
|
||||||
|
and not TESTS_RUNNING
|
||||||
|
and ONLINE
|
||||||
|
and (
|
||||||
|
is_pip_package()
|
||||||
|
or get_git_origin_url() == "https://github.com/lancedb/lancedb.git"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, event_name, params={}):
|
||||||
|
"""
|
||||||
|
Attempts to add a new event to the events list and send events if the rate limit is reached.
|
||||||
|
|
||||||
|
Args
|
||||||
|
----
|
||||||
|
event_name : str
|
||||||
|
The name of the event to be logged.
|
||||||
|
params : dict, optional
|
||||||
|
A dictionary of additional parameters to be logged with the event.
|
||||||
|
"""
|
||||||
|
### NOTE: We might need a way to tag a session with a label to check usage from a source. Setting label should be exposed to the user.
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
|
if (
|
||||||
|
len(self.events) < self.max_events
|
||||||
|
): # Events list limited to 25 events (drop any events past this)
|
||||||
|
params.update(self.metadata)
|
||||||
|
self.events.append(
|
||||||
|
{
|
||||||
|
"event": event_name,
|
||||||
|
"properties": params,
|
||||||
|
"timestamp": datetime.datetime.now(
|
||||||
|
tz=datetime.timezone.utc
|
||||||
|
).isoformat(),
|
||||||
|
"distinct_id": CONFIG["uuid"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check rate limit
|
||||||
|
t = time.time()
|
||||||
|
if (t - self.time) < self.rate_limit:
|
||||||
|
return
|
||||||
|
# Time is over rate limiter, send now
|
||||||
|
data = {
|
||||||
|
"api_key": self.api_key,
|
||||||
|
"distinct_id": CONFIG["uuid"], # posthog needs this to accepts the event
|
||||||
|
"batch": self.events,
|
||||||
|
}
|
||||||
|
|
||||||
|
# POST equivalent to requests.post(self.url, json=data).
|
||||||
|
# threaded request is used to avoid blocking, retries are disabled, and verbose is disabled
|
||||||
|
# to avoid any possible disruption in the console.
|
||||||
|
threaded_request(
|
||||||
|
method="post",
|
||||||
|
url=self.url,
|
||||||
|
headers=self.headers,
|
||||||
|
json=data,
|
||||||
|
retry=0,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flush & Reset
|
||||||
|
self.events = []
|
||||||
|
self.time = t
|
||||||
|
|
||||||
|
|
||||||
|
@TryExcept(verbose=False)
|
||||||
|
def register_event(name: str, **kwargs):
|
||||||
|
if _Events._instance is None:
|
||||||
|
_Events._instance = _Events()
|
||||||
|
|
||||||
|
_Events._instance(name, **kwargs)
|
||||||
445
python/lancedb/utils/general.py
Normal file
445
python/lancedb/utils/general.py
Normal file
@@ -0,0 +1,445 @@
|
|||||||
|
# Copyright 2023 LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import importlib
|
||||||
|
import logging.config
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
LOGGING_NAME = "lancedb"
|
||||||
|
VERBOSE = (
|
||||||
|
str(os.getenv("LANCEDB_VERBOSE", True)).lower() == "true"
|
||||||
|
) # global verbose mode
|
||||||
|
|
||||||
|
|
||||||
|
def set_logging(name=LOGGING_NAME, verbose=True):
|
||||||
|
"""Sets up logging for the given name.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : str, optional
|
||||||
|
The name of the logger. Default is 'lancedb'.
|
||||||
|
verbose : bool, optional
|
||||||
|
Whether to enable verbose logging. Default is True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
rank = int(os.getenv("RANK", -1)) # rank in world for Multi-GPU trainings
|
||||||
|
level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
|
||||||
|
logging.config.dictConfig(
|
||||||
|
{
|
||||||
|
"version": 1,
|
||||||
|
"disable_existing_loggers": False,
|
||||||
|
"formatters": {name: {"format": "%(message)s"}},
|
||||||
|
"handlers": {
|
||||||
|
name: {
|
||||||
|
"class": "logging.StreamHandler",
|
||||||
|
"formatter": name,
|
||||||
|
"level": level,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"loggers": {name: {"level": level, "handlers": [name], "propagate": False}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
set_logging(LOGGING_NAME, verbose=VERBOSE)
|
||||||
|
LOGGER = logging.getLogger(LOGGING_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
def is_pip_package(filepath: str = __name__) -> bool:
|
||||||
|
"""Determines if the file at the given filepath is part of a pip package.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filepath : str, optional
|
||||||
|
The filepath to check. Default is the current file.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if the file is part of a pip package, False otherwise.
|
||||||
|
"""
|
||||||
|
# Get the spec for the module
|
||||||
|
spec = importlib.util.find_spec(filepath)
|
||||||
|
|
||||||
|
# Return whether the spec is not None and the origin is not None (indicating it is a package)
|
||||||
|
return spec is not None and spec.origin is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_pytest_running():
|
||||||
|
"""Determines whether pytest is currently running or not.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if pytest is running, False otherwise.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
("PYTEST_CURRENT_TEST" in os.environ)
|
||||||
|
or ("pytest" in sys.modules)
|
||||||
|
or ("pytest" in Path(sys.argv[0]).stem)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_github_actions_ci() -> bool:
|
||||||
|
"""
|
||||||
|
Determine if the current environment is a GitHub Actions CI Python runner.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if the current environment is a GitHub Actions CI Python runner, False otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return (
|
||||||
|
"GITHUB_ACTIONS" in os.environ
|
||||||
|
and "RUNNER_OS" in os.environ
|
||||||
|
and "RUNNER_TOOL_CACHE" in os.environ
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_git_dir():
|
||||||
|
"""
|
||||||
|
Determines whether the current file is part of a git repository.
|
||||||
|
If the current file is not part of a git repository, returns None.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if current file is part of a git repository.
|
||||||
|
"""
|
||||||
|
return get_git_dir() is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_online() -> bool:
|
||||||
|
"""
|
||||||
|
Check internet connectivity by attempting to connect to a known online host.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if connection is successful, False otherwise.
|
||||||
|
"""
|
||||||
|
import socket
|
||||||
|
|
||||||
|
for host in "1.1.1.1", "8.8.8.8", "223.5.5.5": # Cloudflare, Google, AliDNS:
|
||||||
|
try:
|
||||||
|
test_connection = socket.create_connection(address=(host, 53), timeout=2)
|
||||||
|
except (socket.timeout, socket.gaierror, OSError):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# If the connection was successful, close it to avoid a ResourceWarning
|
||||||
|
test_connection.close()
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_dir_writeable(dir_path: Union[str, Path]) -> bool:
|
||||||
|
"""Check if a directory is writeable.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
dir_path : Union[str, Path]
|
||||||
|
The path to the directory.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if the directory is writeable, False otherwise.
|
||||||
|
"""
|
||||||
|
return os.access(str(dir_path), os.W_OK)
|
||||||
|
|
||||||
|
|
||||||
|
def is_colab():
|
||||||
|
"""Check if the current script is running inside a Google Colab notebook.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if running inside a Colab notebook, False otherwise.
|
||||||
|
"""
|
||||||
|
return "COLAB_RELEASE_TAG" in os.environ or "COLAB_BACKEND_VERSION" in os.environ
|
||||||
|
|
||||||
|
|
||||||
|
def is_kaggle():
|
||||||
|
"""Check if the current script is running inside a Kaggle kernel.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if running inside a Kaggle kernel, False otherwise.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
os.environ.get("PWD") == "/kaggle/working"
|
||||||
|
and os.environ.get("KAGGLE_URL_BASE") == "https://www.kaggle.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_jupyter():
|
||||||
|
"""Check if the current script is running inside a Jupyter Notebook.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if running inside a Jupyter Notebook, False otherwise.
|
||||||
|
"""
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
from IPython import get_ipython
|
||||||
|
|
||||||
|
return get_ipython() is not None
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_docker() -> bool:
|
||||||
|
"""Determine if the script is running inside a Docker container.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if the script is running inside a Docker container, False otherwise.
|
||||||
|
"""
|
||||||
|
file = Path("/proc/self/cgroup")
|
||||||
|
if file.exists():
|
||||||
|
with open(file) as f:
|
||||||
|
return "docker" in f.read()
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_git_dir():
|
||||||
|
"""Determine whether the current file is part of a git repository and if so, returns the repository root directory.
|
||||||
|
If the current file is not part of a git repository, returns None.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Path | None
|
||||||
|
Git root directory if found or None if not found.
|
||||||
|
"""
|
||||||
|
for d in Path(__file__).parents:
|
||||||
|
if (d / ".git").is_dir():
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def get_git_origin_url():
|
||||||
|
"""Retrieve the origin URL of a git repository.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str | None
|
||||||
|
The origin URL of the git repository or None if not git directory.
|
||||||
|
"""
|
||||||
|
if is_git_dir():
|
||||||
|
with contextlib.suppress(subprocess.CalledProcessError):
|
||||||
|
origin = subprocess.check_output(
|
||||||
|
["git", "config", "--get", "remote.origin.url"]
|
||||||
|
)
|
||||||
|
return origin.decode().strip()
|
||||||
|
|
||||||
|
|
||||||
|
def yaml_save(file="data.yaml", data=None, header=""):
|
||||||
|
"""Save YAML data to a file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
file : str, optional
|
||||||
|
File name, by default 'data.yaml'.
|
||||||
|
data : dict, optional
|
||||||
|
Data to save in YAML format, by default None.
|
||||||
|
header : str, optional
|
||||||
|
YAML header to add, by default "".
|
||||||
|
"""
|
||||||
|
if data is None:
|
||||||
|
data = {}
|
||||||
|
file = Path(file)
|
||||||
|
if not file.parent.exists():
|
||||||
|
# Create parent directories if they don't exist
|
||||||
|
file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Convert Path objects to strings
|
||||||
|
for k, v in data.items():
|
||||||
|
if isinstance(v, Path):
|
||||||
|
data[k] = str(v)
|
||||||
|
|
||||||
|
# Dump data to file in YAML format
|
||||||
|
with open(file, "w", errors="ignore", encoding="utf-8") as f:
|
||||||
|
if header:
|
||||||
|
f.write(header)
|
||||||
|
yaml.safe_dump(data, f, sort_keys=False, allow_unicode=True)
|
||||||
|
|
||||||
|
|
||||||
|
def yaml_load(file="data.yaml", append_filename=False):
|
||||||
|
"""
|
||||||
|
Load YAML data from a file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
file : str, optional
|
||||||
|
File name. Default is 'data.yaml'.
|
||||||
|
append_filename : bool, optional
|
||||||
|
Add the YAML filename to the YAML dictionary. Default is False.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict
|
||||||
|
YAML data and file name.
|
||||||
|
"""
|
||||||
|
assert Path(file).suffix in (
|
||||||
|
".yaml",
|
||||||
|
".yml",
|
||||||
|
), f"Attempting to load non-YAML file {file} with yaml_load()"
|
||||||
|
with open(file, errors="ignore", encoding="utf-8") as f:
|
||||||
|
s = f.read() # string
|
||||||
|
|
||||||
|
# Add YAML filename to dict and return
|
||||||
|
data = (
|
||||||
|
yaml.safe_load(s) or {}
|
||||||
|
) # always return a dict (yaml.safe_load() may return None for empty files)
|
||||||
|
if append_filename:
|
||||||
|
data["yaml_file"] = str(file)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
|
||||||
|
"""
|
||||||
|
Pretty prints a YAML file or a YAML-formatted dictionary.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
yaml_file : Union[str, Path, dict]
|
||||||
|
The file path of the YAML file or a YAML-formatted dictionary.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
yaml_dict = (
|
||||||
|
yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file
|
||||||
|
)
|
||||||
|
dump = yaml.dump(yaml_dict, sort_keys=False, allow_unicode=True)
|
||||||
|
LOGGER.info(f"Printing '{yaml_file}'\n\n{dump}")
|
||||||
|
|
||||||
|
|
||||||
|
PLATFORMS = [platform.system()]
|
||||||
|
if is_colab():
|
||||||
|
PLATFORMS.append("Colab")
|
||||||
|
if is_kaggle():
|
||||||
|
PLATFORMS.append("Kaggle")
|
||||||
|
if is_jupyter():
|
||||||
|
PLATFORMS.append("Jupyter")
|
||||||
|
if is_docker():
|
||||||
|
PLATFORMS.append("Docker")
|
||||||
|
|
||||||
|
PLATFORMS = "|".join(PLATFORMS)
|
||||||
|
|
||||||
|
|
||||||
|
class TryExcept(contextlib.ContextDecorator):
|
||||||
|
"""
|
||||||
|
TryExcept context manager.
|
||||||
|
Usage: @TryExcept() decorator or 'with TryExcept():' context manager.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, msg="", verbose=True):
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
msg : str, optional
|
||||||
|
Custom message to display in case of exception, by default "".
|
||||||
|
verbose : bool, optional
|
||||||
|
Whether to display the message, by default True.
|
||||||
|
"""
|
||||||
|
self.msg = msg
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, value, traceback):
|
||||||
|
if self.verbose and value:
|
||||||
|
LOGGER.info(f"{self.msg}{': ' if self.msg else ''}{value}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def threaded_request(
|
||||||
|
method, url, retry=3, timeout=30, thread=True, code=-1, verbose=True, **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
method : str
|
||||||
|
The HTTP method to use for the request. Choices are 'post' and 'get'.
|
||||||
|
url : str
|
||||||
|
The URL to make the request to.
|
||||||
|
retry : int, optional
|
||||||
|
Number of retries to attempt before giving up, by default 3.
|
||||||
|
timeout : int, optional
|
||||||
|
Timeout in seconds after which the function will give up retrying, by default 30.
|
||||||
|
thread : bool, optional
|
||||||
|
Whether to execute the request in a separate daemon thread, by default True.
|
||||||
|
code : int, optional
|
||||||
|
An identifier for the request, used for logging purposes, by default -1.
|
||||||
|
verbose : bool, optional
|
||||||
|
A flag to determine whether to print out to console or not, by default True.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
requests.Response
|
||||||
|
The HTTP response object. If the request is executed in a separate thread, returns the thread itself.
|
||||||
|
"""
|
||||||
|
retry_codes = () # retry only these codes TODO: add codes if needed in future (500, 408)
|
||||||
|
|
||||||
|
@TryExcept(verbose=verbose)
|
||||||
|
def func(method, url, **kwargs):
|
||||||
|
"""Make HTTP requests with retries and timeouts, with optional progress tracking."""
|
||||||
|
response = None
|
||||||
|
t0 = time.time()
|
||||||
|
for i in range(retry + 1):
|
||||||
|
if (time.time() - t0) > timeout:
|
||||||
|
break
|
||||||
|
response = requests.request(method, url, **kwargs)
|
||||||
|
if response.status_code < 300: # good return codes in the 2xx range
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
m = response.json().get("message", "No JSON message.")
|
||||||
|
except AttributeError:
|
||||||
|
m = "Unable to read JSON."
|
||||||
|
if i == 0:
|
||||||
|
if response.status_code in retry_codes:
|
||||||
|
m += f" Retrying {retry}x for {timeout}s." if retry else ""
|
||||||
|
elif response.status_code == 429: # rate limit
|
||||||
|
m = f"Rate limit reached"
|
||||||
|
if verbose:
|
||||||
|
LOGGER.warning(f"{response.status_code} #{code}")
|
||||||
|
if response.status_code not in retry_codes:
|
||||||
|
return response
|
||||||
|
time.sleep(2**i) # exponential standoff
|
||||||
|
return response
|
||||||
|
|
||||||
|
args = method, url
|
||||||
|
if thread:
|
||||||
|
return threading.Thread(
|
||||||
|
target=func, args=args, kwargs=kwargs, daemon=True
|
||||||
|
).start()
|
||||||
|
else:
|
||||||
|
return func(*args, **kwargs)
|
||||||
113
python/lancedb/utils/sentry_log.py
Normal file
113
python/lancedb/utils/sentry_log.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
# Copyright 2023 LanceDB Developers
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import bdb
|
||||||
|
import importlib.metadata
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from lancedb.utils import CONFIG
|
||||||
|
|
||||||
|
from .general import (
|
||||||
|
PLATFORMS,
|
||||||
|
TryExcept,
|
||||||
|
is_git_dir,
|
||||||
|
is_github_actions_ci,
|
||||||
|
is_online,
|
||||||
|
is_pip_package,
|
||||||
|
is_pytest_running,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@TryExcept(verbose=False)
|
||||||
|
def set_sentry():
|
||||||
|
"""
|
||||||
|
Initialize the Sentry SDK for error tracking and reporting. Only used if sentry_sdk package is installed and
|
||||||
|
sync=True in settings. Run 'lancedb settings' to see and update settings YAML file.
|
||||||
|
|
||||||
|
Conditions required to send errors (ALL conditions must be met or no errors will be reported):
|
||||||
|
- sentry_sdk package is installed
|
||||||
|
- sync=True in settings
|
||||||
|
- pytest is not running
|
||||||
|
- running in a pip package installation
|
||||||
|
- running in a non-git directory
|
||||||
|
- online environment
|
||||||
|
|
||||||
|
The function also configures Sentry SDK to ignore KeyboardInterrupt and FileNotFoundError
|
||||||
|
exceptions for now.
|
||||||
|
|
||||||
|
Additionally, the function sets custom tags and user information for Sentry events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def before_send(event, hint):
|
||||||
|
"""
|
||||||
|
Modify the event before sending it to Sentry based on specific exception types and messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (dict): The event dictionary containing information about the error.
|
||||||
|
hint (dict): A dictionary containing additional information about the error.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The modified event or None if the event should not be sent to Sentry.
|
||||||
|
"""
|
||||||
|
if "exc_info" in hint:
|
||||||
|
exc_type, exc_value, tb = hint["exc_info"]
|
||||||
|
ignored_errors = ["out of memory", "no space left on device", "testing"]
|
||||||
|
if any(error in str(exc_value).lower() for error in ignored_errors):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if is_git_dir():
|
||||||
|
install = "git"
|
||||||
|
elif is_pip_package():
|
||||||
|
install = "pip"
|
||||||
|
else:
|
||||||
|
install = "other"
|
||||||
|
|
||||||
|
event["tags"] = {
|
||||||
|
"sys_argv": sys.argv[0],
|
||||||
|
"sys_argv_name": Path(sys.argv[0]).name,
|
||||||
|
"install": install,
|
||||||
|
"platforms": PLATFORMS,
|
||||||
|
"version": importlib.metadata.version("lancedb"),
|
||||||
|
}
|
||||||
|
return event
|
||||||
|
|
||||||
|
TESTS_RUNNING = is_pytest_running() or is_github_actions_ci()
|
||||||
|
ONLINE = is_online()
|
||||||
|
if CONFIG["diagnostics"] and not TESTS_RUNNING and ONLINE and is_pip_package():
|
||||||
|
# and not is_git_dir(): # not running inside a git dir. Maybe too restrictive?
|
||||||
|
|
||||||
|
# If sentry_sdk package is not installed then return and do not use Sentry
|
||||||
|
try:
|
||||||
|
import sentry_sdk # noqa
|
||||||
|
except ImportError:
|
||||||
|
return
|
||||||
|
|
||||||
|
sentry_sdk.init(
|
||||||
|
dsn="https://c63ef8c64e05d1aa1a96513361f3ca2f@o4505950840946688.ingest.sentry.io/4505950933614592",
|
||||||
|
debug=False,
|
||||||
|
include_local_variables=False,
|
||||||
|
traces_sample_rate=0.5,
|
||||||
|
environment="production", # 'dev' or 'production'
|
||||||
|
before_send=before_send,
|
||||||
|
ignore_errors=[KeyboardInterrupt, FileNotFoundError, bdb.BdbQuit],
|
||||||
|
)
|
||||||
|
sentry_sdk.set_user({"id": CONFIG["uuid"]}) # SHA-256 anonymized UUID hash
|
||||||
|
|
||||||
|
# Disable all sentry logging
|
||||||
|
for logger in "sentry_sdk", "sentry_sdk.errors":
|
||||||
|
logging.getLogger(logger).setLevel(logging.CRITICAL)
|
||||||
|
|
||||||
|
|
||||||
|
set_sentry()
|
||||||
@@ -1,8 +1,9 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.3.0"
|
version = "0.3.4"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"pylance==0.8.1",
|
"deprecation",
|
||||||
|
"pylance==0.8.17",
|
||||||
"ratelimiter~=1.0",
|
"ratelimiter~=1.0",
|
||||||
"retry>=0.9.2",
|
"retry>=0.9.2",
|
||||||
"tqdm>=4.1.0",
|
"tqdm>=4.1.0",
|
||||||
@@ -10,7 +11,11 @@ dependencies = [
|
|||||||
"pydantic>=1.10",
|
"pydantic>=1.10",
|
||||||
"attrs>=21.3.0",
|
"attrs>=21.3.0",
|
||||||
"semver>=3.0",
|
"semver>=3.0",
|
||||||
"cachetools"
|
"cachetools",
|
||||||
|
"pyyaml>=6.0",
|
||||||
|
"click>=8.1.7",
|
||||||
|
"requests>=2.31.0",
|
||||||
|
"overrides>=0.7"
|
||||||
]
|
]
|
||||||
description = "lancedb"
|
description = "lancedb"
|
||||||
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
|
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
|
||||||
@@ -48,7 +53,10 @@ tests = ["pandas>=1.4", "pytest", "pytest-mock", "pytest-asyncio", "requests"]
|
|||||||
dev = ["ruff", "pre-commit", "black"]
|
dev = ["ruff", "pre-commit", "black"]
|
||||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||||
clip = ["torch", "pillow", "open-clip"]
|
clip = ["torch", "pillow", "open-clip"]
|
||||||
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip"]
|
embeddings = ["openai", "sentence-transformers", "torch", "pillow", "open-clip-torch", "cohere", "InstructorEmbedding"]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
lancedb = "lancedb.cli.cli:cli"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["setuptools", "wheel"]
|
requires = ["setuptools", "wheel"]
|
||||||
@@ -57,6 +65,9 @@ build-backend = "setuptools.build_meta"
|
|||||||
[tool.isort]
|
[tool.isort]
|
||||||
profile = "black"
|
profile = "black"
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
select = ["F", "E", "W", "I", "G", "TCH", "PERF"]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
addopts = "--strict-markers"
|
addopts = "--strict-markers"
|
||||||
markers = [
|
markers = [
|
||||||
|
|||||||
35
python/tests/test_cli.py
Normal file
35
python/tests/test_cli.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
from click.testing import CliRunner
|
||||||
|
|
||||||
|
from lancedb.cli.cli import cli
|
||||||
|
from lancedb.utils import CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
def test_entry():
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli)
|
||||||
|
assert result.exit_code == 0 # Main check
|
||||||
|
assert "lancedb" in result.output.lower() # lazy check
|
||||||
|
|
||||||
|
|
||||||
|
def test_diagnostics():
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ["diagnostics", "--disabled"])
|
||||||
|
assert result.exit_code == 0 # Main check
|
||||||
|
assert CONFIG["diagnostics"] == False
|
||||||
|
|
||||||
|
result = runner.invoke(cli, ["diagnostics", "--enabled"])
|
||||||
|
assert result.exit_code == 0 # Main check
|
||||||
|
assert CONFIG["diagnostics"] == True
|
||||||
|
|
||||||
|
|
||||||
|
def test_config():
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ["config"])
|
||||||
|
assert result.exit_code == 0 # Main check
|
||||||
|
cfg = CONFIG.copy()
|
||||||
|
cfg.pop("uuid")
|
||||||
|
for (
|
||||||
|
item,
|
||||||
|
_,
|
||||||
|
) in cfg.items(): # check for keys only as formatting is subject to change
|
||||||
|
assert item in result.output
|
||||||
@@ -47,7 +47,7 @@ def test_contextualizer(raw_df: pd.DataFrame):
|
|||||||
.stride(3)
|
.stride(3)
|
||||||
.text_col("token")
|
.text_col("token")
|
||||||
.groupby("document_id")
|
.groupby("document_id")
|
||||||
.to_df()["token"]
|
.to_pandas()["token"]
|
||||||
.to_list()
|
.to_list()
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -67,7 +67,7 @@ def test_contextualizer_with_threshold(raw_df: pd.DataFrame):
|
|||||||
.text_col("token")
|
.text_col("token")
|
||||||
.groupby("document_id")
|
.groupby("document_id")
|
||||||
.min_window_size(4)
|
.min_window_size(4)
|
||||||
.to_df()["token"]
|
.to_pandas()["token"]
|
||||||
.to_list()
|
.to_list()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -33,11 +33,11 @@ def test_basic(tmp_path):
|
|||||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
rs = table.search([100, 100]).limit(1).to_df()
|
rs = table.search([100, 100]).limit(1).to_pandas()
|
||||||
assert len(rs) == 1
|
assert len(rs) == 1
|
||||||
assert rs["item"].iloc[0] == "bar"
|
assert rs["item"].iloc[0] == "bar"
|
||||||
|
|
||||||
rs = table.search([100, 100]).where("price < 15").limit(2).to_df()
|
rs = table.search([100, 100]).where("price < 15").limit(2).to_pandas()
|
||||||
assert len(rs) == 1
|
assert len(rs) == 1
|
||||||
assert rs["item"].iloc[0] == "foo"
|
assert rs["item"].iloc[0] == "foo"
|
||||||
|
|
||||||
@@ -62,11 +62,11 @@ def test_ingest_pd(tmp_path):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
table = db.create_table("test", data=data)
|
table = db.create_table("test", data=data)
|
||||||
rs = table.search([100, 100]).limit(1).to_df()
|
rs = table.search([100, 100]).limit(1).to_pandas()
|
||||||
assert len(rs) == 1
|
assert len(rs) == 1
|
||||||
assert rs["item"].iloc[0] == "bar"
|
assert rs["item"].iloc[0] == "bar"
|
||||||
|
|
||||||
rs = table.search([100, 100]).where("price < 15").limit(2).to_df()
|
rs = table.search([100, 100]).where("price < 15").limit(2).to_pandas()
|
||||||
assert len(rs) == 1
|
assert len(rs) == 1
|
||||||
assert rs["item"].iloc[0] == "foo"
|
assert rs["item"].iloc[0] == "foo"
|
||||||
|
|
||||||
@@ -129,7 +129,7 @@ def test_ingest_iterator(tmp_path):
|
|||||||
[
|
[
|
||||||
PydanticSchema(vector=[3.1, 4.1], item="foo", price=10.0),
|
PydanticSchema(vector=[3.1, 4.1], item="foo", price=10.0),
|
||||||
PydanticSchema(vector=[5.9, 26.5], item="bar", price=20.0),
|
PydanticSchema(vector=[5.9, 26.5], item="bar", price=20.0),
|
||||||
]
|
],
|
||||||
# TODO: test pydict separately. it is unique column number and names contraint
|
# TODO: test pydict separately. it is unique column number and names contraint
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -137,8 +137,8 @@ def test_ingest_iterator(tmp_path):
|
|||||||
db = lancedb.connect(tmp_path)
|
db = lancedb.connect(tmp_path)
|
||||||
tbl = db.create_table("table2", make_batches(), schema=schema, mode="overwrite")
|
tbl = db.create_table("table2", make_batches(), schema=schema, mode="overwrite")
|
||||||
tbl.to_pandas()
|
tbl.to_pandas()
|
||||||
assert tbl.search([3.1, 4.1]).limit(1).to_df()["_distance"][0] == 0.0
|
assert tbl.search([3.1, 4.1]).limit(1).to_pandas()["_distance"][0] == 0.0
|
||||||
assert tbl.search([5.9, 26.5]).limit(1).to_df()["_distance"][0] == 0.0
|
assert tbl.search([5.9, 26.5]).limit(1).to_pandas()["_distance"][0] == 0.0
|
||||||
tbl_len = len(tbl)
|
tbl_len = len(tbl)
|
||||||
tbl.add(make_batches())
|
tbl.add(make_batches())
|
||||||
assert tbl_len == 50
|
assert tbl_len == 50
|
||||||
@@ -150,6 +150,21 @@ def test_ingest_iterator(tmp_path):
|
|||||||
run_tests(PydanticSchema)
|
run_tests(PydanticSchema)
|
||||||
|
|
||||||
|
|
||||||
|
def test_table_names(tmp_path):
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
data = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||||
|
"item": ["foo", "bar"],
|
||||||
|
"price": [10.0, 20.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
db.create_table("test2", data=data)
|
||||||
|
db.create_table("test1", data=data)
|
||||||
|
db.create_table("test3", data=data)
|
||||||
|
assert db.table_names() == ["test1", "test2", "test3"]
|
||||||
|
|
||||||
|
|
||||||
def test_create_mode(tmp_path):
|
def test_create_mode(tmp_path):
|
||||||
db = lancedb.connect(tmp_path)
|
db = lancedb.connect(tmp_path)
|
||||||
data = pd.DataFrame(
|
data = pd.DataFrame(
|
||||||
@@ -286,4 +301,29 @@ def test_replace_index(tmp_path):
|
|||||||
num_partitions=2,
|
num_partitions=2,
|
||||||
num_sub_vectors=4,
|
num_sub_vectors=4,
|
||||||
replace=True,
|
replace=True,
|
||||||
|
index_cache_size=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prefilter_with_index(tmp_path):
|
||||||
|
db = lancedb.connect(uri=tmp_path)
|
||||||
|
data = [
|
||||||
|
{"vector": np.random.rand(128), "item": "foo", "price": float(i)}
|
||||||
|
for i in range(1000)
|
||||||
|
]
|
||||||
|
sample_key = data[100]["vector"]
|
||||||
|
table = db.create_table(
|
||||||
|
"test",
|
||||||
|
data,
|
||||||
|
)
|
||||||
|
table.create_index(
|
||||||
|
num_partitions=2,
|
||||||
|
num_sub_vectors=4,
|
||||||
|
)
|
||||||
|
table = (
|
||||||
|
table.search(sample_key)
|
||||||
|
.where("price == 500", prefilter=True)
|
||||||
|
.limit(5)
|
||||||
|
.to_arrow()
|
||||||
|
)
|
||||||
|
assert table.num_rows == 1
|
||||||
|
|||||||
@@ -23,5 +23,5 @@ from lancedb import LanceDBConnection
|
|||||||
def test_against_local_server():
|
def test_against_local_server():
|
||||||
conn = LanceDBConnection("lancedb+http://localhost:10024")
|
conn = LanceDBConnection("lancedb+http://localhost:10024")
|
||||||
table = conn.open_table("sift1m_ivf1024_pq16")
|
table = conn.open_table("sift1m_ivf1024_pq16")
|
||||||
df = table.search(np.random.rand(128)).to_df()
|
df = table.search(np.random.rand(128)).to_pandas()
|
||||||
assert len(df) == 10
|
assert len(df) == 10
|
||||||
|
|||||||
@@ -15,13 +15,16 @@ import sys
|
|||||||
import lance
|
import lance
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
|
import pytest
|
||||||
|
|
||||||
from lancedb.conftest import MockTextEmbeddingFunction
|
import lancedb
|
||||||
|
from lancedb.conftest import MockRateLimitedEmbeddingFunction, MockTextEmbeddingFunction
|
||||||
from lancedb.embeddings import (
|
from lancedb.embeddings import (
|
||||||
EmbeddingFunctionConfig,
|
EmbeddingFunctionConfig,
|
||||||
EmbeddingFunctionRegistry,
|
EmbeddingFunctionRegistry,
|
||||||
with_embeddings,
|
with_embeddings,
|
||||||
)
|
)
|
||||||
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
|
||||||
|
|
||||||
def mock_embed_func(input_data):
|
def mock_embed_func(input_data):
|
||||||
@@ -83,3 +86,29 @@ def test_embedding_function(tmp_path):
|
|||||||
expected = func.compute_query_embeddings("hello world")
|
expected = func.compute_query_embeddings("hello world")
|
||||||
|
|
||||||
assert np.allclose(actual, expected)
|
assert np.allclose(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_function_rate_limit(tmp_path):
|
||||||
|
def _get_schema_from_model(model):
|
||||||
|
class Schema(LanceModel):
|
||||||
|
text: str = model.SourceField()
|
||||||
|
vector: Vector(model.ndims()) = model.VectorField()
|
||||||
|
|
||||||
|
return Schema
|
||||||
|
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
registry = EmbeddingFunctionRegistry.get_instance()
|
||||||
|
model = registry.get("test-rate-limited").create(max_retries=0)
|
||||||
|
schema = _get_schema_from_model(model)
|
||||||
|
table = db.create_table("test", schema=schema, mode="overwrite")
|
||||||
|
table.add([{"text": "hello world"}])
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
table.add([{"text": "hello world"}])
|
||||||
|
assert len(table) == 1
|
||||||
|
|
||||||
|
model = registry.get("test-rate-limited").create()
|
||||||
|
schema = _get_schema_from_model(model)
|
||||||
|
table = db.create_table("test", schema=schema, mode="overwrite")
|
||||||
|
table.add([{"text": "hello world"}])
|
||||||
|
table.add([{"text": "hello world"}])
|
||||||
|
assert len(table) == 2
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import io
|
import io
|
||||||
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -18,7 +19,7 @@ import pytest
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
import lancedb
|
import lancedb
|
||||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
from lancedb.embeddings import get_registry
|
||||||
from lancedb.pydantic import LanceModel, Vector
|
from lancedb.pydantic import LanceModel, Vector
|
||||||
|
|
||||||
# These are integration tests for embedding functions.
|
# These are integration tests for embedding functions.
|
||||||
@@ -30,12 +31,15 @@ from lancedb.pydantic import LanceModel, Vector
|
|||||||
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
|
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai"])
|
||||||
def test_sentence_transformer(alias, tmp_path):
|
def test_sentence_transformer(alias, tmp_path):
|
||||||
db = lancedb.connect(tmp_path)
|
db = lancedb.connect(tmp_path)
|
||||||
registry = EmbeddingFunctionRegistry.get_instance()
|
registry = get_registry()
|
||||||
func = registry.get(alias).create()
|
func = registry.get(alias).create(max_retries=0)
|
||||||
|
func2 = registry.get(alias).create(max_retries=0)
|
||||||
|
|
||||||
class Words(LanceModel):
|
class Words(LanceModel):
|
||||||
text: str = func.SourceField()
|
text: str = func.SourceField()
|
||||||
|
text2: str = func2.SourceField()
|
||||||
vector: Vector(func.ndims()) = func.VectorField()
|
vector: Vector(func.ndims()) = func.VectorField()
|
||||||
|
vector2: Vector(func2.ndims()) = func2.VectorField()
|
||||||
|
|
||||||
table = db.create_table("words", schema=Words)
|
table = db.create_table("words", schema=Words)
|
||||||
table.add(
|
table.add(
|
||||||
@@ -49,7 +53,16 @@ def test_sentence_transformer(alias, tmp_path):
|
|||||||
"foo",
|
"foo",
|
||||||
"bar",
|
"bar",
|
||||||
"baz",
|
"baz",
|
||||||
]
|
],
|
||||||
|
"text2": [
|
||||||
|
"to be or not to be",
|
||||||
|
"that is the question",
|
||||||
|
"for whether tis nobler",
|
||||||
|
"in the mind to suffer",
|
||||||
|
"the slings and arrows",
|
||||||
|
"of outrageous fortune",
|
||||||
|
"or to take arms",
|
||||||
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -61,6 +74,13 @@ def test_sentence_transformer(alias, tmp_path):
|
|||||||
expected = table.search(vec).limit(1).to_pydantic(Words)[0]
|
expected = table.search(vec).limit(1).to_pydantic(Words)[0]
|
||||||
assert actual.text == expected.text
|
assert actual.text == expected.text
|
||||||
assert actual.text == "hello world"
|
assert actual.text == "hello world"
|
||||||
|
assert not np.allclose(actual.vector, actual.vector2)
|
||||||
|
|
||||||
|
actual = (
|
||||||
|
table.search(query, vector_column_name="vector2").limit(1).to_pydantic(Words)[0]
|
||||||
|
)
|
||||||
|
assert actual.text != "hello world"
|
||||||
|
assert not np.allclose(actual.vector, actual.vector2)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@@ -68,7 +88,7 @@ def test_openclip(tmp_path):
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
db = lancedb.connect(tmp_path)
|
db = lancedb.connect(tmp_path)
|
||||||
registry = EmbeddingFunctionRegistry.get_instance()
|
registry = get_registry()
|
||||||
func = registry.get("open-clip").create()
|
func = registry.get("open-clip").create()
|
||||||
|
|
||||||
class Images(LanceModel):
|
class Images(LanceModel):
|
||||||
@@ -123,3 +143,42 @@ def test_openclip(tmp_path):
|
|||||||
arrow_table["vector"].combine_chunks().values.to_numpy(),
|
arrow_table["vector"].combine_chunks().values.to_numpy(),
|
||||||
arrow_table["vec_from_bytes"].combine_chunks().values.to_numpy(),
|
arrow_table["vec_from_bytes"].combine_chunks().values.to_numpy(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||||
|
) # also skip if cohere not installed
|
||||||
|
def test_cohere_embedding_function():
|
||||||
|
cohere = (
|
||||||
|
get_registry()
|
||||||
|
.get("cohere")
|
||||||
|
.create(name="embed-multilingual-v2.0", max_retries=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
class TextModel(LanceModel):
|
||||||
|
text: str = cohere.SourceField()
|
||||||
|
vector: Vector(cohere.ndims()) = cohere.VectorField()
|
||||||
|
|
||||||
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||||
|
db = lancedb.connect("~/lancedb")
|
||||||
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||||
|
|
||||||
|
tbl.add(df)
|
||||||
|
assert len(tbl.to_pandas()["vector"][0]) == cohere.ndims()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
def test_instructor_embedding(tmp_path):
|
||||||
|
model = get_registry().get("instructor").create()
|
||||||
|
|
||||||
|
class TextModel(LanceModel):
|
||||||
|
text: str = model.SourceField()
|
||||||
|
vector: Vector(model.ndims()) = model.VectorField()
|
||||||
|
|
||||||
|
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||||
|
|
||||||
|
tbl.add(df)
|
||||||
|
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||||
|
|||||||
@@ -71,14 +71,14 @@ def test_search_index(tmp_path, table):
|
|||||||
|
|
||||||
def test_create_index_from_table(tmp_path, table):
|
def test_create_index_from_table(tmp_path, table):
|
||||||
table.create_fts_index("text")
|
table.create_fts_index("text")
|
||||||
df = table.search("puppy").limit(10).select(["text"]).to_df()
|
df = table.search("puppy").limit(10).select(["text"]).to_pandas()
|
||||||
assert len(df) == 10
|
assert len(df) == 10
|
||||||
assert "text" in df.columns
|
assert "text" in df.columns
|
||||||
|
|
||||||
|
|
||||||
def test_create_index_multiple_columns(tmp_path, table):
|
def test_create_index_multiple_columns(tmp_path, table):
|
||||||
table.create_fts_index(["text", "text2"])
|
table.create_fts_index(["text", "text2"])
|
||||||
df = table.search("puppy").limit(10).to_df()
|
df = table.search("puppy").limit(10).to_pandas()
|
||||||
assert len(df) == 10
|
assert len(df) == 10
|
||||||
assert "text" in df.columns
|
assert "text" in df.columns
|
||||||
assert "text2" in df.columns
|
assert "text2" in df.columns
|
||||||
@@ -87,5 +87,5 @@ def test_create_index_multiple_columns(tmp_path, table):
|
|||||||
def test_empty_rs(tmp_path, table, mocker):
|
def test_empty_rs(tmp_path, table, mocker):
|
||||||
table.create_fts_index(["text", "text2"])
|
table.create_fts_index(["text", "text2"])
|
||||||
mocker.patch("lancedb.fts.search_index", return_value=([], []))
|
mocker.patch("lancedb.fts.search_index", return_value=([], []))
|
||||||
df = table.search("puppy").limit(10).to_df()
|
df = table.search("puppy").limit(10).to_pandas()
|
||||||
assert len(df) == 0
|
assert len(df) == 0
|
||||||
|
|||||||
@@ -36,11 +36,11 @@ def test_s3_io():
|
|||||||
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
rs = table.search([100, 100]).limit(1).to_df()
|
rs = table.search([100, 100]).limit(1).to_pandas()
|
||||||
assert len(rs) == 1
|
assert len(rs) == 1
|
||||||
assert rs["item"].iloc[0] == "bar"
|
assert rs["item"].iloc[0] == "bar"
|
||||||
|
|
||||||
rs = table.search([100, 100]).where("price < 15").limit(2).to_df()
|
rs = table.search([100, 100]).where("price < 15").limit(2).to_pandas()
|
||||||
assert len(rs) == 1
|
assert len(rs) == 1
|
||||||
assert rs["item"].iloc[0] == "foo"
|
assert rs["item"].iloc[0] == "foo"
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
|
from datetime import date, datetime
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
@@ -40,10 +41,18 @@ def test_pydantic_to_arrow():
|
|||||||
li: List[int]
|
li: List[int]
|
||||||
opt: Optional[str] = None
|
opt: Optional[str] = None
|
||||||
st: StructModel
|
st: StructModel
|
||||||
|
dt: date
|
||||||
|
dtt: datetime
|
||||||
# d: dict
|
# d: dict
|
||||||
|
|
||||||
m = TestModel(
|
m = TestModel(
|
||||||
id=1, s="hello", vec=[1.0, 2.0, 3.0], li=[2, 3, 4], st=StructModel(a="a", b=1.0)
|
id=1,
|
||||||
|
s="hello",
|
||||||
|
vec=[1.0, 2.0, 3.0],
|
||||||
|
li=[2, 3, 4],
|
||||||
|
st=StructModel(a="a", b=1.0),
|
||||||
|
dt=date.today(),
|
||||||
|
dtt=datetime.now(),
|
||||||
)
|
)
|
||||||
|
|
||||||
schema = pydantic_to_schema(TestModel)
|
schema = pydantic_to_schema(TestModel)
|
||||||
@@ -62,6 +71,8 @@ def test_pydantic_to_arrow():
|
|||||||
),
|
),
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
|
pa.field("dt", pa.date32(), False),
|
||||||
|
pa.field("dtt", pa.timestamp("us"), False),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
assert schema == expect_schema
|
assert schema == expect_schema
|
||||||
@@ -79,10 +90,18 @@ def test_pydantic_to_arrow_py38():
|
|||||||
li: List[int]
|
li: List[int]
|
||||||
opt: Optional[str] = None
|
opt: Optional[str] = None
|
||||||
st: StructModel
|
st: StructModel
|
||||||
|
dt: date
|
||||||
|
dtt: datetime
|
||||||
# d: dict
|
# d: dict
|
||||||
|
|
||||||
m = TestModel(
|
m = TestModel(
|
||||||
id=1, s="hello", vec=[1.0, 2.0, 3.0], li=[2, 3, 4], st=StructModel(a="a", b=1.0)
|
id=1,
|
||||||
|
s="hello",
|
||||||
|
vec=[1.0, 2.0, 3.0],
|
||||||
|
li=[2, 3, 4],
|
||||||
|
st=StructModel(a="a", b=1.0),
|
||||||
|
dt=date.today(),
|
||||||
|
dtt=datetime.now(),
|
||||||
)
|
)
|
||||||
|
|
||||||
schema = pydantic_to_schema(TestModel)
|
schema = pydantic_to_schema(TestModel)
|
||||||
@@ -101,6 +120,8 @@ def test_pydantic_to_arrow_py38():
|
|||||||
),
|
),
|
||||||
False,
|
False,
|
||||||
),
|
),
|
||||||
|
pa.field("dt", pa.date32(), False),
|
||||||
|
pa.field("dtt", pa.timestamp("us"), False),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
assert schema == expect_schema
|
assert schema == expect_schema
|
||||||
|
|||||||
@@ -85,17 +85,20 @@ def test_cast(table):
|
|||||||
|
|
||||||
|
|
||||||
def test_query_builder(table):
|
def test_query_builder(table):
|
||||||
df = (
|
rs = (
|
||||||
LanceVectorQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df()
|
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||||
|
.limit(1)
|
||||||
|
.select(["id"])
|
||||||
|
.to_list()
|
||||||
)
|
)
|
||||||
assert df["id"].values[0] == 1
|
assert rs[0]["id"] == 1
|
||||||
assert all(df["vector"].values[0] == [1, 2])
|
assert all(np.array(rs[0]["vector"]) == [1, 2])
|
||||||
|
|
||||||
|
|
||||||
def test_query_builder_with_filter(table):
|
def test_query_builder_with_filter(table):
|
||||||
df = LanceVectorQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df()
|
rs = LanceVectorQueryBuilder(table, [0, 0], "vector").where("id = 2").to_list()
|
||||||
assert df["id"].values[0] == 2
|
assert rs[0]["id"] == 2
|
||||||
assert all(df["vector"].values[0] == [3, 4])
|
assert all(np.array(rs[0]["vector"]) == [3, 4])
|
||||||
|
|
||||||
|
|
||||||
def test_query_builder_with_prefilter(table):
|
def test_query_builder_with_prefilter(table):
|
||||||
@@ -103,7 +106,7 @@ def test_query_builder_with_prefilter(table):
|
|||||||
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||||
.where("id = 2")
|
.where("id = 2")
|
||||||
.limit(1)
|
.limit(1)
|
||||||
.to_df()
|
.to_pandas()
|
||||||
)
|
)
|
||||||
assert len(df) == 0
|
assert len(df) == 0
|
||||||
|
|
||||||
@@ -111,7 +114,7 @@ def test_query_builder_with_prefilter(table):
|
|||||||
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||||
.where("id = 2", prefilter=True)
|
.where("id = 2", prefilter=True)
|
||||||
.limit(1)
|
.limit(1)
|
||||||
.to_df()
|
.to_pandas()
|
||||||
)
|
)
|
||||||
assert df["id"].values[0] == 2
|
assert df["id"].values[0] == 2
|
||||||
assert all(df["vector"].values[0] == [3, 4])
|
assert all(df["vector"].values[0] == [3, 4])
|
||||||
@@ -120,9 +123,11 @@ def test_query_builder_with_prefilter(table):
|
|||||||
def test_query_builder_with_metric(table):
|
def test_query_builder_with_metric(table):
|
||||||
query = [4, 8]
|
query = [4, 8]
|
||||||
vector_column_name = "vector"
|
vector_column_name = "vector"
|
||||||
df_default = LanceVectorQueryBuilder(table, query, vector_column_name).to_df()
|
df_default = LanceVectorQueryBuilder(table, query, vector_column_name).to_pandas()
|
||||||
df_l2 = (
|
df_l2 = (
|
||||||
LanceVectorQueryBuilder(table, query, vector_column_name).metric("L2").to_df()
|
LanceVectorQueryBuilder(table, query, vector_column_name)
|
||||||
|
.metric("L2")
|
||||||
|
.to_pandas()
|
||||||
)
|
)
|
||||||
tm.assert_frame_equal(df_default, df_l2)
|
tm.assert_frame_equal(df_default, df_l2)
|
||||||
|
|
||||||
@@ -130,7 +135,7 @@ def test_query_builder_with_metric(table):
|
|||||||
LanceVectorQueryBuilder(table, query, vector_column_name)
|
LanceVectorQueryBuilder(table, query, vector_column_name)
|
||||||
.metric("cosine")
|
.metric("cosine")
|
||||||
.limit(1)
|
.limit(1)
|
||||||
.to_df()
|
.to_pandas()
|
||||||
)
|
)
|
||||||
assert df_cosine._distance[0] == pytest.approx(
|
assert df_cosine._distance[0] == pytest.approx(
|
||||||
cosine_distance(query, df_cosine.vector[0]),
|
cosine_distance(query, df_cosine.vector[0]),
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ async def test_e2e_with_mock_server():
|
|||||||
columns=["id", "vector"],
|
columns=["id", "vector"],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
).to_df()
|
).to_pandas()
|
||||||
|
|
||||||
assert "vector" in df.columns
|
assert "vector" in df.columns
|
||||||
assert "id" in df.columns
|
assert "id" in df.columns
|
||||||
|
|||||||
@@ -32,4 +32,4 @@ def test_remote_db():
|
|||||||
setattr(conn, "_client", FakeLanceDBClient())
|
setattr(conn, "_client", FakeLanceDBClient())
|
||||||
|
|
||||||
table = conn["test"]
|
table = conn["test"]
|
||||||
table.search([1.0, 2.0]).to_df()
|
table.search([1.0, 2.0]).to_pandas()
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
from datetime import timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
from unittest.mock import PropertyMock, patch
|
from unittest.mock import PropertyMock, patch
|
||||||
@@ -212,6 +213,7 @@ def test_create_index_method():
|
|||||||
num_sub_vectors=96,
|
num_sub_vectors=96,
|
||||||
vector_column_name="vector",
|
vector_column_name="vector",
|
||||||
replace=True,
|
replace=True,
|
||||||
|
index_cache_size=256,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check that the _dataset.create_index method was called
|
# Check that the _dataset.create_index method was called
|
||||||
@@ -223,6 +225,8 @@ def test_create_index_method():
|
|||||||
num_partitions=256,
|
num_partitions=256,
|
||||||
num_sub_vectors=96,
|
num_sub_vectors=96,
|
||||||
replace=True,
|
replace=True,
|
||||||
|
accelerator=None,
|
||||||
|
index_cache_size=256,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -426,8 +430,8 @@ def test_multiple_vector_columns(db):
|
|||||||
table.add(df)
|
table.add(df)
|
||||||
|
|
||||||
q = np.random.randn(10)
|
q = np.random.randn(10)
|
||||||
result1 = table.search(q, vector_column_name="vector1").limit(1).to_df()
|
result1 = table.search(q, vector_column_name="vector1").limit(1).to_pandas()
|
||||||
result2 = table.search(q, vector_column_name="vector2").limit(1).to_df()
|
result2 = table.search(q, vector_column_name="vector2").limit(1).to_pandas()
|
||||||
|
|
||||||
assert result1["text"].iloc[0] != result2["text"].iloc[0]
|
assert result1["text"].iloc[0] != result2["text"].iloc[0]
|
||||||
|
|
||||||
@@ -438,6 +442,35 @@ def test_empty_query(db):
|
|||||||
"my_table",
|
"my_table",
|
||||||
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
|
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
|
||||||
)
|
)
|
||||||
df = table.search().select(["id"]).where("text='bar'").limit(1).to_df()
|
df = table.search().select(["id"]).where("text='bar'").limit(1).to_pandas()
|
||||||
val = df.id.iloc[0]
|
val = df.id.iloc[0]
|
||||||
assert val == 1
|
assert val == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_compact_cleanup(db):
|
||||||
|
table = LanceTable.create(
|
||||||
|
db,
|
||||||
|
"my_table",
|
||||||
|
data=[{"text": "foo", "id": 0}, {"text": "bar", "id": 1}],
|
||||||
|
)
|
||||||
|
|
||||||
|
table.add([{"text": "baz", "id": 2}])
|
||||||
|
assert len(table) == 3
|
||||||
|
assert table.version == 3
|
||||||
|
|
||||||
|
stats = table.compact_files()
|
||||||
|
assert len(table) == 3
|
||||||
|
# Compact_files bump 2 versions.
|
||||||
|
assert table.version == 5
|
||||||
|
assert stats.fragments_removed > 0
|
||||||
|
assert stats.fragments_added == 1
|
||||||
|
|
||||||
|
stats = table.cleanup_old_versions()
|
||||||
|
assert stats.bytes_removed == 0
|
||||||
|
|
||||||
|
stats = table.cleanup_old_versions(older_than=timedelta(0), delete_unverified=True)
|
||||||
|
assert stats.bytes_removed > 0
|
||||||
|
assert table.version == 5
|
||||||
|
|
||||||
|
with pytest.raises(Exception, match="Version 3 no longer exists"):
|
||||||
|
table.checkout(3)
|
||||||
|
|||||||
60
python/tests/test_telemetry.py
Normal file
60
python/tests/test_telemetry.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import lancedb
|
||||||
|
from lancedb.utils.events import _Events
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def request_log_path(tmp_path):
|
||||||
|
return tmp_path / "request.json"
|
||||||
|
|
||||||
|
|
||||||
|
def mock_register_event(name: str, **kwargs):
|
||||||
|
if _Events._instance is None:
|
||||||
|
_Events._instance = _Events()
|
||||||
|
|
||||||
|
_Events._instance.enabled = True
|
||||||
|
_Events._instance.rate_limit = 0
|
||||||
|
_Events._instance(name, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_event_reporting(monkeypatch, request_log_path, tmp_path) -> None:
|
||||||
|
def mock_request(**kwargs):
|
||||||
|
json_data = kwargs.get("json", {})
|
||||||
|
with open(request_log_path, "w") as f:
|
||||||
|
json.dump(json_data, f)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
lancedb.table, "register_event", mock_register_event
|
||||||
|
) # Force enable registering events and strip exception handling
|
||||||
|
monkeypatch.setattr(lancedb.utils.events, "threaded_request", mock_request)
|
||||||
|
|
||||||
|
db = lancedb.connect(tmp_path)
|
||||||
|
db.create_table(
|
||||||
|
"test",
|
||||||
|
data=[
|
||||||
|
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||||
|
{"vector": [5.9, 26.5], "item": "bar", "price": 20.0},
|
||||||
|
],
|
||||||
|
mode="overwrite",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert request_log_path.exists() # test if event was registered
|
||||||
|
|
||||||
|
with open(request_log_path, "r") as f:
|
||||||
|
json_data = json.load(f)
|
||||||
|
|
||||||
|
# TODO: don't hardcode these here. Instead create a module level json scehma in lancedb.utils.events for better evolvability
|
||||||
|
batch_keys = ["api_key", "distinct_id", "batch"]
|
||||||
|
event_keys = ["event", "properties", "timestamp", "distinct_id"]
|
||||||
|
property_keys = ["cli", "install", "platforms", "version", "session_id"]
|
||||||
|
|
||||||
|
assert all([key in json_data for key in batch_keys])
|
||||||
|
assert all([key in json_data["batch"][0] for key in event_keys])
|
||||||
|
assert all([key in json_data["batch"][0]["properties"] for key in property_keys])
|
||||||
|
|
||||||
|
# cleanup & reset
|
||||||
|
monkeypatch.undo()
|
||||||
|
_Events._instance = None
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "vectordb-node"
|
name = "vectordb-node"
|
||||||
version = "0.2.6"
|
version = "0.3.8"
|
||||||
description = "Serverless, low-latency vector database for AI applications"
|
description = "Serverless, low-latency vector database for AI applications"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
edition = "2018"
|
edition = "2018"
|
||||||
@@ -13,11 +13,13 @@ crate-type = ["cdylib"]
|
|||||||
arrow-array = { workspace = true }
|
arrow-array = { workspace = true }
|
||||||
arrow-ipc = { workspace = true }
|
arrow-ipc = { workspace = true }
|
||||||
arrow-schema = { workspace = true }
|
arrow-schema = { workspace = true }
|
||||||
|
chrono = { workspace = true }
|
||||||
conv = "0.3.3"
|
conv = "0.3.3"
|
||||||
once_cell = "1"
|
once_cell = "1"
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
half = { workspace = true }
|
half = { workspace = true }
|
||||||
lance = { workspace = true }
|
lance = { workspace = true }
|
||||||
|
lance-index = { workspace = true }
|
||||||
lance-linalg = { workspace = true }
|
lance-linalg = { workspace = true }
|
||||||
vectordb = { path = "../../vectordb" }
|
vectordb = { path = "../../vectordb" }
|
||||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use lance::index::vector::{ivf::IvfBuildParams, pq::PQBuildParams};
|
use lance_index::vector::{ivf::IvfBuildParams, pq::PQBuildParams};
|
||||||
use lance_linalg::distance::MetricType;
|
use lance_linalg::distance::MetricType;
|
||||||
use neon::context::FunctionContext;
|
use neon::context::FunctionContext;
|
||||||
use neon::prelude::*;
|
use neon::prelude::*;
|
||||||
@@ -70,7 +70,6 @@ fn get_index_params_builder(
|
|||||||
.map(|mt| {
|
.map(|mt| {
|
||||||
let metric_type = mt.unwrap();
|
let metric_type = mt.unwrap();
|
||||||
index_builder.metric_type(metric_type);
|
index_builder.metric_type(metric_type);
|
||||||
pq_params.metric_type = metric_type;
|
|
||||||
});
|
});
|
||||||
|
|
||||||
let num_partitions = obj.get_opt_usize(cx, "num_partitions")?;
|
let num_partitions = obj.get_opt_usize(cx, "num_partitions")?;
|
||||||
@@ -78,9 +77,11 @@ fn get_index_params_builder(
|
|||||||
|
|
||||||
num_partitions.map(|np| {
|
num_partitions.map(|np| {
|
||||||
let max_iters = max_iters.unwrap_or(50);
|
let max_iters = max_iters.unwrap_or(50);
|
||||||
let mut ivf_params = IvfBuildParams::default();
|
let ivf_params = IvfBuildParams {
|
||||||
ivf_params.num_partitions = np;
|
num_partitions: np,
|
||||||
ivf_params.max_iters = max_iters;
|
max_iters,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
index_builder.ivf_params(ivf_params)
|
index_builder.ivf_params(ivf_params)
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> {
|
|||||||
static RUNTIME: OnceCell<Runtime> = OnceCell::new();
|
static RUNTIME: OnceCell<Runtime> = OnceCell::new();
|
||||||
static LOG: OnceCell<()> = OnceCell::new();
|
static LOG: OnceCell<()> = OnceCell::new();
|
||||||
|
|
||||||
LOG.get_or_init(|| env_logger::init());
|
LOG.get_or_init(env_logger::init);
|
||||||
|
|
||||||
RUNTIME.get_or_try_init(|| Runtime::new().or_throw(cx))
|
RUNTIME.get_or_try_init(|| Runtime::new().or_throw(cx))
|
||||||
}
|
}
|
||||||
@@ -148,7 +148,7 @@ fn get_aws_creds(
|
|||||||
match (secret_key_id, secret_key, temp_token) {
|
match (secret_key_id, secret_key, temp_token) {
|
||||||
(Some(key_id), Some(key), optional_token) => Ok(Some(Arc::new(
|
(Some(key_id), Some(key), optional_token) => Ok(Some(Arc::new(
|
||||||
StaticCredentialProvider::new(AwsCredential {
|
StaticCredentialProvider::new(AwsCredential {
|
||||||
key_id: key_id,
|
key_id,
|
||||||
secret_key: key,
|
secret_key: key,
|
||||||
token: optional_token,
|
token: optional_token,
|
||||||
}),
|
}),
|
||||||
@@ -195,7 +195,7 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
|||||||
|
|
||||||
let (deferred, promise) = cx.promise();
|
let (deferred, promise) = cx.promise();
|
||||||
rt.spawn(async move {
|
rt.spawn(async move {
|
||||||
let table_rst = database.open_table_with_params(&table_name, ¶ms).await;
|
let table_rst = database.open_table_with_params(&table_name, params).await;
|
||||||
|
|
||||||
deferred.settle_with(&channel, move |mut cx| {
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
let js_table = JsTable::from(table_rst.or_throw(&mut cx)?);
|
let js_table = JsTable::from(table_rst.or_throw(&mut cx)?);
|
||||||
@@ -237,6 +237,10 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> {
|
|||||||
cx.export_function("tableAdd", JsTable::js_add)?;
|
cx.export_function("tableAdd", JsTable::js_add)?;
|
||||||
cx.export_function("tableCountRows", JsTable::js_count_rows)?;
|
cx.export_function("tableCountRows", JsTable::js_count_rows)?;
|
||||||
cx.export_function("tableDelete", JsTable::js_delete)?;
|
cx.export_function("tableDelete", JsTable::js_delete)?;
|
||||||
|
cx.export_function("tableCleanupOldVersions", JsTable::js_cleanup)?;
|
||||||
|
cx.export_function("tableCompactFiles", JsTable::js_compact)?;
|
||||||
|
cx.export_function("tableListIndices", JsTable::js_list_indices)?;
|
||||||
|
cx.export_function("tableIndexStats", JsTable::js_index_stats)?;
|
||||||
cx.export_function(
|
cx.export_function(
|
||||||
"tableCreateVectorIndex",
|
"tableCreateVectorIndex",
|
||||||
index::vector::table_create_vector_index,
|
index::vector::table_create_vector_index,
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use arrow_array::RecordBatchIterator;
|
use arrow_array::RecordBatchIterator;
|
||||||
|
use lance::dataset::optimize::CompactionOptions;
|
||||||
use lance::dataset::{WriteMode, WriteParams};
|
use lance::dataset::{WriteMode, WriteParams};
|
||||||
use lance::io::object_store::ObjectStoreParams;
|
use lance::io::object_store::ObjectStoreParams;
|
||||||
|
|
||||||
@@ -69,7 +70,7 @@ impl JsTable {
|
|||||||
store_params: Some(ObjectStoreParams::with_aws_credentials(
|
store_params: Some(ObjectStoreParams::with_aws_credentials(
|
||||||
aws_creds, aws_region,
|
aws_creds, aws_region,
|
||||||
)),
|
)),
|
||||||
mode: mode,
|
mode,
|
||||||
..WriteParams::default()
|
..WriteParams::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -120,7 +121,7 @@ impl JsTable {
|
|||||||
let add_result = table.add(batch_reader, Some(params)).await;
|
let add_result = table.add(batch_reader, Some(params)).await;
|
||||||
|
|
||||||
deferred.settle_with(&channel, move |mut cx| {
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
let _added = add_result.or_throw(&mut cx)?;
|
add_result.or_throw(&mut cx)?;
|
||||||
Ok(cx.boxed(JsTable::from(table)))
|
Ok(cx.boxed(JsTable::from(table)))
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -163,4 +164,203 @@ impl JsTable {
|
|||||||
});
|
});
|
||||||
Ok(promise)
|
Ok(promise)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn js_cleanup(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
|
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||||
|
let rt = runtime(&mut cx)?;
|
||||||
|
let (deferred, promise) = cx.promise();
|
||||||
|
let table = js_table.table.clone();
|
||||||
|
let channel = cx.channel();
|
||||||
|
|
||||||
|
let older_than: i64 = cx
|
||||||
|
.argument_opt(0)
|
||||||
|
.and_then(|val| val.downcast::<JsNumber, _>(&mut cx).ok())
|
||||||
|
.map(|val| val.value(&mut cx) as i64)
|
||||||
|
.unwrap_or_else(|| 2 * 7 * 24 * 60); // 2 weeks
|
||||||
|
let older_than = chrono::Duration::minutes(older_than);
|
||||||
|
let delete_unverified: bool = cx
|
||||||
|
.argument_opt(1)
|
||||||
|
.and_then(|val| val.downcast::<JsBoolean, _>(&mut cx).ok())
|
||||||
|
.map(|val| val.value(&mut cx))
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
rt.spawn(async move {
|
||||||
|
let stats = table
|
||||||
|
.cleanup_old_versions(older_than, Some(delete_unverified))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
|
let stats = stats.or_throw(&mut cx)?;
|
||||||
|
|
||||||
|
let output_metrics = JsObject::new(&mut cx);
|
||||||
|
let bytes_removed = cx.number(stats.bytes_removed as f64);
|
||||||
|
output_metrics.set(&mut cx, "bytesRemoved", bytes_removed)?;
|
||||||
|
|
||||||
|
let old_versions = cx.number(stats.old_versions as f64);
|
||||||
|
output_metrics.set(&mut cx, "oldVersions", old_versions)?;
|
||||||
|
|
||||||
|
let output_table = cx.boxed(JsTable::from(table));
|
||||||
|
|
||||||
|
let output = JsObject::new(&mut cx);
|
||||||
|
output.set(&mut cx, "metrics", output_metrics)?;
|
||||||
|
output.set(&mut cx, "newTable", output_table)?;
|
||||||
|
|
||||||
|
Ok(output)
|
||||||
|
})
|
||||||
|
});
|
||||||
|
Ok(promise)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn js_compact(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
|
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||||
|
let rt = runtime(&mut cx)?;
|
||||||
|
let (deferred, promise) = cx.promise();
|
||||||
|
let mut table = js_table.table.clone();
|
||||||
|
let channel = cx.channel();
|
||||||
|
|
||||||
|
let js_options = cx.argument::<JsObject>(0)?;
|
||||||
|
let mut options = CompactionOptions::default();
|
||||||
|
|
||||||
|
if let Some(target_rows) =
|
||||||
|
js_options.get_opt::<JsNumber, _, _>(&mut cx, "targetRowsPerFragment")?
|
||||||
|
{
|
||||||
|
options.target_rows_per_fragment = target_rows.value(&mut cx) as usize;
|
||||||
|
}
|
||||||
|
if let Some(max_per_group) =
|
||||||
|
js_options.get_opt::<JsNumber, _, _>(&mut cx, "maxRowsPerGroup")?
|
||||||
|
{
|
||||||
|
options.max_rows_per_group = max_per_group.value(&mut cx) as usize;
|
||||||
|
}
|
||||||
|
if let Some(materialize_deletions) =
|
||||||
|
js_options.get_opt::<JsBoolean, _, _>(&mut cx, "materializeDeletions")?
|
||||||
|
{
|
||||||
|
options.materialize_deletions = materialize_deletions.value(&mut cx);
|
||||||
|
}
|
||||||
|
if let Some(materialize_deletions_threshold) =
|
||||||
|
js_options.get_opt::<JsNumber, _, _>(&mut cx, "materializeDeletionsThreshold")?
|
||||||
|
{
|
||||||
|
options.materialize_deletions_threshold =
|
||||||
|
materialize_deletions_threshold.value(&mut cx) as f32;
|
||||||
|
}
|
||||||
|
if let Some(num_threads) = js_options.get_opt::<JsNumber, _, _>(&mut cx, "numThreads")? {
|
||||||
|
options.num_threads = num_threads.value(&mut cx) as usize;
|
||||||
|
}
|
||||||
|
|
||||||
|
rt.spawn(async move {
|
||||||
|
let stats = table.compact_files(options, None).await;
|
||||||
|
|
||||||
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
|
let stats = stats.or_throw(&mut cx)?;
|
||||||
|
|
||||||
|
let output_metrics = JsObject::new(&mut cx);
|
||||||
|
let fragments_removed = cx.number(stats.fragments_removed as f64);
|
||||||
|
output_metrics.set(&mut cx, "fragmentsRemoved", fragments_removed)?;
|
||||||
|
|
||||||
|
let fragments_added = cx.number(stats.fragments_added as f64);
|
||||||
|
output_metrics.set(&mut cx, "fragmentsAdded", fragments_added)?;
|
||||||
|
|
||||||
|
let files_removed = cx.number(stats.files_removed as f64);
|
||||||
|
output_metrics.set(&mut cx, "filesRemoved", files_removed)?;
|
||||||
|
|
||||||
|
let files_added = cx.number(stats.files_added as f64);
|
||||||
|
output_metrics.set(&mut cx, "filesAdded", files_added)?;
|
||||||
|
|
||||||
|
let output_table = cx.boxed(JsTable::from(table));
|
||||||
|
|
||||||
|
let output = JsObject::new(&mut cx);
|
||||||
|
output.set(&mut cx, "metrics", output_metrics)?;
|
||||||
|
output.set(&mut cx, "newTable", output_table)?;
|
||||||
|
|
||||||
|
Ok(output)
|
||||||
|
})
|
||||||
|
});
|
||||||
|
Ok(promise)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn js_list_indices(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
|
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||||
|
let rt = runtime(&mut cx)?;
|
||||||
|
let (deferred, promise) = cx.promise();
|
||||||
|
// let predicate = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||||
|
let channel = cx.channel();
|
||||||
|
let table = js_table.table.clone();
|
||||||
|
|
||||||
|
rt.spawn(async move {
|
||||||
|
let indices = table.load_indices().await;
|
||||||
|
|
||||||
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
|
let indices = indices.or_throw(&mut cx)?;
|
||||||
|
|
||||||
|
let output = JsArray::new(&mut cx, indices.len() as u32);
|
||||||
|
for (i, index) in indices.iter().enumerate() {
|
||||||
|
let js_index = JsObject::new(&mut cx);
|
||||||
|
let index_name = cx.string(index.index_name.clone());
|
||||||
|
js_index.set(&mut cx, "name", index_name)?;
|
||||||
|
|
||||||
|
let index_uuid = cx.string(index.index_uuid.clone());
|
||||||
|
js_index.set(&mut cx, "uuid", index_uuid)?;
|
||||||
|
|
||||||
|
let js_index_columns = JsArray::new(&mut cx, index.columns.len() as u32);
|
||||||
|
for (j, column) in index.columns.iter().enumerate() {
|
||||||
|
let js_column = cx.string(column.clone());
|
||||||
|
js_index_columns.set(&mut cx, j as u32, js_column)?;
|
||||||
|
}
|
||||||
|
js_index.set(&mut cx, "columns", js_index_columns)?;
|
||||||
|
|
||||||
|
output.set(&mut cx, i as u32, js_index)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(output)
|
||||||
|
})
|
||||||
|
});
|
||||||
|
Ok(promise)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn js_index_stats(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||||
|
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||||
|
let rt = runtime(&mut cx)?;
|
||||||
|
let (deferred, promise) = cx.promise();
|
||||||
|
let index_uuid = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||||
|
let channel = cx.channel();
|
||||||
|
let table = js_table.table.clone();
|
||||||
|
|
||||||
|
rt.spawn(async move {
|
||||||
|
let load_stats = futures::try_join!(
|
||||||
|
table.count_indexed_rows(&index_uuid),
|
||||||
|
table.count_unindexed_rows(&index_uuid)
|
||||||
|
);
|
||||||
|
|
||||||
|
deferred.settle_with(&channel, move |mut cx| {
|
||||||
|
let (indexed_rows, unindexed_rows) = load_stats.or_throw(&mut cx)?;
|
||||||
|
|
||||||
|
let output = JsObject::new(&mut cx);
|
||||||
|
|
||||||
|
match indexed_rows {
|
||||||
|
Some(x) => {
|
||||||
|
let i = cx.number(x as f64);
|
||||||
|
output.set(&mut cx, "numIndexedRows", i)?;
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
let null = cx.null();
|
||||||
|
output.set(&mut cx, "numIndexedRows", null)?;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match unindexed_rows {
|
||||||
|
Some(x) => {
|
||||||
|
let i = cx.number(x as f64);
|
||||||
|
output.set(&mut cx, "numUnindexedRows", i)?;
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
let null = cx.null();
|
||||||
|
output.set(&mut cx, "numUnindexedRows", null)?;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(output)
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(promise)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "vectordb"
|
name = "vectordb"
|
||||||
version = "0.2.6"
|
version = "0.3.8"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
@@ -16,16 +16,23 @@ arrow-data = { workspace = true }
|
|||||||
arrow-schema = { workspace = true }
|
arrow-schema = { workspace = true }
|
||||||
arrow-ord = { workspace = true }
|
arrow-ord = { workspace = true }
|
||||||
arrow-cast = { workspace = true }
|
arrow-cast = { workspace = true }
|
||||||
|
chrono = { workspace = true }
|
||||||
object_store = { workspace = true }
|
object_store = { workspace = true }
|
||||||
snafu = { workspace = true }
|
snafu = { workspace = true }
|
||||||
half = { workspace = true }
|
half = { workspace = true }
|
||||||
lance = { workspace = true }
|
lance = { workspace = true }
|
||||||
|
lance-index = { workspace = true }
|
||||||
lance-linalg = { workspace = true }
|
lance-linalg = { workspace = true }
|
||||||
|
lance-testing = { workspace = true }
|
||||||
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
tokio = { version = "1.23", features = ["rt-multi-thread"] }
|
||||||
log = { workspace = true }
|
log = { workspace = true }
|
||||||
|
async-trait = "0"
|
||||||
|
bytes = "1"
|
||||||
|
futures = "0"
|
||||||
num-traits = "0"
|
num-traits = "0"
|
||||||
url = { workspace = true }
|
url = { workspace = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tempfile = "3.5.0"
|
tempfile = "3.5.0"
|
||||||
rand = { version = "0.8.3", features = ["small_rng"] }
|
rand = { version = "0.8.3", features = ["small_rng"] }
|
||||||
|
walkdir = "2"
|
||||||
@@ -18,9 +18,9 @@ use arrow::compute::kernels::{aggregate::bool_and, length::length};
|
|||||||
use arrow_array::{
|
use arrow_array::{
|
||||||
cast::AsArray,
|
cast::AsArray,
|
||||||
types::{ArrowPrimitiveType, Int32Type, Int64Type},
|
types::{ArrowPrimitiveType, Int32Type, Int64Type},
|
||||||
Array, GenericListArray, OffsetSizeTrait, RecordBatchReader,
|
Array, GenericListArray, OffsetSizeTrait, PrimitiveArray, RecordBatchReader,
|
||||||
};
|
};
|
||||||
use arrow_ord::comparison::eq_dyn_scalar;
|
use arrow_ord::cmp::eq;
|
||||||
use arrow_schema::DataType;
|
use arrow_schema::DataType;
|
||||||
use num_traits::{ToPrimitive, Zero};
|
use num_traits::{ToPrimitive, Zero};
|
||||||
|
|
||||||
@@ -38,7 +38,8 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
let dim = len_arr.as_primitive::<T>().value(0);
|
let dim = len_arr.as_primitive::<T>().value(0);
|
||||||
if bool_and(&eq_dyn_scalar(len_arr.as_primitive::<T>(), dim)?) != Some(true) {
|
let datum = PrimitiveArray::<T>::new_scalar(dim);
|
||||||
|
if bool_and(&eq(len_arr.as_primitive::<T>(), &datum)?) != Some(true) {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
} else {
|
} else {
|
||||||
Ok(Some(dim))
|
Ok(Some(dim))
|
||||||
|
|||||||
@@ -14,13 +14,16 @@
|
|||||||
|
|
||||||
use std::fs::create_dir_all;
|
use std::fs::create_dir_all;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use arrow_array::RecordBatchReader;
|
use arrow_array::RecordBatchReader;
|
||||||
use lance::dataset::WriteParams;
|
use lance::dataset::WriteParams;
|
||||||
use lance::io::object_store::ObjectStore;
|
use lance::io::object_store::{ObjectStore, WrappingObjectStore};
|
||||||
|
use object_store::local::LocalFileSystem;
|
||||||
use snafu::prelude::*;
|
use snafu::prelude::*;
|
||||||
|
|
||||||
use crate::error::{CreateDirSnafu, InvalidTableNameSnafu, Result};
|
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
|
||||||
|
use crate::io::object_store::MirroringObjectStoreWrapper;
|
||||||
use crate::table::{ReadParams, Table};
|
use crate::table::{ReadParams, Table};
|
||||||
|
|
||||||
pub const LANCE_FILE_EXTENSION: &str = "lance";
|
pub const LANCE_FILE_EXTENSION: &str = "lance";
|
||||||
@@ -31,10 +34,14 @@ pub struct Database {
|
|||||||
|
|
||||||
pub(crate) uri: String,
|
pub(crate) uri: String,
|
||||||
pub(crate) base_path: object_store::path::Path,
|
pub(crate) base_path: object_store::path::Path,
|
||||||
|
|
||||||
|
// the object store wrapper to use on write path
|
||||||
|
pub(crate) store_wrapper: Option<Arc<dyn WrappingObjectStore>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
const LANCE_EXTENSION: &str = "lance";
|
const LANCE_EXTENSION: &str = "lance";
|
||||||
const ENGINE: &str = "engine";
|
const ENGINE: &str = "engine";
|
||||||
|
const MIRRORED_STORE: &str = "mirroredStore";
|
||||||
|
|
||||||
/// A connection to LanceDB
|
/// A connection to LanceDB
|
||||||
impl Database {
|
impl Database {
|
||||||
@@ -55,6 +62,7 @@ impl Database {
|
|||||||
Ok(mut url) => {
|
Ok(mut url) => {
|
||||||
// iter thru the query params and extract the commit store param
|
// iter thru the query params and extract the commit store param
|
||||||
let mut engine = None;
|
let mut engine = None;
|
||||||
|
let mut mirrored_store = None;
|
||||||
let mut filtered_querys = vec![];
|
let mut filtered_querys = vec![];
|
||||||
|
|
||||||
// WARNING: specifying engine is NOT a publicly supported feature in lancedb yet
|
// WARNING: specifying engine is NOT a publicly supported feature in lancedb yet
|
||||||
@@ -62,6 +70,13 @@ impl Database {
|
|||||||
for (key, value) in url.query_pairs() {
|
for (key, value) in url.query_pairs() {
|
||||||
if key == ENGINE {
|
if key == ENGINE {
|
||||||
engine = Some(value.to_string());
|
engine = Some(value.to_string());
|
||||||
|
} else if key == MIRRORED_STORE {
|
||||||
|
if cfg!(windows) {
|
||||||
|
return Err(Error::Lance {
|
||||||
|
message: "mirrored store is not supported on windows".into(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
mirrored_store = Some(value.to_string());
|
||||||
} else {
|
} else {
|
||||||
// to owned so we can modify the url
|
// to owned so we can modify the url
|
||||||
filtered_querys.push((key.to_string(), value.to_string()));
|
filtered_querys.push((key.to_string(), value.to_string()));
|
||||||
@@ -96,11 +111,21 @@ impl Database {
|
|||||||
Self::try_create_dir(&plain_uri).context(CreateDirSnafu { path: plain_uri })?;
|
Self::try_create_dir(&plain_uri).context(CreateDirSnafu { path: plain_uri })?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let write_store_wrapper = match mirrored_store {
|
||||||
|
Some(path) => {
|
||||||
|
let mirrored_store = Arc::new(LocalFileSystem::new_with_prefix(path)?);
|
||||||
|
let wrapper = MirroringObjectStoreWrapper::new(mirrored_store);
|
||||||
|
Some(Arc::new(wrapper) as Arc<dyn WrappingObjectStore>)
|
||||||
|
}
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
|
||||||
Ok(Database {
|
Ok(Database {
|
||||||
uri: table_base_uri,
|
uri: table_base_uri,
|
||||||
query_string,
|
query_string,
|
||||||
base_path,
|
base_path,
|
||||||
object_store,
|
object_store,
|
||||||
|
store_wrapper: write_store_wrapper,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
Err(_) => Self::open_path(uri).await,
|
Err(_) => Self::open_path(uri).await,
|
||||||
@@ -110,13 +135,14 @@ impl Database {
|
|||||||
async fn open_path(path: &str) -> Result<Database> {
|
async fn open_path(path: &str) -> Result<Database> {
|
||||||
let (object_store, base_path) = ObjectStore::from_uri(path).await?;
|
let (object_store, base_path) = ObjectStore::from_uri(path).await?;
|
||||||
if object_store.is_local() {
|
if object_store.is_local() {
|
||||||
Self::try_create_dir(path).context(CreateDirSnafu { path: path })?;
|
Self::try_create_dir(path).context(CreateDirSnafu { path })?;
|
||||||
}
|
}
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
uri: path.to_string(),
|
uri: path.to_string(),
|
||||||
query_string: None,
|
query_string: None,
|
||||||
base_path,
|
base_path,
|
||||||
object_store,
|
object_store,
|
||||||
|
store_wrapper: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,7 +161,7 @@ impl Database {
|
|||||||
///
|
///
|
||||||
/// * A [Vec<String>] with all table names.
|
/// * A [Vec<String>] with all table names.
|
||||||
pub async fn table_names(&self) -> Result<Vec<String>> {
|
pub async fn table_names(&self) -> Result<Vec<String>> {
|
||||||
let f = self
|
let mut f = self
|
||||||
.object_store
|
.object_store
|
||||||
.read_dir(self.base_path.clone())
|
.read_dir(self.base_path.clone())
|
||||||
.await?
|
.await?
|
||||||
@@ -149,7 +175,8 @@ impl Database {
|
|||||||
is_lance.unwrap_or(false)
|
is_lance.unwrap_or(false)
|
||||||
})
|
})
|
||||||
.filter_map(|p| p.file_stem().and_then(|s| s.to_str().map(String::from)))
|
.filter_map(|p| p.file_stem().and_then(|s| s.to_str().map(String::from)))
|
||||||
.collect();
|
.collect::<Vec<String>>();
|
||||||
|
f.sort();
|
||||||
Ok(f)
|
Ok(f)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -166,7 +193,15 @@ impl Database {
|
|||||||
params: Option<WriteParams>,
|
params: Option<WriteParams>,
|
||||||
) -> Result<Table> {
|
) -> Result<Table> {
|
||||||
let table_uri = self.table_uri(name)?;
|
let table_uri = self.table_uri(name)?;
|
||||||
Table::create(&table_uri, name, batches, params).await
|
|
||||||
|
Table::create(
|
||||||
|
&table_uri,
|
||||||
|
name,
|
||||||
|
batches,
|
||||||
|
self.store_wrapper.clone(),
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Open a table in the database.
|
/// Open a table in the database.
|
||||||
@@ -178,7 +213,7 @@ impl Database {
|
|||||||
///
|
///
|
||||||
/// * A [Table] object.
|
/// * A [Table] object.
|
||||||
pub async fn open_table(&self, name: &str) -> Result<Table> {
|
pub async fn open_table(&self, name: &str) -> Result<Table> {
|
||||||
self.open_table_with_params(name, &ReadParams::default())
|
self.open_table_with_params(name, ReadParams::default())
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -191,9 +226,9 @@ impl Database {
|
|||||||
/// # Returns
|
/// # Returns
|
||||||
///
|
///
|
||||||
/// * A [Table] object.
|
/// * A [Table] object.
|
||||||
pub async fn open_table_with_params(&self, name: &str, params: &ReadParams) -> Result<Table> {
|
pub async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result<Table> {
|
||||||
let table_uri = self.table_uri(name)?;
|
let table_uri = self.table_uri(name)?;
|
||||||
Table::open_with_params(&table_uri, name, params).await
|
Table::open_with_params(&table_uri, name, self.store_wrapper.clone(), params).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Drop a table in the database.
|
/// Drop a table in the database.
|
||||||
@@ -278,8 +313,8 @@ mod tests {
|
|||||||
let db = Database::connect(uri).await.unwrap();
|
let db = Database::connect(uri).await.unwrap();
|
||||||
let tables = db.table_names().await.unwrap();
|
let tables = db.table_names().await.unwrap();
|
||||||
assert_eq!(tables.len(), 2);
|
assert_eq!(tables.len(), 2);
|
||||||
assert!(tables.contains(&String::from("table1")));
|
assert!(tables[0].eq(&String::from("table1")));
|
||||||
assert!(tables.contains(&String::from("table2")));
|
assert!(tables[1].eq(&String::from("table2")));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user